1
0
mirror of https://github.com/GNS3/gns3-server synced 2024-11-12 19:38:57 +00:00

Add connect endpoint for computes

Param to connect to compute after creation
Report compute unauthorized HTTP errors to client
This commit is contained in:
grossmj 2021-12-24 13:05:39 +10:30
parent 36cf43475d
commit 10fdd8fcf4
10 changed files with 55 additions and 36 deletions

View File

@ -19,7 +19,7 @@ API routes for computes.
"""
from fastapi import APIRouter, Depends, Response, status
from typing import List, Union
from typing import List, Union, Optional
from uuid import UUID
from gns3server.controller import Controller
@ -47,12 +47,25 @@ router = APIRouter(responses=responses)
async def create_compute(
compute_create: schemas.ComputeCreate,
computes_repo: ComputesRepository = Depends(get_repository(ComputesRepository)),
connect: Optional[bool] = False
) -> schemas.Compute:
"""
Create a new compute on the controller.
"""
return await ComputesService(computes_repo).create_compute(compute_create)
return await ComputesService(computes_repo).create_compute(compute_create, connect)
@router.post("/{compute_id}/connect", status_code=status.HTTP_204_NO_CONTENT)
async def connect_compute(compute_id: Union[str, UUID]) -> Response:
"""
Connect to compute on the controller.
"""
compute = Controller.instance().get_compute(str(compute_id))
if not compute.connected:
await compute.connect(report_failed_connection=True)
return Response(status_code=status.HTTP_204_NO_CONTENT)
@router.get("/{compute_id}", response_model=schemas.Compute, response_model_exclude_unset=True)

View File

@ -21,8 +21,7 @@ FastAPI app
import time
from fastapi import FastAPI, Request
from starlette.exceptions import HTTPException as StarletteHTTPException
from fastapi import FastAPI, Request, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from sqlalchemy.exc import SQLAlchemyError
@ -140,11 +139,12 @@ async def controller_bad_request_error_handler(request: Request, exc: Controller
# make sure the content key is "message", not "detail" per default
@app.exception_handler(StarletteHTTPException)
async def http_exception_handler(request: Request, exc: StarletteHTTPException):
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
return JSONResponse(
status_code=exc.status_code,
content={"message": exc.detail},
headers=exc.headers
)

View File

@ -74,10 +74,6 @@ class Controller:
if host == "0.0.0.0":
host = "127.0.0.1"
name = socket.gethostname()
if name == "gns3vm":
name = "Main server"
self._load_controller_settings()
if server_config.enable_ssl:
@ -93,7 +89,7 @@ class Controller:
try:
self._local_server = await self.add_compute(
compute_id="local",
name=name,
name=f"{socket.gethostname()} (controller)",
protocol=protocol,
host=host,
console_host=console_host,
@ -102,6 +98,7 @@ class Controller:
password=server_config.compute_password,
force=True,
connect=True,
wait_connection=False,
ssl_context=self._ssl_context,
)
except ControllerError:
@ -113,7 +110,12 @@ class Controller:
if computes:
for c in computes:
try:
await self.add_compute(**c, connect=False)
#FIXME: Task exception was never retrieved
await self.add_compute(
compute_id=str(c.compute_id),
connect=False,
**c.dict(exclude_unset=True, exclude={"compute_id", "created_at", "updated_at"}),
)
except (ControllerError, KeyError):
pass # Skip not available servers at loading
@ -341,7 +343,7 @@ class Controller:
os.makedirs(configs_path, exist_ok=True)
return configs_path
async def add_compute(self, compute_id=None, name=None, force=False, connect=True, **kwargs):
async def add_compute(self, compute_id=None, name=None, force=False, connect=True, wait_connection=True, **kwargs):
"""
Add a server to the dictionary of computes controlled by this controller
@ -371,8 +373,11 @@ class Controller:
self._computes[compute.id] = compute
# self.save()
if connect:
# call compute.connect() later to give time to the controller to be fully started
asyncio.get_event_loop().call_later(1, lambda: asyncio.ensure_future(compute.connect()))
if wait_connection:
await compute.connect()
else:
# call compute.connect() later to give time to the controller to be fully started
asyncio.get_event_loop().call_later(1, lambda: asyncio.ensure_future(compute.connect()))
self.notification.controller_emit("compute.created", compute.asdict())
return compute
else:

View File

@ -154,6 +154,7 @@ class Compute:
return self._interfaces_cache
async def update(self, **kwargs):
for kw in kwargs:
if kw not in ("user", "password"):
setattr(self, kw, kwargs[kw])
@ -373,7 +374,7 @@ class Compute:
pass
@locking
async def connect(self):
async def connect(self, report_failed_connection=False):
"""
Check if remote server is accessible
"""
@ -383,6 +384,8 @@ class Compute:
log.info(f"Connecting to compute '{self._id}'")
response = await self._run_http_query("GET", "/capabilities")
except ComputeError as e:
if report_failed_connection:
raise
log.warning(f"Cannot connect to compute '{self._id}': {e}")
# Try to reconnect after 5 seconds if server unavailable only if not during tests (otherwise we create a ressource usage bomb)
if not hasattr(sys, "_called_from_test") or not sys._called_from_test:
@ -491,7 +494,7 @@ class Compute:
# Try to reconnect after 1 second if server unavailable only if not during tests (otherwise we create a ressources usage bomb)
from gns3server.api.server import app
if not app.state.exiting and not hasattr(sys, "_called_from_test"):
log.info(f"Reconnecting to to compute '{self._id}' WebSocket '{ws_url}'")
log.info(f"Reconnecting to compute '{self._id}' WebSocket '{ws_url}'")
asyncio.get_event_loop().call_later(1, lambda: asyncio.ensure_future(self.connect()))
self._cpu_usage_percent = None
@ -572,7 +575,7 @@ class Compute:
msg = ""
if response.status == 401:
raise ControllerUnauthorizedError(f"Invalid authentication for compute {self.id}")
raise ControllerUnauthorizedError(f"Invalid authentication for compute '{self.name}' [{self.id}]")
elif response.status == 403:
raise ControllerForbiddenError(msg)
elif response.status == 404:

View File

@ -52,18 +52,14 @@ class ComputesRepository(BaseRepository):
async def create_compute(self, compute_create: schemas.ComputeCreate) -> models.Compute:
password = compute_create.password
if password:
password = password.get_secret_value()
db_compute = models.Compute(
compute_id=compute_create.compute_id,
name=compute_create.name,
protocol=compute_create.protocol.value,
protocol=compute_create.protocol,
host=compute_create.host,
port=compute_create.port,
user=compute_create.user,
password=password,
password=compute_create.password.get_secret_value(),
)
self._db_session.add(db_compute)
await self._db_session.commit()

View File

@ -125,7 +125,7 @@ class UsersRepository(BaseRepository):
async def authenticate_user(self, username: str, password: str) -> Optional[models.User]:
"""
Authenticate an user.
Authenticate user.
"""
user = await self.get_user_by_username(username)

View File

@ -68,9 +68,7 @@ async def get_computes(app: FastAPI) -> List[dict]:
db_computes = await ComputesRepository(db_session).get_computes()
for db_compute in db_computes:
try:
compute = jsonable_encoder(
schemas.Compute.from_orm(db_compute), exclude_unset=True, exclude={"created_at", "updated_at"}
)
compute = schemas.Compute.from_orm(db_compute)
except ValidationError as e:
log.error(f"Could not load compute '{db_compute.compute_id}' from database: {e}")
continue

View File

@ -41,9 +41,13 @@ class ComputeBase(BaseModel):
protocol: Protocol
host: str
port: int = Field(..., gt=0, le=65535)
user: Optional[str] = None
user: str = None
password: Optional[SecretStr] = None
name: Optional[str] = None
class Config:
use_enum_values = True
class ComputeCreate(ComputeBase):
"""
@ -51,7 +55,6 @@ class ComputeCreate(ComputeBase):
"""
compute_id: Union[str, uuid.UUID] = None
password: Optional[SecretStr] = None
class Config:
schema_extra = {
@ -102,6 +105,7 @@ class ComputeUpdate(ComputeBase):
protocol: Optional[Protocol] = None
host: Optional[str] = None
port: Optional[int] = Field(None, gt=0, le=65535)
user: Optional[str] = None
password: Optional[SecretStr] = None
class Config:

View File

@ -41,17 +41,17 @@ class ComputesService:
db_computes = await self._computes_repo.get_computes()
return db_computes
async def create_compute(self, compute_create: schemas.ComputeCreate) -> models.Compute:
async def create_compute(self, compute_create: schemas.ComputeCreate, connect: bool = False) -> models.Compute:
if await self._computes_repo.get_compute(compute_create.compute_id):
raise ControllerBadRequestError(f"Compute '{compute_create.compute_id}' is already registered")
db_compute = await self._computes_repo.create_compute(compute_create)
await self._controller.add_compute(
compute = await self._controller.add_compute(
compute_id=str(db_compute.compute_id),
connect=False,
connect=connect,
**compute_create.dict(exclude_unset=True, exclude={"compute_id"}),
)
self._controller.notification.controller_emit("compute.created", db_compute.asjson())
self._controller.notification.controller_emit("compute.created", compute.asdict())
return db_compute
async def get_compute(self, compute_id: Union[str, UUID]) -> models.Compute:
@ -70,7 +70,7 @@ class ComputesService:
db_compute = await self._computes_repo.update_compute(compute_id, compute_update)
if not db_compute:
raise ControllerNotFoundError(f"Compute '{compute_id}' not found")
self._controller.notification.controller_emit("compute.updated", db_compute.asjson())
self._controller.notification.controller_emit("compute.updated", compute.asdict())
return db_compute
async def delete_compute(self, compute_id: Union[str, UUID]) -> None:

View File

@ -248,7 +248,7 @@ async def test_start(controller):
await controller.start()
#assert mock.called
assert len(controller.computes) == 1 # Local compute is created
assert controller.computes["local"].name == socket.gethostname()
assert controller.computes["local"].name == f"{socket.gethostname()} (controller)"
@pytest.mark.asyncio