diff --git a/gns3server/api/routes/controller/__init__.py b/gns3server/api/routes/controller/__init__.py index 2e12010e..a0f51f7e 100644 --- a/gns3server/api/routes/controller/__init__.py +++ b/gns3server/api/routes/controller/__init__.py @@ -28,6 +28,7 @@ from . import projects from . import snapshots from . import symbols from . import templates +from . import images from . import users from . import groups from . import roles @@ -61,9 +62,17 @@ router.include_router( tags=["Permissions"] ) +router.include_router( + images.router, + dependencies=[Depends(get_current_active_user)], + prefix="/images", + tags=["Images"] +) + router.include_router( templates.router, dependencies=[Depends(get_current_active_user)], + prefix="/templates", tags=["Templates"] ) diff --git a/gns3server/api/routes/controller/groups.py b/gns3server/api/routes/controller/groups.py index a028985b..4e9cb8a8 100644 --- a/gns3server/api/routes/controller/groups.py +++ b/gns3server/api/routes/controller/groups.py @@ -25,6 +25,7 @@ from typing import List from gns3server import schemas from gns3server.controller.controller_error import ( + ControllerError, ControllerBadRequestError, ControllerNotFoundError, ControllerForbiddenError, @@ -126,7 +127,7 @@ async def delete_user_group( success = await users_repo.delete_user_group(user_group_id) if not success: - raise ControllerNotFoundError(f"User group '{user_group_id}' could not be deleted") + raise ControllerError(f"User group '{user_group_id}' could not be deleted") @router.get("/{user_group_id}/members", response_model=List[schemas.User]) diff --git a/gns3server/api/routes/controller/images.py b/gns3server/api/routes/controller/images.py new file mode 100644 index 00000000..739d7e19 --- /dev/null +++ b/gns3server/api/routes/controller/images.py @@ -0,0 +1,122 @@ +# +# Copyright (C) 2021 GNS3 Technologies Inc. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +""" +API routes for images. +""" + +import os +import logging +import urllib.parse + +from fastapi import APIRouter, Request, Depends, status +from typing import List +from gns3server import schemas + +from gns3server.utils.images import InvalidImageError, default_images_directory, write_image +from gns3server.db.repositories.images import ImagesRepository +from gns3server.controller.controller_error import ( + ControllerError, + ControllerNotFoundError, + ControllerForbiddenError, + ControllerBadRequestError +) + +from .dependencies.database import get_repository + +log = logging.getLogger(__name__) + +router = APIRouter() + + +@router.get("") +async def get_images( + images_repo: ImagesRepository = Depends(get_repository(ImagesRepository)), +) -> List[schemas.Image]: + """ + Return all images. + """ + + return await images_repo.get_images() + + +@router.post("/upload/{image_name}", response_model=schemas.Image, status_code=status.HTTP_201_CREATED) +async def upload_image( + image_name: str, + image_type: schemas.ImageType, + request: Request, + images_repo: ImagesRepository = Depends(get_repository(ImagesRepository)), +) -> schemas.Image: + """ + Upload an image. + """ + + image_name = urllib.parse.unquote(image_name) + directory = default_images_directory(image_type) + path = os.path.abspath(os.path.join(directory, image_name)) + if os.path.commonprefix([directory, path]) != directory: + raise ControllerForbiddenError(f"Could not write image: {image_name}, '{path}' is forbidden") + + if await images_repo.get_image(image_name): + raise ControllerBadRequestError(f"Image '{image_name}' already exists") + + try: + image = await write_image(image_name, image_type, path, request.stream(), images_repo) + except (OSError, InvalidImageError) as e: + raise ControllerError(f"Could not save {image_type} image '{image_name}': {e}") + + return image + + +@router.get("/{image_name}", response_model=schemas.Image) +async def get_image( + image_name: str, + images_repo: ImagesRepository = Depends(get_repository(ImagesRepository)), +) -> schemas.Image: + """ + Return an image. + """ + + image = await images_repo.get_image(image_name) + if not image: + raise ControllerNotFoundError(f"Image '{image_name}' not found") + return image + + +@router.delete("/{image_name}", status_code=status.HTTP_204_NO_CONTENT) +async def delete_image( + image_name: str, + images_repo: ImagesRepository = Depends(get_repository(ImagesRepository)), +) -> None: + """ + Delete an image. + """ + + image = await images_repo.get_image(image_name) + if not image: + raise ControllerNotFoundError(f"Image '{image_name}' not found") + + if await images_repo.get_image_templates(image.id): + raise ControllerError(f"Image '{image_name}' is used by one or more templates") + + try: + os.remove(image.path) + except OSError: + log.warning(f"Could not delete image file {image.path}") + + success = await images_repo.delete_image(image_name) + if not success: + raise ControllerError(f"Image '{image_name}' could not be deleted") diff --git a/gns3server/api/routes/controller/permissions.py b/gns3server/api/routes/controller/permissions.py index 466a2707..3c033cef 100644 --- a/gns3server/api/routes/controller/permissions.py +++ b/gns3server/api/routes/controller/permissions.py @@ -25,6 +25,7 @@ from typing import List from gns3server import schemas from gns3server.controller.controller_error import ( + ControllerError, ControllerBadRequestError, ControllerNotFoundError, ControllerForbiddenError, @@ -114,4 +115,4 @@ async def delete_permission( success = await rbac_repo.delete_permission(permission_id) if not success: - raise ControllerNotFoundError(f"Permission '{permission_id}' could not be deleted") + raise ControllerError(f"Permission '{permission_id}' could not be deleted") diff --git a/gns3server/api/routes/controller/roles.py b/gns3server/api/routes/controller/roles.py index fb8be351..b28c25e4 100644 --- a/gns3server/api/routes/controller/roles.py +++ b/gns3server/api/routes/controller/roles.py @@ -25,6 +25,7 @@ from typing import List from gns3server import schemas from gns3server.controller.controller_error import ( + ControllerError, ControllerBadRequestError, ControllerNotFoundError, ControllerForbiddenError, @@ -119,7 +120,7 @@ async def delete_role( success = await rbac_repo.delete_role(role_id) if not success: - raise ControllerNotFoundError(f"Role '{role_id}' could not be deleted") + raise ControllerError(f"Role '{role_id}' could not be deleted") @router.get("/{role_id}/permissions", response_model=List[schemas.Permission]) diff --git a/gns3server/api/routes/controller/templates.py b/gns3server/api/routes/controller/templates.py index 9f1cf07d..dac698b1 100644 --- a/gns3server/api/routes/controller/templates.py +++ b/gns3server/api/routes/controller/templates.py @@ -42,7 +42,7 @@ responses = {404: {"model": schemas.ErrorMessage, "description": "Could not find router = APIRouter(responses=responses) -@router.post("/templates", response_model=schemas.Template, status_code=status.HTTP_201_CREATED) +@router.post("", response_model=schemas.Template, status_code=status.HTTP_201_CREATED) async def create_template( template_create: schemas.TemplateCreate, templates_repo: TemplatesRepository = Depends(get_repository(TemplatesRepository)), @@ -59,7 +59,7 @@ async def create_template( return template -@router.get("/templates/{template_id}", response_model=schemas.Template, response_model_exclude_unset=True) +@router.get("/{template_id}", response_model=schemas.Template, response_model_exclude_unset=True) async def get_template( template_id: UUID, request: Request, @@ -81,7 +81,7 @@ async def get_template( return template -@router.put("/templates/{template_id}", response_model=schemas.Template, response_model_exclude_unset=True) +@router.put("/{template_id}", response_model=schemas.Template, response_model_exclude_unset=True) async def update_template( template_id: UUID, template_update: schemas.TemplateUpdate, @@ -94,10 +94,7 @@ async def update_template( return await TemplatesService(templates_repo).update_template(template_id, template_update) -@router.delete( - "/templates/{template_id}", - status_code=status.HTTP_204_NO_CONTENT, -) +@router.delete("/{template_id}", status_code=status.HTTP_204_NO_CONTENT) async def delete_template( template_id: UUID, templates_repo: TemplatesRepository = Depends(get_repository(TemplatesRepository)), @@ -111,7 +108,7 @@ async def delete_template( await rbac_repo.delete_all_permissions_with_path(f"/templates/{template_id}") -@router.get("/templates", response_model=List[schemas.Template], response_model_exclude_unset=True) +@router.get("", response_model=List[schemas.Template], response_model_exclude_unset=True) async def get_templates( templates_repo: TemplatesRepository = Depends(get_repository(TemplatesRepository)), current_user: schemas.User = Depends(get_current_active_user), @@ -138,7 +135,7 @@ async def get_templates( return user_templates -@router.post("/templates/{template_id}/duplicate", response_model=schemas.Template, status_code=status.HTTP_201_CREATED) +@router.post("/{template_id}/duplicate", response_model=schemas.Template, status_code=status.HTTP_201_CREATED) async def duplicate_template( template_id: UUID, templates_repo: TemplatesRepository = Depends(get_repository(TemplatesRepository)), current_user: schemas.User = Depends(get_current_active_user), diff --git a/gns3server/api/routes/controller/users.py b/gns3server/api/routes/controller/users.py index edf53b9a..421cda0c 100644 --- a/gns3server/api/routes/controller/users.py +++ b/gns3server/api/routes/controller/users.py @@ -26,6 +26,7 @@ from typing import List from gns3server import schemas from gns3server.controller.controller_error import ( + ControllerError, ControllerBadRequestError, ControllerNotFoundError, ControllerForbiddenError, @@ -194,7 +195,7 @@ async def delete_user( success = await users_repo.delete_user(user_id) if not success: - raise ControllerNotFoundError(f"User '{user_id}' could not be deleted") + raise ControllerError(f"User '{user_id}' could not be deleted") @router.get( diff --git a/gns3server/db/models/__init__.py b/gns3server/db/models/__init__.py index ed5f7ead..d10d0668 100644 --- a/gns3server/db/models/__init__.py +++ b/gns3server/db/models/__init__.py @@ -20,6 +20,7 @@ from .users import User, UserGroup from .roles import Role from .permissions import Permission from .computes import Compute +from .images import Image from .templates import ( Template, CloudTemplate, diff --git a/gns3server/db/models/images.py b/gns3server/db/models/images.py new file mode 100644 index 00000000..3565d4c5 --- /dev/null +++ b/gns3server/db/models/images.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python +# +# Copyright (C) 2021 GNS3 Technologies Inc. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from sqlalchemy import Column, String, Integer +from sqlalchemy.orm import relationship + +from .base import BaseTable + + +class Image(BaseTable): + + __tablename__ = "images" + + id = Column(Integer, primary_key=True, autoincrement=True) + filename = Column(String, unique=True, index=True) + image_type = Column(String) + path = Column(String) + checksum = Column(String) + checksum_algorithm = Column(String) + templates = relationship("Template") diff --git a/gns3server/db/models/templates.py b/gns3server/db/models/templates.py index 470b88cb..75795039 100644 --- a/gns3server/db/models/templates.py +++ b/gns3server/db/models/templates.py @@ -35,6 +35,8 @@ class Template(BaseTable): usage = Column(String) template_type = Column(String) + image_id = Column(Integer, ForeignKey('images.id', ondelete="CASCADE")) + __mapper_args__ = { "polymorphic_identity": "templates", "polymorphic_on": template_type, diff --git a/gns3server/db/repositories/computes.py b/gns3server/db/repositories/computes.py index b76842cb..2ea00bbd 100644 --- a/gns3server/db/repositories/computes.py +++ b/gns3server/db/repositories/computes.py @@ -23,15 +23,14 @@ from sqlalchemy.ext.asyncio import AsyncSession from .base import BaseRepository import gns3server.db.models as models -from gns3server.services import auth_service from gns3server import schemas class ComputesRepository(BaseRepository): + def __init__(self, db_session: AsyncSession) -> None: super().__init__(db_session) - self._auth_service = auth_service async def get_compute(self, compute_id: UUID) -> Optional[models.Compute]: diff --git a/gns3server/db/repositories/images.py b/gns3server/db/repositories/images.py new file mode 100644 index 00000000..d9ef8d91 --- /dev/null +++ b/gns3server/db/repositories/images.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python +# +# Copyright (C) 2021 GNS3 Technologies Inc. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from typing import Optional, List +from sqlalchemy import select, delete +from sqlalchemy.ext.asyncio import AsyncSession + +from .base import BaseRepository + +import gns3server.db.models as models + + +class ImagesRepository(BaseRepository): + + def __init__(self, db_session: AsyncSession) -> None: + + super().__init__(db_session) + + async def get_image(self, image_name: str) -> Optional[models.Image]: + + query = select(models.Image).where(models.Image.filename == image_name) + result = await self._db_session.execute(query) + return result.scalars().first() + + async def get_images(self) -> List[models.Image]: + + query = select(models.Image) + result = await self._db_session.execute(query) + return result.scalars().all() + + async def get_image_templates(self, image_id: int) -> Optional[List[models.Template]]: + + query = select(models.Template).\ + join(models.Image.templates). \ + filter(models.Image.id == image_id) + + result = await self._db_session.execute(query) + return result.scalars().all() + + async def get_image_by_checksum(self, checksum: str) -> Optional[models.Image]: + + query = select(models.Image).where(models.Image.checksum == checksum) + result = await self._db_session.execute(query) + return result.scalars().first() + + async def add_image(self, image_name, image_type, path, checksum, checksum_algorithm) -> models.Image: + + db_image = models.Image( + id=None, + filename=image_name, + image_type=image_type, + path=path, + checksum=checksum, + checksum_algorithm=checksum_algorithm + ) + + self._db_session.add(db_image) + await self._db_session.commit() + await self._db_session.refresh(db_image) + return db_image + + async def delete_image(self, image_name: str) -> bool: + + query = delete(models.Image).where(models.Image.filename == image_name) + result = await self._db_session.execute(query) + await self._db_session.commit() + return result.rowcount > 0 diff --git a/gns3server/schemas/__init__.py b/gns3server/schemas/__init__.py index e5683059..23b89f72 100644 --- a/gns3server/schemas/__init__.py +++ b/gns3server/schemas/__init__.py @@ -23,6 +23,7 @@ from .version import Version from .controller.links import LinkCreate, LinkUpdate, Link from .controller.computes import ComputeCreate, ComputeUpdate, AutoIdlePC, Compute from .controller.templates import TemplateCreate, TemplateUpdate, TemplateUsage, Template +from .controller.images import Image, ImageType from .controller.drawings import Drawing from .controller.gns3vm import GNS3VM from .controller.nodes import NodeCreate, NodeUpdate, NodeDuplicate, NodeCapture, Node diff --git a/gns3server/schemas/controller/images.py b/gns3server/schemas/controller/images.py new file mode 100644 index 00000000..992e6742 --- /dev/null +++ b/gns3server/schemas/controller/images.py @@ -0,0 +1,44 @@ +# +# Copyright (C) 2021 GNS3 Technologies Inc. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from pydantic import BaseModel, Field +from enum import Enum + +from .base import DateTimeModelMixin + + +class ImageType(str, Enum): + + qemu = "qemu" + ios = "ios" + iou = "iou" + + +class ImageBase(BaseModel): + """ + Common image properties. + """ + + filename: str = Field(..., description="Image name") + image_type: ImageType = Field(..., description="Image type") + checksum: str = Field(..., description="Checksum value") + checksum_algorithm: str = Field(..., description="Checksum algorithm") + + +class Image(DateTimeModelMixin, ImageBase): + + class Config: + orm_mode = True diff --git a/gns3server/utils/images.py b/gns3server/utils/images.py index 5a1fdb2d..3843d6f6 100644 --- a/gns3server/utils/images.py +++ b/gns3server/utils/images.py @@ -16,21 +16,27 @@ import os import hashlib +import stat +import aiofiles +import shutil +from typing import AsyncGenerator from ..config import Config from . import force_unix_path +import gns3server.db.models as models +from gns3server.db.repositories.images import ImagesRepository import logging log = logging.getLogger(__name__) -def list_images(type): +def list_images(image_type): """ - Scan directories for available image for a type + Scan directories for available image for a given type. - :param type: emulator type (dynamips, qemu, iou) + :param image_type: image type (dynamips, qemu, iou) """ files = set() images = [] @@ -39,9 +45,9 @@ def list_images(type): general_images_directory = os.path.expanduser(server_config.images_path) # Subfolder of the general_images_directory specific to this VM type - default_directory = default_images_directory(type) + default_directory = default_images_directory(image_type) - for directory in images_directories(type): + for directory in images_directories(image_type): # We limit recursion to path outside the default images directory # the reason is in the default directory manage file organization and @@ -58,9 +64,9 @@ def list_images(type): if filename.endswith(".md5sum") or filename.startswith("."): continue elif ( - ((filename.endswith(".image") or filename.endswith(".bin")) and type == "dynamips") - or ((filename.endswith(".bin") or filename.startswith("i86bi")) and type == "iou") - or (not filename.endswith(".bin") and not filename.endswith(".image") and type == "qemu") + ((filename.endswith(".image") or filename.endswith(".bin")) and image_type == "dynamips") + or ((filename.endswith(".bin") or filename.startswith("i86bi")) and image_type == "iou") + or (not filename.endswith(".bin") and not filename.endswith(".image") and image_type == "qemu") ): files.add(filename) @@ -71,7 +77,7 @@ def list_images(type): path = os.path.relpath(os.path.join(root, filename), default_directory) try: - if type in ["dynamips", "iou"]: + if image_type in ["dynamips", "iou"]: with open(os.path.join(root, filename), "rb") as f: # read the first 7 bytes of the file. elf_header_start = f.read(7) @@ -110,20 +116,21 @@ def _os_walk(directory, recurse=True, **kwargs): yield directory, [], files -def default_images_directory(type): +def default_images_directory(image_type): """ - :returns: Return the default directory for a node type + :returns: Return the default directory for an image type. """ + server_config = Config.instance().settings.Server img_dir = os.path.expanduser(server_config.images_path) - if type == "qemu": + if image_type == "qemu": return os.path.join(img_dir, "QEMU") - elif type == "iou": + elif image_type == "iou": return os.path.join(img_dir, "IOU") - elif type == "dynamips": + elif image_type == "dynamips" or image_type == "ios": return os.path.join(img_dir, "IOS") else: - raise NotImplementedError("%s node type is not supported", type) + raise NotImplementedError(f"%s node type is not supported", image_type) def images_directories(type): @@ -206,3 +213,71 @@ def remove_checksum(path): path = f"{path}.md5sum" if os.path.exists(path): os.remove(path) + + +class InvalidImageError(Exception): + + def __init__(self, message: str): + super().__init__() + self._message = message + + def __str__(self): + return self._message + + +def check_valid_image_header(data: bytes, image_type: str, header_magic_len: int) -> None: + + if image_type == "ios": + # file must start with the ELF magic number, be 32-bit, big endian and have an ELF version of 1 + if data[:header_magic_len] != b'\x7fELF\x01\x02\x01': + raise InvalidImageError("Invalid IOS file detected") + elif image_type == "iou": + # file must start with the ELF magic number, be 32-bit or 64-bit, little endian and have an ELF version of 1 + # (normal IOS images are big endian!) + if data[:header_magic_len] != b'\x7fELF\x01\x01\x01' and data[:7] != b'\x7fELF\x02\x01\x01': + raise InvalidImageError("Invalid IOU file detected") + # elif image_type == "qemu": + # if data[:expected_header_magic_len] != b'QFI\xfb': + # raise InvalidImageError("Invalid Qemu file detected (must be raw or qcow2)") + + +async def write_image( + image_name: str, + image_type: str, + path: str, + stream: AsyncGenerator[bytes, None], + images_repo: ImagesRepository, + check_image_header=True +) -> models.Image: + + log.info(f"Writing image file to '{path}'") + # Store the file under its final name only when the upload is completed + tmp_path = path + ".tmp" + os.makedirs(os.path.dirname(path), exist_ok=True) + checksum = hashlib.md5() + header_magic_len = 7 + if image_type == "qemu": + header_magic_len = 4 + try: + async with aiofiles.open(tmp_path, "wb") as f: + async for chunk in stream: + if check_image_header and len(chunk) >= header_magic_len: + check_image_header = False + check_valid_image_header(chunk, image_type, header_magic_len) + await f.write(chunk) + checksum.update(chunk) + + file_size = os.path.getsize(tmp_path) + if not file_size or file_size < header_magic_len: + raise InvalidImageError("The image content is empty or too small to be valid") + + checksum = checksum.hexdigest() + duplicate_image = await images_repo.get_image_by_checksum(checksum) + if duplicate_image: + raise InvalidImageError(f"Image {duplicate_image.filename} with same checksum already exists") + except InvalidImageError: + os.remove(tmp_path) + raise + os.chmod(tmp_path, stat.S_IWRITE | stat.S_IREAD | stat.S_IEXEC) + shutil.move(tmp_path, path) + return await images_repo.add_image(image_name, image_type, path, checksum, checksum_algorithm="md5") diff --git a/tests/api/routes/controller/test_images.py b/tests/api/routes/controller/test_images.py new file mode 100644 index 00000000..dd81b732 --- /dev/null +++ b/tests/api/routes/controller/test_images.py @@ -0,0 +1,195 @@ +#!/usr/bin/env python +# +# Copyright (C) 2021 GNS3 Technologies Inc. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import os +import pytest +import hashlib + +from fastapi import FastAPI, status +from httpx import AsyncClient + +pytestmark = pytest.mark.asyncio + + +@pytest.fixture +def iou_32_bit_image(tmpdir) -> str: + """ + Create a fake IOU image on disk + """ + + path = os.path.join(tmpdir, "iou_32bit.bin") + with open(path, "wb+") as f: + f.write(b'\x7fELF\x01\x01\x01') + return path + + +@pytest.fixture +def iou_64_bit_image(tmpdir) -> str: + """ + Create a fake IOU image on disk + """ + + path = os.path.join(tmpdir, "iou_64bit.bin") + with open(path, "wb+") as f: + f.write(b'\x7fELF\x02\x01\x01') + return path + + +@pytest.fixture +def ios_image(tmpdir) -> str: + """ + Create a fake IOS image on disk + """ + + path = os.path.join(tmpdir, "ios.bin") + with open(path, "wb+") as f: + f.write(b'\x7fELF\x01\x02\x01') + return path + + +@pytest.fixture +def qcow2_image(tmpdir) -> str: + """ + Create a fake Qemu qcow2 image on disk + """ + + path = os.path.join(tmpdir, "image.qcow2") + with open(path, "wb+") as f: + f.write(b'QFI\xfb') + return path + + +@pytest.fixture +def invalid_image(tmpdir) -> str: + """ + Create a fake invalid image on disk + """ + + path = os.path.join(tmpdir, "invalid_image.bin") + with open(path, "wb+") as f: + f.write(b'\x01\x01\x01\x01') + return path + + +@pytest.fixture +def empty_image(tmpdir) -> str: + """ + Create a fake empty image on disk + """ + + path = os.path.join(tmpdir, "empty_image.bin") + with open(path, "wb+") as f: + f.write(b'') + return path + + +class TestImageRoutes: + + @pytest.mark.parametrize( + "image_type, fixture_name, valid_request", + ( + ("iou", "iou_32_bit_image", True), + ("iou", "iou_64_bit_image", True), + ("iou", "invalid_image", False), + ("ios", "ios_image", True), + ("ios", "invalid_image", False), + ("qemu", "qcow2_image", True), + ("qemu", "empty_image", False), + ("wrong_type", "qcow2_image", False), + ), + ) + async def test_upload_image( + self, + app: FastAPI, + client: AsyncClient, + images_dir: str, + image_type: str, + fixture_name: str, + valid_request: bool, + request + ) -> None: + + image_path = request.getfixturevalue(fixture_name) + image_name = os.path.basename(image_path) + image_checksum = hashlib.md5() + with open(image_path, "rb") as f: + image_data = f.read() + image_checksum.update(image_data) + + response = await client.post( + app.url_path_for("upload_image", image_name=image_name), + params={"image_type": image_type}, + content=image_data) + + if valid_request: + assert response.status_code == status.HTTP_201_CREATED + assert response.json()["filename"] == image_name + assert response.json()["checksum"] == image_checksum.hexdigest() + assert os.path.exists(os.path.join(images_dir, image_type.upper(), image_name)) + else: + assert response.status_code != status.HTTP_201_CREATED + + async def test_image_list(self, app: FastAPI, client: AsyncClient) -> None: + + response = await client.get(app.url_path_for("get_images")) + assert response.status_code == status.HTTP_200_OK + assert len(response.json()) == 4 # 4 valid images uploaded before + + async def test_image_get(self, app: FastAPI, client: AsyncClient, qcow2_image: str) -> None: + + image_name = os.path.basename(qcow2_image) + response = await client.get(app.url_path_for("get_image", image_name=image_name)) + assert response.status_code == status.HTTP_200_OK + assert response.json()["filename"] == image_name + + async def test_same_image_cannot_be_uploaded(self, app: FastAPI, client: AsyncClient, qcow2_image: str) -> None: + + image_name = os.path.basename(qcow2_image) + with open(qcow2_image, "rb") as f: + image_data = f.read() + response = await client.post( + app.url_path_for("upload_image", image_name=image_name), + params={"image_type": "qemu"}, + content=image_data) + assert response.status_code == status.HTTP_400_BAD_REQUEST + + async def test_image_delete(self, app: FastAPI, client: AsyncClient, images_dir: str, qcow2_image: str) -> None: + + image_name = os.path.basename(qcow2_image) + response = await client.delete(app.url_path_for("delete_image", image_name=image_name)) + assert response.status_code == status.HTTP_204_NO_CONTENT + + async def test_not_found_image(self, app: FastAPI, client: AsyncClient, qcow2_image: str) -> None: + + image_name = os.path.basename(qcow2_image) + response = await client.get(app.url_path_for("get_image", image_name=image_name)) + assert response.status_code == status.HTTP_404_NOT_FOUND + + async def test_image_deleted_on_disk(self, app: FastAPI, client: AsyncClient, images_dir: str, qcow2_image: str) -> None: + + image_name = os.path.basename(qcow2_image) + with open(qcow2_image, "rb") as f: + image_data = f.read() + response = await client.post( + app.url_path_for("upload_image", image_name=image_name), + params={"image_type": "qemu"}, + content=image_data) + assert response.status_code == status.HTTP_201_CREATED + + response = await client.delete(app.url_path_for("delete_image", image_name=image_name)) + assert response.status_code == status.HTTP_204_NO_CONTENT + assert not os.path.exists(os.path.join(images_dir, "QEMU", image_name))