1
0
mirror of https://github.com/GNS3/gns3-server synced 2025-02-11 07:32:40 +00:00

Save computes to database

This commit is contained in:
grossmj 2021-04-05 14:21:41 +09:30
parent e607793e74
commit 566e326b57
13 changed files with 515 additions and 337 deletions

View File

@ -19,20 +19,23 @@
API routes for computes. API routes for computes.
""" """
from fastapi import APIRouter, status from fastapi import APIRouter, Depends, status
from fastapi.encoders import jsonable_encoder
from typing import List, Union from typing import List, Union
from uuid import UUID from uuid import UUID
from gns3server.controller import Controller from gns3server.controller import Controller
from gns3server.db.repositories.computes import ComputesRepository
from gns3server.services.computes import ComputesService
from gns3server import schemas from gns3server import schemas
router = APIRouter() from .dependencies.database import get_repository
responses = { responses = {
404: {"model": schemas.ErrorMessage, "description": "Compute not found"} 404: {"model": schemas.ErrorMessage, "description": "Compute not found"}
} }
router = APIRouter(responses=responses)
@router.post("", @router.post("",
status_code=status.HTTP_201_CREATED, status_code=status.HTTP_201_CREATED,
@ -40,69 +43,73 @@ responses = {
responses={404: {"model": schemas.ErrorMessage, "description": "Could not connect to compute"}, responses={404: {"model": schemas.ErrorMessage, "description": "Could not connect to compute"},
409: {"model": schemas.ErrorMessage, "description": "Could not create compute"}, 409: {"model": schemas.ErrorMessage, "description": "Could not create compute"},
401: {"model": schemas.ErrorMessage, "description": "Invalid authentication for compute"}}) 401: {"model": schemas.ErrorMessage, "description": "Invalid authentication for compute"}})
async def create_compute(compute_data: schemas.ComputeCreate): async def create_compute(
compute_create: schemas.ComputeCreate,
computes_repo: ComputesRepository = Depends(get_repository(ComputesRepository))
) -> schemas.Compute:
""" """
Create a new compute on the controller. Create a new compute on the controller.
""" """
compute = await Controller.instance().add_compute(**jsonable_encoder(compute_data, exclude_unset=True), return await ComputesService(computes_repo).create_compute(compute_create)
connect=False)
return compute.__json__()
@router.get("/{compute_id}", @router.get("/{compute_id}",
response_model=schemas.Compute, response_model=schemas.Compute,
response_model_exclude_unset=True, response_model_exclude_unset=True)
responses=responses) async def get_compute(
def get_compute(compute_id: Union[str, UUID]): compute_id: Union[str, UUID],
computes_repo: ComputesRepository = Depends(get_repository(ComputesRepository))
) -> schemas.Compute:
""" """
Return a compute from the controller. Return a compute from the controller.
""" """
compute = Controller.instance().get_compute(str(compute_id)) return await ComputesService(computes_repo).get_compute(compute_id)
return compute.__json__()
@router.get("", @router.get("",
response_model=List[schemas.Compute], response_model=List[schemas.Compute],
response_model_exclude_unset=True) response_model_exclude_unset=True)
async def get_computes(): async def get_computes(
computes_repo: ComputesRepository = Depends(get_repository(ComputesRepository))
) -> List[schemas.Compute]:
""" """
Return all computes known by the controller. Return all computes known by the controller.
""" """
controller = Controller.instance() return await ComputesService(computes_repo).get_computes()
return [c.__json__() for c in controller.computes.values()]
@router.put("/{compute_id}", @router.put("/{compute_id}",
response_model=schemas.Compute, response_model=schemas.Compute,
response_model_exclude_unset=True, response_model_exclude_unset=True)
responses=responses) async def update_compute(
async def update_compute(compute_id: Union[str, UUID], compute_data: schemas.ComputeUpdate): compute_id: Union[str, UUID],
compute_update: schemas.ComputeUpdate,
computes_repo: ComputesRepository = Depends(get_repository(ComputesRepository))
) -> schemas.Compute:
""" """
Update a compute on the controller. Update a compute on the controller.
""" """
compute = Controller.instance().get_compute(str(compute_id)) return await ComputesService(computes_repo).update_compute(compute_id, compute_update)
# exclude compute_id because we only use it when creating a new compute
await compute.update(**jsonable_encoder(compute_data, exclude_unset=True, exclude={"compute_id"}))
return compute.__json__()
@router.delete("/{compute_id}", @router.delete("/{compute_id}",
status_code=status.HTTP_204_NO_CONTENT, status_code=status.HTTP_204_NO_CONTENT)
responses=responses) async def delete_compute(
async def delete_compute(compute_id: Union[str, UUID]): compute_id: Union[str, UUID],
computes_repo: ComputesRepository = Depends(get_repository(ComputesRepository))
):
""" """
Delete a compute from the controller. Delete a compute from the controller.
""" """
await Controller.instance().delete_compute(str(compute_id)) await ComputesService(computes_repo).delete_compute(compute_id)
@router.get("/{compute_id}/{emulator}/images", @router.get("/{compute_id}/{emulator}/images")
responses=responses)
async def get_images(compute_id: Union[str, UUID], emulator: str): async def get_images(compute_id: Union[str, UUID], emulator: str):
""" """
Return the list of images available on a compute for a given emulator type. Return the list of images available on a compute for a given emulator type.
@ -113,8 +120,7 @@ async def get_images(compute_id: Union[str, UUID], emulator: str):
return await compute.images(emulator) return await compute.images(emulator)
@router.get("/{compute_id}/{emulator}/{endpoint_path:path}", @router.get("/{compute_id}/{emulator}/{endpoint_path:path}")
responses=responses)
async def forward_get(compute_id: Union[str, UUID], emulator: str, endpoint_path: str): async def forward_get(compute_id: Union[str, UUID], emulator: str, endpoint_path: str):
""" """
Forward a GET request to a compute. Forward a GET request to a compute.
@ -126,8 +132,7 @@ async def forward_get(compute_id: Union[str, UUID], emulator: str, endpoint_path
return result return result
@router.post("/{compute_id}/{emulator}/{endpoint_path:path}", @router.post("/{compute_id}/{emulator}/{endpoint_path:path}")
responses=responses)
async def forward_post(compute_id: Union[str, UUID], emulator: str, endpoint_path: str, compute_data: dict): async def forward_post(compute_id: Union[str, UUID], emulator: str, endpoint_path: str, compute_data: dict):
""" """
Forward a POST request to a compute. Forward a POST request to a compute.
@ -138,8 +143,7 @@ async def forward_post(compute_id: Union[str, UUID], emulator: str, endpoint_pat
return await compute.forward("POST", emulator, endpoint_path, data=compute_data) return await compute.forward("POST", emulator, endpoint_path, data=compute_data)
@router.put("/{compute_id}/{emulator}/{endpoint_path:path}", @router.put("/{compute_id}/{emulator}/{endpoint_path:path}")
responses=responses)
async def forward_put(compute_id: Union[str, UUID], emulator: str, endpoint_path: str, compute_data: dict): async def forward_put(compute_id: Union[str, UUID], emulator: str, endpoint_path: str, compute_data: dict):
""" """
Forward a PUT request to a compute. Forward a PUT request to a compute.
@ -150,8 +154,7 @@ async def forward_put(compute_id: Union[str, UUID], emulator: str, endpoint_path
return await compute.forward("PUT", emulator, endpoint_path, data=compute_data) return await compute.forward("PUT", emulator, endpoint_path, data=compute_data)
@router.post("/{compute_id}/auto_idlepc", @router.post("/{compute_id}/auto_idlepc")
responses=responses)
async def autoidlepc(compute_id: Union[str, UUID], auto_idle_pc: schemas.AutoIdlePC): async def autoidlepc(compute_id: Union[str, UUID], auto_idle_pc: schemas.AutoIdlePC):
""" """
Find a suitable Idle-PC value for a given IOS image. This may take a few minutes. Find a suitable Idle-PC value for a given IOS image. This may take a few minutes.
@ -162,14 +165,3 @@ async def autoidlepc(compute_id: Union[str, UUID], auto_idle_pc: schemas.AutoIdl
auto_idle_pc.platform, auto_idle_pc.platform,
auto_idle_pc.image, auto_idle_pc.image,
auto_idle_pc.ram) auto_idle_pc.ram)
@router.get("/{compute_id}/ports",
deprecated=True,
responses=responses)
async def ports(compute_id: Union[str, UUID]):
"""
Return ports information for a given compute.
"""
return await Controller.instance().compute_ports(str(compute_id))

View File

@ -44,43 +44,42 @@ router = APIRouter()
@router.get("", response_model=List[schemas.User]) @router.get("", response_model=List[schemas.User])
async def get_users(user_repo: UsersRepository = Depends(get_repository(UsersRepository))) -> List[schemas.User]: async def get_users(users_repo: UsersRepository = Depends(get_repository(UsersRepository))) -> List[schemas.User]:
""" """
Get all users. Get all users.
""" """
users = await user_repo.get_users() return await users_repo.get_users()
return users
@router.post("", response_model=schemas.User, status_code=status.HTTP_201_CREATED) @router.post("", response_model=schemas.User, status_code=status.HTTP_201_CREATED)
async def create_user( async def create_user(
new_user: schemas.UserCreate, user_create: schemas.UserCreate,
user_repo: UsersRepository = Depends(get_repository(UsersRepository)) users_repo: UsersRepository = Depends(get_repository(UsersRepository))
) -> schemas.User: ) -> schemas.User:
""" """
Create a new user. Create a new user.
""" """
if await user_repo.get_user_by_username(new_user.username): if await users_repo.get_user_by_username(user_create.username):
raise ControllerBadRequestError(f"Username '{new_user.username}' is already registered") raise ControllerBadRequestError(f"Username '{user_create.username}' is already registered")
if new_user.email and await user_repo.get_user_by_email(new_user.email): if user_create.email and await users_repo.get_user_by_email(user_create.email):
raise ControllerBadRequestError(f"Email '{new_user.email}' is already registered") raise ControllerBadRequestError(f"Email '{user_create.email}' is already registered")
return await user_repo.create_user(new_user) return await users_repo.create_user(user_create)
@router.get("/{user_id}", response_model=schemas.User) @router.get("/{user_id}", response_model=schemas.User)
async def get_user( async def get_user(
user_id: UUID, user_id: UUID,
user_repo: UsersRepository = Depends(get_repository(UsersRepository)) users_repo: UsersRepository = Depends(get_repository(UsersRepository))
) -> schemas.User: ) -> schemas.User:
""" """
Get an user. Get an user.
""" """
user = await user_repo.get_user(user_id) user = await users_repo.get_user(user_id)
if not user: if not user:
raise ControllerNotFoundError(f"User '{user_id}' not found") raise ControllerNotFoundError(f"User '{user_id}' not found")
return user return user
@ -89,14 +88,14 @@ async def get_user(
@router.put("/{user_id}", response_model=schemas.User) @router.put("/{user_id}", response_model=schemas.User)
async def update_user( async def update_user(
user_id: UUID, user_id: UUID,
update_user: schemas.UserUpdate, user_update: schemas.UserUpdate,
user_repo: UsersRepository = Depends(get_repository(UsersRepository)) users_repo: UsersRepository = Depends(get_repository(UsersRepository))
) -> schemas.User: ) -> schemas.User:
""" """
Update an user. Update an user.
""" """
user = await user_repo.update_user(user_id, update_user) user = await users_repo.update_user(user_id, user_update)
if not user: if not user:
raise ControllerNotFoundError(f"User '{user_id}' not found") raise ControllerNotFoundError(f"User '{user_id}' not found")
return user return user
@ -105,7 +104,7 @@ async def update_user(
@router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT) @router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_user( async def delete_user(
user_id: UUID, user_id: UUID,
user_repo: UsersRepository = Depends(get_repository(UsersRepository)), users_repo: UsersRepository = Depends(get_repository(UsersRepository)),
current_user: schemas.User = Depends(get_current_active_user) current_user: schemas.User = Depends(get_current_active_user)
) -> None: ) -> None:
""" """
@ -115,21 +114,21 @@ async def delete_user(
if current_user.is_superuser: if current_user.is_superuser:
raise ControllerUnauthorizedError("The super user cannot be deleted") raise ControllerUnauthorizedError("The super user cannot be deleted")
success = await user_repo.delete_user(user_id) success = await users_repo.delete_user(user_id)
if not success: if not success:
raise ControllerNotFoundError(f"User '{user_id}' not found") raise ControllerNotFoundError(f"User '{user_id}' not found")
@router.post("/login", response_model=schemas.Token) @router.post("/login", response_model=schemas.Token)
async def login( async def login(
user_repo: UsersRepository = Depends(get_repository(UsersRepository)), users_repo: UsersRepository = Depends(get_repository(UsersRepository)),
form_data: OAuth2PasswordRequestForm = Depends() form_data: OAuth2PasswordRequestForm = Depends()
) -> schemas.Token: ) -> schemas.Token:
""" """
User login. User login.
""" """
user = await user_repo.authenticate_user(username=form_data.username, password=form_data.password) user = await users_repo.authenticate_user(username=form_data.username, password=form_data.password)
if not user: if not user:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication was unsuccessful.", detail="Authentication was unsuccessful.",

View File

@ -37,6 +37,7 @@ from ..utils.get_resource import get_resource
from .gns3vm.gns3_vm_error import GNS3VMError from .gns3vm.gns3_vm_error import GNS3VMError
from .controller_error import ControllerError, ControllerNotFoundError from .controller_error import ControllerError, ControllerNotFoundError
import logging import logging
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -47,6 +48,7 @@ class Controller:
""" """
def __init__(self): def __init__(self):
self._computes = {} self._computes = {}
self._projects = {} self._projects = {}
self._notification = Notification(self) self._notification = Notification(self)
@ -59,7 +61,7 @@ class Controller:
self._config_file = Config.instance().controller_config self._config_file = Config.instance().controller_config
log.info("Load controller configuration file {}".format(self._config_file)) log.info("Load controller configuration file {}".format(self._config_file))
async def start(self): async def start(self, computes):
log.info("Controller is starting") log.info("Controller is starting")
self.load_base_files() self.load_base_files()
@ -78,7 +80,7 @@ class Controller:
if name == "gns3vm": if name == "gns3vm":
name = "Main server" name = "Main server"
computes = self._load_controller_settings() self._load_controller_settings()
ssl_context = None ssl_context = None
if server_config.getboolean("ssl"): if server_config.getboolean("ssl"):
@ -198,22 +200,20 @@ class Controller:
if self._config_loaded is False: if self._config_loaded is False:
return return
controller_settings = {"computes": [], controller_settings = {"gns3vm": self.gns3vm.__json__(),
"templates": [],
"gns3vm": self.gns3vm.__json__(),
"iou_license": self._iou_license_settings, "iou_license": self._iou_license_settings,
"appliances_etag": self._appliance_manager.appliances_etag, "appliances_etag": self._appliance_manager.appliances_etag,
"version": __version__} "version": __version__}
for compute in self._computes.values(): # for compute in self._computes.values():
if compute.id != "local" and compute.id != "vm": # if compute.id != "local" and compute.id != "vm":
controller_settings["computes"].append({"host": compute.host, # controller_settings["computes"].append({"host": compute.host,
"name": compute.name, # "name": compute.name,
"port": compute.port, # "port": compute.port,
"protocol": compute.protocol, # "protocol": compute.protocol,
"user": compute.user, # "user": compute.user,
"password": compute.password, # "password": compute.password,
"compute_id": compute.id}) # "compute_id": compute.id})
try: try:
os.makedirs(os.path.dirname(self._config_file), exist_ok=True) os.makedirs(os.path.dirname(self._config_file), exist_ok=True)
@ -584,14 +584,3 @@ class Controller:
await project.delete() await project.delete()
self.remove_project(project) self.remove_project(project)
return res return res
async def compute_ports(self, compute_id):
"""
Get the ports used by a compute.
:param compute_id: ID of the compute
"""
compute = self.get_compute(compute_id)
response = await compute.get("/network/ports")
return response.json

View File

@ -70,10 +70,10 @@ class Compute:
assert controller is not None assert controller is not None
log.info("Create compute %s", compute_id) log.info("Create compute %s", compute_id)
if compute_id is None: # if compute_id is None:
self._id = str(uuid.uuid4()) # self._id = str(uuid.uuid4())
else: # else:
self._id = compute_id self._id = compute_id
self.protocol = protocol self.protocol = protocol
self._console_host = console_host self._console_host = console_host
@ -181,17 +181,8 @@ class Compute:
@name.setter @name.setter
def name(self, name): def name(self, name):
if name is not None:
self._name = name self._name = name
else:
if self._user:
user = self._user
# Due to random user generated by 1.4 it's common to have a very long user
if len(user) > 14:
user = user[:11] + "..."
self._name = "{}://{}@{}:{}".format(self._protocol, user, self._host, self._port)
else:
self._name = "{}://{}:{}".format(self._protocol, self._host, self._port)
@property @property
def connected(self): def connected(self):

View File

@ -25,8 +25,7 @@ from gns3server.controller import Controller
from gns3server.compute import MODULES from gns3server.compute import MODULES
from gns3server.compute.port_manager import PortManager from gns3server.compute.port_manager import PortManager
from gns3server.utils.http_client import HTTPClient from gns3server.utils.http_client import HTTPClient
from gns3server.db.tasks import connect_to_db from gns3server.db.tasks import connect_to_db, get_computes
import logging import logging
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -57,11 +56,14 @@ def create_startup_handler(app: FastAPI) -> Callable:
# connect to the database # connect to the database
await connect_to_db(app) await connect_to_db(app)
await Controller.instance().start() # retrieve the computes from the database
computes = await get_computes(app)
await Controller.instance().start(computes)
# Because with a large image collection # Because with a large image collection
# without md5sum already computed we start the # without md5sum already computed we start the
# computing with server start # computing with server start
from gns3server.compute.qemu import Qemu from gns3server.compute.qemu import Qemu
asyncio.ensure_future(Qemu.instance().list_images()) asyncio.ensure_future(Qemu.instance().list_images())

View File

@ -17,6 +17,7 @@
from .base import Base from .base import Base
from .users import User from .users import User
from .computes import Compute
from .templates import ( from .templates import (
Template, Template,
CloudTemplate, CloudTemplate,

View File

@ -0,0 +1,33 @@
#!/usr/bin/env python
#
# Copyright (C) 2021 GNS3 Technologies Inc.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from sqlalchemy import Column, String
from .base import BaseTable, GUID
class Compute(BaseTable):
__tablename__ = "computes"
compute_id = Column(GUID, primary_key=True)
name = Column(String, index=True)
protocol = Column(String)
host = Column(String)
port = Column(String)
user = Column(String)
password = Column(String)

View File

@ -0,0 +1,88 @@
#!/usr/bin/env python
#
# Copyright (C) 2021 GNS3 Technologies Inc.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from uuid import UUID
from typing import Optional, List
from sqlalchemy import select, update, delete
from sqlalchemy.ext.asyncio import AsyncSession
from .base import BaseRepository
import gns3server.db.models as models
from gns3server.services import auth_service
from gns3server import schemas
class ComputesRepository(BaseRepository):
def __init__(self, db_session: AsyncSession) -> None:
super().__init__(db_session)
self._auth_service = auth_service
async def get_compute(self, compute_id: UUID) -> Optional[models.Compute]:
query = select(models.Compute).where(models.Compute.compute_id == compute_id)
result = await self._db_session.execute(query)
return result.scalars().first()
async def get_compute_by_name(self, name: str) -> Optional[models.Compute]:
query = select(models.Compute).where(models.Compute.name == name)
result = await self._db_session.execute(query)
return result.scalars().first()
async def get_computes(self) -> List[models.Compute]:
query = select(models.Compute)
result = await self._db_session.execute(query)
return result.scalars().all()
async def create_compute(self, compute_create: schemas.ComputeCreate) -> models.Compute:
db_compute = models.Compute(
compute_id=compute_create.compute_id,
name=compute_create.name,
protocol=compute_create.protocol.value,
host=compute_create.host,
port=compute_create.port,
user=compute_create.user,
password=compute_create.password
)
self._db_session.add(db_compute)
await self._db_session.commit()
await self._db_session.refresh(db_compute)
return db_compute
async def update_compute(self, compute_id: UUID, compute_update: schemas.ComputeUpdate) -> Optional[models.Compute]:
update_values = compute_update.dict(exclude_unset=True)
query = update(models.Compute) \
.where(models.Compute.compute_id == compute_id) \
.values(update_values)
await self._db_session.execute(query)
await self._db_session.commit()
return await self.get_compute(compute_id)
async def delete_compute(self, compute_id: UUID) -> bool:
query = delete(models.Compute).where(models.Compute.compute_id == compute_id)
result = await self._db_session.execute(query)
await self._db_session.commit()
return result.rowcount > 0

View File

@ -18,8 +18,14 @@
import os import os
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.encoders import jsonable_encoder
from pydantic import ValidationError
from typing import List
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from gns3server.db.repositories.computes import ComputesRepository
from gns3server import schemas
from .models import Base from .models import Base
from gns3server.config import Config from gns3server.config import Config
@ -40,3 +46,21 @@ async def connect_to_db(app: FastAPI) -> None:
app.state._db_engine = engine app.state._db_engine = engine
except SQLAlchemyError as e: except SQLAlchemyError as e:
log.error(f"Error while connecting to database '{db_url}: {e}") log.error(f"Error while connecting to database '{db_url}: {e}")
async def get_computes(app: FastAPI) -> List[dict]:
computes = []
async with AsyncSession(app.state._db_engine) as db_session:
db_computes = await ComputesRepository(db_session).get_computes()
for db_compute in db_computes:
try:
compute = jsonable_encoder(
schemas.Compute.from_orm(db_compute),
exclude_unset=True,
exclude={"created_at", "updated_at"})
except ValidationError as e:
log.error(f"Could not load compute '{db_compute.compute_id}' from database: {e}")
continue
computes.append(compute)
return computes

View File

@ -15,12 +15,13 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, validator
from typing import List, Optional, Union from typing import List, Optional, Union
from uuid import UUID from uuid import UUID, uuid4
from enum import Enum from enum import Enum
from .nodes import NodeType from .nodes import NodeType
from .base import DateTimeModelMixin
class Protocol(str, Enum): class Protocol(str, Enum):
@ -37,12 +38,11 @@ class ComputeBase(BaseModel):
Data to create a compute. Data to create a compute.
""" """
compute_id: Optional[Union[str, UUID]] = None
name: Optional[str] = None
protocol: Protocol protocol: Protocol
host: str host: str
port: int = Field(..., gt=0, le=65535) port: int = Field(..., gt=0, le=65535)
user: Optional[str] = None user: Optional[str] = None
name: Optional[str] = None
class ComputeCreate(ComputeBase): class ComputeCreate(ComputeBase):
@ -50,6 +50,7 @@ class ComputeCreate(ComputeBase):
Data to create a compute. Data to create a compute.
""" """
compute_id: Union[str, UUID] = Field(default_factory=uuid4)
password: Optional[str] = None password: Optional[str] = None
class Config: class Config:
@ -63,6 +64,24 @@ class ComputeCreate(ComputeBase):
} }
} }
@validator("name", always=True)
def generate_name(cls, name, values):
if name is not None:
return name
else:
protocol = values.get("protocol")
host = values.get("host")
port = values.get("port")
user = values.get("user")
if user:
# due to random user generated by 1.4 it's common to have a very long user
if len(user) > 14:
user = user[:11] + "..."
return "{}://{}@{}:{}".format(protocol, user, host, port)
else:
return "{}://{}:{}".format(protocol, host, port)
class ComputeUpdate(ComputeBase): class ComputeUpdate(ComputeBase):
""" """
@ -96,7 +115,7 @@ class Capabilities(BaseModel):
disk_size: int = Field(..., description="Disk size on this compute") disk_size: int = Field(..., description="Disk size on this compute")
class Compute(ComputeBase): class Compute(DateTimeModelMixin, ComputeBase):
""" """
Data returned for a compute. Data returned for a compute.
""" """
@ -110,6 +129,9 @@ class Compute(ComputeBase):
last_error: Optional[str] = Field(None, description="Last error found on the compute") last_error: Optional[str] = Field(None, description="Last error found on the compute")
capabilities: Optional[Capabilities] = None capabilities: Optional[Capabilities] = None
class Config:
orm_mode = True
class AutoIdlePC(BaseModel): class AutoIdlePC(BaseModel):
""" """

View File

@ -0,0 +1,84 @@
# -*- coding: utf-8 -*-
#
# Copyright (C) 2021 GNS3 Technologies Inc.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from uuid import UUID
from typing import List, Union
from gns3server import schemas
import gns3server.db.models as models
from gns3server.db.repositories.computes import ComputesRepository
from gns3server.controller import Controller
from gns3server.controller.controller_error import (
ControllerBadRequestError,
ControllerNotFoundError,
ControllerForbiddenError
)
class ComputesService:
def __init__(self, computes_repo: ComputesRepository):
self._computes_repo = computes_repo
self._controller = Controller.instance()
async def get_computes(self) -> List[models.Compute]:
db_computes = await self._computes_repo.get_computes()
return db_computes
async def create_compute(self, compute_create: schemas.ComputeCreate) -> models.Compute:
if await self._computes_repo.get_compute(compute_create.compute_id):
raise ControllerBadRequestError(f"Compute '{compute_create.compute_id}' is already registered")
db_compute = await self._computes_repo.create_compute(compute_create)
await self._controller.add_compute(compute_id=str(db_compute.compute_id),
connect=False,
**compute_create.dict(exclude_unset=True, exclude={"compute_id"}))
self._controller.notification.controller_emit("compute.created", db_compute.asjson())
return db_compute
async def get_compute(self, compute_id: Union[str, UUID]) -> models.Compute:
db_compute = await self._computes_repo.get_compute(compute_id)
if not db_compute:
raise ControllerNotFoundError(f"Compute '{compute_id}' not found")
return db_compute
async def update_compute(
self,
compute_id: Union[str, UUID],
compute_update: schemas.ComputeUpdate
) -> models.Compute:
compute = self._controller.get_compute(str(compute_id))
await compute.update(**compute_update.dict(exclude_unset=True))
db_compute = await self._computes_repo.update_compute(compute_id, compute_update)
if not db_compute:
raise ControllerNotFoundError(f"Compute '{compute_id}' not found")
self._controller.notification.controller_emit("compute.updated", db_compute.asjson())
return db_compute
async def delete_compute(self, compute_id: Union[str, UUID]) -> None:
if await self._computes_repo.delete_compute(compute_id):
await self._controller.delete_compute(str(compute_id))
self._controller.notification.controller_emit("compute.deleted", {"compute_id": str(compute_id)})
else:
raise ControllerNotFoundError(f"Compute '{compute_id}' not found")

View File

@ -15,12 +15,13 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import uuid
import pytest import pytest
from fastapi import FastAPI, status from fastapi import FastAPI, status
from httpx import AsyncClient from httpx import AsyncClient
from gns3server.controller import Controller from gns3server.schemas.computes import Compute
pytestmark = pytest.mark.asyncio pytestmark = pytest.mark.asyncio
@ -28,234 +29,167 @@ import unittest
from tests.utils import asyncio_patch from tests.utils import asyncio_patch
async def test_compute_create_without_id(app: FastAPI, client: AsyncClient, controller: Controller) -> None: class TestComputeRoutes:
params = { async def test_compute_create(self, app: FastAPI, client: AsyncClient) -> None:
"protocol": "http",
"host": "localhost",
"port": 84,
"user": "julien",
"password": "secure"}
response = await client.post(app.url_path_for("create_compute"), json=params) params = {
assert response.status_code == status.HTTP_201_CREATED "protocol": "http",
response_content = response.json() "host": "localhost",
assert response_content["user"] == "julien" "port": 84,
assert response_content["compute_id"] is not None "user": "julien",
assert "password" not in response_content "password": "secure"}
assert len(controller.computes) == 1
assert controller.computes[response_content["compute_id"]].host == "localhost" response = await client.post(app.url_path_for("create_compute"), json=params)
assert response.status_code == status.HTTP_201_CREATED
assert response.json()["compute_id"] is not None
del params["password"]
for param, value in params.items():
assert response.json()[param] == value
async def test_compute_create_with_id(self, app: FastAPI, client: AsyncClient) -> None:
compute_id = str(uuid.uuid4())
params = {
"compute_id": compute_id,
"protocol": "http",
"host": "localhost",
"port": 84,
"user": "julien",
"password": "secure"}
response = await client.post(app.url_path_for("create_compute"), json=params)
assert response.status_code == status.HTTP_201_CREATED
assert response.json()["compute_id"] == compute_id
del params["password"]
for param, value in params.items():
assert response.json()[param] == value
async def test_compute_list(self, app: FastAPI, client: AsyncClient) -> None:
response = await client.get(app.url_path_for("get_computes"))
assert response.status_code == status.HTTP_200_OK
assert len(response.json()) > 0
async def test_compute_get(self, app: FastAPI, client: AsyncClient, test_compute: Compute) -> None:
response = await client.get(app.url_path_for("get_compute", compute_id=test_compute.compute_id))
assert response.status_code == status.HTTP_200_OK
assert response.json()["compute_id"] == str(test_compute.compute_id)
async def test_compute_update(self, app: FastAPI, client: AsyncClient, test_compute: Compute) -> None:
params = {
"protocol": "http",
"host": "localhost",
"port": 84,
"user": "julien",
"password": "secure"
}
response = await client.post(app.url_path_for("create_compute"), json=params)
assert response.status_code == status.HTTP_201_CREATED
compute_id = response.json()["compute_id"]
params["protocol"] = "https"
response = await client.put(app.url_path_for("update_compute", compute_id=compute_id), json=params)
assert response.status_code == status.HTTP_200_OK
del params["password"]
for param, value in params.items():
assert response.json()[param] == value
async def test_compute_delete(self, app: FastAPI, client: AsyncClient, test_compute: Compute) -> None:
response = await client.delete(app.url_path_for("delete_compute", compute_id=test_compute.compute_id))
assert response.status_code == status.HTTP_204_NO_CONTENT
async def test_compute_create_with_id(app: FastAPI, client: AsyncClient, controller: Controller) -> None: class TestComputeFeatures:
params = { async def test_compute_list_images(self, app: FastAPI, client: AsyncClient) -> None:
"compute_id": "my_compute_id",
"protocol": "http",
"host": "localhost",
"port": 84,
"user": "julien",
"password": "secure"}
response = await client.post(app.url_path_for("create_compute"), json=params) params = {
assert response.status_code == status.HTTP_201_CREATED "protocol": "http",
assert response.json()["user"] == "julien" "host": "localhost",
assert "password" not in response.json() "port": 84,
assert len(controller.computes) == 1 "user": "julien",
assert controller.computes["my_compute_id"].host == "localhost" "password": "secure"
}
response = await client.post(app.url_path_for("create_compute"), json=params)
assert response.status_code == status.HTTP_201_CREATED
compute_id = response.json()["compute_id"]
async def test_compute_get(app: FastAPI, client: AsyncClient, controller: Controller) -> None: with asyncio_patch("gns3server.controller.compute.Compute.images", return_value=[{"filename": "linux.qcow2"}, {"filename": "asav.qcow2"}]) as mock:
response = await client.get(app.url_path_for("delete_compute", compute_id=compute_id) + "/qemu/images")
assert response.json() == [{"filename": "linux.qcow2"}, {"filename": "asav.qcow2"}]
mock.assert_called_with("qemu")
params = { async def test_compute_list_vms(self, app: FastAPI, client: AsyncClient) -> None:
"compute_id": "my_compute_id",
"protocol": "http",
"host": "localhost",
"port": 84,
"user": "julien",
"password": "secure"
}
response = await client.post(app.url_path_for("create_compute"), json=params) params = {
assert response.status_code == status.HTTP_201_CREATED "protocol": "http",
"host": "localhost",
"port": 84,
"user": "julien",
"password": "secure"
}
response = await client.post(app.url_path_for("get_computes"), json=params)
assert response.status_code == status.HTTP_201_CREATED
compute_id = response.json()["compute_id"]
response = await client.get(app.url_path_for("update_compute", compute_id="my_compute_id")) with asyncio_patch("gns3server.controller.compute.Compute.forward", return_value=[]) as mock:
assert response.status_code == status.HTTP_200_OK response = await client.get(app.url_path_for("get_compute", compute_id=compute_id) + "/virtualbox/vms")
mock.assert_called_with("GET", "virtualbox", "vms")
assert response.json() == []
async def test_compute_create_img(self, app: FastAPI, client: AsyncClient) -> None:
async def test_compute_update(app: FastAPI, client: AsyncClient) -> None: params = {
"protocol": "http",
"host": "localhost",
"port": 84,
"user": "julien",
"password": "secure"
}
params = { response = await client.post(app.url_path_for("get_computes"), json=params)
"compute_id": "my_compute_id", assert response.status_code == status.HTTP_201_CREATED
"protocol": "http", compute_id = response.json()["compute_id"]
"host": "localhost",
"port": 84,
"user": "julien",
"password": "secure"
}
response = await client.post(app.url_path_for("create_compute"), json=params) params = {"path": "/test"}
assert response.status_code == status.HTTP_201_CREATED with asyncio_patch("gns3server.controller.compute.Compute.forward", return_value=[]) as mock:
response = await client.get(app.url_path_for("get_compute", compute_id="my_compute_id")) response = await client.post(app.url_path_for("get_compute", compute_id=compute_id) + "/qemu/img", json=params)
assert response.status_code == status.HTTP_200_OK assert response.json() == []
assert response.json()["protocol"] == "http" mock.assert_called_with("POST", "qemu", "img", data=unittest.mock.ANY)
params["protocol"] = "https" # async def test_compute_autoidlepc(self, app: FastAPI, client: AsyncClient) -> None:
response = await client.put(app.url_path_for("update_compute", compute_id="my_compute_id"), json=params) #
# params = {
assert response.status_code == status.HTTP_200_OK # "protocol": "http",
assert response.json()["protocol"] == "https" # "host": "localhost",
# "port": 84,
# "user": "julien",
async def test_compute_list(app: FastAPI, client: AsyncClient, controller: Controller) -> None: # "password": "secure"
# }
params = { #
"compute_id": "my_compute_id", # response = await client.post(app.url_path_for("get_computes"), json=params)
"protocol": "http", # assert response.status_code == status.HTTP_201_CREATED
"host": "localhost", # compute_id = response.json()["compute_id"]
"port": 84, #
"user": "julien", # params = {
"password": "secure", # "platform": "c7200",
"name": "My super server" # "image": "test.bin",
} # "ram": 512
# }
response = await client.post(app.url_path_for("create_compute"), json=params) #
assert response.status_code == status.HTTP_201_CREATED # with asyncio_patch("gns3server.controller.Controller.autoidlepc", return_value={"idlepc": "0x606de20c"}) as mock:
assert response.json()["user"] == "julien" # response = await client.post(app.url_path_for("autoidlepc", compute_id=compute_id) + "/auto_idlepc", json=params)
assert "password" not in response.json() # assert mock.called
# assert response.status_code == status.HTTP_200_OK
response = await client.get(app.url_path_for("get_computes"))
for compute in response.json():
if compute['compute_id'] != 'local':
assert compute == {
'compute_id': 'my_compute_id',
'connected': False,
'host': 'localhost',
'port': 84,
'protocol': 'http',
'user': 'julien',
'name': 'My super server',
'cpu_usage_percent': 0.0,
'memory_usage_percent': 0.0,
'disk_usage_percent': 0.0,
'last_error': None,
'capabilities': {
'version': '',
'platform': '',
'cpus': 0,
'memory': 0,
'disk_size': 0,
'node_types': []
}
}
async def test_compute_delete(app: FastAPI, client: AsyncClient, controller: Controller) -> None:
params = {
"compute_id": "my_compute_id",
"protocol": "http",
"host": "localhost",
"port": 84,
"user": "julien",
"password": "secure"
}
response = await client.post(app.url_path_for("create_compute"), json=params)
assert response.status_code == status.HTTP_201_CREATED
response = await client.get(app.url_path_for("get_computes"))
assert len(response.json()) == 1
response = await client.delete(app.url_path_for("delete_compute", compute_id="my_compute_id"))
assert response.status_code == status.HTTP_204_NO_CONTENT
response = await client.get(app.url_path_for("get_computes"))
assert len(response.json()) == 0
async def test_compute_list_images(app: FastAPI, client: AsyncClient) -> None:
params = {
"compute_id": "my_compute_id",
"protocol": "http",
"host": "localhost",
"port": 84,
"user": "julien",
"password": "secure"
}
response = await client.post(app.url_path_for("create_compute"), json=params)
assert response.status_code == status.HTTP_201_CREATED
with asyncio_patch("gns3server.controller.compute.Compute.images", return_value=[{"filename": "linux.qcow2"}, {"filename": "asav.qcow2"}]) as mock:
response = await client.get(app.url_path_for("delete_compute", compute_id="my_compute_id") + "/qemu/images")
assert response.json() == [{"filename": "linux.qcow2"}, {"filename": "asav.qcow2"}]
mock.assert_called_with("qemu")
async def test_compute_list_vms(app: FastAPI, client: AsyncClient) -> None:
params = {
"compute_id": "my_compute",
"protocol": "http",
"host": "localhost",
"port": 84,
"user": "julien",
"password": "secure"
}
response = await client.post(app.url_path_for("get_computes"), json=params)
assert response.status_code == status.HTTP_201_CREATED
with asyncio_patch("gns3server.controller.compute.Compute.forward", return_value=[]) as mock:
response = await client.get(app.url_path_for("get_compute", compute_id="my_compute_id") + "/virtualbox/vms")
mock.assert_called_with("GET", "virtualbox", "vms")
assert response.json() == []
async def test_compute_create_img(app: FastAPI, client: AsyncClient) -> None:
params = {
"compute_id": "my_compute",
"protocol": "http",
"host": "localhost",
"port": 84,
"user": "julien",
"password": "secure"
}
response = await client.post(app.url_path_for("get_computes"), json=params)
assert response.status_code == status.HTTP_201_CREATED
params = {"path": "/test"}
with asyncio_patch("gns3server.controller.compute.Compute.forward", return_value=[]) as mock:
response = await client.post(app.url_path_for("get_compute", compute_id="my_compute_id") + "/qemu/img", json=params)
assert response.json() == []
mock.assert_called_with("POST", "qemu", "img", data=unittest.mock.ANY)
async def test_compute_autoidlepc(app: FastAPI, client: AsyncClient) -> None:
params = {
"compute_id": "my_compute_id",
"protocol": "http",
"host": "localhost",
"port": 84,
"user": "julien",
"password": "secure"
}
await client.post(app.url_path_for("get_computes"), json=params)
params = {
"platform": "c7200",
"image": "test.bin",
"ram": 512
}
with asyncio_patch("gns3server.controller.Controller.autoidlepc", return_value={"idlepc": "0x606de20c"}) as mock:
response = await client.post(app.url_path_for("get_compute", compute_id="my_compute_id") + "/auto_idlepc", json=params)
assert mock.called
assert response.status_code == status.HTTP_200_OK
# FIXME # FIXME

View File

@ -4,6 +4,7 @@ import tempfile
import shutil import shutil
import sys import sys
import os import os
import uuid
from fastapi import FastAPI from fastapi import FastAPI
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
@ -16,10 +17,12 @@ from gns3server.config import Config
from gns3server.compute import MODULES from gns3server.compute import MODULES
from gns3server.compute.port_manager import PortManager from gns3server.compute.port_manager import PortManager
from gns3server.compute.project_manager import ProjectManager from gns3server.compute.project_manager import ProjectManager
from gns3server.db.models import Base, User from gns3server.db.models import Base, User, Compute
from gns3server.db.repositories.users import UsersRepository from gns3server.db.repositories.users import UsersRepository
from gns3server.db.repositories.computes import ComputesRepository
from gns3server.api.routes.controller.dependencies.database import get_db_session from gns3server.api.routes.controller.dependencies.database import get_db_session
from gns3server.schemas.users import UserCreate from gns3server import schemas
from gns3server.schemas.computes import Protocol
from gns3server.services import auth_service from gns3server.services import auth_service
sys._called_from_test = True sys._called_from_test = True
@ -27,7 +30,7 @@ sys.original_platform = sys.platform
if sys.platform.startswith("win") and sys.version_info < (3, 8): if sys.platform.startswith("win") and sys.version_info < (3, 8):
@pytest.yield_fixture(scope="session") @pytest.fixture(scope="session")
def event_loop(request): def event_loop(request):
""" """
Overwrite pytest_asyncio event loop on Windows for Python < 3.8 Overwrite pytest_asyncio event loop on Windows for Python < 3.8
@ -43,7 +46,7 @@ if sys.platform.startswith("win") and sys.version_info < (3, 8):
# https://github.com/pytest-dev/pytest-asyncio/issues/68 # https://github.com/pytest-dev/pytest-asyncio/issues/68
# this event_loop is used by pytest-asyncio, and redefining it # this event_loop is used by pytest-asyncio, and redefining it
# is currently the only way of changing the scope of this fixture # is currently the only way of changing the scope of this fixture
@pytest.yield_fixture(scope="class") @pytest.fixture(scope="class")
def event_loop(request): def event_loop(request):
loop = asyncio.get_event_loop_policy().new_event_loop() loop = asyncio.get_event_loop_policy().new_event_loop()
@ -54,9 +57,6 @@ def event_loop(request):
@pytest.fixture(scope="class") @pytest.fixture(scope="class")
async def app() -> FastAPI: async def app() -> FastAPI:
# async with db_engine.begin() as conn:
# await conn.run_sync(Base.metadata.drop_all)
# await conn.run_sync(Base.metadata.create_all)
from gns3server.api.server import app as gns3app from gns3server.api.server import app as gns3app
yield gns3app yield gns3app
@ -109,7 +109,7 @@ async def client(app: FastAPI, db_session: AsyncSession) -> AsyncClient:
@pytest.fixture @pytest.fixture
async def test_user(db_session: AsyncSession) -> User: async def test_user(db_session: AsyncSession) -> User:
new_user = UserCreate( new_user = schemas.UserCreate(
username="user1", username="user1",
email="user1@email.com", email="user1@email.com",
password="user1_password", password="user1_password",
@ -121,6 +121,25 @@ async def test_user(db_session: AsyncSession) -> User:
return await user_repo.create_user(new_user) return await user_repo.create_user(new_user)
@pytest.fixture
async def test_compute(db_session: AsyncSession) -> Compute:
new_compute = schemas.ComputeCreate(
compute_id=uuid.uuid4(),
protocol=Protocol.http,
host="localhost",
port=4242,
user="julien",
password="secure"
)
compute_repo = ComputesRepository(db_session)
existing_compute = await compute_repo.get_compute(new_compute.compute_id)
if existing_compute:
return existing_compute
return await compute_repo.create_compute(new_compute)
@pytest.fixture @pytest.fixture
def authorized_client(client: AsyncClient, test_user: User) -> AsyncClient: def authorized_client(client: AsyncClient, test_user: User) -> AsyncClient: