Add connect endpoint for computes

Param to connect to compute after creation
Report compute unauthorized HTTP errors to client
pull/2025/head
grossmj 3 years ago
parent 36cf43475d
commit 10fdd8fcf4

@ -19,7 +19,7 @@ API routes for computes.
""" """
from fastapi import APIRouter, Depends, Response, status from fastapi import APIRouter, Depends, Response, status
from typing import List, Union from typing import List, Union, Optional
from uuid import UUID from uuid import UUID
from gns3server.controller import Controller from gns3server.controller import Controller
@ -47,12 +47,25 @@ router = APIRouter(responses=responses)
async def create_compute( async def create_compute(
compute_create: schemas.ComputeCreate, compute_create: schemas.ComputeCreate,
computes_repo: ComputesRepository = Depends(get_repository(ComputesRepository)), computes_repo: ComputesRepository = Depends(get_repository(ComputesRepository)),
connect: Optional[bool] = False
) -> schemas.Compute: ) -> schemas.Compute:
""" """
Create a new compute on the controller. Create a new compute on the controller.
""" """
return await ComputesService(computes_repo).create_compute(compute_create) return await ComputesService(computes_repo).create_compute(compute_create, connect)
@router.post("/{compute_id}/connect", status_code=status.HTTP_204_NO_CONTENT)
async def connect_compute(compute_id: Union[str, UUID]) -> Response:
"""
Connect to compute on the controller.
"""
compute = Controller.instance().get_compute(str(compute_id))
if not compute.connected:
await compute.connect(report_failed_connection=True)
return Response(status_code=status.HTTP_204_NO_CONTENT)
@router.get("/{compute_id}", response_model=schemas.Compute, response_model_exclude_unset=True) @router.get("/{compute_id}", response_model=schemas.Compute, response_model_exclude_unset=True)

@ -21,8 +21,7 @@ FastAPI app
import time import time
from fastapi import FastAPI, Request from fastapi import FastAPI, Request, HTTPException
from starlette.exceptions import HTTPException as StarletteHTTPException
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
@ -140,11 +139,12 @@ async def controller_bad_request_error_handler(request: Request, exc: Controller
# make sure the content key is "message", not "detail" per default # make sure the content key is "message", not "detail" per default
@app.exception_handler(StarletteHTTPException) @app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: StarletteHTTPException): async def http_exception_handler(request: Request, exc: HTTPException):
return JSONResponse( return JSONResponse(
status_code=exc.status_code, status_code=exc.status_code,
content={"message": exc.detail}, content={"message": exc.detail},
headers=exc.headers
) )

@ -74,10 +74,6 @@ class Controller:
if host == "0.0.0.0": if host == "0.0.0.0":
host = "127.0.0.1" host = "127.0.0.1"
name = socket.gethostname()
if name == "gns3vm":
name = "Main server"
self._load_controller_settings() self._load_controller_settings()
if server_config.enable_ssl: if server_config.enable_ssl:
@ -93,7 +89,7 @@ class Controller:
try: try:
self._local_server = await self.add_compute( self._local_server = await self.add_compute(
compute_id="local", compute_id="local",
name=name, name=f"{socket.gethostname()} (controller)",
protocol=protocol, protocol=protocol,
host=host, host=host,
console_host=console_host, console_host=console_host,
@ -102,6 +98,7 @@ class Controller:
password=server_config.compute_password, password=server_config.compute_password,
force=True, force=True,
connect=True, connect=True,
wait_connection=False,
ssl_context=self._ssl_context, ssl_context=self._ssl_context,
) )
except ControllerError: except ControllerError:
@ -113,7 +110,12 @@ class Controller:
if computes: if computes:
for c in computes: for c in computes:
try: try:
await self.add_compute(**c, connect=False) #FIXME: Task exception was never retrieved
await self.add_compute(
compute_id=str(c.compute_id),
connect=False,
**c.dict(exclude_unset=True, exclude={"compute_id", "created_at", "updated_at"}),
)
except (ControllerError, KeyError): except (ControllerError, KeyError):
pass # Skip not available servers at loading pass # Skip not available servers at loading
@ -341,7 +343,7 @@ class Controller:
os.makedirs(configs_path, exist_ok=True) os.makedirs(configs_path, exist_ok=True)
return configs_path return configs_path
async def add_compute(self, compute_id=None, name=None, force=False, connect=True, **kwargs): async def add_compute(self, compute_id=None, name=None, force=False, connect=True, wait_connection=True, **kwargs):
""" """
Add a server to the dictionary of computes controlled by this controller Add a server to the dictionary of computes controlled by this controller
@ -371,8 +373,11 @@ class Controller:
self._computes[compute.id] = compute self._computes[compute.id] = compute
# self.save() # self.save()
if connect: if connect:
# call compute.connect() later to give time to the controller to be fully started if wait_connection:
asyncio.get_event_loop().call_later(1, lambda: asyncio.ensure_future(compute.connect())) await compute.connect()
else:
# call compute.connect() later to give time to the controller to be fully started
asyncio.get_event_loop().call_later(1, lambda: asyncio.ensure_future(compute.connect()))
self.notification.controller_emit("compute.created", compute.asdict()) self.notification.controller_emit("compute.created", compute.asdict())
return compute return compute
else: else:

@ -154,6 +154,7 @@ class Compute:
return self._interfaces_cache return self._interfaces_cache
async def update(self, **kwargs): async def update(self, **kwargs):
for kw in kwargs: for kw in kwargs:
if kw not in ("user", "password"): if kw not in ("user", "password"):
setattr(self, kw, kwargs[kw]) setattr(self, kw, kwargs[kw])
@ -373,7 +374,7 @@ class Compute:
pass pass
@locking @locking
async def connect(self): async def connect(self, report_failed_connection=False):
""" """
Check if remote server is accessible Check if remote server is accessible
""" """
@ -383,6 +384,8 @@ class Compute:
log.info(f"Connecting to compute '{self._id}'") log.info(f"Connecting to compute '{self._id}'")
response = await self._run_http_query("GET", "/capabilities") response = await self._run_http_query("GET", "/capabilities")
except ComputeError as e: except ComputeError as e:
if report_failed_connection:
raise
log.warning(f"Cannot connect to compute '{self._id}': {e}") log.warning(f"Cannot connect to compute '{self._id}': {e}")
# Try to reconnect after 5 seconds if server unavailable only if not during tests (otherwise we create a ressource usage bomb) # Try to reconnect after 5 seconds if server unavailable only if not during tests (otherwise we create a ressource usage bomb)
if not hasattr(sys, "_called_from_test") or not sys._called_from_test: if not hasattr(sys, "_called_from_test") or not sys._called_from_test:
@ -491,7 +494,7 @@ class Compute:
# Try to reconnect after 1 second if server unavailable only if not during tests (otherwise we create a ressources usage bomb) # Try to reconnect after 1 second if server unavailable only if not during tests (otherwise we create a ressources usage bomb)
from gns3server.api.server import app from gns3server.api.server import app
if not app.state.exiting and not hasattr(sys, "_called_from_test"): if not app.state.exiting and not hasattr(sys, "_called_from_test"):
log.info(f"Reconnecting to to compute '{self._id}' WebSocket '{ws_url}'") log.info(f"Reconnecting to compute '{self._id}' WebSocket '{ws_url}'")
asyncio.get_event_loop().call_later(1, lambda: asyncio.ensure_future(self.connect())) asyncio.get_event_loop().call_later(1, lambda: asyncio.ensure_future(self.connect()))
self._cpu_usage_percent = None self._cpu_usage_percent = None
@ -572,7 +575,7 @@ class Compute:
msg = "" msg = ""
if response.status == 401: if response.status == 401:
raise ControllerUnauthorizedError(f"Invalid authentication for compute {self.id}") raise ControllerUnauthorizedError(f"Invalid authentication for compute '{self.name}' [{self.id}]")
elif response.status == 403: elif response.status == 403:
raise ControllerForbiddenError(msg) raise ControllerForbiddenError(msg)
elif response.status == 404: elif response.status == 404:

@ -52,18 +52,14 @@ class ComputesRepository(BaseRepository):
async def create_compute(self, compute_create: schemas.ComputeCreate) -> models.Compute: async def create_compute(self, compute_create: schemas.ComputeCreate) -> models.Compute:
password = compute_create.password
if password:
password = password.get_secret_value()
db_compute = models.Compute( db_compute = models.Compute(
compute_id=compute_create.compute_id, compute_id=compute_create.compute_id,
name=compute_create.name, name=compute_create.name,
protocol=compute_create.protocol.value, protocol=compute_create.protocol,
host=compute_create.host, host=compute_create.host,
port=compute_create.port, port=compute_create.port,
user=compute_create.user, user=compute_create.user,
password=password, password=compute_create.password.get_secret_value(),
) )
self._db_session.add(db_compute) self._db_session.add(db_compute)
await self._db_session.commit() await self._db_session.commit()

@ -125,7 +125,7 @@ class UsersRepository(BaseRepository):
async def authenticate_user(self, username: str, password: str) -> Optional[models.User]: async def authenticate_user(self, username: str, password: str) -> Optional[models.User]:
""" """
Authenticate an user. Authenticate user.
""" """
user = await self.get_user_by_username(username) user = await self.get_user_by_username(username)

@ -68,9 +68,7 @@ async def get_computes(app: FastAPI) -> List[dict]:
db_computes = await ComputesRepository(db_session).get_computes() db_computes = await ComputesRepository(db_session).get_computes()
for db_compute in db_computes: for db_compute in db_computes:
try: try:
compute = jsonable_encoder( compute = schemas.Compute.from_orm(db_compute)
schemas.Compute.from_orm(db_compute), exclude_unset=True, exclude={"created_at", "updated_at"}
)
except ValidationError as e: except ValidationError as e:
log.error(f"Could not load compute '{db_compute.compute_id}' from database: {e}") log.error(f"Could not load compute '{db_compute.compute_id}' from database: {e}")
continue continue

@ -41,9 +41,13 @@ class ComputeBase(BaseModel):
protocol: Protocol protocol: Protocol
host: str host: str
port: int = Field(..., gt=0, le=65535) port: int = Field(..., gt=0, le=65535)
user: Optional[str] = None user: str = None
password: Optional[SecretStr] = None
name: Optional[str] = None name: Optional[str] = None
class Config:
use_enum_values = True
class ComputeCreate(ComputeBase): class ComputeCreate(ComputeBase):
""" """
@ -51,7 +55,6 @@ class ComputeCreate(ComputeBase):
""" """
compute_id: Union[str, uuid.UUID] = None compute_id: Union[str, uuid.UUID] = None
password: Optional[SecretStr] = None
class Config: class Config:
schema_extra = { schema_extra = {
@ -102,6 +105,7 @@ class ComputeUpdate(ComputeBase):
protocol: Optional[Protocol] = None protocol: Optional[Protocol] = None
host: Optional[str] = None host: Optional[str] = None
port: Optional[int] = Field(None, gt=0, le=65535) port: Optional[int] = Field(None, gt=0, le=65535)
user: Optional[str] = None
password: Optional[SecretStr] = None password: Optional[SecretStr] = None
class Config: class Config:

@ -41,17 +41,17 @@ class ComputesService:
db_computes = await self._computes_repo.get_computes() db_computes = await self._computes_repo.get_computes()
return db_computes return db_computes
async def create_compute(self, compute_create: schemas.ComputeCreate) -> models.Compute: async def create_compute(self, compute_create: schemas.ComputeCreate, connect: bool = False) -> models.Compute:
if await self._computes_repo.get_compute(compute_create.compute_id): if await self._computes_repo.get_compute(compute_create.compute_id):
raise ControllerBadRequestError(f"Compute '{compute_create.compute_id}' is already registered") raise ControllerBadRequestError(f"Compute '{compute_create.compute_id}' is already registered")
db_compute = await self._computes_repo.create_compute(compute_create) db_compute = await self._computes_repo.create_compute(compute_create)
await self._controller.add_compute( compute = await self._controller.add_compute(
compute_id=str(db_compute.compute_id), compute_id=str(db_compute.compute_id),
connect=False, connect=connect,
**compute_create.dict(exclude_unset=True, exclude={"compute_id"}), **compute_create.dict(exclude_unset=True, exclude={"compute_id"}),
) )
self._controller.notification.controller_emit("compute.created", db_compute.asjson()) self._controller.notification.controller_emit("compute.created", compute.asdict())
return db_compute return db_compute
async def get_compute(self, compute_id: Union[str, UUID]) -> models.Compute: async def get_compute(self, compute_id: Union[str, UUID]) -> models.Compute:
@ -70,7 +70,7 @@ class ComputesService:
db_compute = await self._computes_repo.update_compute(compute_id, compute_update) db_compute = await self._computes_repo.update_compute(compute_id, compute_update)
if not db_compute: if not db_compute:
raise ControllerNotFoundError(f"Compute '{compute_id}' not found") raise ControllerNotFoundError(f"Compute '{compute_id}' not found")
self._controller.notification.controller_emit("compute.updated", db_compute.asjson()) self._controller.notification.controller_emit("compute.updated", compute.asdict())
return db_compute return db_compute
async def delete_compute(self, compute_id: Union[str, UUID]) -> None: async def delete_compute(self, compute_id: Union[str, UUID]) -> None:

@ -248,7 +248,7 @@ async def test_start(controller):
await controller.start() await controller.start()
#assert mock.called #assert mock.called
assert len(controller.computes) == 1 # Local compute is created assert len(controller.computes) == 1 # Local compute is created
assert controller.computes["local"].name == socket.gethostname() assert controller.computes["local"].name == f"{socket.gethostname()} (controller)"
@pytest.mark.asyncio @pytest.mark.asyncio

Loading…
Cancel
Save