From 10fdd8fcf44bf54bc6af6f864f05b5473af61e87 Mon Sep 17 00:00:00 2001 From: grossmj Date: Fri, 24 Dec 2021 13:05:39 +1030 Subject: [PATCH] Add connect endpoint for computes Param to connect to compute after creation Report compute unauthorized HTTP errors to client --- gns3server/api/routes/controller/computes.py | 17 +++++++++++++-- gns3server/api/server.py | 8 +++---- gns3server/controller/__init__.py | 23 ++++++++++++-------- gns3server/controller/compute.py | 9 +++++--- gns3server/db/repositories/computes.py | 8 ++----- gns3server/db/repositories/users.py | 2 +- gns3server/db/tasks.py | 4 +--- gns3server/schemas/controller/computes.py | 8 +++++-- gns3server/services/computes.py | 10 ++++----- tests/controller/test_controller.py | 2 +- 10 files changed, 55 insertions(+), 36 deletions(-) diff --git a/gns3server/api/routes/controller/computes.py b/gns3server/api/routes/controller/computes.py index 34e10f5c..aa1f5e58 100644 --- a/gns3server/api/routes/controller/computes.py +++ b/gns3server/api/routes/controller/computes.py @@ -19,7 +19,7 @@ API routes for computes. """ from fastapi import APIRouter, Depends, Response, status -from typing import List, Union +from typing import List, Union, Optional from uuid import UUID from gns3server.controller import Controller @@ -47,12 +47,25 @@ router = APIRouter(responses=responses) async def create_compute( compute_create: schemas.ComputeCreate, computes_repo: ComputesRepository = Depends(get_repository(ComputesRepository)), + connect: Optional[bool] = False ) -> schemas.Compute: """ Create a new compute on the controller. """ - return await ComputesService(computes_repo).create_compute(compute_create) + return await ComputesService(computes_repo).create_compute(compute_create, connect) + + +@router.post("/{compute_id}/connect", status_code=status.HTTP_204_NO_CONTENT) +async def connect_compute(compute_id: Union[str, UUID]) -> Response: + """ + Connect to compute on the controller. + """ + + compute = Controller.instance().get_compute(str(compute_id)) + if not compute.connected: + await compute.connect(report_failed_connection=True) + return Response(status_code=status.HTTP_204_NO_CONTENT) @router.get("/{compute_id}", response_model=schemas.Compute, response_model_exclude_unset=True) diff --git a/gns3server/api/server.py b/gns3server/api/server.py index d0f14d78..3c9fd2b6 100644 --- a/gns3server/api/server.py +++ b/gns3server/api/server.py @@ -21,8 +21,7 @@ FastAPI app import time -from fastapi import FastAPI, Request -from starlette.exceptions import HTTPException as StarletteHTTPException +from fastapi import FastAPI, Request, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from sqlalchemy.exc import SQLAlchemyError @@ -140,11 +139,12 @@ async def controller_bad_request_error_handler(request: Request, exc: Controller # make sure the content key is "message", not "detail" per default -@app.exception_handler(StarletteHTTPException) -async def http_exception_handler(request: Request, exc: StarletteHTTPException): +@app.exception_handler(HTTPException) +async def http_exception_handler(request: Request, exc: HTTPException): return JSONResponse( status_code=exc.status_code, content={"message": exc.detail}, + headers=exc.headers ) diff --git a/gns3server/controller/__init__.py b/gns3server/controller/__init__.py index 2622993d..ff5b9f7a 100644 --- a/gns3server/controller/__init__.py +++ b/gns3server/controller/__init__.py @@ -74,10 +74,6 @@ class Controller: if host == "0.0.0.0": host = "127.0.0.1" - name = socket.gethostname() - if name == "gns3vm": - name = "Main server" - self._load_controller_settings() if server_config.enable_ssl: @@ -93,7 +89,7 @@ class Controller: try: self._local_server = await self.add_compute( compute_id="local", - name=name, + name=f"{socket.gethostname()} (controller)", protocol=protocol, host=host, console_host=console_host, @@ -102,6 +98,7 @@ class Controller: password=server_config.compute_password, force=True, connect=True, + wait_connection=False, ssl_context=self._ssl_context, ) except ControllerError: @@ -113,7 +110,12 @@ class Controller: if computes: for c in computes: try: - await self.add_compute(**c, connect=False) + #FIXME: Task exception was never retrieved + await self.add_compute( + compute_id=str(c.compute_id), + connect=False, + **c.dict(exclude_unset=True, exclude={"compute_id", "created_at", "updated_at"}), + ) except (ControllerError, KeyError): pass # Skip not available servers at loading @@ -341,7 +343,7 @@ class Controller: os.makedirs(configs_path, exist_ok=True) return configs_path - async def add_compute(self, compute_id=None, name=None, force=False, connect=True, **kwargs): + async def add_compute(self, compute_id=None, name=None, force=False, connect=True, wait_connection=True, **kwargs): """ Add a server to the dictionary of computes controlled by this controller @@ -371,8 +373,11 @@ class Controller: self._computes[compute.id] = compute # self.save() if connect: - # call compute.connect() later to give time to the controller to be fully started - asyncio.get_event_loop().call_later(1, lambda: asyncio.ensure_future(compute.connect())) + if wait_connection: + await compute.connect() + else: + # call compute.connect() later to give time to the controller to be fully started + asyncio.get_event_loop().call_later(1, lambda: asyncio.ensure_future(compute.connect())) self.notification.controller_emit("compute.created", compute.asdict()) return compute else: diff --git a/gns3server/controller/compute.py b/gns3server/controller/compute.py index 4add4e90..f23d450e 100644 --- a/gns3server/controller/compute.py +++ b/gns3server/controller/compute.py @@ -154,6 +154,7 @@ class Compute: return self._interfaces_cache async def update(self, **kwargs): + for kw in kwargs: if kw not in ("user", "password"): setattr(self, kw, kwargs[kw]) @@ -373,7 +374,7 @@ class Compute: pass @locking - async def connect(self): + async def connect(self, report_failed_connection=False): """ Check if remote server is accessible """ @@ -383,6 +384,8 @@ class Compute: log.info(f"Connecting to compute '{self._id}'") response = await self._run_http_query("GET", "/capabilities") except ComputeError as e: + if report_failed_connection: + raise log.warning(f"Cannot connect to compute '{self._id}': {e}") # Try to reconnect after 5 seconds if server unavailable only if not during tests (otherwise we create a ressource usage bomb) if not hasattr(sys, "_called_from_test") or not sys._called_from_test: @@ -491,7 +494,7 @@ class Compute: # Try to reconnect after 1 second if server unavailable only if not during tests (otherwise we create a ressources usage bomb) from gns3server.api.server import app if not app.state.exiting and not hasattr(sys, "_called_from_test"): - log.info(f"Reconnecting to to compute '{self._id}' WebSocket '{ws_url}'") + log.info(f"Reconnecting to compute '{self._id}' WebSocket '{ws_url}'") asyncio.get_event_loop().call_later(1, lambda: asyncio.ensure_future(self.connect())) self._cpu_usage_percent = None @@ -572,7 +575,7 @@ class Compute: msg = "" if response.status == 401: - raise ControllerUnauthorizedError(f"Invalid authentication for compute {self.id}") + raise ControllerUnauthorizedError(f"Invalid authentication for compute '{self.name}' [{self.id}]") elif response.status == 403: raise ControllerForbiddenError(msg) elif response.status == 404: diff --git a/gns3server/db/repositories/computes.py b/gns3server/db/repositories/computes.py index 2ea00bbd..e5dfb1f8 100644 --- a/gns3server/db/repositories/computes.py +++ b/gns3server/db/repositories/computes.py @@ -52,18 +52,14 @@ class ComputesRepository(BaseRepository): async def create_compute(self, compute_create: schemas.ComputeCreate) -> models.Compute: - password = compute_create.password - if password: - password = password.get_secret_value() - db_compute = models.Compute( compute_id=compute_create.compute_id, name=compute_create.name, - protocol=compute_create.protocol.value, + protocol=compute_create.protocol, host=compute_create.host, port=compute_create.port, user=compute_create.user, - password=password, + password=compute_create.password.get_secret_value(), ) self._db_session.add(db_compute) await self._db_session.commit() diff --git a/gns3server/db/repositories/users.py b/gns3server/db/repositories/users.py index 2db1516f..310d67a2 100644 --- a/gns3server/db/repositories/users.py +++ b/gns3server/db/repositories/users.py @@ -125,7 +125,7 @@ class UsersRepository(BaseRepository): async def authenticate_user(self, username: str, password: str) -> Optional[models.User]: """ - Authenticate an user. + Authenticate user. """ user = await self.get_user_by_username(username) diff --git a/gns3server/db/tasks.py b/gns3server/db/tasks.py index c3c5ec0c..7c40a6fb 100644 --- a/gns3server/db/tasks.py +++ b/gns3server/db/tasks.py @@ -68,9 +68,7 @@ async def get_computes(app: FastAPI) -> List[dict]: db_computes = await ComputesRepository(db_session).get_computes() for db_compute in db_computes: try: - compute = jsonable_encoder( - schemas.Compute.from_orm(db_compute), exclude_unset=True, exclude={"created_at", "updated_at"} - ) + compute = schemas.Compute.from_orm(db_compute) except ValidationError as e: log.error(f"Could not load compute '{db_compute.compute_id}' from database: {e}") continue diff --git a/gns3server/schemas/controller/computes.py b/gns3server/schemas/controller/computes.py index 00343e64..1c035dc7 100644 --- a/gns3server/schemas/controller/computes.py +++ b/gns3server/schemas/controller/computes.py @@ -41,9 +41,13 @@ class ComputeBase(BaseModel): protocol: Protocol host: str port: int = Field(..., gt=0, le=65535) - user: Optional[str] = None + user: str = None + password: Optional[SecretStr] = None name: Optional[str] = None + class Config: + use_enum_values = True + class ComputeCreate(ComputeBase): """ @@ -51,7 +55,6 @@ class ComputeCreate(ComputeBase): """ compute_id: Union[str, uuid.UUID] = None - password: Optional[SecretStr] = None class Config: schema_extra = { @@ -102,6 +105,7 @@ class ComputeUpdate(ComputeBase): protocol: Optional[Protocol] = None host: Optional[str] = None port: Optional[int] = Field(None, gt=0, le=65535) + user: Optional[str] = None password: Optional[SecretStr] = None class Config: diff --git a/gns3server/services/computes.py b/gns3server/services/computes.py index 737fc13f..46a58b0e 100644 --- a/gns3server/services/computes.py +++ b/gns3server/services/computes.py @@ -41,17 +41,17 @@ class ComputesService: db_computes = await self._computes_repo.get_computes() return db_computes - async def create_compute(self, compute_create: schemas.ComputeCreate) -> models.Compute: + async def create_compute(self, compute_create: schemas.ComputeCreate, connect: bool = False) -> models.Compute: if await self._computes_repo.get_compute(compute_create.compute_id): raise ControllerBadRequestError(f"Compute '{compute_create.compute_id}' is already registered") db_compute = await self._computes_repo.create_compute(compute_create) - await self._controller.add_compute( + compute = await self._controller.add_compute( compute_id=str(db_compute.compute_id), - connect=False, + connect=connect, **compute_create.dict(exclude_unset=True, exclude={"compute_id"}), ) - self._controller.notification.controller_emit("compute.created", db_compute.asjson()) + self._controller.notification.controller_emit("compute.created", compute.asdict()) return db_compute async def get_compute(self, compute_id: Union[str, UUID]) -> models.Compute: @@ -70,7 +70,7 @@ class ComputesService: db_compute = await self._computes_repo.update_compute(compute_id, compute_update) if not db_compute: raise ControllerNotFoundError(f"Compute '{compute_id}' not found") - self._controller.notification.controller_emit("compute.updated", db_compute.asjson()) + self._controller.notification.controller_emit("compute.updated", compute.asdict()) return db_compute async def delete_compute(self, compute_id: Union[str, UUID]) -> None: diff --git a/tests/controller/test_controller.py b/tests/controller/test_controller.py index 9f2060e9..28c77709 100644 --- a/tests/controller/test_controller.py +++ b/tests/controller/test_controller.py @@ -248,7 +248,7 @@ async def test_start(controller): await controller.start() #assert mock.called assert len(controller.computes) == 1 # Local compute is created - assert controller.computes["local"].name == socket.gethostname() + assert controller.computes["local"].name == f"{socket.gethostname()} (controller)" @pytest.mark.asyncio