1
0
mirror of https://github.com/GNS3/gns3-server synced 2025-01-26 16:01:23 +00:00

Secure websocket endpoints

This commit is contained in:
grossmj 2021-11-01 16:45:14 +10:30
parent 7ce5e19a6e
commit 741fc4a557
5 changed files with 82 additions and 21 deletions

View File

@ -16,8 +16,9 @@
import re import re
from fastapi import Request, Depends, HTTPException, status from fastapi import Request, Query, Depends, HTTPException, WebSocket, status
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from typing import Optional
from gns3server import schemas from gns3server import schemas
from gns3server.db.repositories.users import UsersRepository from gns3server.db.repositories.users import UsersRepository
@ -76,3 +77,53 @@ async def get_current_active_user(
) )
return current_user return current_user
async def get_current_active_user_from_websocket(
websocket: WebSocket,
token: str = Query(...),
user_repo: UsersRepository = Depends(get_repository(UsersRepository)),
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository))
) -> Optional[schemas.User]:
await websocket.accept()
try:
username = auth_service.get_username_from_token(token)
user = await user_repo.get_user_by_username(username)
if user is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"Could not validate credentials for '{username}'"
)
# Super admin is always authorized
if user.is_superadmin:
return user
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"'{username}' is not an active user"
)
# remove the prefix (e.g. "/v3") from URL path
path = re.sub(r"^/v[0-9]", "", websocket.url.path)
# there are no HTTP methods for web sockets, assuming "GET"...
authorized = await rbac_repo.check_user_is_authorized(user.user_id, "GET", path)
if not authorized:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"User is not authorized '{user.user_id}' on '{path}'",
headers={"WWW-Authenticate": "Bearer"},
)
return user
except HTTPException as e:
websocket_error = {"action": "log.error", "event": {"message": f"Could not authenticate while connecting to "
f"WebSocket: {e.detail}"}}
await websocket.send_json(websocket_error)
await websocket.close(code=1008)

View File

@ -15,13 +15,14 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from typing import Callable, Type from typing import Callable, Type
from fastapi import Depends, Request from fastapi import Depends
from starlette.requests import HTTPConnection
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from gns3server.db.repositories.base import BaseRepository from gns3server.db.repositories.base import BaseRepository
async def get_db_session(request: Request) -> AsyncSession: async def get_db_session(request: HTTPConnection) -> AsyncSession:
session = AsyncSession(request.app.state._db_engine, expire_on_commit=False) session = AsyncSession(request.app.state._db_engine, expire_on_commit=False)
try: try:

View File

@ -18,14 +18,14 @@
API routes for controller notifications. API routes for controller notifications.
""" """
from fastapi import APIRouter, Depends, Query, WebSocket, WebSocketDisconnect, HTTPException from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from websockets.exceptions import ConnectionClosed, WebSocketException from websockets.exceptions import ConnectionClosed, WebSocketException
from gns3server.services import auth_service
from gns3server.controller import Controller from gns3server.controller import Controller
from gns3server import schemas
from .dependencies.authentication import get_current_active_user from .dependencies.authentication import get_current_active_user, get_current_active_user_from_websocket
import logging import logging
@ -35,7 +35,7 @@ router = APIRouter()
@router.get("", dependencies=[Depends(get_current_active_user)]) @router.get("", dependencies=[Depends(get_current_active_user)])
async def http_notification() -> StreamingResponse: async def controller_http_notifications() -> StreamingResponse:
""" """
Receive controller notifications about the controller from HTTP stream. Receive controller notifications about the controller from HTTP stream.
""" """
@ -50,19 +50,16 @@ async def http_notification() -> StreamingResponse:
@router.websocket("/ws") @router.websocket("/ws")
async def notification_ws(websocket: WebSocket, token: str = Query(None)) -> None: async def controller_ws_notifications(
websocket: WebSocket,
current_user: schemas.User = Depends(get_current_active_user_from_websocket)
) -> None:
""" """
Receive project notifications about the controller from WebSocket. Receive project notifications about the controller from WebSocket.
""" """
await websocket.accept()
if token: if current_user is None:
try: return
username = auth_service.get_username_from_token(token)
except HTTPException:
log.error("Invalid token received")
await websocket.close(code=1008)
return
log.info(f"New client {websocket.client.host}:{websocket.client.port} has connected to controller WebSocket") log.info(f"New client {websocket.client.host}:{websocket.client.port} has connected to controller WebSocket")
try: try:

View File

@ -51,7 +51,7 @@ from gns3server.db.repositories.rbac import RbacRepository
from gns3server.db.repositories.templates import TemplatesRepository from gns3server.db.repositories.templates import TemplatesRepository
from gns3server.services.templates import TemplatesService from gns3server.services.templates import TemplatesService
from .dependencies.authentication import get_current_active_user from .dependencies.authentication import get_current_active_user, get_current_active_user_from_websocket
from .dependencies.database import get_repository from .dependencies.database import get_repository
responses = {404: {"model": schemas.ErrorMessage, "description": "Could not find project"}} responses = {404: {"model": schemas.ErrorMessage, "description": "Could not find project"}}
@ -214,7 +214,7 @@ async def load_project(path: str = Body(..., embed=True)) -> schemas.Project:
@router.get("/{project_id}/notifications") @router.get("/{project_id}/notifications")
async def notification(project_id: UUID) -> StreamingResponse: async def project_http_notifications(project_id: UUID) -> StreamingResponse:
""" """
Receive project notifications about the controller from HTTP stream. Receive project notifications about the controller from HTTP stream.
""" """
@ -245,14 +245,20 @@ async def notification(project_id: UUID) -> StreamingResponse:
@router.websocket("/{project_id}/notifications/ws") @router.websocket("/{project_id}/notifications/ws")
async def notification_ws(project_id: UUID, websocket: WebSocket) -> None: async def project_ws_notifications(
project_id: UUID,
websocket: WebSocket,
current_user: schemas.User = Depends(get_current_active_user_from_websocket)
) -> None:
""" """
Receive project notifications about the controller from WebSocket. Receive project notifications about the controller from WebSocket.
""" """
if current_user is None:
return
controller = Controller.instance() controller = Controller.instance()
project = controller.get_project(str(project_id)) project = controller.get_project(str(project_id))
await websocket.accept()
log.info(f"New client has connected to the notification stream for project ID '{project.id}' (WebSocket method)") log.info(f"New client has connected to the notification stream for project ID '{project.id}' (WebSocket method)")
try: try:

View File

@ -57,6 +57,12 @@ def create_default_roles(target, connection, **kw):
"path": "/", "path": "/",
"action": "ALLOW" "action": "ALLOW"
}, },
{
"description": "Allow to receive controller notifications",
"methods": ["GET"],
"path": "/notifications",
"action": "ALLOW"
},
{ {
"description": "Allow to create and list projects", "description": "Allow to create and list projects",
"methods": ["GET", "POST"], "methods": ["GET", "POST"],
@ -112,7 +118,7 @@ def add_permissions_to_role(target, connection, **kw):
role_id = result.first().role_id role_id = result.first().role_id
# add minimum required paths to the "User" role # add minimum required paths to the "User" role
for path in ("/projects", "/templates", "/computes/*", "/symbols/*"): for path in ("/notifications", "/projects", "/templates", "/computes/*", "/symbols/*"):
stmt = permissions_table.select().where(permissions_table.c.path == path) stmt = permissions_table.select().where(permissions_table.c.path == path)
result = connection.execute(stmt) result = connection.execute(stmt)
permission_id = result.first().permission_id permission_id = result.first().permission_id