diff --git a/gns3server/api/routes/controller/__init__.py b/gns3server/api/routes/controller/__init__.py index 9a2dc526..2e12010e 100644 --- a/gns3server/api/routes/controller/__init__.py +++ b/gns3server/api/routes/controller/__init__.py @@ -30,6 +30,8 @@ from . import symbols from . import templates from . import users from . import groups +from . import roles +from . import permissions from .dependencies.authentication import get_current_active_user @@ -46,31 +48,36 @@ router.include_router( ) router.include_router( - appliances.router, + roles.router, dependencies=[Depends(get_current_active_user)], - prefix="/appliances", - tags=["Appliances"] + prefix="/roles", + tags=["Roles"] ) router.include_router( - computes.router, + permissions.router, dependencies=[Depends(get_current_active_user)], - prefix="/computes", - tags=["Computes"] + prefix="/permissions", + tags=["Permissions"] ) router.include_router( - drawings.router, + templates.router, dependencies=[Depends(get_current_active_user)], - prefix="/projects/{project_id}/drawings", - tags=["Drawings"]) + tags=["Templates"] +) router.include_router( - gns3vm.router, - deprecated=True, + projects.router, dependencies=[Depends(get_current_active_user)], - prefix="/gns3vm", - tags=["GNS3 VM"] + prefix="/projects", + tags=["Projects"]) + +router.include_router( + nodes.router, + dependencies=[Depends(get_current_active_user)], + prefix="/projects/{project_id}/nodes", + tags=["Nodes"] ) router.include_router( @@ -81,10 +88,28 @@ router.include_router( ) router.include_router( - nodes.router, + drawings.router, dependencies=[Depends(get_current_active_user)], - prefix="/projects/{project_id}/nodes", - tags=["Nodes"] + prefix="/projects/{project_id}/drawings", + tags=["Drawings"]) + +router.include_router( + symbols.router, + dependencies=[Depends(get_current_active_user)], + prefix="/symbols", tags=["Symbols"] +) + +router.include_router( + snapshots.router, + dependencies=[Depends(get_current_active_user)], + prefix="/projects/{project_id}/snapshots", + tags=["Snapshots"]) + +router.include_router( + computes.router, + dependencies=[Depends(get_current_active_user)], + prefix="/computes", + tags=["Computes"] ) router.include_router( @@ -94,25 +119,16 @@ router.include_router( tags=["Notifications"]) router.include_router( - projects.router, + appliances.router, dependencies=[Depends(get_current_active_user)], - prefix="/projects", - tags=["Projects"]) - -router.include_router( - snapshots.router, - dependencies=[Depends(get_current_active_user)], - prefix="/projects/{project_id}/snapshots", - tags=["Snapshots"]) - -router.include_router( - symbols.router, - dependencies=[Depends(get_current_active_user)], - prefix="/symbols", tags=["Symbols"] + prefix="/appliances", + tags=["Appliances"] ) router.include_router( - templates.router, + gns3vm.router, + deprecated=True, dependencies=[Depends(get_current_active_user)], - tags=["Templates"] + prefix="/gns3vm", + tags=["GNS3 VM"] ) diff --git a/gns3server/api/routes/controller/dependencies/authentication.py b/gns3server/api/routes/controller/dependencies/authentication.py index 60ee5de2..0af058d7 100644 --- a/gns3server/api/routes/controller/dependencies/authentication.py +++ b/gns3server/api/routes/controller/dependencies/authentication.py @@ -14,14 +14,15 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . +import re -from fastapi import Depends, HTTPException, status +from fastapi import Request, Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer from gns3server import schemas from gns3server.db.repositories.users import UsersRepository +from gns3server.db.repositories.rbac import RbacRepository from gns3server.services import auth_service - from .database import get_repository oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/v3/users/login") @@ -42,7 +43,11 @@ async def get_user_from_token( return user -async def get_current_active_user(current_user: schemas.User = Depends(get_user_from_token)) -> schemas.User: +async def get_current_active_user( + request: Request, + current_user: schemas.User = Depends(get_user_from_token), + rbac_repo: RbacRepository = Depends(get_repository(RbacRepository)) +) -> schemas.User: # Super admin is always authorized if current_user.is_superadmin: @@ -54,4 +59,24 @@ async def get_current_active_user(current_user: schemas.User = Depends(get_user_ detail="Not an active user", headers={"WWW-Authenticate": "Bearer"}, ) + + # remove the prefix (e.g. "/v3") from URL path + match = re.search(r"^(/v[0-9]+).*", request.url.path) + if match: + path = request.url.path[len(match.group(1)):] + else: + path = request.url.path + + # special case: always authorize access to the "/users/me" endpoint + if path == "/users/me": + return current_user + + authorized = await rbac_repo.check_user_is_authorized(current_user.user_id, request.method, path) + if not authorized: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=f"User is not authorized '{current_user.user_id}' on {request.method} '{path}'", + headers={"WWW-Authenticate": "Bearer"}, + ) + return current_user diff --git a/gns3server/api/routes/controller/groups.py b/gns3server/api/routes/controller/groups.py index b3a96f66..20b17ae4 100644 --- a/gns3server/api/routes/controller/groups.py +++ b/gns3server/api/routes/controller/groups.py @@ -31,6 +31,7 @@ from gns3server.controller.controller_error import ( ) from gns3server.db.repositories.users import UsersRepository +from gns3server.db.repositories.rbac import RbacRepository from .dependencies.database import get_repository import logging @@ -98,8 +99,8 @@ async def update_user_group( if not user_group: raise ControllerNotFoundError(f"User group '{user_group_id}' not found") - if not user_group.is_updatable: - raise ControllerForbiddenError(f"User group '{user_group_id}' cannot be updated") + if user_group.builtin: + raise ControllerForbiddenError(f"Built-in user group '{user_group_id}' cannot be updated") return await users_repo.update_user_group(user_group_id, user_group_update) @@ -120,8 +121,8 @@ async def delete_user_group( if not user_group: raise ControllerNotFoundError(f"User group '{user_group_id}' not found") - if not user_group.is_updatable: - raise ControllerForbiddenError(f"User group '{user_group_id}' cannot be deleted") + if user_group.builtin: + raise ControllerForbiddenError(f"Built-in user group '{user_group_id}' cannot be deleted") success = await users_repo.delete_user_group(user_group_id) if not success: @@ -182,3 +183,61 @@ async def remove_member_from_group( user_group = await users_repo.remove_member_from_user_group(user_group_id, user) if not user_group: raise ControllerNotFoundError(f"User group '{user_group_id}' not found") + + +@router.get("/{user_group_id}/roles", response_model=List[schemas.Role]) +async def get_user_group_roles( + user_group_id: UUID, + users_repo: UsersRepository = Depends(get_repository(UsersRepository)) +) -> List[schemas.Role]: + """ + Get all user group roles. + """ + + return await users_repo.get_user_group_roles(user_group_id) + + +@router.put( + "/{user_group_id}/roles/{role_id}", + status_code=status.HTTP_204_NO_CONTENT +) +async def add_role_to_group( + user_group_id: UUID, + role_id: UUID, + users_repo: UsersRepository = Depends(get_repository(UsersRepository)), + rbac_repo: RbacRepository = Depends(get_repository(RbacRepository)) +) -> None: + """ + Add role to an user group. + """ + + role = await rbac_repo.get_role(role_id) + if not role: + raise ControllerNotFoundError(f"Role '{role_id}' not found") + + user_group = await users_repo.add_role_to_user_group(user_group_id, role) + if not user_group: + raise ControllerNotFoundError(f"User group '{user_group_id}' not found") + + +@router.delete( + "/{user_group_id}/roles/{role_id}", + status_code=status.HTTP_204_NO_CONTENT +) +async def remove_role_from_group( + user_group_id: UUID, + role_id: UUID, + users_repo: UsersRepository = Depends(get_repository(UsersRepository)), + rbac_repo: RbacRepository = Depends(get_repository(RbacRepository)) +) -> None: + """ + Remove role from an user group. + """ + + role = await rbac_repo.get_role(role_id) + if not role: + raise ControllerNotFoundError(f"Role '{role_id}' not found") + + user_group = await users_repo.remove_role_from_user_group(user_group_id, role) + if not user_group: + raise ControllerNotFoundError(f"User group '{user_group_id}' not found") diff --git a/gns3server/api/routes/controller/permissions.py b/gns3server/api/routes/controller/permissions.py new file mode 100644 index 00000000..466a2707 --- /dev/null +++ b/gns3server/api/routes/controller/permissions.py @@ -0,0 +1,117 @@ +#!/usr/bin/env python +# +# Copyright (C) 2021 GNS3 Technologies Inc. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +""" +API routes for permissions. +""" + +from fastapi import APIRouter, Depends, status +from uuid import UUID +from typing import List + +from gns3server import schemas +from gns3server.controller.controller_error import ( + ControllerBadRequestError, + ControllerNotFoundError, + ControllerForbiddenError, +) + +from gns3server.db.repositories.rbac import RbacRepository +from .dependencies.database import get_repository + +import logging + +log = logging.getLogger(__name__) + +router = APIRouter() + + +@router.get("", response_model=List[schemas.Permission]) +async def get_permissions( + rbac_repo: RbacRepository = Depends(get_repository(RbacRepository)) +) -> List[schemas.Permission]: + """ + Get all permissions. + """ + + return await rbac_repo.get_permissions() + + +@router.post("", response_model=schemas.Permission, status_code=status.HTTP_201_CREATED) +async def create_permission( + permission_create: schemas.PermissionCreate, + rbac_repo: RbacRepository = Depends(get_repository(RbacRepository)) +) -> schemas.Permission: + """ + Create a new permission. + """ + + if await rbac_repo.check_permission_exists(permission_create): + raise ControllerBadRequestError(f"Permission '{permission_create.methods} {permission_create.path} " + f"{permission_create.action}' already exists") + + return await rbac_repo.create_permission(permission_create) + + +@router.get("/{permission_id}", response_model=schemas.Permission) +async def get_permission( + permission_id: UUID, + rbac_repo: RbacRepository = Depends(get_repository(RbacRepository)), +) -> schemas.Permission: + """ + Get a permission. + """ + + permission = await rbac_repo.get_permission(permission_id) + if not permission: + raise ControllerNotFoundError(f"Permission '{permission_id}' not found") + return permission + + +@router.put("/{permission_id}", response_model=schemas.Permission) +async def update_permission( + permission_id: UUID, + permission_update: schemas.PermissionUpdate, + rbac_repo: RbacRepository = Depends(get_repository(RbacRepository)) +) -> schemas.Permission: + """ + Update a permission. + """ + + permission = await rbac_repo.get_permission(permission_id) + if not permission: + raise ControllerNotFoundError(f"Permission '{permission_id}' not found") + + return await rbac_repo.update_permission(permission_id, permission_update) + + +@router.delete("/{permission_id}", status_code=status.HTTP_204_NO_CONTENT) +async def delete_permission( + permission_id: UUID, + rbac_repo: RbacRepository = Depends(get_repository(RbacRepository)), +) -> None: + """ + Delete a permission. + """ + + permission = await rbac_repo.get_permission(permission_id) + if not permission: + raise ControllerNotFoundError(f"Permission '{permission_id}' not found") + + success = await rbac_repo.delete_permission(permission_id) + if not success: + raise ControllerNotFoundError(f"Permission '{permission_id}' could not be deleted") diff --git a/gns3server/api/routes/controller/projects.py b/gns3server/api/routes/controller/projects.py index 5e8022c8..55252c41 100644 --- a/gns3server/api/routes/controller/projects.py +++ b/gns3server/api/routes/controller/projects.py @@ -47,6 +47,10 @@ from gns3server.controller.export_project import export_project as export_contro from gns3server.utils.asyncio import aiozipstream from gns3server.utils.path import is_safe_path from gns3server.config import Config +from gns3server.db.repositories.rbac import RbacRepository + +from .dependencies.authentication import get_current_active_user +from .dependencies.database import get_repository responses = {404: {"model": schemas.ErrorMessage, "description": "Could not find project"}} @@ -66,13 +70,25 @@ CHUNK_SIZE = 1024 * 8 # 8KB @router.get("", response_model=List[schemas.Project], response_model_exclude_unset=True) -def get_projects() -> List[schemas.Project]: +async def get_projects( + current_user: schemas.User = Depends(get_current_active_user), + rbac_repo: RbacRepository = Depends(get_repository(RbacRepository)) +) -> List[schemas.Project]: """ Return all projects. """ controller = Controller.instance() - return [p.asdict() for p in controller.projects.values()] + if current_user.is_superadmin: + return [p.asdict() for p in controller.projects.values()] + else: + user_projects = [] + for project in controller.projects.values(): + authorized = await rbac_repo.check_user_is_authorized( + current_user.user_id, "GET", f"/projects/{project.id}") + if authorized: + user_projects.append(project.asdict()) + return user_projects @router.post( @@ -82,13 +98,18 @@ def get_projects() -> List[schemas.Project]: response_model_exclude_unset=True, responses={409: {"model": schemas.ErrorMessage, "description": "Could not create project"}}, ) -async def create_project(project_data: schemas.ProjectCreate) -> schemas.Project: +async def create_project( + project_data: schemas.ProjectCreate, + current_user: schemas.User = Depends(get_current_active_user), + rbac_repo: RbacRepository = Depends(get_repository(RbacRepository)) +) -> schemas.Project: """ Create a new project. """ controller = Controller.instance() project = await controller.add_project(**jsonable_encoder(project_data, exclude_unset=True)) + await rbac_repo.add_permission_to_user_with_path(current_user.user_id, f"/projects/{project.id}/*") return project.asdict() @@ -115,7 +136,10 @@ async def update_project( @router.delete("/{project_id}", status_code=status.HTTP_204_NO_CONTENT) -async def delete_project(project: Project = Depends(dep_project)) -> None: +async def delete_project( + project: Project = Depends(dep_project), + rbac_repo: RbacRepository = Depends(get_repository(RbacRepository)) +) -> None: """ Delete a project. """ @@ -123,6 +147,7 @@ async def delete_project(project: Project = Depends(dep_project)) -> None: controller = Controller.instance() await project.delete() controller.remove_project(project) + await rbac_repo.delete_all_permissions_with_path(f"/projects/{project.id}") @router.get("/{project_id}/stats") @@ -344,7 +369,9 @@ async def import_project( ) async def duplicate_project( project_data: schemas.ProjectDuplicate, - project: Project = Depends(dep_project) + project: Project = Depends(dep_project), + current_user: schemas.User = Depends(get_current_active_user), + rbac_repo: RbacRepository = Depends(get_repository(RbacRepository)) ) -> schemas.Project: """ Duplicate a project. @@ -361,6 +388,7 @@ async def duplicate_project( new_project = await project.duplicate( name=project_data.name, location=location, reset_mac_addresses=reset_mac_addresses ) + await rbac_repo.add_permission_to_user_with_path(current_user.user_id, f"/projects/{new_project.id}/*") return new_project.asdict() diff --git a/gns3server/api/routes/controller/roles.py b/gns3server/api/routes/controller/roles.py new file mode 100644 index 00000000..c96feb64 --- /dev/null +++ b/gns3server/api/routes/controller/roles.py @@ -0,0 +1,178 @@ +#!/usr/bin/env python +# +# Copyright (C) 2021 GNS3 Technologies Inc. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +""" +API routes for roles. +""" + +from fastapi import APIRouter, Depends, status +from uuid import UUID +from typing import List + +from gns3server import schemas +from gns3server.controller.controller_error import ( + ControllerBadRequestError, + ControllerNotFoundError, + ControllerForbiddenError, +) + +from gns3server.db.repositories.rbac import RbacRepository +from .dependencies.database import get_repository + +import logging + +log = logging.getLogger(__name__) + +router = APIRouter() + + +@router.get("", response_model=List[schemas.Role]) +async def get_roles( + rbac_repo: RbacRepository = Depends(get_repository(RbacRepository)) +) -> List[schemas.Role]: + """ + Get all roles. + """ + + return await rbac_repo.get_roles() + + +@router.post("", response_model=schemas.Role, status_code=status.HTTP_201_CREATED) +async def create_role( + role_create: schemas.RoleCreate, + rbac_repo: RbacRepository = Depends(get_repository(RbacRepository)) +) -> schemas.Role: + """ + Create a new role. + """ + + if await rbac_repo.get_role_by_name(role_create.name): + raise ControllerBadRequestError(f"Role '{role_create.name}' already exists") + + return await rbac_repo.create_role(role_create) + + +@router.get("/{role_id}", response_model=schemas.Role) +async def get_role( + role_id: UUID, + rbac_repo: RbacRepository = Depends(get_repository(RbacRepository)), +) -> schemas.Role: + """ + Get a role. + """ + + role = await rbac_repo.get_role(role_id) + if not role: + raise ControllerNotFoundError(f"Role '{role_id}' not found") + return role + + +@router.put("/{role_id}", response_model=schemas.Role) +async def update_role( + role_id: UUID, + role_update: schemas.RoleUpdate, + rbac_repo: RbacRepository = Depends(get_repository(RbacRepository)) +) -> schemas.Role: + """ + Update a role. + """ + + role = await rbac_repo.get_role(role_id) + if not role: + raise ControllerNotFoundError(f"Role '{role_id}' not found") + + if role.builtin: + raise ControllerForbiddenError(f"Built-in role '{role_id}' cannot be updated") + + return await rbac_repo.update_role(role_id, role_update) + + +@router.delete("/{role_id}", status_code=status.HTTP_204_NO_CONTENT) +async def delete_role( + role_id: UUID, + rbac_repo: RbacRepository = Depends(get_repository(RbacRepository)), +) -> None: + """ + Delete a role. + """ + + role = await rbac_repo.get_role(role_id) + if not role: + raise ControllerNotFoundError(f"Role '{role_id}' not found") + + if role.builtin: + raise ControllerForbiddenError(f"Built-in role '{role_id}' cannot be deleted") + + success = await rbac_repo.delete_role(role_id) + if not success: + raise ControllerNotFoundError(f"Role '{role_id}' could not be deleted") + + +@router.get("/{role_id}/permissions", response_model=List[schemas.Permission]) +async def get_role_permissions( + role_id: UUID, + rbac_repo: RbacRepository = Depends(get_repository(RbacRepository)) +) -> List[schemas.Permission]: + """ + Get all role permissions. + """ + + return await rbac_repo.get_role_permissions(role_id) + + +@router.put( + "/{role_id}/permissions/{permission_id}", + status_code=status.HTTP_204_NO_CONTENT +) +async def add_permission_to_role( + role_id: UUID, + permission_id: UUID, + rbac_repo: RbacRepository = Depends(get_repository(RbacRepository)) +) -> None: + """ + Add a permission to a role. + """ + + permission = await rbac_repo.get_permission(permission_id) + if not permission: + raise ControllerNotFoundError(f"Permission '{permission_id}' not found") + + role = await rbac_repo.add_permission_to_role(role_id, permission) + if not role: + raise ControllerNotFoundError(f"Role '{role_id}' not found") + + +@router.delete( + "/{role_id}/permissions/{permission_id}", + status_code=status.HTTP_204_NO_CONTENT +) +async def remove_permission_from_role( + role_id: UUID, + permission_id: UUID, + rbac_repo: RbacRepository = Depends(get_repository(RbacRepository)), +) -> None: + """ + Remove member from an user group. + """ + + permission = await rbac_repo.get_permission(permission_id) + if not permission: + raise ControllerNotFoundError(f"Permission '{permission_id}' not found") + + role = await rbac_repo.remove_permission_from_role(role_id, permission) + if not role: + raise ControllerNotFoundError(f"Role '{role_id}' not found") diff --git a/gns3server/api/routes/controller/templates.py b/gns3server/api/routes/controller/templates.py index ee54fdc7..a346545a 100644 --- a/gns3server/api/routes/controller/templates.py +++ b/gns3server/api/routes/controller/templates.py @@ -33,6 +33,9 @@ from gns3server import schemas from gns3server.controller import Controller from gns3server.db.repositories.templates import TemplatesRepository from gns3server.services.templates import TemplatesService +from gns3server.db.repositories.rbac import RbacRepository + +from .dependencies.authentication import get_current_active_user from .dependencies.database import get_repository responses = {404: {"model": schemas.ErrorMessage, "description": "Could not find template"}} @@ -44,12 +47,17 @@ router = APIRouter(responses=responses) async def create_template( template_create: schemas.TemplateCreate, templates_repo: TemplatesRepository = Depends(get_repository(TemplatesRepository)), + current_user: schemas.User = Depends(get_current_active_user), + rbac_repo: RbacRepository = Depends(get_repository(RbacRepository)) ) -> schemas.Template: """ Create a new template. """ - return await TemplatesService(templates_repo).create_template(template_create) + template = await TemplatesService(templates_repo).create_template(template_create) + template_id = template.get("template_id") + await rbac_repo.add_permission_to_user_with_path(current_user.user_id, f"/templates/{template_id}/*") + return template @router.get("/templates/{template_id}", response_model=schemas.Template, response_model_exclude_unset=True) @@ -92,35 +100,58 @@ async def update_template( status_code=status.HTTP_204_NO_CONTENT, ) async def delete_template( - template_id: UUID, templates_repo: TemplatesRepository = Depends(get_repository(TemplatesRepository)) + template_id: UUID, + templates_repo: TemplatesRepository = Depends(get_repository(TemplatesRepository)), + rbac_repo: RbacRepository = Depends(get_repository(RbacRepository)) ) -> None: """ Delete a template. """ await TemplatesService(templates_repo).delete_template(template_id) + await rbac_repo.delete_all_permissions_with_path(f"/templates/{template_id}") @router.get("/templates", response_model=List[schemas.Template], response_model_exclude_unset=True) async def get_templates( - templates_repo: TemplatesRepository = Depends(get_repository(TemplatesRepository)), + templates_repo: TemplatesRepository = Depends(get_repository(TemplatesRepository)), + current_user: schemas.User = Depends(get_current_active_user), + rbac_repo: RbacRepository = Depends(get_repository(RbacRepository)) ) -> List[schemas.Template]: """ Return all templates. """ - return await TemplatesService(templates_repo).get_templates() + templates = await TemplatesService(templates_repo).get_templates() + if current_user.is_superadmin: + return templates + else: + user_templates = [] + for template in templates: + if template.get("builtin") is True: + user_templates.append(template) + continue + template_id = template.get("template_id") + authorized = await rbac_repo.check_user_is_authorized( + current_user.user_id, "GET", f"/templates/{template_id}") + if authorized: + user_templates.append(template) + return user_templates @router.post("/templates/{template_id}/duplicate", response_model=schemas.Template, status_code=status.HTTP_201_CREATED) async def duplicate_template( - template_id: UUID, templates_repo: TemplatesRepository = Depends(get_repository(TemplatesRepository)) + template_id: UUID, templates_repo: TemplatesRepository = Depends(get_repository(TemplatesRepository)), + current_user: schemas.User = Depends(get_current_active_user), + rbac_repo: RbacRepository = Depends(get_repository(RbacRepository)) ) -> schemas.Template: """ Duplicate a template. """ - return await TemplatesService(templates_repo).duplicate_template(template_id) + template = await TemplatesService(templates_repo).duplicate_template(template_id) + await rbac_repo.add_permission_to_user_with_path(current_user.user_id, f"/templates/{template_id}/*") + return template @router.post( diff --git a/gns3server/api/routes/controller/users.py b/gns3server/api/routes/controller/users.py index e3fc578e..edf53b9a 100644 --- a/gns3server/api/routes/controller/users.py +++ b/gns3server/api/routes/controller/users.py @@ -32,6 +32,7 @@ from gns3server.controller.controller_error import ( ) from gns3server.db.repositories.users import UsersRepository +from gns3server.db.repositories.rbac import RbacRepository from gns3server.services import auth_service from .dependencies.authentication import get_current_active_user @@ -88,6 +89,24 @@ async def authenticate( return token +@router.get("/me", response_model=schemas.User) +async def get_logged_in_user(current_user: schemas.User = Depends(get_current_active_user)) -> schemas.User: + """ + Get the current active user. + """ + + return current_user + + +@router.get("/me", response_model=schemas.User) +async def get_logged_in_user(current_user: schemas.User = Depends(get_current_active_user)) -> schemas.User: + """ + Get the current active user. + """ + + return current_user + + @router.get("", response_model=List[schemas.User], dependencies=[Depends(get_current_active_user)]) async def get_users( users_repo: UsersRepository = Depends(get_repository(UsersRepository)) @@ -178,15 +197,6 @@ async def delete_user( raise ControllerNotFoundError(f"User '{user_id}' could not be deleted") -@router.get("/me/", response_model=schemas.User) -async def get_current_active_user(current_user: schemas.User = Depends(get_current_active_user)) -> schemas.User: - """ - Get the current active user. - """ - - return current_user - - @router.get( "/{user_id}/groups", dependencies=[Depends(get_current_active_user)], @@ -201,3 +211,65 @@ async def get_user_memberships( """ return await users_repo.get_user_memberships(user_id) + + +@router.get( + "/{user_id}/permissions", + dependencies=[Depends(get_current_active_user)], + response_model=List[schemas.Permission] +) +async def get_user_permissions( + user_id: UUID, + rbac_repo: RbacRepository = Depends(get_repository(RbacRepository)) +) -> List[schemas.Permission]: + """ + Get user permissions. + """ + + return await rbac_repo.get_user_permissions(user_id) + + +@router.put( + "/{user_id}/permissions/{permission_id}", + dependencies=[Depends(get_current_active_user)], + status_code=status.HTTP_204_NO_CONTENT +) +async def add_permission_to_user( + user_id: UUID, + permission_id: UUID, + rbac_repo: RbacRepository = Depends(get_repository(RbacRepository)) +) -> None: + """ + Add a permission to an user. + """ + + permission = await rbac_repo.get_permission(permission_id) + if not permission: + raise ControllerNotFoundError(f"Permission '{permission_id}' not found") + + user = await rbac_repo.add_permission_to_user(user_id, permission) + if not user: + raise ControllerNotFoundError(f"User '{user_id}' not found") + + +@router.delete( + "/{user_id}/permissions/{permission_id}", + dependencies=[Depends(get_current_active_user)], + status_code=status.HTTP_204_NO_CONTENT +) +async def remove_permission_from_user( + user_id: UUID, + permission_id: UUID, + rbac_repo: RbacRepository = Depends(get_repository(RbacRepository)), +) -> None: + """ + Remove permission from an user. + """ + + permission = await rbac_repo.get_permission(permission_id) + if not permission: + raise ControllerNotFoundError(f"Permission '{permission_id}' not found") + + user = await rbac_repo.remove_permission_from_user(user_id, permission) + if not user: + raise ControllerNotFoundError(f"User '{user_id}' not found") diff --git a/gns3server/db/models/__init__.py b/gns3server/db/models/__init__.py index 1c644f72..ed5f7ead 100644 --- a/gns3server/db/models/__init__.py +++ b/gns3server/db/models/__init__.py @@ -17,6 +17,8 @@ from .base import Base from .users import User, UserGroup +from .roles import Role +from .permissions import Permission from .computes import Compute from .templates import ( Template, diff --git a/gns3server/db/models/base.py b/gns3server/db/models/base.py index 1dbae1af..731a4547 100644 --- a/gns3server/db/models/base.py +++ b/gns3server/db/models/base.py @@ -19,7 +19,7 @@ import uuid from fastapi.encoders import jsonable_encoder from sqlalchemy import Column, DateTime, func, inspect -from sqlalchemy.types import TypeDecorator, CHAR +from sqlalchemy.types import TypeDecorator, CHAR, VARCHAR from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.ext.declarative import as_declarative @@ -72,6 +72,37 @@ class GUID(TypeDecorator): return value +class ListException(Exception): + pass + + +class ListType(TypeDecorator): + """ + Save/restore a Python list to/from a database column. + """ + + impl = VARCHAR + cache_ok = True + + def __init__(self, separator=',', *args, **kwargs): + + self._separator = separator + super().__init__(*args, **kwargs) + + def process_bind_param(self, value, dialect): + if value is not None: + if any(self._separator in str(item) for item in value): + raise ListException(f"List values cannot contain '{self._separator}'" + f"Please use a different separator.") + return self._separator.join(map(str, value)) + + def process_result_value(self, value, dialect): + if value is None: + return [] + else: + return list(map(str, value.split(self._separator))) + + class BaseTable(Base): __abstract__ = True @@ -79,6 +110,8 @@ class BaseTable(Base): created_at = Column(DateTime, server_default=func.current_timestamp()) updated_at = Column(DateTime, server_default=func.current_timestamp(), onupdate=func.current_timestamp()) + __mapper_args__ = {"eager_defaults": True} + def generate_uuid(): return str(uuid.uuid4()) diff --git a/gns3server/db/models/permissions.py b/gns3server/db/models/permissions.py new file mode 100644 index 00000000..4779b6af --- /dev/null +++ b/gns3server/db/models/permissions.py @@ -0,0 +1,122 @@ +#!/usr/bin/env python +# +# Copyright (C) 2021 GNS3 Technologies Inc. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from sqlalchemy import Table, Column, String, ForeignKey, event +from sqlalchemy.orm import relationship + +from .base import Base, BaseTable, generate_uuid, GUID, ListType + +import logging + +log = logging.getLogger(__name__) + + +permission_role_link = Table( + "permissions_roles_link", + Base.metadata, + Column("permission_id", GUID, ForeignKey("permissions.permission_id", ondelete="CASCADE")), + Column("role_id", GUID, ForeignKey("roles.role_id", ondelete="CASCADE")) + +) + + +class Permission(BaseTable): + + __tablename__ = "permissions" + + permission_id = Column(GUID, primary_key=True, default=generate_uuid) + description = Column(String) + methods = Column(ListType) + path = Column(String) + action = Column(String) + user_id = Column(GUID, ForeignKey('users.user_id', ondelete="CASCADE")) + roles = relationship("Role", secondary=permission_role_link, back_populates="permissions") + + +@event.listens_for(Permission.__table__, 'after_create') +def create_default_roles(target, connection, **kw): + + default_permissions = [ + { + "description": "Allow access to all endpoints", + "methods": ["GET", "HEAD", "POST", "PUT", "DELETE", "PATCH"], + "path": "/", + "action": "ALLOW" + }, + { + "description": "Allow to create and list projects", + "methods": ["GET", "HEAD", "POST"], + "path": "/projects", + "action": "ALLOW" + }, + { + "description": "Allow to create and list templates", + "methods": ["GET", "HEAD", "POST"], + "path": "/templates", + "action": "ALLOW" + }, + { + "description": "Allow to list computes", + "methods": ["GET"], + "path": "/computes/*", + "action": "ALLOW" + }, + { + "description": "Allow access to all symbol endpoints", + "methods": ["GET", "HEAD", "POST"], + "path": "/symbols/*", + "action": "ALLOW" + }, + ] + + stmt = target.insert().values(default_permissions) + connection.execute(stmt) + connection.commit() + log.debug("The default permissions have been created in the database") + + +@event.listens_for(permission_role_link, 'after_create') +def add_permissions_to_role(target, connection, **kw): + + from .roles import Role + roles_table = Role.__table__ + stmt = roles_table.select().where(roles_table.c.name == "Administrator") + result = connection.execute(stmt) + role_id = result.first().role_id + + permissions_table = Permission.__table__ + stmt = permissions_table.select().where(permissions_table.c.path == "/") + result = connection.execute(stmt) + permission_id = result.first().permission_id + + # add root path to the "Administrator" role + stmt = target.insert().values(permission_id=permission_id, role_id=role_id) + connection.execute(stmt) + + stmt = roles_table.select().where(roles_table.c.name == "User") + result = connection.execute(stmt) + role_id = result.first().role_id + + # add minimum required paths to the "User" role + for path in ("/projects", "/templates", "/computes/*", "/symbols/*"): + stmt = permissions_table.select().where(permissions_table.c.path == path) + result = connection.execute(stmt) + permission_id = result.first().permission_id + stmt = target.insert().values(permission_id=permission_id, role_id=role_id) + connection.execute(stmt) + + connection.commit() diff --git a/gns3server/db/models/roles.py b/gns3server/db/models/roles.py new file mode 100644 index 00000000..76b1d6f1 --- /dev/null +++ b/gns3server/db/models/roles.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python +# +# Copyright (C) 2021 GNS3 Technologies Inc. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from sqlalchemy import Table, Column, String, Boolean, ForeignKey, event +from sqlalchemy.orm import relationship + +from .base import Base, BaseTable, generate_uuid, GUID +from .permissions import permission_role_link + +import logging + +log = logging.getLogger(__name__) + +role_group_link = Table( + "roles_groups_link", + Base.metadata, + Column("role_id", GUID, ForeignKey("roles.role_id", ondelete="CASCADE")), + Column("user_group_id", GUID, ForeignKey("user_groups.user_group_id", ondelete="CASCADE")) +) + + +class Role(BaseTable): + + __tablename__ = "roles" + + role_id = Column(GUID, primary_key=True, default=generate_uuid) + name = Column(String) + description = Column(String) + builtin = Column(Boolean, default=False) + permissions = relationship("Permission", secondary=permission_role_link, back_populates="roles") + groups = relationship("UserGroup", secondary=role_group_link, back_populates="roles") + + +@event.listens_for(Role.__table__, 'after_create') +def create_default_roles(target, connection, **kw): + + default_roles = [ + {"name": "Administrator", "description": "Administrator role", "builtin": True}, + {"name": "User", "description": "User role", "builtin": True}, + ] + + stmt = target.insert().values(default_roles) + connection.execute(stmt) + connection.commit() + log.debug("The default roles have been created in the database") + + +@event.listens_for(role_group_link, 'after_create') +def add_admin_to_group(target, connection, **kw): + + from .users import UserGroup + user_groups_table = UserGroup.__table__ + roles_table = Role.__table__ + + # Add roles to built-in user groups + groups_to_roles = {"Administrators": "Administrator", "Users": "User"} + for user_group, role in groups_to_roles.items(): + stmt = user_groups_table.select().where(user_groups_table.c.name == user_group) + result = connection.execute(stmt) + user_group_id = result.first().user_group_id + stmt = roles_table.select().where(roles_table.c.name == role) + result = connection.execute(stmt) + role_id = result.first().role_id + stmt = target.insert().values(role_id=role_id, user_group_id=user_group_id) + connection.execute(stmt) + + connection.commit() diff --git a/gns3server/db/models/users.py b/gns3server/db/models/users.py index 046ec5c4..32c44bbd 100644 --- a/gns3server/db/models/users.py +++ b/gns3server/db/models/users.py @@ -19,6 +19,8 @@ from sqlalchemy import Table, Boolean, Column, String, ForeignKey, event from sqlalchemy.orm import relationship from .base import Base, BaseTable, generate_uuid, GUID +from .roles import role_group_link + from gns3server.config import Config from gns3server.services import auth_service @@ -26,11 +28,11 @@ import logging log = logging.getLogger(__name__) -users_group_members = Table( - "users_group_members", +user_group_link = Table( + "users_groups_link", Base.metadata, Column("user_id", GUID, ForeignKey("users.user_id", ondelete="CASCADE")), - Column("user_group_id", GUID, ForeignKey("users_group.user_group_id", ondelete="CASCADE")) + Column("user_group_id", GUID, ForeignKey("user_groups.user_group_id", ondelete="CASCADE")) ) @@ -45,7 +47,8 @@ class User(BaseTable): hashed_password = Column(String) is_active = Column(Boolean, default=True) is_superadmin = Column(Boolean, default=False) - groups = relationship("UserGroup", secondary=users_group_members, back_populates="users") + groups = relationship("UserGroup", secondary=user_group_link, back_populates="users") + permissions = relationship("Permission") @event.listens_for(User.__table__, 'after_create') @@ -63,47 +66,47 @@ def create_default_super_admin(target, connection, **kw): ) connection.execute(stmt) connection.commit() - log.info("The default super admin account has been created in the database") + log.debug("The default super admin account has been created in the database") class UserGroup(BaseTable): - __tablename__ = "users_group" + __tablename__ = "user_groups" user_group_id = Column(GUID, primary_key=True, default=generate_uuid) name = Column(String, unique=True, index=True) - is_updatable = Column(Boolean, default=True) - users = relationship("User", secondary=users_group_members, back_populates="groups") + builtin = Column(Boolean, default=False) + users = relationship("User", secondary=user_group_link, back_populates="groups") + roles = relationship("Role", secondary=role_group_link, back_populates="groups") @event.listens_for(UserGroup.__table__, 'after_create') def create_default_user_groups(target, connection, **kw): default_groups = [ - {"name": "Administrators", "is_updatable": False}, - {"name": "Editors", "is_updatable": False}, - {"name": "Users", "is_updatable": False} + {"name": "Administrators", "builtin": True}, + {"name": "Users", "builtin": True} ] stmt = target.insert().values(default_groups) connection.execute(stmt) connection.commit() - log.info("The default user groups have been created in the database") + log.debug("The default user groups have been created in the database") -@event.listens_for(users_group_members, 'after_create') -def add_admin_to_group(target, connection, **kw): - - users_group_table = UserGroup.__table__ - stmt = users_group_table.select().where(users_group_table.c.name == "Administrators") - result = connection.execute(stmt) - user_group_id = result.first().user_group_id - - users_table = User.__table__ - stmt = users_table.select().where(users_table.c.is_superadmin.is_(True)) - result = connection.execute(stmt) - user_id = result.first().user_id - - stmt = target.insert().values(user_id=user_id, user_group_id=user_group_id) - connection.execute(stmt) - connection.commit() +# @event.listens_for(user_group_link, 'after_create') +# def add_admin_to_group(target, connection, **kw): +# +# user_groups_table = UserGroup.__table__ +# stmt = user_groups_table.select().where(user_groups_table.c.name == "Administrators") +# result = connection.execute(stmt) +# user_group_id = result.first().user_group_id +# +# users_table = User.__table__ +# stmt = users_table.select().where(users_table.c.is_superadmin.is_(True)) +# result = connection.execute(stmt) +# user_id = result.first().user_id +# +# stmt = target.insert().values(user_id=user_id, user_group_id=user_group_id) +# connection.execute(stmt) +# connection.commit() diff --git a/gns3server/db/repositories/computes.py b/gns3server/db/repositories/computes.py index 445621da..b76842cb 100644 --- a/gns3server/db/repositories/computes.py +++ b/gns3server/db/repositories/computes.py @@ -79,11 +79,16 @@ class ComputesRepository(BaseRepository): if password: update_values["password"] = password.get_secret_value() - query = update(models.Compute).where(models.Compute.compute_id == compute_id).values(update_values) + query = update(models.Compute).\ + where(models.Compute.compute_id == compute_id).\ + values(update_values) await self._db_session.execute(query) await self._db_session.commit() - return await self.get_compute(compute_id) + compute_db = await self.get_compute(compute_id) + if compute_db: + await self._db_session.refresh(compute_db) # force refresh of updated_at value + return compute_db async def delete_compute(self, compute_id: UUID) -> bool: diff --git a/gns3server/db/repositories/rbac.py b/gns3server/db/repositories/rbac.py new file mode 100644 index 00000000..6e1096c9 --- /dev/null +++ b/gns3server/db/repositories/rbac.py @@ -0,0 +1,409 @@ +#!/usr/bin/env python +# +# Copyright (C) 2020 GNS3 Technologies Inc. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from uuid import UUID +from typing import Optional, List, Union +from sqlalchemy import select, update, delete +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload + +from .base import BaseRepository + +import gns3server.db.models as models +from gns3server.schemas.controller.rbac import HTTPMethods, PermissionAction +from gns3server import schemas + +import logging + +log = logging.getLogger(__name__) + + +class RbacRepository(BaseRepository): + + def __init__(self, db_session: AsyncSession) -> None: + + super().__init__(db_session) + + async def get_role(self, role_id: UUID) -> Optional[models.Role]: + """ + Get a role by its ID. + """ + + query = select(models.Role).\ + options(selectinload(models.Role.permissions)).\ + where(models.Role.role_id == role_id) + result = await self._db_session.execute(query) + return result.scalars().first() + + async def get_role_by_name(self, name: str) -> Optional[models.Role]: + """ + Get a role by its name. + """ + + query = select(models.Role).\ + options(selectinload(models.Role.permissions)).\ + where(models.Role.name == name) + #query = select(models.Role).where(models.Role.name == name) + result = await self._db_session.execute(query) + return result.scalars().first() + + async def get_roles(self) -> List[models.Role]: + """ + Get all roles. + """ + + query = select(models.Role).options(selectinload(models.Role.permissions)) + result = await self._db_session.execute(query) + return result.scalars().all() + + async def create_role(self, role_create: schemas.RoleCreate) -> models.Role: + """ + Create a new role. + """ + + db_role = models.Role( + name=role_create.name, + description=role_create.description, + ) + self._db_session.add(db_role) + await self._db_session.commit() + #await self._db_session.refresh(db_role) + return await self.get_role(db_role.role_id) + + async def update_role( + self, + role_id: UUID, + role_update: schemas.RoleUpdate + ) -> Optional[models.Role]: + """ + Update a role. + """ + + update_values = role_update.dict(exclude_unset=True) + query = update(models.Role).\ + where(models.Role.role_id == role_id).\ + values(update_values) + + await self._db_session.execute(query) + await self._db_session.commit() + role_db = await self.get_role(role_id) + if role_db: + await self._db_session.refresh(role_db) # force refresh of updated_at value + return role_db + + async def delete_role(self, role_id: UUID) -> bool: + """ + Delete a role. + """ + + query = delete(models.Role).where(models.Role.role_id == role_id) + result = await self._db_session.execute(query) + await self._db_session.commit() + return result.rowcount > 0 + + async def add_permission_to_role( + self, + role_id: UUID, + permission: models.Permission + ) -> Union[None, models.Role]: + """ + Add a permission to a role. + """ + + query = select(models.Role).\ + options(selectinload(models.Role.permissions)).\ + where(models.Role.role_id == role_id) + result = await self._db_session.execute(query) + role_db = result.scalars().first() + if not role_db: + return None + + role_db.permissions.append(permission) + await self._db_session.commit() + await self._db_session.refresh(role_db) + return role_db + + async def remove_permission_from_role( + self, + role_id: UUID, + permission: models.Permission + ) -> Union[None, models.Role]: + """ + Remove a permission from a role. + """ + + query = select(models.Role).\ + options(selectinload(models.Role.permissions)).\ + where(models.Role.role_id == role_id) + result = await self._db_session.execute(query) + role_db = result.scalars().first() + if not role_db: + return None + + role_db.permissions.remove(permission) + await self._db_session.commit() + await self._db_session.refresh(role_db) + return role_db + + async def get_role_permissions(self, role_id: UUID) -> List[models.Permission]: + """ + Get all the role permissions. + """ + + query = select(models.Permission).\ + join(models.Permission.roles).\ + filter(models.Role.role_id == role_id) + + result = await self._db_session.execute(query) + return result.scalars().all() + + async def get_permission(self, permission_id: UUID) -> Optional[models.Permission]: + """ + Get a permission by its ID. + """ + + query = select(models.Permission).where(models.Permission.permission_id == permission_id) + result = await self._db_session.execute(query) + return result.scalars().first() + + async def get_permission_by_path(self, path: str) -> Optional[models.Permission]: + """ + Get a permission by its path. + """ + + query = select(models.Permission).where(models.Permission.path == path) + result = await self._db_session.execute(query) + return result.scalars().first() + + async def get_permissions(self) -> List[models.Permission]: + """ + Get all permissions. + """ + + query = select(models.Permission) + result = await self._db_session.execute(query) + return result.scalars().all() + + async def check_permission_exists(self, permission_create: schemas.PermissionCreate) -> bool: + """ + Check if a permission exists. + """ + + query = select(models.Permission).\ + where(models.Permission.methods == permission_create.methods, + models.Permission.path == permission_create.path, + models.Permission.action == permission_create.action) + result = await self._db_session.execute(query) + return result.scalars().first() is not None + + async def create_permission(self, permission_create: schemas.PermissionCreate) -> models.Permission: + """ + Create a new permission. + """ + + db_permission = models.Permission( + description=permission_create.description, + methods=permission_create.methods, + path=permission_create.path, + action=permission_create.action, + ) + self._db_session.add(db_permission) + await self._db_session.commit() + await self._db_session.refresh(db_permission) + return db_permission + + async def update_permission( + self, + permission_id: UUID, + permission_update: schemas.PermissionUpdate + ) -> Optional[models.Permission]: + """ + Update a permission. + """ + + update_values = permission_update.dict(exclude_unset=True) + query = update(models.Permission).\ + where(models.Permission.permission_id == permission_id).\ + values(update_values) + + await self._db_session.execute(query) + await self._db_session.commit() + permission_db = await self.get_permission(permission_id) + if permission_db: + await self._db_session.refresh(permission_db) # force refresh of updated_at value + return permission_db + + async def delete_permission(self, permission_id: UUID) -> bool: + """ + Delete a permission. + """ + + query = delete(models.Permission).where(models.Permission.permission_id == permission_id) + result = await self._db_session.execute(query) + await self._db_session.commit() + return result.rowcount > 0 + + def _match_permission( + self, + permissions: List[models.Permission], + method: str, + path: str + ) -> Union[None, models.Permission]: + """ + Match the methods and path with a permission. + """ + + for permission in permissions: + log.debug(f"RBAC: checking permission {permission.methods} {permission.path} {permission.action}") + if method not in permission.methods: + continue + if permission.path.endswith("/*") and path.startswith(permission.path[:-2]): + return permission + elif permission.path == path: + return permission + + async def get_user_permissions(self, user_id: UUID): + """ + Get all permissions from an user. + """ + + query = select(models.Permission).\ + join(models.User.permissions). \ + filter(models.User.user_id == user_id).\ + order_by(models.Permission.path) + + result = await self._db_session.execute(query) + return result.scalars().all() + + async def add_permission_to_user( + self, + user_id: UUID, + permission: models.Permission + ) -> Union[None, models.User]: + """ + Add a permission to an user. + """ + + query = select(models.User).\ + options(selectinload(models.User.permissions)).\ + where(models.User.user_id == user_id) + result = await self._db_session.execute(query) + user_db = result.scalars().first() + if not user_db: + return None + + user_db.permissions.append(permission) + await self._db_session.commit() + await self._db_session.refresh(user_db) + return user_db + + async def remove_permission_from_user( + self, + user_id: UUID, + permission: models.Permission + ) -> Union[None, models.User]: + """ + Remove a permission from a role. + """ + + query = select(models.User).\ + options(selectinload(models.User.permissions)).\ + where(models.User.user_id == user_id) + result = await self._db_session.execute(query) + user_db = result.scalars().first() + if not user_db: + return None + + user_db.permissions.remove(permission) + await self._db_session.commit() + await self._db_session.refresh(user_db) + return user_db + + async def add_permission_to_user_with_path(self, user_id: UUID, path: str) -> Union[None, models.User]: + """ + Add a permission to an user. + """ + + # Create a new permission with full rights on path + new_permission = schemas.PermissionCreate( + description=f"Allow access to {path}", + methods=[HTTPMethods.get, HTTPMethods.head, HTTPMethods.post, HTTPMethods.put, HTTPMethods.delete], + path=path, + action=PermissionAction.allow + ) + permission_db = await self.create_permission(new_permission) + + # Add the permission to the user + query = select(models.User).\ + options(selectinload(models.User.permissions)).\ + where(models.User.user_id == user_id) + + result = await self._db_session.execute(query) + user_db = result.scalars().first() + if not user_db: + return None + + user_db.permissions.append(permission_db) + await self._db_session.commit() + await self._db_session.refresh(user_db) + return user_db + + async def delete_all_permissions_with_path(self, path: str) -> None: + """ + Delete all permissions with path. + """ + + query = delete(models.Permission).\ + where(models.Permission.path.startswith(path)).\ + execution_options(synchronize_session=False) + result = await self._db_session.execute(query) + log.debug(f"{result.rowcount} permission(s) have been deleted") + + async def check_user_is_authorized(self, user_id: UUID, method: str, path: str) -> bool: + """ + Check if an user is authorized to access a resource. + """ + + query = select(models.Permission).\ + join(models.Permission.roles). \ + join(models.Role.groups). \ + join(models.UserGroup.users). \ + filter(models.User.user_id == user_id).\ + order_by(models.Permission.path) + + result = await self._db_session.execute(query) + permissions = result.scalars().all() + log.debug(f"RBAC: checking authorization for user '{user_id}' on {method} '{path}'") + matched_permission = self._match_permission(permissions, method, path) + if matched_permission: + log.debug(f"RBAC: matched role permission {matched_permission.methods} " + f"{matched_permission.path} {matched_permission.action}") + if matched_permission.action == "DENY": + return False + return True + + log.debug(f"RBAC: could not find a role permission, checking user permissions...") + permissions = await self.get_user_permissions(user_id) + matched_permission = self._match_permission(permissions, method, path) + if matched_permission: + log.debug(f"RBAC: matched user permission {matched_permission.methods} " + f"{matched_permission.path} {matched_permission.action}") + if matched_permission.action == "DENY": + return False + return True + + return False diff --git a/gns3server/db/repositories/templates.py b/gns3server/db/repositories/templates.py index e5be7971..4e6f50b1 100644 --- a/gns3server/db/repositories/templates.py +++ b/gns3server/db/repositories/templates.py @@ -70,11 +70,16 @@ class TemplatesRepository(BaseRepository): update_values = template_update.dict(exclude_unset=True) - query = update(models.Template).where(models.Template.template_id == template_id).values(update_values) + query = update(models.Template). \ + where(models.Template.template_id == template_id). \ + values(update_values) await self._db_session.execute(query) await self._db_session.commit() - return await self.get_template(template_id) + template_db = await self.get_template(template_id) + if template_db: + await self._db_session.refresh(template_db) # force refresh of updated_at value + return template_db async def delete_template(self, template_id: UUID) -> bool: diff --git a/gns3server/db/repositories/users.py b/gns3server/db/repositories/users.py index c93c741e..68d12f78 100644 --- a/gns3server/db/repositories/users.py +++ b/gns3server/db/repositories/users.py @@ -33,40 +33,59 @@ log = logging.getLogger(__name__) class UsersRepository(BaseRepository): + def __init__(self, db_session: AsyncSession) -> None: super().__init__(db_session) self._auth_service = auth_service async def get_user(self, user_id: UUID) -> Optional[models.User]: + """ + Get an user by its ID. + """ query = select(models.User).where(models.User.user_id == user_id) result = await self._db_session.execute(query) return result.scalars().first() async def get_user_by_username(self, username: str) -> Optional[models.User]: + """ + Get an user by its name. + """ query = select(models.User).where(models.User.username == username) result = await self._db_session.execute(query) return result.scalars().first() async def get_user_by_email(self, email: str) -> Optional[models.User]: + """ + Get an user by its email. + """ query = select(models.User).where(models.User.email == email) result = await self._db_session.execute(query) return result.scalars().first() async def get_users(self) -> List[models.User]: + """ + Get all users. + """ query = select(models.User) result = await self._db_session.execute(query) return result.scalars().all() async def create_user(self, user: schemas.UserCreate) -> models.User: + """ + Create a new user. + """ hashed_password = self._auth_service.hash_password(user.password.get_secret_value()) db_user = models.User( - username=user.username, email=user.email, full_name=user.full_name, hashed_password=hashed_password + username=user.username, + email=user.email, + full_name=user.full_name, + hashed_password=hashed_password ) self._db_session.add(db_user) await self._db_session.commit() @@ -74,19 +93,30 @@ class UsersRepository(BaseRepository): return db_user async def update_user(self, user_id: UUID, user_update: schemas.UserUpdate) -> Optional[models.User]: + """ + Update an user. + """ update_values = user_update.dict(exclude_unset=True) password = update_values.pop("password", None) if password: update_values["hashed_password"] = self._auth_service.hash_password(password=password.get_secret_value()) - query = update(models.User).where(models.User.user_id == user_id).values(update_values) + query = update(models.User).\ + where(models.User.user_id == user_id).\ + values(update_values) await self._db_session.execute(query) await self._db_session.commit() - return await self.get_user(user_id) + user_db = await self.get_user(user_id) + if user_db: + await self._db_session.refresh(user_db) # force refresh of updated_at value + return user_db async def delete_user(self, user_id: UUID) -> bool: + """ + Delete an user. + """ query = delete(models.User).where(models.User.user_id == user_id) result = await self._db_session.execute(query) @@ -94,6 +124,9 @@ class UsersRepository(BaseRepository): return result.rowcount > 0 async def authenticate_user(self, username: str, password: str) -> Optional[models.User]: + """ + Authenticate an user. + """ user = await self.get_user_by_username(username) if not user: @@ -110,6 +143,9 @@ class UsersRepository(BaseRepository): return user async def get_user_memberships(self, user_id: UUID) -> List[models.UserGroup]: + """ + Get all user memberships (user groups). + """ query = select(models.UserGroup).\ join(models.UserGroup.users).\ @@ -119,24 +155,36 @@ class UsersRepository(BaseRepository): return result.scalars().all() async def get_user_group(self, user_group_id: UUID) -> Optional[models.UserGroup]: + """ + Get an user group by its ID. + """ query = select(models.UserGroup).where(models.UserGroup.user_group_id == user_group_id) result = await self._db_session.execute(query) return result.scalars().first() async def get_user_group_by_name(self, name: str) -> Optional[models.UserGroup]: + """ + Get an user group by its name. + """ query = select(models.UserGroup).where(models.UserGroup.name == name) result = await self._db_session.execute(query) return result.scalars().first() async def get_user_groups(self) -> List[models.UserGroup]: + """ + Get all user groups. + """ query = select(models.UserGroup) result = await self._db_session.execute(query) return result.scalars().all() async def create_user_group(self, user_group: schemas.UserGroupCreate) -> models.UserGroup: + """ + Create a new user group. + """ db_user_group = models.UserGroup(name=user_group.name) self._db_session.add(db_user_group) @@ -149,15 +197,26 @@ class UsersRepository(BaseRepository): user_group_id: UUID, user_group_update: schemas.UserGroupUpdate ) -> Optional[models.UserGroup]: + """ + Update an user group. + """ update_values = user_group_update.dict(exclude_unset=True) - query = update(models.UserGroup).where(models.UserGroup.user_group_id == user_group_id).values(update_values) + query = update(models.UserGroup).\ + where(models.UserGroup.user_group_id == user_group_id).\ + values(update_values) await self._db_session.execute(query) await self._db_session.commit() - return await self.get_user_group(user_group_id) + user_group_db = await self.get_user_group(user_group_id) + if user_group_db: + await self._db_session.refresh(user_group_db) # force refresh of updated_at value + return user_group_db async def delete_user_group(self, user_group_id: UUID) -> bool: + """ + Delete an user group. + """ query = delete(models.UserGroup).where(models.UserGroup.user_group_id == user_group_id) result = await self._db_session.execute(query) @@ -169,6 +228,9 @@ class UsersRepository(BaseRepository): user_group_id: UUID, user: models.User ) -> Union[None, models.UserGroup]: + """ + Add a member to an user group. + """ query = select(models.UserGroup).\ options(selectinload(models.UserGroup.users)).\ @@ -188,6 +250,9 @@ class UsersRepository(BaseRepository): user_group_id: UUID, user: models.User ) -> Union[None, models.UserGroup]: + """ + Remove a member from an user group. + """ query = select(models.UserGroup).\ options(selectinload(models.UserGroup.users)).\ @@ -203,6 +268,9 @@ class UsersRepository(BaseRepository): return user_group_db async def get_user_group_members(self, user_group_id: UUID) -> List[models.User]: + """ + Get all members from an user group. + """ query = select(models.User).\ join(models.User.groups).\ @@ -210,3 +278,60 @@ class UsersRepository(BaseRepository): result = await self._db_session.execute(query) return result.scalars().all() + + async def add_role_to_user_group( + self, + user_group_id: UUID, + role: models.Role + ) -> Union[None, models.UserGroup]: + """ + Add a role to an user group. + """ + + query = select(models.UserGroup).\ + options(selectinload(models.UserGroup.roles)).\ + where(models.UserGroup.user_group_id == user_group_id) + result = await self._db_session.execute(query) + user_group_db = result.scalars().first() + if not user_group_db: + return None + + user_group_db.roles.append(role) + await self._db_session.commit() + await self._db_session.refresh(user_group_db) + return user_group_db + + async def remove_role_from_user_group( + self, + user_group_id: UUID, + role: models.Role + ) -> Union[None, models.UserGroup]: + """ + Remove a role from an user group. + """ + + query = select(models.UserGroup).\ + options(selectinload(models.UserGroup.roles)).\ + where(models.UserGroup.user_group_id == user_group_id) + result = await self._db_session.execute(query) + user_group_db = result.scalars().first() + if not user_group_db: + return None + + user_group_db.roles.remove(role) + await self._db_session.commit() + await self._db_session.refresh(user_group_db) + return user_group_db + + async def get_user_group_roles(self, user_group_id: UUID) -> List[models.Role]: + """ + Get all roles from an user group. + """ + + query = select(models.Role). \ + options(selectinload(models.Role.permissions)). \ + join(models.UserGroup.roles). \ + filter(models.UserGroup.user_group_id == user_group_id) + + result = await self._db_session.execute(query) + return result.scalars().all() diff --git a/gns3server/schemas/__init__.py b/gns3server/schemas/__init__.py index f32985a5..e5683059 100644 --- a/gns3server/schemas/__init__.py +++ b/gns3server/schemas/__init__.py @@ -28,6 +28,7 @@ from .controller.gns3vm import GNS3VM from .controller.nodes import NodeCreate, NodeUpdate, NodeDuplicate, NodeCapture, Node from .controller.projects import ProjectCreate, ProjectUpdate, ProjectDuplicate, Project, ProjectFile from .controller.users import UserCreate, UserUpdate, User, Credentials, UserGroupCreate, UserGroupUpdate, UserGroup +from .controller.rbac import RoleCreate, RoleUpdate, Role, PermissionCreate, PermissionUpdate, Permission from .controller.tokens import Token from .controller.snapshots import SnapshotCreate, Snapshot from .controller.iou_license import IOULicense diff --git a/gns3server/schemas/controller/rbac.py b/gns3server/schemas/controller/rbac.py new file mode 100644 index 00000000..4967e5ef --- /dev/null +++ b/gns3server/schemas/controller/rbac.py @@ -0,0 +1,121 @@ +# +# Copyright (C) 2020 GNS3 Technologies Inc. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from typing import Optional, List +from pydantic import BaseModel, validator +from uuid import UUID +from enum import Enum + +from .base import DateTimeModelMixin + + +class HTTPMethods(str, Enum): + """ + HTTP method type. + """ + + get = "GET" + head = "HEAD" + post = "POST" + patch = "PATCH" + put = "PUT" + delete = "DELETE" + + +class PermissionAction(str, Enum): + """ + Action to perform when permission is matched. + """ + + allow = "ALLOW" + deny = "DENY" + + +class PermissionBase(BaseModel): + """ + Common permission properties. + """ + + methods: List[HTTPMethods] + path: str + action: PermissionAction + description: Optional[str] = None + + class Config: + use_enum_values = True + + @validator("action", pre=True) + def action_uppercase(cls, v): + return v.upper() + + +class PermissionCreate(PermissionBase): + """ + Properties to create a permission. + """ + + pass + + +class PermissionUpdate(PermissionBase): + """ + Properties to update a role. + """ + + pass + + +class Permission(DateTimeModelMixin, PermissionBase): + + permission_id: UUID + + class Config: + orm_mode = True + + +class RoleBase(BaseModel): + """ + Common role properties. + """ + + name: Optional[str] = None + description: Optional[str] = None + + +class RoleCreate(RoleBase): + """ + Properties to create a role. + """ + + name: str + + +class RoleUpdate(RoleBase): + """ + Properties to update a role. + """ + + pass + + +class Role(DateTimeModelMixin, RoleBase): + + role_id: UUID + builtin: bool + permissions: List[Permission] + + class Config: + orm_mode = True diff --git a/gns3server/schemas/controller/users.py b/gns3server/schemas/controller/users.py index 28d40688..97988c53 100644 --- a/gns3server/schemas/controller/users.py +++ b/gns3server/schemas/controller/users.py @@ -85,7 +85,7 @@ class UserGroupUpdate(UserGroupBase): class UserGroup(DateTimeModelMixin, UserGroupBase): user_group_id: UUID - is_updatable: bool + builtin: bool class Config: orm_mode = True diff --git a/requirements.txt b/requirements.txt index 20be3627..f58a051c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,16 +1,16 @@ uvicorn==0.13.4 -fastapi==0.64.0 -websockets==9.0.1 +fastapi==0.65.1 +websockets==9.1 python-multipart==0.0.5 aiohttp==3.7.4.post0 -aiofiles==0.6.0 -Jinja2==2.11.3 +aiofiles==0.7.0 +Jinja2==3.0.1 sentry-sdk==1.1.0 psutil==5.8.0 async-timeout==3.0.1 distro==1.5.0 py-cpuinfo==8.0.0 -sqlalchemy==1.4.14 +sqlalchemy==1.4.17 aiosqlite===0.17.0 passlib[bcrypt]==1.7.4 python-jose==3.2.0 diff --git a/tests/api/routes/controller/test_groups.py b/tests/api/routes/controller/test_groups.py index 7551ab9b..b762d5c8 100644 --- a/tests/api/routes/controller/test_groups.py +++ b/tests/api/routes/controller/test_groups.py @@ -22,7 +22,10 @@ from httpx import AsyncClient from sqlalchemy.ext.asyncio import AsyncSession from gns3server.db.repositories.users import UsersRepository +from gns3server.db.repositories.rbac import RbacRepository from gns3server.schemas.controller.users import User +from gns3server.schemas.controller.rbac import Role +from gns3server import schemas pytestmark = pytest.mark.asyncio @@ -47,7 +50,7 @@ class TestGroupRoutes: response = await client.get(app.url_path_for("get_user_groups")) assert response.status_code == status.HTTP_200_OK - assert len(response.json()) == 4 # 3 default groups + group1 + assert len(response.json()) == 3 # 2 default groups + group1 async def test_update_group(self, app: FastAPI, client: AsyncClient, db_session: AsyncSession) -> None: @@ -103,6 +106,9 @@ class TestGroupRoutes: response = await client.delete(app.url_path_for("delete_user_group", user_group_id=group_in_db.user_group_id)) assert response.status_code == status.HTTP_403_FORBIDDEN + +class TestGroupMembersRoutes: + async def test_add_member_to_group( self, app: FastAPI, @@ -163,3 +169,84 @@ class TestGroupRoutes: assert response.status_code == status.HTTP_204_NO_CONTENT members = await user_repo.get_user_group_members(group_in_db.user_group_id) assert len(members) == 0 + + +@pytest.fixture +async def test_role(db_session: AsyncSession) -> Role: + + new_role = schemas.RoleCreate( + name="TestRole", + description="This is my test role" + ) + rbac_repo = RbacRepository(db_session) + existing_role = await rbac_repo.get_role_by_name(new_role.name) + if existing_role: + return existing_role + return await rbac_repo.create_role(new_role) + + +class TestGroupRolesRoutes: + + async def test_add_role_to_group( + self, + app: FastAPI, + client: AsyncClient, + test_role: Role, + db_session: AsyncSession + ) -> None: + + user_repo = UsersRepository(db_session) + group_in_db = await user_repo.get_user_group_by_name("Users") + response = await client.put( + app.url_path_for( + "add_role_to_group", + user_group_id=group_in_db.user_group_id, + role_id=str(test_role.role_id) + ) + ) + assert response.status_code == status.HTTP_204_NO_CONTENT + roles = await user_repo.get_user_group_roles(group_in_db.user_group_id) + assert len(roles) == 2 # 1 default role + 1 custom role + for role in roles: + if not role.builtin: + assert role.name == test_role.name + + async def test_get_user_group_roles( + self, + app: FastAPI, + client: AsyncClient, + db_session: AsyncSession + ) -> None: + + user_repo = UsersRepository(db_session) + group_in_db = await user_repo.get_user_group_by_name("Users") + response = await client.get( + app.url_path_for( + "get_user_group_roles", + user_group_id=group_in_db.user_group_id) + ) + assert response.status_code == status.HTTP_200_OK + assert len(response.json()) == 2 # 1 default role + 1 custom role + + async def test_remove_role_from_group( + self, + app: FastAPI, + client: AsyncClient, + test_role: Role, + db_session: AsyncSession + ) -> None: + + user_repo = UsersRepository(db_session) + group_in_db = await user_repo.get_user_group_by_name("Users") + + response = await client.delete( + app.url_path_for( + "remove_role_from_group", + user_group_id=group_in_db.user_group_id, + role_id=test_role.role_id + ), + ) + assert response.status_code == status.HTTP_204_NO_CONTENT + roles = await user_repo.get_user_group_roles(group_in_db.user_group_id) + assert len(roles) == 1 # 1 default role + assert roles[0].name != test_role.name diff --git a/tests/api/routes/controller/test_permissions.py b/tests/api/routes/controller/test_permissions.py new file mode 100644 index 00000000..1bd4e0ff --- /dev/null +++ b/tests/api/routes/controller/test_permissions.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python +# +# Copyright (C) 2021 GNS3 Technologies Inc. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import pytest + +from fastapi import FastAPI, status +from httpx import AsyncClient + +from sqlalchemy.ext.asyncio import AsyncSession +from gns3server.db.repositories.rbac import RbacRepository + +pytestmark = pytest.mark.asyncio + + +class TestPermissionRoutes: + + async def test_create_permission(self, app: FastAPI, client: AsyncClient) -> None: + + new_permission = { + "methods": ["GET"], + "path": "/templates", + "action": "ALLOW" + } + response = await client.post(app.url_path_for("create_permission"), json=new_permission) + assert response.status_code == status.HTTP_201_CREATED + + async def test_get_permission(self, app: FastAPI, client: AsyncClient, db_session: AsyncSession) -> None: + + rbac_repo = RbacRepository(db_session) + permission_in_db = await rbac_repo.get_permission_by_path("/templates") + response = await client.get(app.url_path_for("get_permission", permission_id=permission_in_db.permission_id)) + assert response.status_code == status.HTTP_200_OK + assert response.json()["permission_id"] == str(permission_in_db.permission_id) + + async def test_list_permissions(self, app: FastAPI, client: AsyncClient) -> None: + + response = await client.get(app.url_path_for("get_permissions")) + assert response.status_code == status.HTTP_200_OK + assert len(response.json()) == 6 # 5 default permissions + 1 custom permission + + async def test_update_permission(self, app: FastAPI, client: AsyncClient, db_session: AsyncSession) -> None: + + rbac_repo = RbacRepository(db_session) + permission_in_db = await rbac_repo.get_permission_by_path("/templates") + + update_permission = { + "methods": ["GET"], + "path": "/appliances", + "action": "ALLOW" + } + response = await client.put( + app.url_path_for("update_permission", permission_id=permission_in_db.permission_id), + json=update_permission + ) + assert response.status_code == status.HTTP_200_OK + updated_permission_in_db = await rbac_repo.get_permission(permission_in_db.permission_id) + assert updated_permission_in_db.path == "/appliances" + + async def test_delete_permission( + self, + app: FastAPI, + client: AsyncClient, + db_session: AsyncSession + ) -> None: + + rbac_repo = RbacRepository(db_session) + permission_in_db = await rbac_repo.get_permission_by_path("/appliances") + response = await client.delete(app.url_path_for("delete_permission", permission_id=permission_in_db.permission_id)) + assert response.status_code == status.HTTP_204_NO_CONTENT diff --git a/tests/api/routes/controller/test_roles.py b/tests/api/routes/controller/test_roles.py new file mode 100644 index 00000000..20646d92 --- /dev/null +++ b/tests/api/routes/controller/test_roles.py @@ -0,0 +1,185 @@ +#!/usr/bin/env python +# +# Copyright (C) 2021 GNS3 Technologies Inc. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import pytest + +from fastapi import FastAPI, status +from httpx import AsyncClient + +from sqlalchemy.ext.asyncio import AsyncSession +from gns3server.db.repositories.rbac import RbacRepository +from gns3server.schemas.controller.rbac import Permission, HTTPMethods, PermissionAction +from gns3server import schemas + +pytestmark = pytest.mark.asyncio + + +class TestRolesRoutes: + + async def test_create_role(self, app: FastAPI, client: AsyncClient) -> None: + + new_role = {"name": "role1"} + response = await client.post(app.url_path_for("create_role"), json=new_role) + assert response.status_code == status.HTTP_201_CREATED + + async def test_get_role(self, app: FastAPI, client: AsyncClient, db_session: AsyncSession) -> None: + + rbac_repo = RbacRepository(db_session) + role_in_db = await rbac_repo.get_role_by_name("role1") + response = await client.get(app.url_path_for("get_role", role_id=role_in_db.role_id)) + assert response.status_code == status.HTTP_200_OK + assert response.json()["role_id"] == str(role_in_db.role_id) + + async def test_list_roles(self, app: FastAPI, client: AsyncClient) -> None: + + response = await client.get(app.url_path_for("get_roles")) + assert response.status_code == status.HTTP_200_OK + assert len(response.json()) == 3 # 2 default roles + role1 + + async def test_update_role(self, app: FastAPI, client: AsyncClient, db_session: AsyncSession) -> None: + + rbac_repo = RbacRepository(db_session) + role_in_db = await rbac_repo.get_role_by_name("role1") + + update_role = {"name": "role42"} + response = await client.put( + app.url_path_for("update_role", role_id=role_in_db.role_id), + json=update_role + ) + assert response.status_code == status.HTTP_200_OK + updated_role_in_db = await rbac_repo.get_role(role_in_db.role_id) + assert updated_role_in_db.name == "role42" + + async def test_cannot_update_builtin_user_role( + self, + app: FastAPI, + client: AsyncClient, + db_session: AsyncSession + ) -> None: + + rbac_repo = RbacRepository(db_session) + role_in_db = await rbac_repo.get_role_by_name("User") + update_role = {"name": "Hackers"} + response = await client.put( + app.url_path_for("update_role", role_id=role_in_db.role_id), + json=update_role + ) + assert response.status_code == status.HTTP_403_FORBIDDEN + + async def test_delete_role( + self, + app: FastAPI, + client: AsyncClient, + db_session: AsyncSession + ) -> None: + + rbac_repo = RbacRepository(db_session) + role_in_db = await rbac_repo.get_role_by_name("role42") + response = await client.delete(app.url_path_for("delete_role", role_id=role_in_db.role_id)) + assert response.status_code == status.HTTP_204_NO_CONTENT + + async def test_cannot_delete_builtin_administrator_role( + self, + app: FastAPI, + client: AsyncClient, + db_session: AsyncSession + ) -> None: + + rbac_repo = RbacRepository(db_session) + role_in_db = await rbac_repo.get_role_by_name("Administrator") + response = await client.delete(app.url_path_for("delete_role", role_id=role_in_db.role_id)) + assert response.status_code == status.HTTP_403_FORBIDDEN + + +@pytest.fixture +async def test_permission(db_session: AsyncSession) -> Permission: + + new_permission = schemas.PermissionCreate( + methods=[HTTPMethods.get], + path="/statistics", + action=PermissionAction.allow + ) + rbac_repo = RbacRepository(db_session) + existing_permission = await rbac_repo.get_permission_by_path("/statistics") + if existing_permission: + return existing_permission + return await rbac_repo.create_permission(new_permission) + + +class TestRolesPermissionsRoutes: + + async def test_add_permission_to_role( + self, + app: FastAPI, + client: AsyncClient, + test_permission: Permission, + db_session: AsyncSession + ) -> None: + + rbac_repo = RbacRepository(db_session) + role_in_db = await rbac_repo.get_role_by_name("User") + + response = await client.put( + app.url_path_for( + "add_permission_to_role", + role_id=role_in_db.role_id, + permission_id=str(test_permission.permission_id) + ) + ) + assert response.status_code == status.HTTP_204_NO_CONTENT + permissions = await rbac_repo.get_role_permissions(role_in_db.role_id) + assert len(permissions) == 5 # 4 default permissions + 1 custom permission + + async def test_get_role_permissions( + self, + app: FastAPI, + client: AsyncClient, + db_session: AsyncSession + ) -> None: + + rbac_repo = RbacRepository(db_session) + role_in_db = await rbac_repo.get_role_by_name("User") + + response = await client.get( + app.url_path_for( + "get_role_permissions", + role_id=role_in_db.role_id) + ) + assert response.status_code == status.HTTP_200_OK + assert len(response.json()) == 5 # 4 default permissions + 1 custom permission + + async def test_remove_role_from_group( + self, + app: FastAPI, + client: AsyncClient, + test_permission: Permission, + db_session: AsyncSession + ) -> None: + + rbac_repo = RbacRepository(db_session) + role_in_db = await rbac_repo.get_role_by_name("User") + + response = await client.delete( + app.url_path_for( + "remove_permission_from_role", + role_id=role_in_db.role_id, + permission_id=str(test_permission.permission_id) + ), + ) + assert response.status_code == status.HTTP_204_NO_CONTENT + permissions = await rbac_repo.get_role_permissions(role_in_db.role_id) + assert len(permissions) == 4 # 4 default permissions diff --git a/tests/api/routes/controller/test_users.py b/tests/api/routes/controller/test_users.py index f1a5dc41..ce0d8801 100644 --- a/tests/api/routes/controller/test_users.py +++ b/tests/api/routes/controller/test_users.py @@ -25,9 +25,12 @@ from jose import jwt from sqlalchemy.ext.asyncio import AsyncSession from gns3server.db.repositories.users import UsersRepository +from gns3server.db.repositories.rbac import RbacRepository +from gns3server.schemas.controller.rbac import Permission, HTTPMethods, PermissionAction from gns3server.services import auth_service from gns3server.config import Config from gns3server.schemas.controller.users import User +from gns3server import schemas import gns3server.db.models as models pytestmark = pytest.mark.asyncio @@ -266,7 +269,7 @@ class TestUserMe: test_user: User, ) -> None: - response = await authorized_client.get(app.url_path_for("get_current_active_user")) + response = await authorized_client.get(app.url_path_for("get_logged_in_user")) assert response.status_code == status.HTTP_200_OK user = User(**response.json()) assert user.username == test_user.username @@ -279,7 +282,7 @@ class TestUserMe: test_user: User, ) -> None: - response = await unauthorized_client.get(app.url_path_for("get_current_active_user")) + response = await unauthorized_client.get(app.url_path_for("get_logged_in_user")) assert response.status_code == status.HTTP_401_UNAUTHORIZED @@ -329,15 +332,91 @@ class TestSuperAdmin: response = await unauthorized_client.post(app.url_path_for("login"), data=login_data) assert response.status_code == status.HTTP_200_OK - async def test_super_admin_belongs_to_admin_group( + # async def test_super_admin_belongs_to_admin_group( + # self, + # app: FastAPI, + # client: AsyncClient, + # db_session: AsyncSession + # ) -> None: + # + # user_repo = UsersRepository(db_session) + # admin_in_db = await user_repo.get_user_by_username("admin") + # response = await client.get(app.url_path_for("get_user_memberships", user_id=admin_in_db.user_id)) + # assert response.status_code == status.HTTP_200_OK + # assert len(response.json()) == 1 + +@pytest.fixture +async def test_permission(db_session: AsyncSession) -> Permission: + + new_permission = schemas.PermissionCreate( + methods=[HTTPMethods.get], + path="/statistics", + action=PermissionAction.allow + ) + rbac_repo = RbacRepository(db_session) + existing_permission = await rbac_repo.get_permission_by_path("/statistics") + if existing_permission: + return existing_permission + return await rbac_repo.create_permission(new_permission) + + +class TestUserPermissionsRoutes: + + async def test_add_permission_to_user( self, app: FastAPI, client: AsyncClient, + test_user: User, + test_permission: Permission, db_session: AsyncSession ) -> None: - user_repo = UsersRepository(db_session) - admin_in_db = await user_repo.get_user_by_username("admin") - response = await client.get(app.url_path_for("get_user_memberships", user_id=admin_in_db.user_id)) + response = await client.put( + app.url_path_for( + "add_permission_to_user", + user_id=str(test_user.user_id), + permission_id=str(test_permission.permission_id) + ) + ) + assert response.status_code == status.HTTP_204_NO_CONTENT + rbac_repo = RbacRepository(db_session) + permissions = await rbac_repo.get_user_permissions(test_user.user_id) + assert len(permissions) == 1 + assert permissions[0].permission_id == test_permission.permission_id + + async def test_get_user_permissions( + self, + app: FastAPI, + client: AsyncClient, + test_user: User, + db_session: AsyncSession + ) -> None: + + response = await client.get( + app.url_path_for( + "get_user_permissions", + user_id=str(test_user.user_id)) + ) assert response.status_code == status.HTTP_200_OK assert len(response.json()) == 1 + + async def test_remove_permission_from_user( + self, + app: FastAPI, + client: AsyncClient, + test_user: User, + test_permission: Permission, + db_session: AsyncSession + ) -> None: + + response = await client.delete( + app.url_path_for( + "remove_permission_from_user", + user_id=str(test_user.user_id), + permission_id=str(test_permission.permission_id) + ), + ) + assert response.status_code == status.HTTP_204_NO_CONTENT + rbac_repo = RbacRepository(db_session) + permissions = await rbac_repo.get_user_permissions(test_user.user_id) + assert len(permissions) == 0 diff --git a/tests/conftest.py b/tests/conftest.py index 669f777d..5cd3fa43 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -124,7 +124,12 @@ async def test_user(db_session: AsyncSession) -> User: existing_user = await user_repo.get_user_by_username(new_user.username) if existing_user: return existing_user - return await user_repo.create_user(new_user) + user = await user_repo.create_user(new_user) + + # add new user to "Users group + group = await user_repo.get_user_group_by_name("Users") + await user_repo.add_member_to_user_group(group.user_group_id, user) + return user @pytest.fixture diff --git a/tests/controller/test_rbac.py b/tests/controller/test_rbac.py new file mode 100644 index 00000000..faa4e6df --- /dev/null +++ b/tests/controller/test_rbac.py @@ -0,0 +1,203 @@ +#!/usr/bin/env python +# +# Copyright (C) 2021 GNS3 Technologies Inc. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import pytest + +from fastapi import FastAPI, status +from httpx import AsyncClient + +from sqlalchemy.ext.asyncio import AsyncSession +from gns3server.db.repositories.rbac import RbacRepository +from gns3server.db.models import User + +pytestmark = pytest.mark.asyncio + + +class TestPermissions: + + @pytest.mark.parametrize( + "method, path, result", + ( + ("GET", "/users", False), + ("GET", "/projects", True), + ("GET", "/projects/e451ad73-2519-4f83-87fe-a8e821792d44", False), + ("POST", "/projects", True), + ("GET", "/templates", True), + ("GET", "/templates/62e92cf1-244a-4486-8dae-b95439b54da9", False), + ("POST", "/templates", True), + ("GET", "/computes", True), + ("GET", "/computes/local", True), + ("GET", "/symbols", True), + ("GET", "/symbols/default_symbols", True), + ), + ) + async def test_default_permissions_user_group( + self, + app: FastAPI, + authorized_client: AsyncClient, + test_user: User, + db_session: AsyncSession, + method: str, + path: str, + result: bool + ) -> None: + + rbac_repo = RbacRepository(db_session) + authorized = await rbac_repo.check_user_is_authorized(test_user.user_id, method, path) + assert authorized == result + + +class TestProjectsWithRbac: + + async def test_admin_create_project(self, app: FastAPI, client: AsyncClient): + + params = {"name": "Admin project"} + response = await client.post(app.url_path_for("create_project"), json=params) + assert response.status_code == status.HTTP_201_CREATED + + async def test_user_only_access_own_projects( + self, + app: FastAPI, + authorized_client: AsyncClient, + test_user: User, + db_session: AsyncSession + ) -> None: + + params = {"name": "User project"} + response = await authorized_client.post(app.url_path_for("create_project"), json=params) + assert response.status_code == status.HTTP_201_CREATED + project_id = response.json()["project_id"] + + rbac_repo = RbacRepository(db_session) + permissions_in_db = await rbac_repo.get_user_permissions(test_user.user_id) + assert len(permissions_in_db) == 1 + assert permissions_in_db[0].path == f"/projects/{project_id}/*" + + response = await authorized_client.get(app.url_path_for("get_projects")) + assert response.status_code == status.HTTP_200_OK + projects = response.json() + assert len(projects) == 1 + + async def test_admin_access_all_projects(self, app: FastAPI, client: AsyncClient): + + response = await client.get(app.url_path_for("get_projects")) + assert response.status_code == status.HTTP_200_OK + projects = response.json() + assert len(projects) == 2 + + async def test_admin_user_give_permission_on_project( + self, + app: FastAPI, + client: AsyncClient, + test_user: User + ): + + response = await client.get(app.url_path_for("get_projects")) + assert response.status_code == status.HTTP_200_OK + projects = response.json() + project_id = None + for project in projects: + if project["name"] == "Admin project": + project_id = project["project_id"] + break + + new_permission = { + "methods": ["GET"], + "path": f"/projects/{project_id}", + "action": "ALLOW" + } + response = await client.post(app.url_path_for("create_permission"), json=new_permission) + assert response.status_code == status.HTTP_201_CREATED + permission_id = response.json()["permission_id"] + + response = await client.put( + app.url_path_for( + "add_permission_to_user", + user_id=test_user.user_id, + permission_id=permission_id + ) + ) + assert response.status_code == status.HTTP_204_NO_CONTENT + + async def test_user_access_admin_project( + self, + app: FastAPI, + authorized_client: AsyncClient, + test_user: User, + db_session: AsyncSession + ) -> None: + + response = await authorized_client.get(app.url_path_for("get_projects")) + assert response.status_code == status.HTTP_200_OK + projects = response.json() + assert len(projects) == 2 + + +class TestTemplatesWithRbac: + + async def test_admin_create_template(self, app: FastAPI, client: AsyncClient): + + new_template = {"base_script_file": "vpcs_base_config.txt", + "category": "guest", + "console_auto_start": False, + "console_type": "telnet", + "default_name_format": "PC{0}", + "name": "ADMIN_VPCS_TEMPLATE", + "compute_id": "local", + "symbol": ":/symbols/vpcs_guest.svg", + "template_type": "vpcs"} + + response = await client.post(app.url_path_for("create_template"), json=new_template) + assert response.status_code == status.HTTP_201_CREATED + + async def test_user_only_access_own_templates( + self, app: FastAPI, + authorized_client: AsyncClient, + test_user: User, + db_session: AsyncSession + ) -> None: + + new_template = {"base_script_file": "vpcs_base_config.txt", + "category": "guest", + "console_auto_start": False, + "console_type": "telnet", + "default_name_format": "PC{0}", + "name": "USER_VPCS_TEMPLATE", + "compute_id": "local", + "symbol": ":/symbols/vpcs_guest.svg", + "template_type": "vpcs"} + + response = await authorized_client.post(app.url_path_for("create_template"), json=new_template) + assert response.status_code == status.HTTP_201_CREATED + template_id = response.json()["template_id"] + + rbac_repo = RbacRepository(db_session) + permissions_in_db = await rbac_repo.get_user_permissions(test_user.user_id) + assert len(permissions_in_db) == 1 + assert permissions_in_db[0].path == f"/templates/{template_id}/*" + + response = await authorized_client.get(app.url_path_for("get_templates")) + assert response.status_code == status.HTTP_200_OK + templates = [template for template in response.json() if template["builtin"] is False] + assert len(templates) == 1 + + async def test_admin_access_all_templates(self, app: FastAPI, client: AsyncClient): + + response = await client.get(app.url_path_for("get_templates")) + assert response.status_code == status.HTTP_200_OK + templates = [template for template in response.json() if template["builtin"] is False] + assert len(templates) == 2