From 741fc4a557859e4ad5c47ccf00876051880eb1f0 Mon Sep 17 00:00:00 2001 From: grossmj Date: Mon, 1 Nov 2021 16:45:14 +1030 Subject: [PATCH] Secure websocket endpoints --- .../controller/dependencies/authentication.py | 53 ++++++++++++++++++- .../controller/dependencies/database.py | 5 +- .../api/routes/controller/notifications.py | 23 ++++---- gns3server/api/routes/controller/projects.py | 14 +++-- gns3server/db/models/permissions.py | 8 ++- 5 files changed, 82 insertions(+), 21 deletions(-) diff --git a/gns3server/api/routes/controller/dependencies/authentication.py b/gns3server/api/routes/controller/dependencies/authentication.py index 0ca08e21..c1647c4b 100644 --- a/gns3server/api/routes/controller/dependencies/authentication.py +++ b/gns3server/api/routes/controller/dependencies/authentication.py @@ -16,8 +16,9 @@ import re -from fastapi import Request, Depends, HTTPException, status +from fastapi import Request, Query, Depends, HTTPException, WebSocket, status from fastapi.security import OAuth2PasswordBearer +from typing import Optional from gns3server import schemas from gns3server.db.repositories.users import UsersRepository @@ -76,3 +77,53 @@ async def get_current_active_user( ) return current_user + + +async def get_current_active_user_from_websocket( + websocket: WebSocket, + token: str = Query(...), + user_repo: UsersRepository = Depends(get_repository(UsersRepository)), + rbac_repo: RbacRepository = Depends(get_repository(RbacRepository)) +) -> Optional[schemas.User]: + + await websocket.accept() + + try: + username = auth_service.get_username_from_token(token) + user = await user_repo.get_user_by_username(username) + + if user is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=f"Could not validate credentials for '{username}'" + ) + + # Super admin is always authorized + if user.is_superadmin: + return user + + if not user.is_active: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=f"'{username}' is not an active user" + ) + + # remove the prefix (e.g. "/v3") from URL path + path = re.sub(r"^/v[0-9]", "", websocket.url.path) + + # there are no HTTP methods for web sockets, assuming "GET"... + authorized = await rbac_repo.check_user_is_authorized(user.user_id, "GET", path) + if not authorized: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=f"User is not authorized '{user.user_id}' on '{path}'", + headers={"WWW-Authenticate": "Bearer"}, + ) + + return user + + except HTTPException as e: + websocket_error = {"action": "log.error", "event": {"message": f"Could not authenticate while connecting to " + f"WebSocket: {e.detail}"}} + await websocket.send_json(websocket_error) + await websocket.close(code=1008) diff --git a/gns3server/api/routes/controller/dependencies/database.py b/gns3server/api/routes/controller/dependencies/database.py index b003cd02..fe2226c7 100644 --- a/gns3server/api/routes/controller/dependencies/database.py +++ b/gns3server/api/routes/controller/dependencies/database.py @@ -15,13 +15,14 @@ # along with this program. If not, see . from typing import Callable, Type -from fastapi import Depends, Request +from fastapi import Depends +from starlette.requests import HTTPConnection from sqlalchemy.ext.asyncio import AsyncSession from gns3server.db.repositories.base import BaseRepository -async def get_db_session(request: Request) -> AsyncSession: +async def get_db_session(request: HTTPConnection) -> AsyncSession: session = AsyncSession(request.app.state._db_engine, expire_on_commit=False) try: diff --git a/gns3server/api/routes/controller/notifications.py b/gns3server/api/routes/controller/notifications.py index 79e99328..624a2f80 100644 --- a/gns3server/api/routes/controller/notifications.py +++ b/gns3server/api/routes/controller/notifications.py @@ -18,14 +18,14 @@ API routes for controller notifications. """ -from fastapi import APIRouter, Depends, Query, WebSocket, WebSocketDisconnect, HTTPException +from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect from fastapi.responses import StreamingResponse from websockets.exceptions import ConnectionClosed, WebSocketException -from gns3server.services import auth_service from gns3server.controller import Controller +from gns3server import schemas -from .dependencies.authentication import get_current_active_user +from .dependencies.authentication import get_current_active_user, get_current_active_user_from_websocket import logging @@ -35,7 +35,7 @@ router = APIRouter() @router.get("", dependencies=[Depends(get_current_active_user)]) -async def http_notification() -> StreamingResponse: +async def controller_http_notifications() -> StreamingResponse: """ Receive controller notifications about the controller from HTTP stream. """ @@ -50,19 +50,16 @@ async def http_notification() -> StreamingResponse: @router.websocket("/ws") -async def notification_ws(websocket: WebSocket, token: str = Query(None)) -> None: +async def controller_ws_notifications( + websocket: WebSocket, + current_user: schemas.User = Depends(get_current_active_user_from_websocket) +) -> None: """ Receive project notifications about the controller from WebSocket. """ - await websocket.accept() - if token: - try: - username = auth_service.get_username_from_token(token) - except HTTPException: - log.error("Invalid token received") - await websocket.close(code=1008) - return + if current_user is None: + return log.info(f"New client {websocket.client.host}:{websocket.client.port} has connected to controller WebSocket") try: diff --git a/gns3server/api/routes/controller/projects.py b/gns3server/api/routes/controller/projects.py index 757faf6f..4050b24b 100644 --- a/gns3server/api/routes/controller/projects.py +++ b/gns3server/api/routes/controller/projects.py @@ -51,7 +51,7 @@ from gns3server.db.repositories.rbac import RbacRepository from gns3server.db.repositories.templates import TemplatesRepository from gns3server.services.templates import TemplatesService -from .dependencies.authentication import get_current_active_user +from .dependencies.authentication import get_current_active_user, get_current_active_user_from_websocket from .dependencies.database import get_repository responses = {404: {"model": schemas.ErrorMessage, "description": "Could not find project"}} @@ -214,7 +214,7 @@ async def load_project(path: str = Body(..., embed=True)) -> schemas.Project: @router.get("/{project_id}/notifications") -async def notification(project_id: UUID) -> StreamingResponse: +async def project_http_notifications(project_id: UUID) -> StreamingResponse: """ Receive project notifications about the controller from HTTP stream. """ @@ -245,14 +245,20 @@ async def notification(project_id: UUID) -> StreamingResponse: @router.websocket("/{project_id}/notifications/ws") -async def notification_ws(project_id: UUID, websocket: WebSocket) -> None: +async def project_ws_notifications( + project_id: UUID, + websocket: WebSocket, + current_user: schemas.User = Depends(get_current_active_user_from_websocket) +) -> None: """ Receive project notifications about the controller from WebSocket. """ + if current_user is None: + return + controller = Controller.instance() project = controller.get_project(str(project_id)) - await websocket.accept() log.info(f"New client has connected to the notification stream for project ID '{project.id}' (WebSocket method)") try: diff --git a/gns3server/db/models/permissions.py b/gns3server/db/models/permissions.py index 8be3d669..f7344e31 100644 --- a/gns3server/db/models/permissions.py +++ b/gns3server/db/models/permissions.py @@ -57,6 +57,12 @@ def create_default_roles(target, connection, **kw): "path": "/", "action": "ALLOW" }, + { + "description": "Allow to receive controller notifications", + "methods": ["GET"], + "path": "/notifications", + "action": "ALLOW" + }, { "description": "Allow to create and list projects", "methods": ["GET", "POST"], @@ -112,7 +118,7 @@ def add_permissions_to_role(target, connection, **kw): role_id = result.first().role_id # add minimum required paths to the "User" role - for path in ("/projects", "/templates", "/computes/*", "/symbols/*"): + for path in ("/notifications", "/projects", "/templates", "/computes/*", "/symbols/*"): stmt = permissions_table.select().where(permissions_table.c.path == path) result = connection.execute(stmt) permission_id = result.first().permission_id