From 0fea3f969ec52dc2bd07e9353b12373929af65d2 Mon Sep 17 00:00:00 2001 From: grossmj Date: Sun, 28 Mar 2021 21:17:29 +1030 Subject: [PATCH] Use aiosqlite and add service for templates --- .../controller/dependencies/database.py | 2 +- gns3server/api/routes/controller/templates.py | 77 ++----- gns3server/compute/compute_error.py | 2 +- gns3server/controller/controller_error.py | 2 +- gns3server/db/models/base.py | 6 +- gns3server/db/models/templates.py | 32 ++- gns3server/db/repositories/base.py | 2 - gns3server/db/repositories/templates.py | 171 ++------------ gns3server/db/tasks.py | 2 +- gns3server/schemas/templates.py | 4 +- gns3server/services/templates.py | 217 ++++++++++++++++++ requirements.txt | 3 +- tests/api/routes/controller/test_templates.py | 2 +- tests/conftest.py | 2 +- 14 files changed, 292 insertions(+), 232 deletions(-) create mode 100644 gns3server/services/templates.py diff --git a/gns3server/api/routes/controller/dependencies/database.py b/gns3server/api/routes/controller/dependencies/database.py index b1dbaa12..f3f59d88 100644 --- a/gns3server/api/routes/controller/dependencies/database.py +++ b/gns3server/api/routes/controller/dependencies/database.py @@ -24,7 +24,7 @@ from gns3server.db.repositories.base import BaseRepository async def get_db_session(request: Request) -> AsyncSession: - session = AsyncSession(request.app.state._db_engine) + session = AsyncSession(request.app.state._db_engine, expire_on_commit=False) try: yield session finally: diff --git a/gns3server/api/routes/controller/templates.py b/gns3server/api/routes/controller/templates.py index 46123189..a73ca1d3 100644 --- a/gns3server/api/routes/controller/templates.py +++ b/gns3server/api/routes/controller/templates.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # -# Copyright (C) 2020 GNS3 Technologies Inc. +# Copyright (C) 2021 GNS3 Technologies Inc. # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by @@ -21,7 +21,6 @@ API routes for templates. import hashlib import json -import pydantic import logging log = logging.getLogger(__name__) @@ -33,54 +32,44 @@ from uuid import UUID from gns3server import schemas from gns3server.controller import Controller from gns3server.db.repositories.templates import TemplatesRepository -from gns3server.controller.controller_error import ( - ControllerBadRequestError, - ControllerNotFoundError, - ControllerForbiddenError -) +from gns3server.services.templates import TemplatesService from .dependencies.database import get_repository -router = APIRouter() - responses = { 404: {"model": schemas.ErrorMessage, "description": "Could not find template"} } +router = APIRouter(responses=responses) + @router.post("/templates", response_model=schemas.Template, status_code=status.HTTP_201_CREATED) async def create_template( - new_template: schemas.TemplateCreate, - template_repo: TemplatesRepository = Depends(get_repository(TemplatesRepository)) + template_data: schemas.TemplateCreate, + templates_repo: TemplatesRepository = Depends(get_repository(TemplatesRepository)) ) -> dict: """ Create a new template. """ - try: - return await template_repo.create_template(new_template) - except pydantic.ValidationError as e: - raise ControllerBadRequestError(f"JSON schema error received while creating new template: {e}") + return await TemplatesService(templates_repo).create_template(template_data) @router.get("/templates/{template_id}", response_model=schemas.Template, - response_model_exclude_unset=True, - responses=responses) + response_model_exclude_unset=True) async def get_template( template_id: UUID, request: Request, response: Response, - template_repo: TemplatesRepository = Depends(get_repository(TemplatesRepository)) + templates_repo: TemplatesRepository = Depends(get_repository(TemplatesRepository)) ) -> dict: """ Return a template. """ request_etag = request.headers.get("If-None-Match", "") - template = await template_repo.get_template(template_id) - if not template: - raise ControllerNotFoundError(f"Template '{template_id}' not found") + template = await TemplatesService(templates_repo).get_template(template_id) data = json.dumps(template) template_etag = '"' + hashlib.md5(data.encode()).hexdigest() + '"' if template_etag == request_etag: @@ -92,75 +81,57 @@ async def get_template( @router.put("/templates/{template_id}", response_model=schemas.Template, - response_model_exclude_unset=True, - responses=responses) + response_model_exclude_unset=True) async def update_template( template_id: UUID, template_data: schemas.TemplateUpdate, - template_repo: TemplatesRepository = Depends(get_repository(TemplatesRepository)) + templates_repo: TemplatesRepository = Depends(get_repository(TemplatesRepository)) ) -> dict: """ Update a template. """ - if template_repo.get_builtin_template(template_id): - raise ControllerForbiddenError(f"Template '{template_id}' cannot be updated because it is built-in") - template = await template_repo.update_template(template_id, template_data) - if not template: - raise ControllerNotFoundError(f"Template '{template_id}' not found") - return template + return await TemplatesService(templates_repo).update_template(template_id, template_data) @router.delete("/templates/{template_id}", - status_code=status.HTTP_204_NO_CONTENT, - responses=responses) + status_code=status.HTTP_204_NO_CONTENT,) async def delete_template( template_id: UUID, - template_repo: TemplatesRepository = Depends(get_repository(TemplatesRepository)) + templates_repo: TemplatesRepository = Depends(get_repository(TemplatesRepository)) ) -> None: """ Delete a template. """ - if template_repo.get_builtin_template(template_id): - raise ControllerForbiddenError(f"Template '{template_id}' cannot be deleted because it is built-in") - success = await template_repo.delete_template(template_id) - if not success: - raise ControllerNotFoundError(f"Template '{template_id}' not found") + await TemplatesService(templates_repo).delete_template(template_id) @router.get("/templates", response_model=List[schemas.Template], response_model_exclude_unset=True) async def get_templates( - template_repo: TemplatesRepository = Depends(get_repository(TemplatesRepository)) + templates_repo: TemplatesRepository = Depends(get_repository(TemplatesRepository)) ) -> List[dict]: """ Return all templates. """ - templates = await template_repo.get_templates() - return templates + return await TemplatesService(templates_repo).get_templates() @router.post("/templates/{template_id}/duplicate", response_model=schemas.Template, - status_code=status.HTTP_201_CREATED, - responses=responses) + status_code=status.HTTP_201_CREATED) async def duplicate_template( template_id: UUID, - template_repo: TemplatesRepository = Depends(get_repository(TemplatesRepository)) + templates_repo: TemplatesRepository = Depends(get_repository(TemplatesRepository)) ) -> dict: """ Duplicate a template. """ - if template_repo.get_builtin_template(template_id): - raise ControllerForbiddenError(f"Template '{template_id}' cannot be duplicated because it is built-in") - template = await template_repo.duplicate_template(template_id) - if not template: - raise ControllerNotFoundError(f"Template '{template_id}' not found") - return template + return await TemplatesService(templates_repo).duplicate_template(template_id) @router.post("/projects/{project_id}/templates/{template_id}", @@ -171,15 +142,13 @@ async def create_node_from_template( project_id: UUID, template_id: UUID, template_usage: schemas.TemplateUsage, - template_repo: TemplatesRepository = Depends(get_repository(TemplatesRepository)) + templates_repo: TemplatesRepository = Depends(get_repository(TemplatesRepository)) ) -> schemas.Node: """ Create a new node from a template. """ - template = await template_repo.get_template(template_id) - if not template: - raise ControllerNotFoundError(f"Template '{template_id}' not found") + template = TemplatesService(templates_repo).get_template(template_id) controller = Controller.instance() project = controller.get_project(str(project_id)) node = await project.add_node_from_template(template, diff --git a/gns3server/compute/compute_error.py b/gns3server/compute/compute_error.py index 08c6ea88..93f74c73 100644 --- a/gns3server/compute/compute_error.py +++ b/gns3server/compute/compute_error.py @@ -19,7 +19,7 @@ class ComputeError(Exception): def __init__(self, message: str): - super().__init__(message) + super().__init__() self._message = message def __repr__(self): diff --git a/gns3server/controller/controller_error.py b/gns3server/controller/controller_error.py index 2d54b15e..04569e7c 100644 --- a/gns3server/controller/controller_error.py +++ b/gns3server/controller/controller_error.py @@ -19,7 +19,7 @@ class ControllerError(Exception): def __init__(self, message: str): - super().__init__(message) + super().__init__() self._message = message def __repr__(self): diff --git a/gns3server/db/models/base.py b/gns3server/db/models/base.py index 09ad9c46..6459e122 100644 --- a/gns3server/db/models/base.py +++ b/gns3server/db/models/base.py @@ -27,14 +27,14 @@ from sqlalchemy.ext.declarative import as_declarative @as_declarative() class Base: - def _asdict(self): + def asdict(self): return {c.key: getattr(self, c.key) for c in inspect(self).mapper.column_attrs} - def _asjson(self): + def asjson(self): - return jsonable_encoder(self._asdict()) + return jsonable_encoder(self.asdict()) class GUID(TypeDecorator): diff --git a/gns3server/db/models/templates.py b/gns3server/db/models/templates.py index ee2a83cd..2e2580f0 100644 --- a/gns3server/db/models/templates.py +++ b/gns3server/db/models/templates.py @@ -37,7 +37,7 @@ class Template(BaseTable): __mapper_args__ = { "polymorphic_identity": "templates", - "polymorphic_on": template_type + "polymorphic_on": template_type, } @@ -53,7 +53,8 @@ class CloudTemplate(Template): remote_console_http_path = Column(String) __mapper_args__ = { - "polymorphic_identity": "cloud" + "polymorphic_identity": "cloud", + "polymorphic_load": "selectin" } @@ -79,7 +80,8 @@ class DockerTemplate(Template): custom_adapters = Column(PickleType) __mapper_args__ = { - "polymorphic_identity": "docker" + "polymorphic_identity": "docker", + "polymorphic_load": "selectin" } @@ -124,7 +126,8 @@ class DynamipsTemplate(Template): wic2 = Column(String) __mapper_args__ = { - "polymorphic_identity": "dynamips" + "polymorphic_identity": "dynamips", + "polymorphic_load": "selectin" } @@ -136,7 +139,8 @@ class EthernetHubTemplate(Template): ports_mapping = Column(PickleType) __mapper_args__ = { - "polymorphic_identity": "ethernet_hub" + "polymorphic_identity": "ethernet_hub", + "polymorphic_load": "selectin" } @@ -149,7 +153,8 @@ class EthernetSwitchTemplate(Template): console_type = Column(String) __mapper_args__ = { - "polymorphic_identity": "ethernet_switch" + "polymorphic_identity": "ethernet_switch", + "polymorphic_load": "selectin" } @@ -171,7 +176,8 @@ class IOUTemplate(Template): console_auto_start = Column(Boolean) __mapper_args__ = { - "polymorphic_identity": "iou" + "polymorphic_identity": "iou", + "polymorphic_load": "selectin" } @@ -219,7 +225,8 @@ class QemuTemplate(Template): custom_adapters = Column(PickleType) __mapper_args__ = { - "polymorphic_identity": "qemu" + "polymorphic_identity": "qemu", + "polymorphic_load": "selectin" } @@ -244,7 +251,8 @@ class VirtualBoxTemplate(Template): custom_adapters = Column(PickleType) __mapper_args__ = { - "polymorphic_identity": "virtualbox" + "polymorphic_identity": "virtualbox", + "polymorphic_load": "selectin" } @@ -268,7 +276,8 @@ class VMwareTemplate(Template): custom_adapters = Column(PickleType) __mapper_args__ = { - "polymorphic_identity": "vmware" + "polymorphic_identity": "vmware", + "polymorphic_load": "selectin" } @@ -282,5 +291,6 @@ class VPCSTemplate(Template): console_auto_start = Column(Boolean, default=False) __mapper_args__ = { - "polymorphic_identity": "vpcs" + "polymorphic_identity": "vpcs", + "polymorphic_load": "selectin" } diff --git a/gns3server/db/repositories/base.py b/gns3server/db/repositories/base.py index e4e8179b..ab7c5ca3 100644 --- a/gns3server/db/repositories/base.py +++ b/gns3server/db/repositories/base.py @@ -16,7 +16,6 @@ # along with this program. If not, see . from sqlalchemy.ext.asyncio import AsyncSession -from gns3server.controller import Controller class BaseRepository: @@ -24,4 +23,3 @@ class BaseRepository: def __init__(self, db_session: AsyncSession) -> None: self._db_session = db_session - self._controller = Controller.instance() diff --git a/gns3server/db/repositories/templates.py b/gns3server/db/repositories/templates.py index 7122625a..8804d931 100644 --- a/gns3server/db/repositories/templates.py +++ b/gns3server/db/repositories/templates.py @@ -15,11 +15,8 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -import uuid - from uuid import UUID -from typing import List -from fastapi.encoders import jsonable_encoder +from typing import List, Union from sqlalchemy import select, update, delete from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm.session import make_transient @@ -29,29 +26,6 @@ from .base import BaseRepository import gns3server.db.models as models from gns3server import schemas -TEMPLATE_TYPE_TO_SHEMA = { - "cloud": schemas.CloudTemplate, - "ethernet_hub": schemas.EthernetHubTemplate, - "ethernet_switch": schemas.EthernetSwitchTemplate, - "docker": schemas.DockerTemplate, - "dynamips": schemas.DynamipsTemplate, - "vpcs": schemas.VPCSTemplate, - "virtualbox": schemas.VirtualBoxTemplate, - "vmware": schemas.VMwareTemplate, - "iou": schemas.IOUTemplate, - "qemu": schemas.QemuTemplate -} - -DYNAMIPS_PLATFORM_TO_SHEMA = { - "c7200": schemas.C7200DynamipsTemplate, - "c3745": schemas.C3745DynamipsTemplate, - "c3725": schemas.C3725DynamipsTemplate, - "c3600": schemas.C3600DynamipsTemplate, - "c2691": schemas.C2691DynamipsTemplate, - "c2600": schemas.C2600DynamipsTemplate, - "c1700": schemas.C1700DynamipsTemplate -} - TEMPLATE_TYPE_TO_MODEL = { "cloud": models.CloudTemplate, "docker": models.DockerTemplate, @@ -65,82 +39,6 @@ TEMPLATE_TYPE_TO_MODEL = { "vpcs": models.VPCSTemplate } -# built-in templates have their compute_id set to None to tell clients to select a compute -BUILTIN_TEMPLATES = [ - { - "template_id": uuid.uuid3(uuid.NAMESPACE_DNS, "cloud"), - "template_type": "cloud", - "name": "Cloud", - "default_name_format": "Cloud{0}", - "category": "guest", - "symbol": ":/symbols/cloud.svg", - "compute_id": None, - "builtin": True - }, - { - "template_id": uuid.uuid3(uuid.NAMESPACE_DNS, "nat"), - "template_type": "nat", - "name": "NAT", - "default_name_format": "NAT{0}", - "category": "guest", - "symbol": ":/symbols/cloud.svg", - "compute_id": None, - "builtin": True - }, - { - "template_id": uuid.uuid3(uuid.NAMESPACE_DNS, "vpcs"), - "template_type": "vpcs", - "name": "VPCS", - "default_name_format": "PC{0}", - "category": "guest", - "symbol": ":/symbols/vpcs_guest.svg", - "base_script_file": "vpcs_base_config.txt", - "compute_id": None, - "builtin": True - }, - { - "template_id": uuid.uuid3(uuid.NAMESPACE_DNS, "ethernet_switch"), - "template_type": "ethernet_switch", - "name": "Ethernet switch", - "console_type": "none", - "default_name_format": "Switch{0}", - "category": "switch", - "symbol": ":/symbols/ethernet_switch.svg", - "compute_id": None, - "builtin": True - }, - { - "template_id": uuid.uuid3(uuid.NAMESPACE_DNS, "ethernet_hub"), - "template_type": "ethernet_hub", - "name": "Ethernet hub", - "default_name_format": "Hub{0}", - "category": "switch", - "symbol": ":/symbols/hub.svg", - "compute_id": None, - "builtin": True - }, - { - "template_id": uuid.uuid3(uuid.NAMESPACE_DNS, "frame_relay_switch"), - "template_type": "frame_relay_switch", - "name": "Frame Relay switch", - "default_name_format": "FRSW{0}", - "category": "switch", - "symbol": ":/symbols/frame_relay_switch.svg", - "compute_id": None, - "builtin": True - }, - { - "template_id": uuid.uuid3(uuid.NAMESPACE_DNS, "atm_switch"), - "template_type": "atm_switch", - "name": "ATM switch", - "default_name_format": "ATMSW{0}", - "category": "switch", - "symbol": ":/symbols/atm_switch.svg", - "compute_id": None, - "builtin": True - }, -] - class TemplatesRepository(BaseRepository): @@ -148,58 +46,31 @@ class TemplatesRepository(BaseRepository): super().__init__(db_session) - def get_builtin_template(self, template_id: UUID) -> dict: - - for builtin_template in BUILTIN_TEMPLATES: - if builtin_template["template_id"] == template_id: - return jsonable_encoder(builtin_template) - - async def get_template(self, template_id: UUID) -> dict: + async def get_template(self, template_id: UUID) -> Union[None, models.Template]: query = select(models.Template).where(models.Template.template_id == template_id) - result = (await self._db_session.execute(query)).scalars().first() - if result: - return result._asjson() - else: - return self.get_builtin_template(template_id) + result = await self._db_session.execute(query) + return result.scalars().first() - async def get_templates(self) -> List[dict]: + async def get_templates(self) -> List[models.Template]: - templates = [] query = select(models.Template) result = await self._db_session.execute(query) - for db_template in result.scalars().all(): - templates.append(db_template._asjson()) - for builtin_template in BUILTIN_TEMPLATES: - templates.append(jsonable_encoder(builtin_template)) - return templates - - async def create_template(self, template_create: schemas.TemplateCreate) -> dict: - - # get the default template settings - template_settings = jsonable_encoder(template_create, exclude_unset=True) - template_schema = TEMPLATE_TYPE_TO_SHEMA[template_create.template_type] - template_settings_with_defaults = template_schema.parse_obj(template_settings) - settings = template_settings_with_defaults.dict() - if template_create.template_type == "dynamips": - # special case for Dynamips to cover all platform types that contain specific settings - dynamips_template_schema = DYNAMIPS_PLATFORM_TO_SHEMA[settings["platform"]] - dynamips_template_settings_with_defaults = dynamips_template_schema.parse_obj(template_settings) - settings = dynamips_template_settings_with_defaults.dict() - - model = TEMPLATE_TYPE_TO_MODEL[template_create.template_type] - db_template = model(**settings) + return result.scalars().all() + + async def create_template(self, template_type: str, template_settings: dict) -> models.Template: + + model = TEMPLATE_TYPE_TO_MODEL[template_type] + db_template = model(**template_settings) self._db_session.add(db_template) await self._db_session.commit() await self._db_session.refresh(db_template) - template = db_template._asjson() - self._controller.notification.controller_emit("template.created", template) - return template + return db_template async def update_template( self, template_id: UUID, - template_update: schemas.TemplateUpdate) -> dict: + template_update: schemas.TemplateUpdate) -> schemas.Template: update_values = template_update.dict(exclude_unset=True) @@ -209,22 +80,16 @@ class TemplatesRepository(BaseRepository): await self._db_session.execute(query) await self._db_session.commit() - template = await self.get_template(template_id) - if template: - self._controller.notification.controller_emit("template.updated", template) - return template + return await self.get_template(template_id) async def delete_template(self, template_id: UUID) -> bool: query = delete(models.Template).where(models.Template.template_id == template_id) result = await self._db_session.execute(query) await self._db_session.commit() - if result.rowcount > 0: - self._controller.notification.controller_emit("template.deleted", {"template_id": str(template_id)}) - return True - return False + return result.rowcount > 0 - async def duplicate_template(self, template_id: UUID) -> dict: + async def duplicate_template(self, template_id: UUID) -> schemas.Template: query = select(models.Template).where(models.Template.template_id == template_id) db_template = (await self._db_session.execute(query)).scalars().first() @@ -238,6 +103,4 @@ class TemplatesRepository(BaseRepository): self._db_session.add(db_template) await self._db_session.commit() await self._db_session.refresh(db_template) - template = db_template._asjson() - self._controller.notification.controller_emit("template.created", template) - return template + return db_template diff --git a/gns3server/db/tasks.py b/gns3server/db/tasks.py index 2a6d0787..7673ef77 100644 --- a/gns3server/db/tasks.py +++ b/gns3server/db/tasks.py @@ -31,7 +31,7 @@ log = logging.getLogger(__name__) async def connect_to_db(app: FastAPI) -> None: db_path = os.path.join(Config.instance().config_dir, "gns3_controller.db") - db_url = os.environ.get("GNS3_DATABASE_URI", f"sqlite+pysqlite:///{db_path}") + db_url = os.environ.get("GNS3_DATABASE_URI", f"sqlite+aiosqlite:///{db_path}") engine = create_async_engine(db_url, connect_args={"check_same_thread": False}, future=True) try: async with engine.begin() as conn: diff --git a/gns3server/schemas/templates.py b/gns3server/schemas/templates.py index fe270a25..cd66377c 100644 --- a/gns3server/schemas/templates.py +++ b/gns3server/schemas/templates.py @@ -52,7 +52,6 @@ class TemplateBase(BaseModel): class Config: extra = "allow" - orm_mode = True class TemplateCreate(TemplateBase): @@ -80,6 +79,9 @@ class Template(DateTimeModelMixin, TemplateBase): template_type: NodeType compute_id: Union[str, None] + class Config: + orm_mode = True + class TemplateUsage(BaseModel): diff --git a/gns3server/services/templates.py b/gns3server/services/templates.py new file mode 100644 index 00000000..a888ccd2 --- /dev/null +++ b/gns3server/services/templates.py @@ -0,0 +1,217 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2021 GNS3 Technologies Inc. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import uuid +import pydantic + +from uuid import UUID +from fastapi.encoders import jsonable_encoder +from typing import List + +from gns3server import schemas +from gns3server.db.repositories.templates import TemplatesRepository +from gns3server.controller import Controller +from gns3server.controller.controller_error import ( + ControllerBadRequestError, + ControllerNotFoundError, + ControllerForbiddenError +) + +TEMPLATE_TYPE_TO_SHEMA = { + "cloud": schemas.CloudTemplate, + "ethernet_hub": schemas.EthernetHubTemplate, + "ethernet_switch": schemas.EthernetSwitchTemplate, + "docker": schemas.DockerTemplate, + "dynamips": schemas.DynamipsTemplate, + "vpcs": schemas.VPCSTemplate, + "virtualbox": schemas.VirtualBoxTemplate, + "vmware": schemas.VMwareTemplate, + "iou": schemas.IOUTemplate, + "qemu": schemas.QemuTemplate +} + +DYNAMIPS_PLATFORM_TO_SHEMA = { + "c7200": schemas.C7200DynamipsTemplate, + "c3745": schemas.C3745DynamipsTemplate, + "c3725": schemas.C3725DynamipsTemplate, + "c3600": schemas.C3600DynamipsTemplate, + "c2691": schemas.C2691DynamipsTemplate, + "c2600": schemas.C2600DynamipsTemplate, + "c1700": schemas.C1700DynamipsTemplate +} + +# built-in templates have their compute_id set to None to tell clients to select a compute +BUILTIN_TEMPLATES = [ + { + "template_id": uuid.uuid3(uuid.NAMESPACE_DNS, "cloud"), + "template_type": "cloud", + "name": "Cloud", + "default_name_format": "Cloud{0}", + "category": "guest", + "symbol": ":/symbols/cloud.svg", + "compute_id": None, + "builtin": True + }, + { + "template_id": uuid.uuid3(uuid.NAMESPACE_DNS, "nat"), + "template_type": "nat", + "name": "NAT", + "default_name_format": "NAT{0}", + "category": "guest", + "symbol": ":/symbols/cloud.svg", + "compute_id": None, + "builtin": True + }, + { + "template_id": uuid.uuid3(uuid.NAMESPACE_DNS, "vpcs"), + "template_type": "vpcs", + "name": "VPCS", + "default_name_format": "PC{0}", + "category": "guest", + "symbol": ":/symbols/vpcs_guest.svg", + "base_script_file": "vpcs_base_config.txt", + "compute_id": None, + "builtin": True + }, + { + "template_id": uuid.uuid3(uuid.NAMESPACE_DNS, "ethernet_switch"), + "template_type": "ethernet_switch", + "name": "Ethernet switch", + "console_type": "none", + "default_name_format": "Switch{0}", + "category": "switch", + "symbol": ":/symbols/ethernet_switch.svg", + "compute_id": None, + "builtin": True + }, + { + "template_id": uuid.uuid3(uuid.NAMESPACE_DNS, "ethernet_hub"), + "template_type": "ethernet_hub", + "name": "Ethernet hub", + "default_name_format": "Hub{0}", + "category": "switch", + "symbol": ":/symbols/hub.svg", + "compute_id": None, + "builtin": True + }, + { + "template_id": uuid.uuid3(uuid.NAMESPACE_DNS, "frame_relay_switch"), + "template_type": "frame_relay_switch", + "name": "Frame Relay switch", + "default_name_format": "FRSW{0}", + "category": "switch", + "symbol": ":/symbols/frame_relay_switch.svg", + "compute_id": None, + "builtin": True + }, + { + "template_id": uuid.uuid3(uuid.NAMESPACE_DNS, "atm_switch"), + "template_type": "atm_switch", + "name": "ATM switch", + "default_name_format": "ATMSW{0}", + "category": "switch", + "symbol": ":/symbols/atm_switch.svg", + "compute_id": None, + "builtin": True + }, +] + + +class TemplatesService: + + def __init__(self, templates_repo: TemplatesRepository): + + self._templates_repo = templates_repo + self._controller = Controller.instance() + + def get_builtin_template(self, template_id: UUID) -> dict: + + for builtin_template in BUILTIN_TEMPLATES: + if builtin_template["template_id"] == template_id: + return jsonable_encoder(builtin_template) + + async def get_templates(self) -> List[dict]: + + templates = [] + db_templates = await self._templates_repo.get_templates() + for db_template in db_templates: + templates.append(db_template.asjson()) + for builtin_template in BUILTIN_TEMPLATES: + templates.append(jsonable_encoder(builtin_template)) + return templates + + async def create_template(self, template_data: schemas.TemplateCreate) -> dict: + + try: + # get the default template settings + template_settings = jsonable_encoder(template_data, exclude_unset=True) + template_schema = TEMPLATE_TYPE_TO_SHEMA[template_data.template_type] + template_settings_with_defaults = template_schema.parse_obj(template_settings) + settings = template_settings_with_defaults.dict() + if template_data.template_type == "dynamips": + # special case for Dynamips to cover all platform types that contain specific settings + dynamips_template_schema = DYNAMIPS_PLATFORM_TO_SHEMA[settings["platform"]] + dynamips_template_settings_with_defaults = dynamips_template_schema.parse_obj(template_settings) + settings = dynamips_template_settings_with_defaults.dict() + except pydantic.ValidationError as e: + raise ControllerBadRequestError(f"JSON schema error received while creating new template: {e}") + db_template = await self._templates_repo.create_template(template_data.template_type, settings) + template = db_template.asjson() + self._controller.notification.controller_emit("template.created", template) + return template + + async def get_template(self, template_id: UUID) -> dict: + + db_template = await self._templates_repo.get_template(template_id) + if db_template: + template = db_template.asjson() + else: + template = self.get_builtin_template(template_id) + if not template: + raise ControllerNotFoundError(f"Template '{template_id}' not found") + return template + + async def update_template(self, template_id: UUID, template_data: schemas.TemplateUpdate) -> dict: + + if self.get_builtin_template(template_id): + raise ControllerForbiddenError(f"Template '{template_id}' cannot be updated because it is built-in") + template = await self._templates_repo.update_template(template_id, template_data) + if not template: + raise ControllerNotFoundError(f"Template '{template_id}' not found") + template = template.asjson() + self._controller.notification.controller_emit("template.updated", template) + return template + + async def duplicate_template(self, template_id: UUID) -> dict: + + if self.get_builtin_template(template_id): + raise ControllerForbiddenError(f"Template '{template_id}' cannot be duplicated because it is built-in") + db_template = await self._templates_repo.duplicate_template(template_id) + if not db_template: + raise ControllerNotFoundError(f"Template '{template_id}' not found") + template = db_template.asjson() + self._controller.notification.controller_emit("template.created", template) + return template + + async def delete_template(self, template_id: UUID) -> None: + + if self.get_builtin_template(template_id): + raise ControllerForbiddenError(f"Template '{template_id}' cannot be deleted because it is built-in") + if await self._templates_repo.delete_template(template_id): + self._controller.notification.controller_emit("template.deleted", {"template_id": str(template_id)}) + else: + raise ControllerNotFoundError(f"Template '{template_id}' not found") diff --git a/requirements.txt b/requirements.txt index 6e3c36c9..5c7aadca 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,7 +10,8 @@ psutil==5.7.3 async-timeout==3.0.1 distro==1.5.0 py-cpuinfo==7.0.0 -sqlalchemy==1.4.0b2 # beta version with asyncio support +sqlalchemy==1.4.3 +aiosqlite===0.17.0 passlib[bcrypt]==1.7.2 python-jose==3.2.0 email-validator==1.1.2 diff --git a/tests/api/routes/controller/test_templates.py b/tests/api/routes/controller/test_templates.py index 2c2ef2f3..67389073 100644 --- a/tests/api/routes/controller/test_templates.py +++ b/tests/api/routes/controller/test_templates.py @@ -23,7 +23,7 @@ from fastapi import FastAPI, status from httpx import AsyncClient from gns3server.controller import Controller -from gns3server.db.repositories.templates import BUILTIN_TEMPLATES +from gns3server.services.templates import BUILTIN_TEMPLATES pytestmark = pytest.mark.asyncio diff --git a/tests/conftest.py b/tests/conftest.py index 808a3cf1..bb938090 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -64,7 +64,7 @@ async def app() -> FastAPI: @pytest.fixture(scope="class") def db_engine(): - db_url = os.getenv("GNS3_TEST_DATABASE_URI", "sqlite:///:memory:") # "sqlite:///./sql_test_app.db" + db_url = os.getenv("GNS3_TEST_DATABASE_URI", "sqlite+aiosqlite:///:memory:") # "sqlite:///./sql_test_app.db" engine = create_async_engine(db_url, connect_args={"check_same_thread": False}, future=True) yield engine engine.sync_engine.dispose()