mirror of
https://github.com/GNS3/gns3-server
synced 2025-02-20 03:52:00 +00:00
Add connect endpoint for computes
Param to connect to compute after creation Report compute unauthorized HTTP errors to client
This commit is contained in:
parent
36cf43475d
commit
10fdd8fcf4
@ -19,7 +19,7 @@ API routes for computes.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Response, status
|
from fastapi import APIRouter, Depends, Response, status
|
||||||
from typing import List, Union
|
from typing import List, Union, Optional
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from gns3server.controller import Controller
|
from gns3server.controller import Controller
|
||||||
@ -47,12 +47,25 @@ router = APIRouter(responses=responses)
|
|||||||
async def create_compute(
|
async def create_compute(
|
||||||
compute_create: schemas.ComputeCreate,
|
compute_create: schemas.ComputeCreate,
|
||||||
computes_repo: ComputesRepository = Depends(get_repository(ComputesRepository)),
|
computes_repo: ComputesRepository = Depends(get_repository(ComputesRepository)),
|
||||||
|
connect: Optional[bool] = False
|
||||||
) -> schemas.Compute:
|
) -> schemas.Compute:
|
||||||
"""
|
"""
|
||||||
Create a new compute on the controller.
|
Create a new compute on the controller.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return await ComputesService(computes_repo).create_compute(compute_create)
|
return await ComputesService(computes_repo).create_compute(compute_create, connect)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{compute_id}/connect", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
async def connect_compute(compute_id: Union[str, UUID]) -> Response:
|
||||||
|
"""
|
||||||
|
Connect to compute on the controller.
|
||||||
|
"""
|
||||||
|
|
||||||
|
compute = Controller.instance().get_compute(str(compute_id))
|
||||||
|
if not compute.connected:
|
||||||
|
await compute.connect(report_failed_connection=True)
|
||||||
|
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{compute_id}", response_model=schemas.Compute, response_model_exclude_unset=True)
|
@router.get("/{compute_id}", response_model=schemas.Compute, response_model_exclude_unset=True)
|
||||||
|
@ -21,8 +21,7 @@ FastAPI app
|
|||||||
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from fastapi import FastAPI, Request
|
from fastapi import FastAPI, Request, HTTPException
|
||||||
from starlette.exceptions import HTTPException as StarletteHTTPException
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from sqlalchemy.exc import SQLAlchemyError
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
@ -140,11 +139,12 @@ async def controller_bad_request_error_handler(request: Request, exc: Controller
|
|||||||
|
|
||||||
|
|
||||||
# make sure the content key is "message", not "detail" per default
|
# make sure the content key is "message", not "detail" per default
|
||||||
@app.exception_handler(StarletteHTTPException)
|
@app.exception_handler(HTTPException)
|
||||||
async def http_exception_handler(request: Request, exc: StarletteHTTPException):
|
async def http_exception_handler(request: Request, exc: HTTPException):
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=exc.status_code,
|
status_code=exc.status_code,
|
||||||
content={"message": exc.detail},
|
content={"message": exc.detail},
|
||||||
|
headers=exc.headers
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -74,10 +74,6 @@ class Controller:
|
|||||||
if host == "0.0.0.0":
|
if host == "0.0.0.0":
|
||||||
host = "127.0.0.1"
|
host = "127.0.0.1"
|
||||||
|
|
||||||
name = socket.gethostname()
|
|
||||||
if name == "gns3vm":
|
|
||||||
name = "Main server"
|
|
||||||
|
|
||||||
self._load_controller_settings()
|
self._load_controller_settings()
|
||||||
|
|
||||||
if server_config.enable_ssl:
|
if server_config.enable_ssl:
|
||||||
@ -93,7 +89,7 @@ class Controller:
|
|||||||
try:
|
try:
|
||||||
self._local_server = await self.add_compute(
|
self._local_server = await self.add_compute(
|
||||||
compute_id="local",
|
compute_id="local",
|
||||||
name=name,
|
name=f"{socket.gethostname()} (controller)",
|
||||||
protocol=protocol,
|
protocol=protocol,
|
||||||
host=host,
|
host=host,
|
||||||
console_host=console_host,
|
console_host=console_host,
|
||||||
@ -102,6 +98,7 @@ class Controller:
|
|||||||
password=server_config.compute_password,
|
password=server_config.compute_password,
|
||||||
force=True,
|
force=True,
|
||||||
connect=True,
|
connect=True,
|
||||||
|
wait_connection=False,
|
||||||
ssl_context=self._ssl_context,
|
ssl_context=self._ssl_context,
|
||||||
)
|
)
|
||||||
except ControllerError:
|
except ControllerError:
|
||||||
@ -113,7 +110,12 @@ class Controller:
|
|||||||
if computes:
|
if computes:
|
||||||
for c in computes:
|
for c in computes:
|
||||||
try:
|
try:
|
||||||
await self.add_compute(**c, connect=False)
|
#FIXME: Task exception was never retrieved
|
||||||
|
await self.add_compute(
|
||||||
|
compute_id=str(c.compute_id),
|
||||||
|
connect=False,
|
||||||
|
**c.dict(exclude_unset=True, exclude={"compute_id", "created_at", "updated_at"}),
|
||||||
|
)
|
||||||
except (ControllerError, KeyError):
|
except (ControllerError, KeyError):
|
||||||
pass # Skip not available servers at loading
|
pass # Skip not available servers at loading
|
||||||
|
|
||||||
@ -341,7 +343,7 @@ class Controller:
|
|||||||
os.makedirs(configs_path, exist_ok=True)
|
os.makedirs(configs_path, exist_ok=True)
|
||||||
return configs_path
|
return configs_path
|
||||||
|
|
||||||
async def add_compute(self, compute_id=None, name=None, force=False, connect=True, **kwargs):
|
async def add_compute(self, compute_id=None, name=None, force=False, connect=True, wait_connection=True, **kwargs):
|
||||||
"""
|
"""
|
||||||
Add a server to the dictionary of computes controlled by this controller
|
Add a server to the dictionary of computes controlled by this controller
|
||||||
|
|
||||||
@ -371,8 +373,11 @@ class Controller:
|
|||||||
self._computes[compute.id] = compute
|
self._computes[compute.id] = compute
|
||||||
# self.save()
|
# self.save()
|
||||||
if connect:
|
if connect:
|
||||||
# call compute.connect() later to give time to the controller to be fully started
|
if wait_connection:
|
||||||
asyncio.get_event_loop().call_later(1, lambda: asyncio.ensure_future(compute.connect()))
|
await compute.connect()
|
||||||
|
else:
|
||||||
|
# call compute.connect() later to give time to the controller to be fully started
|
||||||
|
asyncio.get_event_loop().call_later(1, lambda: asyncio.ensure_future(compute.connect()))
|
||||||
self.notification.controller_emit("compute.created", compute.asdict())
|
self.notification.controller_emit("compute.created", compute.asdict())
|
||||||
return compute
|
return compute
|
||||||
else:
|
else:
|
||||||
|
@ -154,6 +154,7 @@ class Compute:
|
|||||||
return self._interfaces_cache
|
return self._interfaces_cache
|
||||||
|
|
||||||
async def update(self, **kwargs):
|
async def update(self, **kwargs):
|
||||||
|
|
||||||
for kw in kwargs:
|
for kw in kwargs:
|
||||||
if kw not in ("user", "password"):
|
if kw not in ("user", "password"):
|
||||||
setattr(self, kw, kwargs[kw])
|
setattr(self, kw, kwargs[kw])
|
||||||
@ -373,7 +374,7 @@ class Compute:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@locking
|
@locking
|
||||||
async def connect(self):
|
async def connect(self, report_failed_connection=False):
|
||||||
"""
|
"""
|
||||||
Check if remote server is accessible
|
Check if remote server is accessible
|
||||||
"""
|
"""
|
||||||
@ -383,6 +384,8 @@ class Compute:
|
|||||||
log.info(f"Connecting to compute '{self._id}'")
|
log.info(f"Connecting to compute '{self._id}'")
|
||||||
response = await self._run_http_query("GET", "/capabilities")
|
response = await self._run_http_query("GET", "/capabilities")
|
||||||
except ComputeError as e:
|
except ComputeError as e:
|
||||||
|
if report_failed_connection:
|
||||||
|
raise
|
||||||
log.warning(f"Cannot connect to compute '{self._id}': {e}")
|
log.warning(f"Cannot connect to compute '{self._id}': {e}")
|
||||||
# Try to reconnect after 5 seconds if server unavailable only if not during tests (otherwise we create a ressource usage bomb)
|
# Try to reconnect after 5 seconds if server unavailable only if not during tests (otherwise we create a ressource usage bomb)
|
||||||
if not hasattr(sys, "_called_from_test") or not sys._called_from_test:
|
if not hasattr(sys, "_called_from_test") or not sys._called_from_test:
|
||||||
@ -491,7 +494,7 @@ class Compute:
|
|||||||
# Try to reconnect after 1 second if server unavailable only if not during tests (otherwise we create a ressources usage bomb)
|
# Try to reconnect after 1 second if server unavailable only if not during tests (otherwise we create a ressources usage bomb)
|
||||||
from gns3server.api.server import app
|
from gns3server.api.server import app
|
||||||
if not app.state.exiting and not hasattr(sys, "_called_from_test"):
|
if not app.state.exiting and not hasattr(sys, "_called_from_test"):
|
||||||
log.info(f"Reconnecting to to compute '{self._id}' WebSocket '{ws_url}'")
|
log.info(f"Reconnecting to compute '{self._id}' WebSocket '{ws_url}'")
|
||||||
asyncio.get_event_loop().call_later(1, lambda: asyncio.ensure_future(self.connect()))
|
asyncio.get_event_loop().call_later(1, lambda: asyncio.ensure_future(self.connect()))
|
||||||
|
|
||||||
self._cpu_usage_percent = None
|
self._cpu_usage_percent = None
|
||||||
@ -572,7 +575,7 @@ class Compute:
|
|||||||
msg = ""
|
msg = ""
|
||||||
|
|
||||||
if response.status == 401:
|
if response.status == 401:
|
||||||
raise ControllerUnauthorizedError(f"Invalid authentication for compute {self.id}")
|
raise ControllerUnauthorizedError(f"Invalid authentication for compute '{self.name}' [{self.id}]")
|
||||||
elif response.status == 403:
|
elif response.status == 403:
|
||||||
raise ControllerForbiddenError(msg)
|
raise ControllerForbiddenError(msg)
|
||||||
elif response.status == 404:
|
elif response.status == 404:
|
||||||
|
@ -52,18 +52,14 @@ class ComputesRepository(BaseRepository):
|
|||||||
|
|
||||||
async def create_compute(self, compute_create: schemas.ComputeCreate) -> models.Compute:
|
async def create_compute(self, compute_create: schemas.ComputeCreate) -> models.Compute:
|
||||||
|
|
||||||
password = compute_create.password
|
|
||||||
if password:
|
|
||||||
password = password.get_secret_value()
|
|
||||||
|
|
||||||
db_compute = models.Compute(
|
db_compute = models.Compute(
|
||||||
compute_id=compute_create.compute_id,
|
compute_id=compute_create.compute_id,
|
||||||
name=compute_create.name,
|
name=compute_create.name,
|
||||||
protocol=compute_create.protocol.value,
|
protocol=compute_create.protocol,
|
||||||
host=compute_create.host,
|
host=compute_create.host,
|
||||||
port=compute_create.port,
|
port=compute_create.port,
|
||||||
user=compute_create.user,
|
user=compute_create.user,
|
||||||
password=password,
|
password=compute_create.password.get_secret_value(),
|
||||||
)
|
)
|
||||||
self._db_session.add(db_compute)
|
self._db_session.add(db_compute)
|
||||||
await self._db_session.commit()
|
await self._db_session.commit()
|
||||||
|
@ -125,7 +125,7 @@ class UsersRepository(BaseRepository):
|
|||||||
|
|
||||||
async def authenticate_user(self, username: str, password: str) -> Optional[models.User]:
|
async def authenticate_user(self, username: str, password: str) -> Optional[models.User]:
|
||||||
"""
|
"""
|
||||||
Authenticate an user.
|
Authenticate user.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
user = await self.get_user_by_username(username)
|
user = await self.get_user_by_username(username)
|
||||||
|
@ -68,9 +68,7 @@ async def get_computes(app: FastAPI) -> List[dict]:
|
|||||||
db_computes = await ComputesRepository(db_session).get_computes()
|
db_computes = await ComputesRepository(db_session).get_computes()
|
||||||
for db_compute in db_computes:
|
for db_compute in db_computes:
|
||||||
try:
|
try:
|
||||||
compute = jsonable_encoder(
|
compute = schemas.Compute.from_orm(db_compute)
|
||||||
schemas.Compute.from_orm(db_compute), exclude_unset=True, exclude={"created_at", "updated_at"}
|
|
||||||
)
|
|
||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
log.error(f"Could not load compute '{db_compute.compute_id}' from database: {e}")
|
log.error(f"Could not load compute '{db_compute.compute_id}' from database: {e}")
|
||||||
continue
|
continue
|
||||||
|
@ -41,9 +41,13 @@ class ComputeBase(BaseModel):
|
|||||||
protocol: Protocol
|
protocol: Protocol
|
||||||
host: str
|
host: str
|
||||||
port: int = Field(..., gt=0, le=65535)
|
port: int = Field(..., gt=0, le=65535)
|
||||||
user: Optional[str] = None
|
user: str = None
|
||||||
|
password: Optional[SecretStr] = None
|
||||||
name: Optional[str] = None
|
name: Optional[str] = None
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
use_enum_values = True
|
||||||
|
|
||||||
|
|
||||||
class ComputeCreate(ComputeBase):
|
class ComputeCreate(ComputeBase):
|
||||||
"""
|
"""
|
||||||
@ -51,7 +55,6 @@ class ComputeCreate(ComputeBase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
compute_id: Union[str, uuid.UUID] = None
|
compute_id: Union[str, uuid.UUID] = None
|
||||||
password: Optional[SecretStr] = None
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
@ -102,6 +105,7 @@ class ComputeUpdate(ComputeBase):
|
|||||||
protocol: Optional[Protocol] = None
|
protocol: Optional[Protocol] = None
|
||||||
host: Optional[str] = None
|
host: Optional[str] = None
|
||||||
port: Optional[int] = Field(None, gt=0, le=65535)
|
port: Optional[int] = Field(None, gt=0, le=65535)
|
||||||
|
user: Optional[str] = None
|
||||||
password: Optional[SecretStr] = None
|
password: Optional[SecretStr] = None
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
|
@ -41,17 +41,17 @@ class ComputesService:
|
|||||||
db_computes = await self._computes_repo.get_computes()
|
db_computes = await self._computes_repo.get_computes()
|
||||||
return db_computes
|
return db_computes
|
||||||
|
|
||||||
async def create_compute(self, compute_create: schemas.ComputeCreate) -> models.Compute:
|
async def create_compute(self, compute_create: schemas.ComputeCreate, connect: bool = False) -> models.Compute:
|
||||||
|
|
||||||
if await self._computes_repo.get_compute(compute_create.compute_id):
|
if await self._computes_repo.get_compute(compute_create.compute_id):
|
||||||
raise ControllerBadRequestError(f"Compute '{compute_create.compute_id}' is already registered")
|
raise ControllerBadRequestError(f"Compute '{compute_create.compute_id}' is already registered")
|
||||||
db_compute = await self._computes_repo.create_compute(compute_create)
|
db_compute = await self._computes_repo.create_compute(compute_create)
|
||||||
await self._controller.add_compute(
|
compute = await self._controller.add_compute(
|
||||||
compute_id=str(db_compute.compute_id),
|
compute_id=str(db_compute.compute_id),
|
||||||
connect=False,
|
connect=connect,
|
||||||
**compute_create.dict(exclude_unset=True, exclude={"compute_id"}),
|
**compute_create.dict(exclude_unset=True, exclude={"compute_id"}),
|
||||||
)
|
)
|
||||||
self._controller.notification.controller_emit("compute.created", db_compute.asjson())
|
self._controller.notification.controller_emit("compute.created", compute.asdict())
|
||||||
return db_compute
|
return db_compute
|
||||||
|
|
||||||
async def get_compute(self, compute_id: Union[str, UUID]) -> models.Compute:
|
async def get_compute(self, compute_id: Union[str, UUID]) -> models.Compute:
|
||||||
@ -70,7 +70,7 @@ class ComputesService:
|
|||||||
db_compute = await self._computes_repo.update_compute(compute_id, compute_update)
|
db_compute = await self._computes_repo.update_compute(compute_id, compute_update)
|
||||||
if not db_compute:
|
if not db_compute:
|
||||||
raise ControllerNotFoundError(f"Compute '{compute_id}' not found")
|
raise ControllerNotFoundError(f"Compute '{compute_id}' not found")
|
||||||
self._controller.notification.controller_emit("compute.updated", db_compute.asjson())
|
self._controller.notification.controller_emit("compute.updated", compute.asdict())
|
||||||
return db_compute
|
return db_compute
|
||||||
|
|
||||||
async def delete_compute(self, compute_id: Union[str, UUID]) -> None:
|
async def delete_compute(self, compute_id: Union[str, UUID]) -> None:
|
||||||
|
@ -248,7 +248,7 @@ async def test_start(controller):
|
|||||||
await controller.start()
|
await controller.start()
|
||||||
#assert mock.called
|
#assert mock.called
|
||||||
assert len(controller.computes) == 1 # Local compute is created
|
assert len(controller.computes) == 1 # Local compute is created
|
||||||
assert controller.computes["local"].name == socket.gethostname()
|
assert controller.computes["local"].name == f"{socket.gethostname()} (controller)"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
Loading…
Reference in New Issue
Block a user