diff --git a/gns3server/api/routes/controller/__init__.py b/gns3server/api/routes/controller/__init__.py
index 2e12010e..a0f51f7e 100644
--- a/gns3server/api/routes/controller/__init__.py
+++ b/gns3server/api/routes/controller/__init__.py
@@ -28,6 +28,7 @@ from . import projects
from . import snapshots
from . import symbols
from . import templates
+from . import images
from . import users
from . import groups
from . import roles
@@ -61,9 +62,17 @@ router.include_router(
tags=["Permissions"]
)
+router.include_router(
+ images.router,
+ dependencies=[Depends(get_current_active_user)],
+ prefix="/images",
+ tags=["Images"]
+)
+
router.include_router(
templates.router,
dependencies=[Depends(get_current_active_user)],
+ prefix="/templates",
tags=["Templates"]
)
diff --git a/gns3server/api/routes/controller/groups.py b/gns3server/api/routes/controller/groups.py
index a028985b..4e9cb8a8 100644
--- a/gns3server/api/routes/controller/groups.py
+++ b/gns3server/api/routes/controller/groups.py
@@ -25,6 +25,7 @@ from typing import List
from gns3server import schemas
from gns3server.controller.controller_error import (
+ ControllerError,
ControllerBadRequestError,
ControllerNotFoundError,
ControllerForbiddenError,
@@ -126,7 +127,7 @@ async def delete_user_group(
success = await users_repo.delete_user_group(user_group_id)
if not success:
- raise ControllerNotFoundError(f"User group '{user_group_id}' could not be deleted")
+ raise ControllerError(f"User group '{user_group_id}' could not be deleted")
@router.get("/{user_group_id}/members", response_model=List[schemas.User])
diff --git a/gns3server/api/routes/controller/images.py b/gns3server/api/routes/controller/images.py
new file mode 100644
index 00000000..739d7e19
--- /dev/null
+++ b/gns3server/api/routes/controller/images.py
@@ -0,0 +1,122 @@
+#
+# Copyright (C) 2021 GNS3 Technologies Inc.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+"""
+API routes for images.
+"""
+
+import os
+import logging
+import urllib.parse
+
+from fastapi import APIRouter, Request, Depends, status
+from typing import List
+from gns3server import schemas
+
+from gns3server.utils.images import InvalidImageError, default_images_directory, write_image
+from gns3server.db.repositories.images import ImagesRepository
+from gns3server.controller.controller_error import (
+ ControllerError,
+ ControllerNotFoundError,
+ ControllerForbiddenError,
+ ControllerBadRequestError
+)
+
+from .dependencies.database import get_repository
+
+log = logging.getLogger(__name__)
+
+router = APIRouter()
+
+
+@router.get("")
+async def get_images(
+ images_repo: ImagesRepository = Depends(get_repository(ImagesRepository)),
+) -> List[schemas.Image]:
+ """
+ Return all images.
+ """
+
+ return await images_repo.get_images()
+
+
+@router.post("/upload/{image_name}", response_model=schemas.Image, status_code=status.HTTP_201_CREATED)
+async def upload_image(
+ image_name: str,
+ image_type: schemas.ImageType,
+ request: Request,
+ images_repo: ImagesRepository = Depends(get_repository(ImagesRepository)),
+) -> schemas.Image:
+ """
+ Upload an image.
+ """
+
+ image_name = urllib.parse.unquote(image_name)
+ directory = default_images_directory(image_type)
+ path = os.path.abspath(os.path.join(directory, image_name))
+ if os.path.commonprefix([directory, path]) != directory:
+ raise ControllerForbiddenError(f"Could not write image: {image_name}, '{path}' is forbidden")
+
+ if await images_repo.get_image(image_name):
+ raise ControllerBadRequestError(f"Image '{image_name}' already exists")
+
+ try:
+ image = await write_image(image_name, image_type, path, request.stream(), images_repo)
+ except (OSError, InvalidImageError) as e:
+ raise ControllerError(f"Could not save {image_type} image '{image_name}': {e}")
+
+ return image
+
+
+@router.get("/{image_name}", response_model=schemas.Image)
+async def get_image(
+ image_name: str,
+ images_repo: ImagesRepository = Depends(get_repository(ImagesRepository)),
+) -> schemas.Image:
+ """
+ Return an image.
+ """
+
+ image = await images_repo.get_image(image_name)
+ if not image:
+ raise ControllerNotFoundError(f"Image '{image_name}' not found")
+ return image
+
+
+@router.delete("/{image_name}", status_code=status.HTTP_204_NO_CONTENT)
+async def delete_image(
+ image_name: str,
+ images_repo: ImagesRepository = Depends(get_repository(ImagesRepository)),
+) -> None:
+ """
+ Delete an image.
+ """
+
+ image = await images_repo.get_image(image_name)
+ if not image:
+ raise ControllerNotFoundError(f"Image '{image_name}' not found")
+
+ if await images_repo.get_image_templates(image.id):
+ raise ControllerError(f"Image '{image_name}' is used by one or more templates")
+
+ try:
+ os.remove(image.path)
+ except OSError:
+ log.warning(f"Could not delete image file {image.path}")
+
+ success = await images_repo.delete_image(image_name)
+ if not success:
+ raise ControllerError(f"Image '{image_name}' could not be deleted")
diff --git a/gns3server/api/routes/controller/permissions.py b/gns3server/api/routes/controller/permissions.py
index 466a2707..3c033cef 100644
--- a/gns3server/api/routes/controller/permissions.py
+++ b/gns3server/api/routes/controller/permissions.py
@@ -25,6 +25,7 @@ from typing import List
from gns3server import schemas
from gns3server.controller.controller_error import (
+ ControllerError,
ControllerBadRequestError,
ControllerNotFoundError,
ControllerForbiddenError,
@@ -114,4 +115,4 @@ async def delete_permission(
success = await rbac_repo.delete_permission(permission_id)
if not success:
- raise ControllerNotFoundError(f"Permission '{permission_id}' could not be deleted")
+ raise ControllerError(f"Permission '{permission_id}' could not be deleted")
diff --git a/gns3server/api/routes/controller/roles.py b/gns3server/api/routes/controller/roles.py
index fb8be351..b28c25e4 100644
--- a/gns3server/api/routes/controller/roles.py
+++ b/gns3server/api/routes/controller/roles.py
@@ -25,6 +25,7 @@ from typing import List
from gns3server import schemas
from gns3server.controller.controller_error import (
+ ControllerError,
ControllerBadRequestError,
ControllerNotFoundError,
ControllerForbiddenError,
@@ -119,7 +120,7 @@ async def delete_role(
success = await rbac_repo.delete_role(role_id)
if not success:
- raise ControllerNotFoundError(f"Role '{role_id}' could not be deleted")
+ raise ControllerError(f"Role '{role_id}' could not be deleted")
@router.get("/{role_id}/permissions", response_model=List[schemas.Permission])
diff --git a/gns3server/api/routes/controller/templates.py b/gns3server/api/routes/controller/templates.py
index 9f1cf07d..dac698b1 100644
--- a/gns3server/api/routes/controller/templates.py
+++ b/gns3server/api/routes/controller/templates.py
@@ -42,7 +42,7 @@ responses = {404: {"model": schemas.ErrorMessage, "description": "Could not find
router = APIRouter(responses=responses)
-@router.post("/templates", response_model=schemas.Template, status_code=status.HTTP_201_CREATED)
+@router.post("", response_model=schemas.Template, status_code=status.HTTP_201_CREATED)
async def create_template(
template_create: schemas.TemplateCreate,
templates_repo: TemplatesRepository = Depends(get_repository(TemplatesRepository)),
@@ -59,7 +59,7 @@ async def create_template(
return template
-@router.get("/templates/{template_id}", response_model=schemas.Template, response_model_exclude_unset=True)
+@router.get("/{template_id}", response_model=schemas.Template, response_model_exclude_unset=True)
async def get_template(
template_id: UUID,
request: Request,
@@ -81,7 +81,7 @@ async def get_template(
return template
-@router.put("/templates/{template_id}", response_model=schemas.Template, response_model_exclude_unset=True)
+@router.put("/{template_id}", response_model=schemas.Template, response_model_exclude_unset=True)
async def update_template(
template_id: UUID,
template_update: schemas.TemplateUpdate,
@@ -94,10 +94,7 @@ async def update_template(
return await TemplatesService(templates_repo).update_template(template_id, template_update)
-@router.delete(
- "/templates/{template_id}",
- status_code=status.HTTP_204_NO_CONTENT,
-)
+@router.delete("/{template_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_template(
template_id: UUID,
templates_repo: TemplatesRepository = Depends(get_repository(TemplatesRepository)),
@@ -111,7 +108,7 @@ async def delete_template(
await rbac_repo.delete_all_permissions_with_path(f"/templates/{template_id}")
-@router.get("/templates", response_model=List[schemas.Template], response_model_exclude_unset=True)
+@router.get("", response_model=List[schemas.Template], response_model_exclude_unset=True)
async def get_templates(
templates_repo: TemplatesRepository = Depends(get_repository(TemplatesRepository)),
current_user: schemas.User = Depends(get_current_active_user),
@@ -138,7 +135,7 @@ async def get_templates(
return user_templates
-@router.post("/templates/{template_id}/duplicate", response_model=schemas.Template, status_code=status.HTTP_201_CREATED)
+@router.post("/{template_id}/duplicate", response_model=schemas.Template, status_code=status.HTTP_201_CREATED)
async def duplicate_template(
template_id: UUID, templates_repo: TemplatesRepository = Depends(get_repository(TemplatesRepository)),
current_user: schemas.User = Depends(get_current_active_user),
diff --git a/gns3server/api/routes/controller/users.py b/gns3server/api/routes/controller/users.py
index edf53b9a..421cda0c 100644
--- a/gns3server/api/routes/controller/users.py
+++ b/gns3server/api/routes/controller/users.py
@@ -26,6 +26,7 @@ from typing import List
from gns3server import schemas
from gns3server.controller.controller_error import (
+ ControllerError,
ControllerBadRequestError,
ControllerNotFoundError,
ControllerForbiddenError,
@@ -194,7 +195,7 @@ async def delete_user(
success = await users_repo.delete_user(user_id)
if not success:
- raise ControllerNotFoundError(f"User '{user_id}' could not be deleted")
+ raise ControllerError(f"User '{user_id}' could not be deleted")
@router.get(
diff --git a/gns3server/db/models/__init__.py b/gns3server/db/models/__init__.py
index ed5f7ead..d10d0668 100644
--- a/gns3server/db/models/__init__.py
+++ b/gns3server/db/models/__init__.py
@@ -20,6 +20,7 @@ from .users import User, UserGroup
from .roles import Role
from .permissions import Permission
from .computes import Compute
+from .images import Image
from .templates import (
Template,
CloudTemplate,
diff --git a/gns3server/db/models/images.py b/gns3server/db/models/images.py
new file mode 100644
index 00000000..3565d4c5
--- /dev/null
+++ b/gns3server/db/models/images.py
@@ -0,0 +1,34 @@
+#!/usr/bin/env python
+#
+# Copyright (C) 2021 GNS3 Technologies Inc.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+from sqlalchemy import Column, String, Integer
+from sqlalchemy.orm import relationship
+
+from .base import BaseTable
+
+
+class Image(BaseTable):
+
+ __tablename__ = "images"
+
+ id = Column(Integer, primary_key=True, autoincrement=True)
+ filename = Column(String, unique=True, index=True)
+ image_type = Column(String)
+ path = Column(String)
+ checksum = Column(String)
+ checksum_algorithm = Column(String)
+ templates = relationship("Template")
diff --git a/gns3server/db/models/templates.py b/gns3server/db/models/templates.py
index 470b88cb..75795039 100644
--- a/gns3server/db/models/templates.py
+++ b/gns3server/db/models/templates.py
@@ -35,6 +35,8 @@ class Template(BaseTable):
usage = Column(String)
template_type = Column(String)
+ image_id = Column(Integer, ForeignKey('images.id', ondelete="CASCADE"))
+
__mapper_args__ = {
"polymorphic_identity": "templates",
"polymorphic_on": template_type,
diff --git a/gns3server/db/repositories/computes.py b/gns3server/db/repositories/computes.py
index b76842cb..2ea00bbd 100644
--- a/gns3server/db/repositories/computes.py
+++ b/gns3server/db/repositories/computes.py
@@ -23,15 +23,14 @@ from sqlalchemy.ext.asyncio import AsyncSession
from .base import BaseRepository
import gns3server.db.models as models
-from gns3server.services import auth_service
from gns3server import schemas
class ComputesRepository(BaseRepository):
+
def __init__(self, db_session: AsyncSession) -> None:
super().__init__(db_session)
- self._auth_service = auth_service
async def get_compute(self, compute_id: UUID) -> Optional[models.Compute]:
diff --git a/gns3server/db/repositories/images.py b/gns3server/db/repositories/images.py
new file mode 100644
index 00000000..d9ef8d91
--- /dev/null
+++ b/gns3server/db/repositories/images.py
@@ -0,0 +1,81 @@
+#!/usr/bin/env python
+#
+# Copyright (C) 2021 GNS3 Technologies Inc.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+from typing import Optional, List
+from sqlalchemy import select, delete
+from sqlalchemy.ext.asyncio import AsyncSession
+
+from .base import BaseRepository
+
+import gns3server.db.models as models
+
+
+class ImagesRepository(BaseRepository):
+
+ def __init__(self, db_session: AsyncSession) -> None:
+
+ super().__init__(db_session)
+
+ async def get_image(self, image_name: str) -> Optional[models.Image]:
+
+ query = select(models.Image).where(models.Image.filename == image_name)
+ result = await self._db_session.execute(query)
+ return result.scalars().first()
+
+ async def get_images(self) -> List[models.Image]:
+
+ query = select(models.Image)
+ result = await self._db_session.execute(query)
+ return result.scalars().all()
+
+ async def get_image_templates(self, image_id: int) -> Optional[List[models.Template]]:
+
+ query = select(models.Template).\
+ join(models.Image.templates). \
+ filter(models.Image.id == image_id)
+
+ result = await self._db_session.execute(query)
+ return result.scalars().all()
+
+ async def get_image_by_checksum(self, checksum: str) -> Optional[models.Image]:
+
+ query = select(models.Image).where(models.Image.checksum == checksum)
+ result = await self._db_session.execute(query)
+ return result.scalars().first()
+
+ async def add_image(self, image_name, image_type, path, checksum, checksum_algorithm) -> models.Image:
+
+ db_image = models.Image(
+ id=None,
+ filename=image_name,
+ image_type=image_type,
+ path=path,
+ checksum=checksum,
+ checksum_algorithm=checksum_algorithm
+ )
+
+ self._db_session.add(db_image)
+ await self._db_session.commit()
+ await self._db_session.refresh(db_image)
+ return db_image
+
+ async def delete_image(self, image_name: str) -> bool:
+
+ query = delete(models.Image).where(models.Image.filename == image_name)
+ result = await self._db_session.execute(query)
+ await self._db_session.commit()
+ return result.rowcount > 0
diff --git a/gns3server/schemas/__init__.py b/gns3server/schemas/__init__.py
index e5683059..23b89f72 100644
--- a/gns3server/schemas/__init__.py
+++ b/gns3server/schemas/__init__.py
@@ -23,6 +23,7 @@ from .version import Version
from .controller.links import LinkCreate, LinkUpdate, Link
from .controller.computes import ComputeCreate, ComputeUpdate, AutoIdlePC, Compute
from .controller.templates import TemplateCreate, TemplateUpdate, TemplateUsage, Template
+from .controller.images import Image, ImageType
from .controller.drawings import Drawing
from .controller.gns3vm import GNS3VM
from .controller.nodes import NodeCreate, NodeUpdate, NodeDuplicate, NodeCapture, Node
diff --git a/gns3server/schemas/controller/images.py b/gns3server/schemas/controller/images.py
new file mode 100644
index 00000000..992e6742
--- /dev/null
+++ b/gns3server/schemas/controller/images.py
@@ -0,0 +1,44 @@
+#
+# Copyright (C) 2021 GNS3 Technologies Inc.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+from pydantic import BaseModel, Field
+from enum import Enum
+
+from .base import DateTimeModelMixin
+
+
+class ImageType(str, Enum):
+
+ qemu = "qemu"
+ ios = "ios"
+ iou = "iou"
+
+
+class ImageBase(BaseModel):
+ """
+ Common image properties.
+ """
+
+ filename: str = Field(..., description="Image name")
+ image_type: ImageType = Field(..., description="Image type")
+ checksum: str = Field(..., description="Checksum value")
+ checksum_algorithm: str = Field(..., description="Checksum algorithm")
+
+
+class Image(DateTimeModelMixin, ImageBase):
+
+ class Config:
+ orm_mode = True
diff --git a/gns3server/utils/images.py b/gns3server/utils/images.py
index 5a1fdb2d..3843d6f6 100644
--- a/gns3server/utils/images.py
+++ b/gns3server/utils/images.py
@@ -16,21 +16,27 @@
import os
import hashlib
+import stat
+import aiofiles
+import shutil
+from typing import AsyncGenerator
from ..config import Config
from . import force_unix_path
+import gns3server.db.models as models
+from gns3server.db.repositories.images import ImagesRepository
import logging
log = logging.getLogger(__name__)
-def list_images(type):
+def list_images(image_type):
"""
- Scan directories for available image for a type
+ Scan directories for available image for a given type.
- :param type: emulator type (dynamips, qemu, iou)
+ :param image_type: image type (dynamips, qemu, iou)
"""
files = set()
images = []
@@ -39,9 +45,9 @@ def list_images(type):
general_images_directory = os.path.expanduser(server_config.images_path)
# Subfolder of the general_images_directory specific to this VM type
- default_directory = default_images_directory(type)
+ default_directory = default_images_directory(image_type)
- for directory in images_directories(type):
+ for directory in images_directories(image_type):
# We limit recursion to path outside the default images directory
# the reason is in the default directory manage file organization and
@@ -58,9 +64,9 @@ def list_images(type):
if filename.endswith(".md5sum") or filename.startswith("."):
continue
elif (
- ((filename.endswith(".image") or filename.endswith(".bin")) and type == "dynamips")
- or ((filename.endswith(".bin") or filename.startswith("i86bi")) and type == "iou")
- or (not filename.endswith(".bin") and not filename.endswith(".image") and type == "qemu")
+ ((filename.endswith(".image") or filename.endswith(".bin")) and image_type == "dynamips")
+ or ((filename.endswith(".bin") or filename.startswith("i86bi")) and image_type == "iou")
+ or (not filename.endswith(".bin") and not filename.endswith(".image") and image_type == "qemu")
):
files.add(filename)
@@ -71,7 +77,7 @@ def list_images(type):
path = os.path.relpath(os.path.join(root, filename), default_directory)
try:
- if type in ["dynamips", "iou"]:
+ if image_type in ["dynamips", "iou"]:
with open(os.path.join(root, filename), "rb") as f:
# read the first 7 bytes of the file.
elf_header_start = f.read(7)
@@ -110,20 +116,21 @@ def _os_walk(directory, recurse=True, **kwargs):
yield directory, [], files
-def default_images_directory(type):
+def default_images_directory(image_type):
"""
- :returns: Return the default directory for a node type
+ :returns: Return the default directory for an image type.
"""
+
server_config = Config.instance().settings.Server
img_dir = os.path.expanduser(server_config.images_path)
- if type == "qemu":
+ if image_type == "qemu":
return os.path.join(img_dir, "QEMU")
- elif type == "iou":
+ elif image_type == "iou":
return os.path.join(img_dir, "IOU")
- elif type == "dynamips":
+ elif image_type == "dynamips" or image_type == "ios":
return os.path.join(img_dir, "IOS")
else:
- raise NotImplementedError("%s node type is not supported", type)
+ raise NotImplementedError(f"%s node type is not supported", image_type)
def images_directories(type):
@@ -206,3 +213,71 @@ def remove_checksum(path):
path = f"{path}.md5sum"
if os.path.exists(path):
os.remove(path)
+
+
+class InvalidImageError(Exception):
+
+ def __init__(self, message: str):
+ super().__init__()
+ self._message = message
+
+ def __str__(self):
+ return self._message
+
+
+def check_valid_image_header(data: bytes, image_type: str, header_magic_len: int) -> None:
+
+ if image_type == "ios":
+ # file must start with the ELF magic number, be 32-bit, big endian and have an ELF version of 1
+ if data[:header_magic_len] != b'\x7fELF\x01\x02\x01':
+ raise InvalidImageError("Invalid IOS file detected")
+ elif image_type == "iou":
+ # file must start with the ELF magic number, be 32-bit or 64-bit, little endian and have an ELF version of 1
+ # (normal IOS images are big endian!)
+ if data[:header_magic_len] != b'\x7fELF\x01\x01\x01' and data[:7] != b'\x7fELF\x02\x01\x01':
+ raise InvalidImageError("Invalid IOU file detected")
+ # elif image_type == "qemu":
+ # if data[:expected_header_magic_len] != b'QFI\xfb':
+ # raise InvalidImageError("Invalid Qemu file detected (must be raw or qcow2)")
+
+
+async def write_image(
+ image_name: str,
+ image_type: str,
+ path: str,
+ stream: AsyncGenerator[bytes, None],
+ images_repo: ImagesRepository,
+ check_image_header=True
+) -> models.Image:
+
+ log.info(f"Writing image file to '{path}'")
+ # Store the file under its final name only when the upload is completed
+ tmp_path = path + ".tmp"
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+ checksum = hashlib.md5()
+ header_magic_len = 7
+ if image_type == "qemu":
+ header_magic_len = 4
+ try:
+ async with aiofiles.open(tmp_path, "wb") as f:
+ async for chunk in stream:
+ if check_image_header and len(chunk) >= header_magic_len:
+ check_image_header = False
+ check_valid_image_header(chunk, image_type, header_magic_len)
+ await f.write(chunk)
+ checksum.update(chunk)
+
+ file_size = os.path.getsize(tmp_path)
+ if not file_size or file_size < header_magic_len:
+ raise InvalidImageError("The image content is empty or too small to be valid")
+
+ checksum = checksum.hexdigest()
+ duplicate_image = await images_repo.get_image_by_checksum(checksum)
+ if duplicate_image:
+ raise InvalidImageError(f"Image {duplicate_image.filename} with same checksum already exists")
+ except InvalidImageError:
+ os.remove(tmp_path)
+ raise
+ os.chmod(tmp_path, stat.S_IWRITE | stat.S_IREAD | stat.S_IEXEC)
+ shutil.move(tmp_path, path)
+ return await images_repo.add_image(image_name, image_type, path, checksum, checksum_algorithm="md5")
diff --git a/tests/api/routes/controller/test_images.py b/tests/api/routes/controller/test_images.py
new file mode 100644
index 00000000..dd81b732
--- /dev/null
+++ b/tests/api/routes/controller/test_images.py
@@ -0,0 +1,195 @@
+#!/usr/bin/env python
+#
+# Copyright (C) 2021 GNS3 Technologies Inc.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+import os
+import pytest
+import hashlib
+
+from fastapi import FastAPI, status
+from httpx import AsyncClient
+
+pytestmark = pytest.mark.asyncio
+
+
+@pytest.fixture
+def iou_32_bit_image(tmpdir) -> str:
+ """
+ Create a fake IOU image on disk
+ """
+
+ path = os.path.join(tmpdir, "iou_32bit.bin")
+ with open(path, "wb+") as f:
+ f.write(b'\x7fELF\x01\x01\x01')
+ return path
+
+
+@pytest.fixture
+def iou_64_bit_image(tmpdir) -> str:
+ """
+ Create a fake IOU image on disk
+ """
+
+ path = os.path.join(tmpdir, "iou_64bit.bin")
+ with open(path, "wb+") as f:
+ f.write(b'\x7fELF\x02\x01\x01')
+ return path
+
+
+@pytest.fixture
+def ios_image(tmpdir) -> str:
+ """
+ Create a fake IOS image on disk
+ """
+
+ path = os.path.join(tmpdir, "ios.bin")
+ with open(path, "wb+") as f:
+ f.write(b'\x7fELF\x01\x02\x01')
+ return path
+
+
+@pytest.fixture
+def qcow2_image(tmpdir) -> str:
+ """
+ Create a fake Qemu qcow2 image on disk
+ """
+
+ path = os.path.join(tmpdir, "image.qcow2")
+ with open(path, "wb+") as f:
+ f.write(b'QFI\xfb')
+ return path
+
+
+@pytest.fixture
+def invalid_image(tmpdir) -> str:
+ """
+ Create a fake invalid image on disk
+ """
+
+ path = os.path.join(tmpdir, "invalid_image.bin")
+ with open(path, "wb+") as f:
+ f.write(b'\x01\x01\x01\x01')
+ return path
+
+
+@pytest.fixture
+def empty_image(tmpdir) -> str:
+ """
+ Create a fake empty image on disk
+ """
+
+ path = os.path.join(tmpdir, "empty_image.bin")
+ with open(path, "wb+") as f:
+ f.write(b'')
+ return path
+
+
+class TestImageRoutes:
+
+ @pytest.mark.parametrize(
+ "image_type, fixture_name, valid_request",
+ (
+ ("iou", "iou_32_bit_image", True),
+ ("iou", "iou_64_bit_image", True),
+ ("iou", "invalid_image", False),
+ ("ios", "ios_image", True),
+ ("ios", "invalid_image", False),
+ ("qemu", "qcow2_image", True),
+ ("qemu", "empty_image", False),
+ ("wrong_type", "qcow2_image", False),
+ ),
+ )
+ async def test_upload_image(
+ self,
+ app: FastAPI,
+ client: AsyncClient,
+ images_dir: str,
+ image_type: str,
+ fixture_name: str,
+ valid_request: bool,
+ request
+ ) -> None:
+
+ image_path = request.getfixturevalue(fixture_name)
+ image_name = os.path.basename(image_path)
+ image_checksum = hashlib.md5()
+ with open(image_path, "rb") as f:
+ image_data = f.read()
+ image_checksum.update(image_data)
+
+ response = await client.post(
+ app.url_path_for("upload_image", image_name=image_name),
+ params={"image_type": image_type},
+ content=image_data)
+
+ if valid_request:
+ assert response.status_code == status.HTTP_201_CREATED
+ assert response.json()["filename"] == image_name
+ assert response.json()["checksum"] == image_checksum.hexdigest()
+ assert os.path.exists(os.path.join(images_dir, image_type.upper(), image_name))
+ else:
+ assert response.status_code != status.HTTP_201_CREATED
+
+ async def test_image_list(self, app: FastAPI, client: AsyncClient) -> None:
+
+ response = await client.get(app.url_path_for("get_images"))
+ assert response.status_code == status.HTTP_200_OK
+ assert len(response.json()) == 4 # 4 valid images uploaded before
+
+ async def test_image_get(self, app: FastAPI, client: AsyncClient, qcow2_image: str) -> None:
+
+ image_name = os.path.basename(qcow2_image)
+ response = await client.get(app.url_path_for("get_image", image_name=image_name))
+ assert response.status_code == status.HTTP_200_OK
+ assert response.json()["filename"] == image_name
+
+ async def test_same_image_cannot_be_uploaded(self, app: FastAPI, client: AsyncClient, qcow2_image: str) -> None:
+
+ image_name = os.path.basename(qcow2_image)
+ with open(qcow2_image, "rb") as f:
+ image_data = f.read()
+ response = await client.post(
+ app.url_path_for("upload_image", image_name=image_name),
+ params={"image_type": "qemu"},
+ content=image_data)
+ assert response.status_code == status.HTTP_400_BAD_REQUEST
+
+ async def test_image_delete(self, app: FastAPI, client: AsyncClient, images_dir: str, qcow2_image: str) -> None:
+
+ image_name = os.path.basename(qcow2_image)
+ response = await client.delete(app.url_path_for("delete_image", image_name=image_name))
+ assert response.status_code == status.HTTP_204_NO_CONTENT
+
+ async def test_not_found_image(self, app: FastAPI, client: AsyncClient, qcow2_image: str) -> None:
+
+ image_name = os.path.basename(qcow2_image)
+ response = await client.get(app.url_path_for("get_image", image_name=image_name))
+ assert response.status_code == status.HTTP_404_NOT_FOUND
+
+ async def test_image_deleted_on_disk(self, app: FastAPI, client: AsyncClient, images_dir: str, qcow2_image: str) -> None:
+
+ image_name = os.path.basename(qcow2_image)
+ with open(qcow2_image, "rb") as f:
+ image_data = f.read()
+ response = await client.post(
+ app.url_path_for("upload_image", image_name=image_name),
+ params={"image_type": "qemu"},
+ content=image_data)
+ assert response.status_code == status.HTTP_201_CREATED
+
+ response = await client.delete(app.url_path_for("delete_image", image_name=image_name))
+ assert response.status_code == status.HTTP_204_NO_CONTENT
+ assert not os.path.exists(os.path.join(images_dir, "QEMU", image_name))