Secure websocket endpoints

pull/1992/head
grossmj 3 years ago
parent 7ce5e19a6e
commit 741fc4a557

@ -16,8 +16,9 @@
import re
from fastapi import Request, Depends, HTTPException, status
from fastapi import Request, Query, Depends, HTTPException, WebSocket, status
from fastapi.security import OAuth2PasswordBearer
from typing import Optional
from gns3server import schemas
from gns3server.db.repositories.users import UsersRepository
@ -76,3 +77,53 @@ async def get_current_active_user(
)
return current_user
async def get_current_active_user_from_websocket(
websocket: WebSocket,
token: str = Query(...),
user_repo: UsersRepository = Depends(get_repository(UsersRepository)),
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository))
) -> Optional[schemas.User]:
await websocket.accept()
try:
username = auth_service.get_username_from_token(token)
user = await user_repo.get_user_by_username(username)
if user is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"Could not validate credentials for '{username}'"
)
# Super admin is always authorized
if user.is_superadmin:
return user
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"'{username}' is not an active user"
)
# remove the prefix (e.g. "/v3") from URL path
path = re.sub(r"^/v[0-9]", "", websocket.url.path)
# there are no HTTP methods for web sockets, assuming "GET"...
authorized = await rbac_repo.check_user_is_authorized(user.user_id, "GET", path)
if not authorized:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"User is not authorized '{user.user_id}' on '{path}'",
headers={"WWW-Authenticate": "Bearer"},
)
return user
except HTTPException as e:
websocket_error = {"action": "log.error", "event": {"message": f"Could not authenticate while connecting to "
f"WebSocket: {e.detail}"}}
await websocket.send_json(websocket_error)
await websocket.close(code=1008)

@ -15,13 +15,14 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from typing import Callable, Type
from fastapi import Depends, Request
from fastapi import Depends
from starlette.requests import HTTPConnection
from sqlalchemy.ext.asyncio import AsyncSession
from gns3server.db.repositories.base import BaseRepository
async def get_db_session(request: Request) -> AsyncSession:
async def get_db_session(request: HTTPConnection) -> AsyncSession:
session = AsyncSession(request.app.state._db_engine, expire_on_commit=False)
try:

@ -18,14 +18,14 @@
API routes for controller notifications.
"""
from fastapi import APIRouter, Depends, Query, WebSocket, WebSocketDisconnect, HTTPException
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect
from fastapi.responses import StreamingResponse
from websockets.exceptions import ConnectionClosed, WebSocketException
from gns3server.services import auth_service
from gns3server.controller import Controller
from gns3server import schemas
from .dependencies.authentication import get_current_active_user
from .dependencies.authentication import get_current_active_user, get_current_active_user_from_websocket
import logging
@ -35,7 +35,7 @@ router = APIRouter()
@router.get("", dependencies=[Depends(get_current_active_user)])
async def http_notification() -> StreamingResponse:
async def controller_http_notifications() -> StreamingResponse:
"""
Receive controller notifications about the controller from HTTP stream.
"""
@ -50,19 +50,16 @@ async def http_notification() -> StreamingResponse:
@router.websocket("/ws")
async def notification_ws(websocket: WebSocket, token: str = Query(None)) -> None:
async def controller_ws_notifications(
websocket: WebSocket,
current_user: schemas.User = Depends(get_current_active_user_from_websocket)
) -> None:
"""
Receive project notifications about the controller from WebSocket.
"""
await websocket.accept()
if token:
try:
username = auth_service.get_username_from_token(token)
except HTTPException:
log.error("Invalid token received")
await websocket.close(code=1008)
return
if current_user is None:
return
log.info(f"New client {websocket.client.host}:{websocket.client.port} has connected to controller WebSocket")
try:

@ -51,7 +51,7 @@ from gns3server.db.repositories.rbac import RbacRepository
from gns3server.db.repositories.templates import TemplatesRepository
from gns3server.services.templates import TemplatesService
from .dependencies.authentication import get_current_active_user
from .dependencies.authentication import get_current_active_user, get_current_active_user_from_websocket
from .dependencies.database import get_repository
responses = {404: {"model": schemas.ErrorMessage, "description": "Could not find project"}}
@ -214,7 +214,7 @@ async def load_project(path: str = Body(..., embed=True)) -> schemas.Project:
@router.get("/{project_id}/notifications")
async def notification(project_id: UUID) -> StreamingResponse:
async def project_http_notifications(project_id: UUID) -> StreamingResponse:
"""
Receive project notifications about the controller from HTTP stream.
"""
@ -245,14 +245,20 @@ async def notification(project_id: UUID) -> StreamingResponse:
@router.websocket("/{project_id}/notifications/ws")
async def notification_ws(project_id: UUID, websocket: WebSocket) -> None:
async def project_ws_notifications(
project_id: UUID,
websocket: WebSocket,
current_user: schemas.User = Depends(get_current_active_user_from_websocket)
) -> None:
"""
Receive project notifications about the controller from WebSocket.
"""
if current_user is None:
return
controller = Controller.instance()
project = controller.get_project(str(project_id))
await websocket.accept()
log.info(f"New client has connected to the notification stream for project ID '{project.id}' (WebSocket method)")
try:

@ -57,6 +57,12 @@ def create_default_roles(target, connection, **kw):
"path": "/",
"action": "ALLOW"
},
{
"description": "Allow to receive controller notifications",
"methods": ["GET"],
"path": "/notifications",
"action": "ALLOW"
},
{
"description": "Allow to create and list projects",
"methods": ["GET", "POST"],
@ -112,7 +118,7 @@ def add_permissions_to_role(target, connection, **kw):
role_id = result.first().role_id
# add minimum required paths to the "User" role
for path in ("/projects", "/templates", "/computes/*", "/symbols/*"):
for path in ("/notifications", "/projects", "/templates", "/computes/*", "/symbols/*"):
stmt = permissions_table.select().where(permissions_table.c.path == path)
result = connection.execute(stmt)
permission_id = result.first().permission_id

Loading…
Cancel
Save