mirror of
https://github.com/GNS3/gns3-server
synced 2025-01-26 16:01:23 +00:00
Secure websocket endpoints
This commit is contained in:
parent
7ce5e19a6e
commit
741fc4a557
@ -16,8 +16,9 @@
|
|||||||
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from fastapi import Request, Depends, HTTPException, status
|
from fastapi import Request, Query, Depends, HTTPException, WebSocket, status
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from gns3server import schemas
|
from gns3server import schemas
|
||||||
from gns3server.db.repositories.users import UsersRepository
|
from gns3server.db.repositories.users import UsersRepository
|
||||||
@ -76,3 +77,53 @@ async def get_current_active_user(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return current_user
|
return current_user
|
||||||
|
|
||||||
|
|
||||||
|
async def get_current_active_user_from_websocket(
|
||||||
|
websocket: WebSocket,
|
||||||
|
token: str = Query(...),
|
||||||
|
user_repo: UsersRepository = Depends(get_repository(UsersRepository)),
|
||||||
|
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository))
|
||||||
|
) -> Optional[schemas.User]:
|
||||||
|
|
||||||
|
await websocket.accept()
|
||||||
|
|
||||||
|
try:
|
||||||
|
username = auth_service.get_username_from_token(token)
|
||||||
|
user = await user_repo.get_user_by_username(username)
|
||||||
|
|
||||||
|
if user is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail=f"Could not validate credentials for '{username}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Super admin is always authorized
|
||||||
|
if user.is_superadmin:
|
||||||
|
return user
|
||||||
|
|
||||||
|
if not user.is_active:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail=f"'{username}' is not an active user"
|
||||||
|
)
|
||||||
|
|
||||||
|
# remove the prefix (e.g. "/v3") from URL path
|
||||||
|
path = re.sub(r"^/v[0-9]", "", websocket.url.path)
|
||||||
|
|
||||||
|
# there are no HTTP methods for web sockets, assuming "GET"...
|
||||||
|
authorized = await rbac_repo.check_user_is_authorized(user.user_id, "GET", path)
|
||||||
|
if not authorized:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail=f"User is not authorized '{user.user_id}' on '{path}'",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
|
||||||
|
return user
|
||||||
|
|
||||||
|
except HTTPException as e:
|
||||||
|
websocket_error = {"action": "log.error", "event": {"message": f"Could not authenticate while connecting to "
|
||||||
|
f"WebSocket: {e.detail}"}}
|
||||||
|
await websocket.send_json(websocket_error)
|
||||||
|
await websocket.close(code=1008)
|
||||||
|
@ -15,13 +15,14 @@
|
|||||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
from typing import Callable, Type
|
from typing import Callable, Type
|
||||||
from fastapi import Depends, Request
|
from fastapi import Depends
|
||||||
|
from starlette.requests import HTTPConnection
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from gns3server.db.repositories.base import BaseRepository
|
from gns3server.db.repositories.base import BaseRepository
|
||||||
|
|
||||||
|
|
||||||
async def get_db_session(request: Request) -> AsyncSession:
|
async def get_db_session(request: HTTPConnection) -> AsyncSession:
|
||||||
|
|
||||||
session = AsyncSession(request.app.state._db_engine, expire_on_commit=False)
|
session = AsyncSession(request.app.state._db_engine, expire_on_commit=False)
|
||||||
try:
|
try:
|
||||||
|
@ -18,14 +18,14 @@
|
|||||||
API routes for controller notifications.
|
API routes for controller notifications.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Query, WebSocket, WebSocketDisconnect, HTTPException
|
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from websockets.exceptions import ConnectionClosed, WebSocketException
|
from websockets.exceptions import ConnectionClosed, WebSocketException
|
||||||
|
|
||||||
from gns3server.services import auth_service
|
|
||||||
from gns3server.controller import Controller
|
from gns3server.controller import Controller
|
||||||
|
from gns3server import schemas
|
||||||
|
|
||||||
from .dependencies.authentication import get_current_active_user
|
from .dependencies.authentication import get_current_active_user, get_current_active_user_from_websocket
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
@ -35,7 +35,7 @@ router = APIRouter()
|
|||||||
|
|
||||||
|
|
||||||
@router.get("", dependencies=[Depends(get_current_active_user)])
|
@router.get("", dependencies=[Depends(get_current_active_user)])
|
||||||
async def http_notification() -> StreamingResponse:
|
async def controller_http_notifications() -> StreamingResponse:
|
||||||
"""
|
"""
|
||||||
Receive controller notifications about the controller from HTTP stream.
|
Receive controller notifications about the controller from HTTP stream.
|
||||||
"""
|
"""
|
||||||
@ -50,19 +50,16 @@ async def http_notification() -> StreamingResponse:
|
|||||||
|
|
||||||
|
|
||||||
@router.websocket("/ws")
|
@router.websocket("/ws")
|
||||||
async def notification_ws(websocket: WebSocket, token: str = Query(None)) -> None:
|
async def controller_ws_notifications(
|
||||||
|
websocket: WebSocket,
|
||||||
|
current_user: schemas.User = Depends(get_current_active_user_from_websocket)
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Receive project notifications about the controller from WebSocket.
|
Receive project notifications about the controller from WebSocket.
|
||||||
"""
|
"""
|
||||||
await websocket.accept()
|
|
||||||
|
|
||||||
if token:
|
if current_user is None:
|
||||||
try:
|
return
|
||||||
username = auth_service.get_username_from_token(token)
|
|
||||||
except HTTPException:
|
|
||||||
log.error("Invalid token received")
|
|
||||||
await websocket.close(code=1008)
|
|
||||||
return
|
|
||||||
|
|
||||||
log.info(f"New client {websocket.client.host}:{websocket.client.port} has connected to controller WebSocket")
|
log.info(f"New client {websocket.client.host}:{websocket.client.port} has connected to controller WebSocket")
|
||||||
try:
|
try:
|
||||||
|
@ -51,7 +51,7 @@ from gns3server.db.repositories.rbac import RbacRepository
|
|||||||
from gns3server.db.repositories.templates import TemplatesRepository
|
from gns3server.db.repositories.templates import TemplatesRepository
|
||||||
from gns3server.services.templates import TemplatesService
|
from gns3server.services.templates import TemplatesService
|
||||||
|
|
||||||
from .dependencies.authentication import get_current_active_user
|
from .dependencies.authentication import get_current_active_user, get_current_active_user_from_websocket
|
||||||
from .dependencies.database import get_repository
|
from .dependencies.database import get_repository
|
||||||
|
|
||||||
responses = {404: {"model": schemas.ErrorMessage, "description": "Could not find project"}}
|
responses = {404: {"model": schemas.ErrorMessage, "description": "Could not find project"}}
|
||||||
@ -214,7 +214,7 @@ async def load_project(path: str = Body(..., embed=True)) -> schemas.Project:
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/{project_id}/notifications")
|
@router.get("/{project_id}/notifications")
|
||||||
async def notification(project_id: UUID) -> StreamingResponse:
|
async def project_http_notifications(project_id: UUID) -> StreamingResponse:
|
||||||
"""
|
"""
|
||||||
Receive project notifications about the controller from HTTP stream.
|
Receive project notifications about the controller from HTTP stream.
|
||||||
"""
|
"""
|
||||||
@ -245,14 +245,20 @@ async def notification(project_id: UUID) -> StreamingResponse:
|
|||||||
|
|
||||||
|
|
||||||
@router.websocket("/{project_id}/notifications/ws")
|
@router.websocket("/{project_id}/notifications/ws")
|
||||||
async def notification_ws(project_id: UUID, websocket: WebSocket) -> None:
|
async def project_ws_notifications(
|
||||||
|
project_id: UUID,
|
||||||
|
websocket: WebSocket,
|
||||||
|
current_user: schemas.User = Depends(get_current_active_user_from_websocket)
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Receive project notifications about the controller from WebSocket.
|
Receive project notifications about the controller from WebSocket.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if current_user is None:
|
||||||
|
return
|
||||||
|
|
||||||
controller = Controller.instance()
|
controller = Controller.instance()
|
||||||
project = controller.get_project(str(project_id))
|
project = controller.get_project(str(project_id))
|
||||||
await websocket.accept()
|
|
||||||
|
|
||||||
log.info(f"New client has connected to the notification stream for project ID '{project.id}' (WebSocket method)")
|
log.info(f"New client has connected to the notification stream for project ID '{project.id}' (WebSocket method)")
|
||||||
try:
|
try:
|
||||||
|
@ -57,6 +57,12 @@ def create_default_roles(target, connection, **kw):
|
|||||||
"path": "/",
|
"path": "/",
|
||||||
"action": "ALLOW"
|
"action": "ALLOW"
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"description": "Allow to receive controller notifications",
|
||||||
|
"methods": ["GET"],
|
||||||
|
"path": "/notifications",
|
||||||
|
"action": "ALLOW"
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"description": "Allow to create and list projects",
|
"description": "Allow to create and list projects",
|
||||||
"methods": ["GET", "POST"],
|
"methods": ["GET", "POST"],
|
||||||
@ -112,7 +118,7 @@ def add_permissions_to_role(target, connection, **kw):
|
|||||||
role_id = result.first().role_id
|
role_id = result.first().role_id
|
||||||
|
|
||||||
# add minimum required paths to the "User" role
|
# add minimum required paths to the "User" role
|
||||||
for path in ("/projects", "/templates", "/computes/*", "/symbols/*"):
|
for path in ("/notifications", "/projects", "/templates", "/computes/*", "/symbols/*"):
|
||||||
stmt = permissions_table.select().where(permissions_table.c.path == path)
|
stmt = permissions_table.select().where(permissions_table.c.path == path)
|
||||||
result = connection.execute(stmt)
|
result = connection.execute(stmt)
|
||||||
permission_id = result.first().permission_id
|
permission_id = result.first().permission_id
|
||||||
|
Loading…
Reference in New Issue
Block a user