mirror of
https://github.com/GNS3/gns3-server
synced 2024-11-12 19:38:57 +00:00
Add connect endpoint for computes
Param to connect to compute after creation Report compute unauthorized HTTP errors to client
This commit is contained in:
parent
36cf43475d
commit
10fdd8fcf4
@ -19,7 +19,7 @@ API routes for computes.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, Response, status
|
||||
from typing import List, Union
|
||||
from typing import List, Union, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from gns3server.controller import Controller
|
||||
@ -47,12 +47,25 @@ router = APIRouter(responses=responses)
|
||||
async def create_compute(
|
||||
compute_create: schemas.ComputeCreate,
|
||||
computes_repo: ComputesRepository = Depends(get_repository(ComputesRepository)),
|
||||
connect: Optional[bool] = False
|
||||
) -> schemas.Compute:
|
||||
"""
|
||||
Create a new compute on the controller.
|
||||
"""
|
||||
|
||||
return await ComputesService(computes_repo).create_compute(compute_create)
|
||||
return await ComputesService(computes_repo).create_compute(compute_create, connect)
|
||||
|
||||
|
||||
@router.post("/{compute_id}/connect", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def connect_compute(compute_id: Union[str, UUID]) -> Response:
|
||||
"""
|
||||
Connect to compute on the controller.
|
||||
"""
|
||||
|
||||
compute = Controller.instance().get_compute(str(compute_id))
|
||||
if not compute.connected:
|
||||
await compute.connect(report_failed_connection=True)
|
||||
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||
|
||||
|
||||
@router.get("/{compute_id}", response_model=schemas.Compute, response_model_exclude_unset=True)
|
||||
|
@ -21,8 +21,7 @@ FastAPI app
|
||||
|
||||
import time
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||
from fastapi import FastAPI, Request, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
@ -140,11 +139,12 @@ async def controller_bad_request_error_handler(request: Request, exc: Controller
|
||||
|
||||
|
||||
# make sure the content key is "message", not "detail" per default
|
||||
@app.exception_handler(StarletteHTTPException)
|
||||
async def http_exception_handler(request: Request, exc: StarletteHTTPException):
|
||||
@app.exception_handler(HTTPException)
|
||||
async def http_exception_handler(request: Request, exc: HTTPException):
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content={"message": exc.detail},
|
||||
headers=exc.headers
|
||||
)
|
||||
|
||||
|
||||
|
@ -74,10 +74,6 @@ class Controller:
|
||||
if host == "0.0.0.0":
|
||||
host = "127.0.0.1"
|
||||
|
||||
name = socket.gethostname()
|
||||
if name == "gns3vm":
|
||||
name = "Main server"
|
||||
|
||||
self._load_controller_settings()
|
||||
|
||||
if server_config.enable_ssl:
|
||||
@ -93,7 +89,7 @@ class Controller:
|
||||
try:
|
||||
self._local_server = await self.add_compute(
|
||||
compute_id="local",
|
||||
name=name,
|
||||
name=f"{socket.gethostname()} (controller)",
|
||||
protocol=protocol,
|
||||
host=host,
|
||||
console_host=console_host,
|
||||
@ -102,6 +98,7 @@ class Controller:
|
||||
password=server_config.compute_password,
|
||||
force=True,
|
||||
connect=True,
|
||||
wait_connection=False,
|
||||
ssl_context=self._ssl_context,
|
||||
)
|
||||
except ControllerError:
|
||||
@ -113,7 +110,12 @@ class Controller:
|
||||
if computes:
|
||||
for c in computes:
|
||||
try:
|
||||
await self.add_compute(**c, connect=False)
|
||||
#FIXME: Task exception was never retrieved
|
||||
await self.add_compute(
|
||||
compute_id=str(c.compute_id),
|
||||
connect=False,
|
||||
**c.dict(exclude_unset=True, exclude={"compute_id", "created_at", "updated_at"}),
|
||||
)
|
||||
except (ControllerError, KeyError):
|
||||
pass # Skip not available servers at loading
|
||||
|
||||
@ -341,7 +343,7 @@ class Controller:
|
||||
os.makedirs(configs_path, exist_ok=True)
|
||||
return configs_path
|
||||
|
||||
async def add_compute(self, compute_id=None, name=None, force=False, connect=True, **kwargs):
|
||||
async def add_compute(self, compute_id=None, name=None, force=False, connect=True, wait_connection=True, **kwargs):
|
||||
"""
|
||||
Add a server to the dictionary of computes controlled by this controller
|
||||
|
||||
@ -371,8 +373,11 @@ class Controller:
|
||||
self._computes[compute.id] = compute
|
||||
# self.save()
|
||||
if connect:
|
||||
# call compute.connect() later to give time to the controller to be fully started
|
||||
asyncio.get_event_loop().call_later(1, lambda: asyncio.ensure_future(compute.connect()))
|
||||
if wait_connection:
|
||||
await compute.connect()
|
||||
else:
|
||||
# call compute.connect() later to give time to the controller to be fully started
|
||||
asyncio.get_event_loop().call_later(1, lambda: asyncio.ensure_future(compute.connect()))
|
||||
self.notification.controller_emit("compute.created", compute.asdict())
|
||||
return compute
|
||||
else:
|
||||
|
@ -154,6 +154,7 @@ class Compute:
|
||||
return self._interfaces_cache
|
||||
|
||||
async def update(self, **kwargs):
|
||||
|
||||
for kw in kwargs:
|
||||
if kw not in ("user", "password"):
|
||||
setattr(self, kw, kwargs[kw])
|
||||
@ -373,7 +374,7 @@ class Compute:
|
||||
pass
|
||||
|
||||
@locking
|
||||
async def connect(self):
|
||||
async def connect(self, report_failed_connection=False):
|
||||
"""
|
||||
Check if remote server is accessible
|
||||
"""
|
||||
@ -383,6 +384,8 @@ class Compute:
|
||||
log.info(f"Connecting to compute '{self._id}'")
|
||||
response = await self._run_http_query("GET", "/capabilities")
|
||||
except ComputeError as e:
|
||||
if report_failed_connection:
|
||||
raise
|
||||
log.warning(f"Cannot connect to compute '{self._id}': {e}")
|
||||
# Try to reconnect after 5 seconds if server unavailable only if not during tests (otherwise we create a ressource usage bomb)
|
||||
if not hasattr(sys, "_called_from_test") or not sys._called_from_test:
|
||||
@ -491,7 +494,7 @@ class Compute:
|
||||
# Try to reconnect after 1 second if server unavailable only if not during tests (otherwise we create a ressources usage bomb)
|
||||
from gns3server.api.server import app
|
||||
if not app.state.exiting and not hasattr(sys, "_called_from_test"):
|
||||
log.info(f"Reconnecting to to compute '{self._id}' WebSocket '{ws_url}'")
|
||||
log.info(f"Reconnecting to compute '{self._id}' WebSocket '{ws_url}'")
|
||||
asyncio.get_event_loop().call_later(1, lambda: asyncio.ensure_future(self.connect()))
|
||||
|
||||
self._cpu_usage_percent = None
|
||||
@ -572,7 +575,7 @@ class Compute:
|
||||
msg = ""
|
||||
|
||||
if response.status == 401:
|
||||
raise ControllerUnauthorizedError(f"Invalid authentication for compute {self.id}")
|
||||
raise ControllerUnauthorizedError(f"Invalid authentication for compute '{self.name}' [{self.id}]")
|
||||
elif response.status == 403:
|
||||
raise ControllerForbiddenError(msg)
|
||||
elif response.status == 404:
|
||||
|
@ -52,18 +52,14 @@ class ComputesRepository(BaseRepository):
|
||||
|
||||
async def create_compute(self, compute_create: schemas.ComputeCreate) -> models.Compute:
|
||||
|
||||
password = compute_create.password
|
||||
if password:
|
||||
password = password.get_secret_value()
|
||||
|
||||
db_compute = models.Compute(
|
||||
compute_id=compute_create.compute_id,
|
||||
name=compute_create.name,
|
||||
protocol=compute_create.protocol.value,
|
||||
protocol=compute_create.protocol,
|
||||
host=compute_create.host,
|
||||
port=compute_create.port,
|
||||
user=compute_create.user,
|
||||
password=password,
|
||||
password=compute_create.password.get_secret_value(),
|
||||
)
|
||||
self._db_session.add(db_compute)
|
||||
await self._db_session.commit()
|
||||
|
@ -125,7 +125,7 @@ class UsersRepository(BaseRepository):
|
||||
|
||||
async def authenticate_user(self, username: str, password: str) -> Optional[models.User]:
|
||||
"""
|
||||
Authenticate an user.
|
||||
Authenticate user.
|
||||
"""
|
||||
|
||||
user = await self.get_user_by_username(username)
|
||||
|
@ -68,9 +68,7 @@ async def get_computes(app: FastAPI) -> List[dict]:
|
||||
db_computes = await ComputesRepository(db_session).get_computes()
|
||||
for db_compute in db_computes:
|
||||
try:
|
||||
compute = jsonable_encoder(
|
||||
schemas.Compute.from_orm(db_compute), exclude_unset=True, exclude={"created_at", "updated_at"}
|
||||
)
|
||||
compute = schemas.Compute.from_orm(db_compute)
|
||||
except ValidationError as e:
|
||||
log.error(f"Could not load compute '{db_compute.compute_id}' from database: {e}")
|
||||
continue
|
||||
|
@ -41,9 +41,13 @@ class ComputeBase(BaseModel):
|
||||
protocol: Protocol
|
||||
host: str
|
||||
port: int = Field(..., gt=0, le=65535)
|
||||
user: Optional[str] = None
|
||||
user: str = None
|
||||
password: Optional[SecretStr] = None
|
||||
name: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
use_enum_values = True
|
||||
|
||||
|
||||
class ComputeCreate(ComputeBase):
|
||||
"""
|
||||
@ -51,7 +55,6 @@ class ComputeCreate(ComputeBase):
|
||||
"""
|
||||
|
||||
compute_id: Union[str, uuid.UUID] = None
|
||||
password: Optional[SecretStr] = None
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
@ -102,6 +105,7 @@ class ComputeUpdate(ComputeBase):
|
||||
protocol: Optional[Protocol] = None
|
||||
host: Optional[str] = None
|
||||
port: Optional[int] = Field(None, gt=0, le=65535)
|
||||
user: Optional[str] = None
|
||||
password: Optional[SecretStr] = None
|
||||
|
||||
class Config:
|
||||
|
@ -41,17 +41,17 @@ class ComputesService:
|
||||
db_computes = await self._computes_repo.get_computes()
|
||||
return db_computes
|
||||
|
||||
async def create_compute(self, compute_create: schemas.ComputeCreate) -> models.Compute:
|
||||
async def create_compute(self, compute_create: schemas.ComputeCreate, connect: bool = False) -> models.Compute:
|
||||
|
||||
if await self._computes_repo.get_compute(compute_create.compute_id):
|
||||
raise ControllerBadRequestError(f"Compute '{compute_create.compute_id}' is already registered")
|
||||
db_compute = await self._computes_repo.create_compute(compute_create)
|
||||
await self._controller.add_compute(
|
||||
compute = await self._controller.add_compute(
|
||||
compute_id=str(db_compute.compute_id),
|
||||
connect=False,
|
||||
connect=connect,
|
||||
**compute_create.dict(exclude_unset=True, exclude={"compute_id"}),
|
||||
)
|
||||
self._controller.notification.controller_emit("compute.created", db_compute.asjson())
|
||||
self._controller.notification.controller_emit("compute.created", compute.asdict())
|
||||
return db_compute
|
||||
|
||||
async def get_compute(self, compute_id: Union[str, UUID]) -> models.Compute:
|
||||
@ -70,7 +70,7 @@ class ComputesService:
|
||||
db_compute = await self._computes_repo.update_compute(compute_id, compute_update)
|
||||
if not db_compute:
|
||||
raise ControllerNotFoundError(f"Compute '{compute_id}' not found")
|
||||
self._controller.notification.controller_emit("compute.updated", db_compute.asjson())
|
||||
self._controller.notification.controller_emit("compute.updated", compute.asdict())
|
||||
return db_compute
|
||||
|
||||
async def delete_compute(self, compute_id: Union[str, UUID]) -> None:
|
||||
|
@ -248,7 +248,7 @@ async def test_start(controller):
|
||||
await controller.start()
|
||||
#assert mock.called
|
||||
assert len(controller.computes) == 1 # Local compute is created
|
||||
assert controller.computes["local"].name == socket.gethostname()
|
||||
assert controller.computes["local"].name == f"{socket.gethostname()} (controller)"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
Loading…
Reference in New Issue
Block a user