1
0
mirror of https://github.com/GNS3/gns3-server synced 2025-01-27 16:31:02 +00:00

Use an ACL table to check for privileges

This commit is contained in:
grossmj 2023-08-27 18:20:42 +10:00
parent 6bd855b3c5
commit 60ce1172e0
30 changed files with 1195 additions and 1423 deletions

View File

@ -32,7 +32,7 @@ from . import images
from . import users from . import users
from . import groups from . import groups
from . import roles from . import roles
from . import permissions from . import acl
from .dependencies.authentication import get_current_active_user from .dependencies.authentication import get_current_active_user
@ -56,10 +56,10 @@ router.include_router(
) )
router.include_router( router.include_router(
permissions.router, acl.router,
dependencies=[Depends(get_current_active_user)], dependencies=[Depends(get_current_active_user)],
prefix="/permissions", prefix="/acl",
tags=["Permissions"] tags=["ACL"]
) )
router.include_router( router.include_router(

View File

@ -0,0 +1,145 @@
#!/usr/bin/env python
#
# Copyright (C) 2023 GNS3 Technologies Inc.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
API routes for ACL.
"""
import re
from fastapi import APIRouter, Depends, Request, status
from fastapi.routing import APIRoute
from uuid import UUID
from typing import List
from gns3server import schemas
from gns3server.controller.controller_error import (
ControllerBadRequestError,
ControllerNotFoundError,
ControllerForbiddenError,
)
from gns3server.db.repositories.rbac import RbacRepository
from .dependencies.database import get_repository
from .dependencies.authentication import get_current_active_user
import logging
log = logging.getLogger(__name__)
router = APIRouter()
@router.get("", response_model=List[schemas.ACE])
async def get_aces(
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository))
) -> List[schemas.ACE]:
"""
Get all ACL entries.
"""
return await rbac_repo.get_aces()
@router.post("", response_model=schemas.ACE, status_code=status.HTTP_201_CREATED)
async def create_ace(
request: Request,
ace_create: schemas.ACECreate,
current_user: schemas.User = Depends(get_current_active_user),
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository))
) -> schemas.ACE:
"""
Create a new ACL entry.
"""
for route in request.app.routes:
if isinstance(route, APIRoute):
# remove the prefix (e.g. "/v3") from the route path
route_path = re.sub(r"^/v[0-9]", "", route.path)
# replace route path ID parameters by a UUID regex
route_path = re.sub(r"{\w+_id}", "[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}", route_path)
# replace remaining route path parameters by a word matching regex
route_path = re.sub(r"/{[\w:]+}", r"/\\w+", route_path)
if re.fullmatch(route_path, ace_create.path):
log.info("Creating ACE for route path", ace_create.path, route_path)
return await rbac_repo.create_ace(ace_create)
raise ControllerBadRequestError(f"Path '{ace_create.path}' doesn't match any existing endpoint")
@router.get("/{ace_id}", response_model=schemas.ACE)
async def get_ace(
ace_id: UUID,
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository)),
) -> schemas.ACE:
"""
Get an ACL entry.
"""
ace = await rbac_repo.get_ace(ace_id)
if not ace:
raise ControllerNotFoundError(f"ACL entry '{ace_id}' not found")
return ace
@router.put("/{ace_id}", response_model=schemas.ACE)
async def update_ace(
ace_id: UUID,
ace_update: schemas.ACEUpdate,
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository))
) -> schemas.ACE:
"""
Update an ACL entry.
"""
ace = await rbac_repo.get_ace(ace_id)
if not ace:
raise ControllerNotFoundError(f"ACL entry '{ace_id}' not found")
return await rbac_repo.update_ace(ace_id, ace_update)
@router.delete("/{ace_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_ace(
ace_id: UUID,
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository)),
) -> None:
"""
Delete an ACL entry.
"""
ace = await rbac_repo.get_ace(ace_id)
if not ace:
raise ControllerNotFoundError(f"ACL entry '{ace_id}' not found")
success = await rbac_repo.delete_ace(ace_id)
if not success:
raise ControllerNotFoundError(f"ACL entry '{ace_id}' could not be deleted")
# @router.post("/prune", status_code=status.HTTP_204_NO_CONTENT)
# async def prune_permissions(
# rbac_repo: RbacRepository = Depends(get_repository(RbacRepository))
# ) -> None:
# """
# Prune orphaned permissions.
# """
#
# await rbac_repo.prune_permissions()

View File

@ -74,21 +74,6 @@ async def get_current_active_user(
headers={"WWW-Authenticate": "Bearer"}, headers={"WWW-Authenticate": "Bearer"},
) )
# remove the prefix (e.g. "/v3") from URL path
path = re.sub(r"^/v[0-9]", "", request.url.path)
# special case: always authorize access to the "/users/me" endpoint
if path == "/users/me":
return current_user
authorized = await rbac_repo.check_user_is_authorized(current_user.user_id, request.method, path)
if not authorized:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"User is not authorized '{current_user.user_id}' on {request.method} '{path}'",
headers={"WWW-Authenticate": "Bearer"},
)
return current_user return current_user
@ -96,7 +81,6 @@ async def get_current_active_user_from_websocket(
websocket: WebSocket, websocket: WebSocket,
token: str = Query(...), token: str = Query(...),
user_repo: UsersRepository = Depends(get_repository(UsersRepository)), user_repo: UsersRepository = Depends(get_repository(UsersRepository)),
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository))
) -> Optional[schemas.User]: ) -> Optional[schemas.User]:
await websocket.accept() await websocket.accept()
@ -121,18 +105,6 @@ async def get_current_active_user_from_websocket(
detail=f"'{username}' is not an active user" detail=f"'{username}' is not an active user"
) )
# remove the prefix (e.g. "/v3") from URL path
path = re.sub(r"^/v[0-9]", "", websocket.url.path)
# there are no HTTP methods for web sockets, assuming "GET"...
authorized = await rbac_repo.check_user_is_authorized(user.user_id, "GET", path)
if not authorized:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"User is not authorized '{user.user_id}' on '{path}'",
headers={"WWW-Authenticate": "Bearer"},
)
return user return user
except HTTPException as e: except HTTPException as e:

View File

@ -0,0 +1,78 @@
#
# Copyright (C) 2023 GNS3 Technologies Inc.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import re
from fastapi import Request, WebSocket, Depends, HTTPException
from gns3server import schemas
from gns3server.db.repositories.rbac import RbacRepository
from .authentication import get_current_active_user, get_current_active_user_from_websocket
from .database import get_repository
import logging
log = logging.getLogger()
def has_privilege(
privilege_name: str
):
async def get_user_and_check_privilege(
request: Request,
current_user: schemas.User = Depends(get_current_active_user),
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository))
):
if not current_user.is_superadmin:
path = re.sub(r"^/v[0-9]", "", request.url.path) # remove the prefix (e.g. "/v3") from URL path
print(f"Checking user {current_user.username} has privilege {privilege_name} on '{path}'")
if not await rbac_repo.check_user_has_privilege(current_user.user_id, path, privilege_name):
raise HTTPException(status_code=403, detail=f"Permission denied (privilege {privilege_name} is required)")
return current_user
return get_user_and_check_privilege
def has_privilege_on_websocket(
privilege_name: str
):
async def get_user_and_check_privilege(
websocket: WebSocket,
current_user: schemas.User = Depends(get_current_active_user_from_websocket),
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository))
):
if not current_user.is_superadmin:
path = re.sub(r"^/v[0-9]", "", websocket.url.path) # remove the prefix (e.g. "/v3") from URL path
log.debug(f"Checking user {current_user.username} has privilege {privilege_name} on '{path}'")
if not await rbac_repo.check_user_has_privilege(current_user.user_id, path, privilege_name):
raise HTTPException(status_code=403, detail=f"Permission denied (privilege {privilege_name} is required)")
return current_user
return get_user_and_check_privilege
# class PrivilegeChecker:
#
# def __init__(self, required_privilege: str) -> None:
# self._required_privilege = required_privilege
#
# async def __call__(
# self,
# current_user: schemas.User = Depends(get_current_active_user),
# rbac_repo: RbacRepository = Depends(get_repository(RbacRepository))
# ) -> bool:
#
# if not await rbac_repo.check_user_has_privilege(current_user.user_id, "/projects", self._required_privilege):
# raise HTTPException(status_code=403, detail=f"Permission denied (privilege {self._required_privilege} is required)")
# return True
# Depends(PrivilegeChecker("Project.Audit"))

View File

@ -78,7 +78,7 @@ async def get_user_group(
users_repo: UsersRepository = Depends(get_repository(UsersRepository)), users_repo: UsersRepository = Depends(get_repository(UsersRepository)),
) -> schemas.UserGroup: ) -> schemas.UserGroup:
""" """
Get an user group. Get a user group.
""" """
user_group = await users_repo.get_user_group(user_group_id) user_group = await users_repo.get_user_group(user_group_id)
@ -94,7 +94,7 @@ async def update_user_group(
users_repo: UsersRepository = Depends(get_repository(UsersRepository)) users_repo: UsersRepository = Depends(get_repository(UsersRepository))
) -> schemas.UserGroup: ) -> schemas.UserGroup:
""" """
Update an user group. Update a user group.
""" """
user_group = await users_repo.get_user_group(user_group_id) user_group = await users_repo.get_user_group(user_group_id)
if not user_group: if not user_group:
@ -115,7 +115,7 @@ async def delete_user_group(
users_repo: UsersRepository = Depends(get_repository(UsersRepository)), users_repo: UsersRepository = Depends(get_repository(UsersRepository)),
) -> None: ) -> None:
""" """
Delete an user group Delete a user group
""" """
user_group = await users_repo.get_user_group(user_group_id) user_group = await users_repo.get_user_group(user_group_id)
@ -152,7 +152,7 @@ async def add_member_to_group(
users_repo: UsersRepository = Depends(get_repository(UsersRepository)) users_repo: UsersRepository = Depends(get_repository(UsersRepository))
) -> None: ) -> None:
""" """
Add member to an user group. Add member to a user group.
""" """
user = await users_repo.get_user(user_id) user = await users_repo.get_user(user_id)
@ -174,7 +174,7 @@ async def remove_member_from_group(
users_repo: UsersRepository = Depends(get_repository(UsersRepository)), users_repo: UsersRepository = Depends(get_repository(UsersRepository)),
) -> None: ) -> None:
""" """
Remove member from an user group. Remove member from a user group.
""" """
user = await users_repo.get_user(user_id) user = await users_repo.get_user(user_id)
@ -184,61 +184,3 @@ async def remove_member_from_group(
user_group = await users_repo.remove_member_from_user_group(user_group_id, user) user_group = await users_repo.remove_member_from_user_group(user_group_id, user)
if not user_group: if not user_group:
raise ControllerNotFoundError(f"User group '{user_group_id}' not found") raise ControllerNotFoundError(f"User group '{user_group_id}' not found")
@router.get("/{user_group_id}/roles", response_model=List[schemas.Role])
async def get_user_group_roles(
user_group_id: UUID,
users_repo: UsersRepository = Depends(get_repository(UsersRepository))
) -> List[schemas.Role]:
"""
Get all user group roles.
"""
return await users_repo.get_user_group_roles(user_group_id)
@router.put(
"/{user_group_id}/roles/{role_id}",
status_code=status.HTTP_204_NO_CONTENT
)
async def add_role_to_group(
user_group_id: UUID,
role_id: UUID,
users_repo: UsersRepository = Depends(get_repository(UsersRepository)),
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository))
) -> Response:
"""
Add role to an user group.
"""
role = await rbac_repo.get_role(role_id)
if not role:
raise ControllerNotFoundError(f"Role '{role_id}' not found")
user_group = await users_repo.add_role_to_user_group(user_group_id, role)
if not user_group:
raise ControllerNotFoundError(f"User group '{user_group_id}' not found")
@router.delete(
"/{user_group_id}/roles/{role_id}",
status_code=status.HTTP_204_NO_CONTENT
)
async def remove_role_from_group(
user_group_id: UUID,
role_id: UUID,
users_repo: UsersRepository = Depends(get_repository(UsersRepository)),
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository))
) -> None:
"""
Remove role from an user group.
"""
role = await rbac_repo.get_role(role_id)
if not role:
raise ControllerNotFoundError(f"Role '{role_id}' not found")
user_group = await users_repo.remove_role_from_user_group(user_group_id, role)
if not user_group:
raise ControllerNotFoundError(f"User group '{user_group_id}' not found")

View File

@ -1,161 +0,0 @@
#!/usr/bin/env python
#
# Copyright (C) 2021 GNS3 Technologies Inc.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
API routes for permissions.
"""
import re
from fastapi import APIRouter, Depends, Response, Request, status
from fastapi.routing import APIRoute
from uuid import UUID
from typing import List
from gns3server import schemas
from gns3server.controller.controller_error import (
ControllerBadRequestError,
ControllerNotFoundError,
ControllerForbiddenError,
)
from gns3server.db.repositories.rbac import RbacRepository
from .dependencies.database import get_repository
from .dependencies.authentication import get_current_active_user
import logging
log = logging.getLogger(__name__)
router = APIRouter()
@router.get("", response_model=List[schemas.Permission])
async def get_permissions(
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository))
) -> List[schemas.Permission]:
"""
Get all permissions.
"""
return await rbac_repo.get_permissions()
@router.post("", response_model=schemas.Permission, status_code=status.HTTP_201_CREATED)
async def create_permission(
request: Request,
permission_create: schemas.PermissionCreate,
current_user: schemas.User = Depends(get_current_active_user),
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository))
) -> schemas.Permission:
"""
Create a new permission.
"""
# TODO: should we prevent having multiple permissions with same methods/path?
#if await rbac_repo.check_permission_exists(permission_create):
# raise ControllerBadRequestError(f"Permission '{permission_create.methods} {permission_create.path} "
# f"{permission_create.action}' already exists")
for route in request.app.routes:
if isinstance(route, APIRoute):
# remove the prefix (e.g. "/v3") from the route path
route_path = re.sub(r"^/v[0-9]", "", route.path)
# replace route path ID parameters by an UUID regex
route_path = re.sub(r"{\w+_id}", "[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}", route_path)
# replace remaining route path parameters by an word matching regex
route_path = re.sub(r"/{[\w:]+}", r"/\\w+", route_path)
# the permission can match multiple routes
if permission_create.path.endswith("/*"):
route_path += r"/.*"
if re.fullmatch(route_path, permission_create.path):
for method in permission_create.methods:
if method in list(route.methods):
# check user has the right to add the permission (i.e has already to right on the path)
if not await rbac_repo.check_user_is_authorized(current_user.user_id, method, permission_create.path):
raise ControllerForbiddenError(f"User '{current_user.username}' doesn't have the rights to "
f"add a permission on {method} {permission_create.path} or "
f"the endpoint doesn't exist")
return await rbac_repo.create_permission(permission_create)
raise ControllerBadRequestError(f"Permission '{permission_create.methods} {permission_create.path}' "
f"doesn't match any existing endpoint")
@router.get("/{permission_id}", response_model=schemas.Permission)
async def get_permission(
permission_id: UUID,
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository)),
) -> schemas.Permission:
"""
Get a permission.
"""
permission = await rbac_repo.get_permission(permission_id)
if not permission:
raise ControllerNotFoundError(f"Permission '{permission_id}' not found")
return permission
@router.put("/{permission_id}", response_model=schemas.Permission)
async def update_permission(
permission_id: UUID,
permission_update: schemas.PermissionUpdate,
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository))
) -> schemas.Permission:
"""
Update a permission.
"""
permission = await rbac_repo.get_permission(permission_id)
if not permission:
raise ControllerNotFoundError(f"Permission '{permission_id}' not found")
return await rbac_repo.update_permission(permission_id, permission_update)
@router.delete("/{permission_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_permission(
permission_id: UUID,
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository)),
) -> None:
"""
Delete a permission.
"""
permission = await rbac_repo.get_permission(permission_id)
if not permission:
raise ControllerNotFoundError(f"Permission '{permission_id}' not found")
success = await rbac_repo.delete_permission(permission_id)
if not success:
raise ControllerNotFoundError(f"Permission '{permission_id}' could not be deleted")
@router.post("/prune", status_code=status.HTTP_204_NO_CONTENT)
async def prune_permissions(
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository))
) -> None:
"""
Prune orphaned permissions.
"""
await rbac_repo.prune_permissions()

View File

@ -49,7 +49,8 @@ from gns3server.db.repositories.rbac import RbacRepository
from gns3server.db.repositories.templates import TemplatesRepository from gns3server.db.repositories.templates import TemplatesRepository
from gns3server.services.templates import TemplatesService from gns3server.services.templates import TemplatesService
from .dependencies.authentication import get_current_active_user, get_current_active_user_from_websocket from .dependencies.authentication import get_current_active_user
from .dependencies.rbac import has_privilege, has_privilege_on_websocket
from .dependencies.database import get_repository from .dependencies.database import get_repository
responses = {404: {"model": schemas.ErrorMessage, "description": "Could not find project"}} responses = {404: {"model": schemas.ErrorMessage, "description": "Could not find project"}}
@ -84,9 +85,11 @@ async def get_projects(
else: else:
user_projects = [] user_projects = []
for project in controller.projects.values(): for project in controller.projects.values():
authorized = await rbac_repo.check_user_is_authorized( if await rbac_repo.check_user_has_privilege(
current_user.user_id, "GET", f"/projects/{project.id}") current_user.user_id,
if authorized: f"/projects/{project.id}",
"Project.Audit"
):
user_projects.append(project.asdict()) user_projects.append(project.asdict())
return user_projects return user_projects
@ -97,11 +100,10 @@ async def get_projects(
response_model=schemas.Project, response_model=schemas.Project,
response_model_exclude_unset=True, response_model_exclude_unset=True,
responses={409: {"model": schemas.ErrorMessage, "description": "Could not create project"}}, responses={409: {"model": schemas.ErrorMessage, "description": "Could not create project"}},
dependencies=[Depends(has_privilege("Project.Allocate"))]
) )
async def create_project( async def create_project(
project_data: schemas.ProjectCreate, project_data: schemas.ProjectCreate,
current_user: schemas.User = Depends(get_current_active_user),
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository))
) -> schemas.Project: ) -> schemas.Project:
""" """
Create a new project. Create a new project.
@ -109,12 +111,11 @@ async def create_project(
controller = Controller.instance() controller = Controller.instance()
project = await controller.add_project(**jsonable_encoder(project_data, exclude_unset=True)) project = await controller.add_project(**jsonable_encoder(project_data, exclude_unset=True))
await rbac_repo.add_permission_to_user_with_path(current_user.user_id, f"/projects/{project.id}/*")
return project.asdict() return project.asdict()
@router.get("/{project_id}", response_model=schemas.Project, dependencies=[Depends(get_current_active_user)]) @router.get("/{project_id}", response_model=schemas.Project, dependencies=[Depends(has_privilege("Project.Audit"))])
def get_project(project: Project = Depends(dep_project)) -> schemas.Project: async def get_project(project: Project = Depends(dep_project)) -> schemas.Project:
""" """
Return a project. Return a project.
""" """
@ -126,7 +127,7 @@ def get_project(project: Project = Depends(dep_project)) -> schemas.Project:
"/{project_id}", "/{project_id}",
response_model=schemas.Project, response_model=schemas.Project,
response_model_exclude_unset=True, response_model_exclude_unset=True,
dependencies=[Depends(get_current_active_user)] dependencies=[Depends(has_privilege("Project.Modify"))]
) )
async def update_project( async def update_project(
project_data: schemas.ProjectUpdate, project_data: schemas.ProjectUpdate,
@ -143,11 +144,10 @@ async def update_project(
@router.delete( @router.delete(
"/{project_id}", "/{project_id}",
status_code=status.HTTP_204_NO_CONTENT, status_code=status.HTTP_204_NO_CONTENT,
dependencies=[Depends(get_current_active_user)] dependencies=[Depends(has_privilege("Project.Allocate"))]
) )
async def delete_project( async def delete_project(
project: Project = Depends(dep_project), project: Project = Depends(dep_project)
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository))
) -> None: ) -> None:
""" """
Delete a project. Delete a project.
@ -156,10 +156,9 @@ async def delete_project(
controller = Controller.instance() controller = Controller.instance()
await project.delete() await project.delete()
controller.remove_project(project) controller.remove_project(project)
await rbac_repo.delete_all_permissions_with_path(f"/projects/{project.id}")
@router.get("/{project_id}/stats", dependencies=[Depends(get_current_active_user)]) @router.get("/{project_id}/stats", dependencies=[Depends(has_privilege("Project.Audit"))])
def get_project_stats(project: Project = Depends(dep_project)) -> dict: def get_project_stats(project: Project = Depends(dep_project)) -> dict:
""" """
Return a project statistics. Return a project statistics.
@ -172,7 +171,7 @@ def get_project_stats(project: Project = Depends(dep_project)) -> dict:
"/{project_id}/close", "/{project_id}/close",
status_code=status.HTTP_204_NO_CONTENT, status_code=status.HTTP_204_NO_CONTENT,
responses={**responses, 409: {"model": schemas.ErrorMessage, "description": "Could not close project"}}, responses={**responses, 409: {"model": schemas.ErrorMessage, "description": "Could not close project"}},
dependencies=[Depends(get_current_active_user)] dependencies=[Depends(has_privilege("Project.Allocate"))]
) )
async def close_project(project: Project = Depends(dep_project)) -> None: async def close_project(project: Project = Depends(dep_project)) -> None:
""" """
@ -187,7 +186,7 @@ async def close_project(project: Project = Depends(dep_project)) -> None:
status_code=status.HTTP_201_CREATED, status_code=status.HTTP_201_CREATED,
response_model=schemas.Project, response_model=schemas.Project,
responses={**responses, 409: {"model": schemas.ErrorMessage, "description": "Could not open project"}}, responses={**responses, 409: {"model": schemas.ErrorMessage, "description": "Could not open project"}},
dependencies=[Depends(get_current_active_user)] dependencies=[Depends(has_privilege("Project.Allocate"))]
) )
async def open_project(project: Project = Depends(dep_project)) -> schemas.Project: async def open_project(project: Project = Depends(dep_project)) -> schemas.Project:
""" """
@ -203,7 +202,7 @@ async def open_project(project: Project = Depends(dep_project)) -> schemas.Proje
status_code=status.HTTP_201_CREATED, status_code=status.HTTP_201_CREATED,
response_model=schemas.Project, response_model=schemas.Project,
responses={**responses, 409: {"model": schemas.ErrorMessage, "description": "Could not load project"}}, responses={**responses, 409: {"model": schemas.ErrorMessage, "description": "Could not load project"}},
dependencies=[Depends(get_current_active_user)] dependencies=[Depends(has_privilege("Project.Allocate"))]
) )
async def load_project(path: str = Body(..., embed=True)) -> schemas.Project: async def load_project(path: str = Body(..., embed=True)) -> schemas.Project:
""" """
@ -216,7 +215,7 @@ async def load_project(path: str = Body(..., embed=True)) -> schemas.Project:
return project.asdict() return project.asdict()
@router.get("/{project_id}/notifications", dependencies=[Depends(get_current_active_user)]) @router.get("/{project_id}/notifications", dependencies=[Depends(has_privilege("Project.Audit"))])
async def project_http_notifications(project_id: UUID) -> StreamingResponse: async def project_http_notifications(project_id: UUID) -> StreamingResponse:
""" """
Receive project notifications about the controller from HTTP stream. Receive project notifications about the controller from HTTP stream.
@ -252,7 +251,7 @@ async def project_http_notifications(project_id: UUID) -> StreamingResponse:
async def project_ws_notifications( async def project_ws_notifications(
project_id: UUID, project_id: UUID,
websocket: WebSocket, websocket: WebSocket,
current_user: schemas.User = Depends(get_current_active_user_from_websocket) current_user: schemas.User = Depends(has_privilege_on_websocket("Project.Audit"))
) -> None: ) -> None:
""" """
Receive project notifications about the controller from WebSocket. Receive project notifications about the controller from WebSocket.
@ -288,7 +287,7 @@ async def project_ws_notifications(
await project.close() await project.close()
@router.get("/{project_id}/export", dependencies=[Depends(get_current_active_user)]) @router.get("/{project_id}/export", dependencies=[Depends(has_privilege("Project.Audit"))])
async def export_project( async def export_project(
project: Project = Depends(dep_project), project: Project = Depends(dep_project),
include_snapshots: bool = False, include_snapshots: bool = False,
@ -345,7 +344,7 @@ async def export_project(
log.info(f"Project '{project.name}' exported in {time.time() - begin:.4f} seconds") log.info(f"Project '{project.name}' exported in {time.time() - begin:.4f} seconds")
# Will be raise if you have no space left or permission issue on your temporary directory # Will be raised if you have no space left or permission issue on your temporary directory
# RuntimeError: something was wrong during the zip process # RuntimeError: something was wrong during the zip process
except (ValueError, OSError, RuntimeError) as e: except (ValueError, OSError, RuntimeError) as e:
raise ConnectionError(f"Cannot export project: {e}") raise ConnectionError(f"Cannot export project: {e}")
@ -358,7 +357,7 @@ async def export_project(
"/{project_id}/import", "/{project_id}/import",
status_code=status.HTTP_201_CREATED, status_code=status.HTTP_201_CREATED,
response_model=schemas.Project, response_model=schemas.Project,
dependencies=[Depends(get_current_active_user)] dependencies=[Depends(has_privilege("Project.Allocate"))]
) )
async def import_project( async def import_project(
project_id: UUID, project_id: UUID,
@ -394,13 +393,11 @@ async def import_project(
status_code=status.HTTP_201_CREATED, status_code=status.HTTP_201_CREATED,
response_model=schemas.Project, response_model=schemas.Project,
responses={**responses, 409: {"model": schemas.ErrorMessage, "description": "Could not duplicate project"}}, responses={**responses, 409: {"model": schemas.ErrorMessage, "description": "Could not duplicate project"}},
dependencies=[Depends(get_current_active_user)] dependencies=[Depends(has_privilege("Project.Allocate"))]
) )
async def duplicate_project( async def duplicate_project(
project_data: schemas.ProjectDuplicate, project_data: schemas.ProjectDuplicate,
project: Project = Depends(dep_project), project: Project = Depends(dep_project)
current_user: schemas.User = Depends(get_current_active_user),
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository))
) -> schemas.Project: ) -> schemas.Project:
""" """
Duplicate a project. Duplicate a project.
@ -410,11 +407,10 @@ async def duplicate_project(
new_project = await project.duplicate( new_project = await project.duplicate(
name=project_data.name, reset_mac_addresses=reset_mac_addresses name=project_data.name, reset_mac_addresses=reset_mac_addresses
) )
await rbac_repo.add_permission_to_user_with_path(current_user.user_id, f"/projects/{new_project.id}/*")
return new_project.asdict() return new_project.asdict()
@router.get("/{project_id}/locked", dependencies=[Depends(get_current_active_user)]) @router.get("/{project_id}/locked", dependencies=[Depends(has_privilege("Project.Audit"))])
async def locked_project(project: Project = Depends(dep_project)) -> bool: async def locked_project(project: Project = Depends(dep_project)) -> bool:
""" """
Returns whether a project is locked or not Returns whether a project is locked or not
@ -426,7 +422,7 @@ async def locked_project(project: Project = Depends(dep_project)) -> bool:
@router.post( @router.post(
"/{project_id}/lock", "/{project_id}/lock",
status_code=status.HTTP_204_NO_CONTENT, status_code=status.HTTP_204_NO_CONTENT,
dependencies=[Depends(get_current_active_user)] dependencies=[Depends(has_privilege("Project.Modify"))]
) )
async def lock_project(project: Project = Depends(dep_project)) -> None: async def lock_project(project: Project = Depends(dep_project)) -> None:
""" """
@ -439,7 +435,7 @@ async def lock_project(project: Project = Depends(dep_project)) -> None:
@router.post( @router.post(
"/{project_id}/unlock", "/{project_id}/unlock",
status_code=status.HTTP_204_NO_CONTENT, status_code=status.HTTP_204_NO_CONTENT,
dependencies=[Depends(get_current_active_user)] dependencies=[Depends(has_privilege("Project.Modify"))]
) )
async def unlock_project(project: Project = Depends(dep_project)) -> None: async def unlock_project(project: Project = Depends(dep_project)) -> None:
""" """
@ -449,7 +445,7 @@ async def unlock_project(project: Project = Depends(dep_project)) -> None:
project.unlock() project.unlock()
@router.get("/{project_id}/files/{file_path:path}", dependencies=[Depends(get_current_active_user)]) @router.get("/{project_id}/files/{file_path:path}", dependencies=[Depends(has_privilege("Project.Audit"))])
async def get_file(file_path: str, project: Project = Depends(dep_project)) -> FileResponse: async def get_file(file_path: str, project: Project = Depends(dep_project)) -> FileResponse:
""" """
Return a file from a project. Return a file from a project.
@ -472,7 +468,7 @@ async def get_file(file_path: str, project: Project = Depends(dep_project)) -> F
@router.post( @router.post(
"/{project_id}/files/{file_path:path}", "/{project_id}/files/{file_path:path}",
status_code=status.HTTP_204_NO_CONTENT, status_code=status.HTTP_204_NO_CONTENT,
dependencies=[Depends(get_current_active_user)] dependencies=[Depends(has_privilege("Project.Modify"))]
) )
async def write_file(file_path: str, request: Request, project: Project = Depends(dep_project)) -> None: async def write_file(file_path: str, request: Request, project: Project = Depends(dep_project)) -> None:
""" """
@ -505,7 +501,7 @@ async def write_file(file_path: str, request: Request, project: Project = Depend
response_model=schemas.Node, response_model=schemas.Node,
status_code=status.HTTP_201_CREATED, status_code=status.HTTP_201_CREATED,
responses={404: {"model": schemas.ErrorMessage, "description": "Could not find project or template"}}, responses={404: {"model": schemas.ErrorMessage, "description": "Could not find project or template"}},
dependencies=[Depends(get_current_active_user)] dependencies=[Depends(has_privilege("Node.Allocate"))]
) )
async def create_node_from_template( async def create_node_from_template(
project_id: UUID, project_id: UUID,

View File

@ -123,57 +123,57 @@ async def delete_role(
raise ControllerError(f"Role '{role_id}' could not be deleted") raise ControllerError(f"Role '{role_id}' could not be deleted")
@router.get("/{role_id}/permissions", response_model=List[schemas.Permission]) @router.get("/{role_id}/privileges", response_model=List[schemas.Privilege])
async def get_role_permissions( async def get_role_privileges(
role_id: UUID, role_id: UUID,
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository)) rbac_repo: RbacRepository = Depends(get_repository(RbacRepository))
) -> List[schemas.Permission]: ) -> List[schemas.Privilege]:
""" """
Get all role permissions. Get all role privileges.
""" """
return await rbac_repo.get_role_permissions(role_id) return await rbac_repo.get_role_privileges(role_id)
@router.put( @router.put(
"/{role_id}/permissions/{permission_id}", "/{role_id}/privileges/{privilege_id}",
status_code=status.HTTP_204_NO_CONTENT status_code=status.HTTP_204_NO_CONTENT
) )
async def add_permission_to_role( async def add_privilege_to_role(
role_id: UUID, role_id: UUID,
permission_id: UUID, privilege_id: UUID,
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository)) rbac_repo: RbacRepository = Depends(get_repository(RbacRepository))
) -> None: ) -> None:
""" """
Add a permission to a role. Add a privilege to a role.
""" """
permission = await rbac_repo.get_permission(permission_id) privilege = await rbac_repo.get_privilege(privilege_id)
if not permission: if not privilege:
raise ControllerNotFoundError(f"Permission '{permission_id}' not found") raise ControllerNotFoundError(f"Privilege '{privilege_id}' not found")
role = await rbac_repo.add_permission_to_role(role_id, permission) role = await rbac_repo.add_privilege_to_role(role_id, privilege)
if not role: if not role:
raise ControllerNotFoundError(f"Role '{role_id}' not found") raise ControllerNotFoundError(f"Role '{role_id}' not found")
@router.delete( @router.delete(
"/{role_id}/permissions/{permission_id}", "/{role_id}/privileges/{privilege_id}",
status_code=status.HTTP_204_NO_CONTENT status_code=status.HTTP_204_NO_CONTENT
) )
async def remove_permission_from_role( async def remove_privilege_from_role(
role_id: UUID, role_id: UUID,
permission_id: UUID, privilege_id: UUID,
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository)), rbac_repo: RbacRepository = Depends(get_repository(RbacRepository)),
) -> None: ) -> None:
""" """
Remove member from an user group. Remove privilege from a role.
""" """
permission = await rbac_repo.get_permission(permission_id) privilege = await rbac_repo.get_privilege(privilege_id)
if not permission: if not privilege:
raise ControllerNotFoundError(f"Permission '{permission_id}' not found") raise ControllerNotFoundError(f"Privilege '{privilege_id}' not found")
role = await rbac_repo.remove_permission_from_role(role_id, permission) role = await rbac_repo.remove_privilege_from_role(role_id, privilege)
if not role: if not role:
raise ControllerNotFoundError(f"Role '{role_id}' not found") raise ControllerNotFoundError(f"Role '{role_id}' not found")

View File

@ -46,17 +46,13 @@ router = APIRouter(responses=responses)
@router.post("", response_model=schemas.Template, status_code=status.HTTP_201_CREATED) @router.post("", response_model=schemas.Template, status_code=status.HTTP_201_CREATED)
async def create_template( async def create_template(
template_create: schemas.TemplateCreate, template_create: schemas.TemplateCreate,
templates_repo: TemplatesRepository = Depends(get_repository(TemplatesRepository)), templates_repo: TemplatesRepository = Depends(get_repository(TemplatesRepository))
current_user: schemas.User = Depends(get_current_active_user),
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository))
) -> schemas.Template: ) -> schemas.Template:
""" """
Create a new template. Create a new template.
""" """
template = await TemplatesService(templates_repo).create_template(template_create) template = await TemplatesService(templates_repo).create_template(template_create)
template_id = template.get("template_id")
await rbac_repo.add_permission_to_user_with_path(current_user.user_id, f"/templates/{template_id}/*")
return template return template
@ -108,7 +104,7 @@ async def delete_template(
""" """
await TemplatesService(templates_repo).delete_template(template_id) await TemplatesService(templates_repo).delete_template(template_id)
await rbac_repo.delete_all_permissions_with_path(f"/templates/{template_id}") #await rbac_repo.delete_all_permissions_with_path(f"/templates/{template_id}")
if prune_images: if prune_images:
await images_repo.prune_images() await images_repo.prune_images()
@ -129,27 +125,24 @@ async def get_templates(
else: else:
user_templates = [] user_templates = []
for template in templates: for template in templates:
if template.get("builtin") is True: # if template.get("builtin") is True:
user_templates.append(template) # user_templates.append(template)
continue # continue
template_id = template.get("template_id") # template_id = template.get("template_id")
authorized = await rbac_repo.check_user_is_authorized( # authorized = await rbac_repo.check_user_is_authorized(
current_user.user_id, "GET", f"/templates/{template_id}") # current_user.user_id, "GET", f"/templates/{template_id}")
if authorized: # if authorized:
user_templates.append(template) user_templates.append(template)
return user_templates return user_templates
@router.post("/{template_id}/duplicate", response_model=schemas.Template, status_code=status.HTTP_201_CREATED) @router.post("/{template_id}/duplicate", response_model=schemas.Template, status_code=status.HTTP_201_CREATED)
async def duplicate_template( async def duplicate_template(
template_id: UUID, templates_repo: TemplatesRepository = Depends(get_repository(TemplatesRepository)), template_id: UUID, templates_repo: TemplatesRepository = Depends(get_repository(TemplatesRepository))
current_user: schemas.User = Depends(get_current_active_user),
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository))
) -> schemas.Template: ) -> schemas.Template:
""" """
Duplicate a template. Duplicate a template.
""" """
template = await TemplatesService(templates_repo).duplicate_template(template_id) template = await TemplatesService(templates_repo).duplicate_template(template_id)
await rbac_repo.add_permission_to_user_with_path(current_user.user_id, f"/templates/{template_id}/*")
return template return template

View File

@ -155,7 +155,7 @@ async def get_user(
users_repo: UsersRepository = Depends(get_repository(UsersRepository)), users_repo: UsersRepository = Depends(get_repository(UsersRepository)),
) -> schemas.User: ) -> schemas.User:
""" """
Get an user. Get a user.
""" """
user = await users_repo.get_user(user_id) user = await users_repo.get_user(user_id)
@ -171,7 +171,7 @@ async def update_user(
users_repo: UsersRepository = Depends(get_repository(UsersRepository)) users_repo: UsersRepository = Depends(get_repository(UsersRepository))
) -> schemas.User: ) -> schemas.User:
""" """
Update an user. Update a user.
""" """
if user_update.username and await users_repo.get_user_by_username(user_update.username): if user_update.username and await users_repo.get_user_by_username(user_update.username):
@ -196,7 +196,7 @@ async def delete_user(
users_repo: UsersRepository = Depends(get_repository(UsersRepository)), users_repo: UsersRepository = Depends(get_repository(UsersRepository)),
) -> None: ) -> None:
""" """
Delete an user. Delete a user.
""" """
user = await users_repo.get_user(user_id) user = await users_repo.get_user(user_id)
@ -225,65 +225,3 @@ async def get_user_memberships(
""" """
return await users_repo.get_user_memberships(user_id) return await users_repo.get_user_memberships(user_id)
@router.get(
"/{user_id}/permissions",
dependencies=[Depends(get_current_active_user)],
response_model=List[schemas.Permission]
)
async def get_user_permissions(
user_id: UUID,
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository))
) -> List[schemas.Permission]:
"""
Get user permissions.
"""
return await rbac_repo.get_user_permissions(user_id)
@router.put(
"/{user_id}/permissions/{permission_id}",
dependencies=[Depends(get_current_active_user)],
status_code=status.HTTP_204_NO_CONTENT
)
async def add_permission_to_user(
user_id: UUID,
permission_id: UUID,
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository))
) -> None:
"""
Add a permission to an user.
"""
permission = await rbac_repo.get_permission(permission_id)
if not permission:
raise ControllerNotFoundError(f"Permission '{permission_id}' not found")
user = await rbac_repo.add_permission_to_user(user_id, permission)
if not user:
raise ControllerNotFoundError(f"User '{user_id}' not found")
@router.delete(
"/{user_id}/permissions/{permission_id}",
dependencies=[Depends(get_current_active_user)],
status_code=status.HTTP_204_NO_CONTENT
)
async def remove_permission_from_user(
user_id: UUID,
permission_id: UUID,
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository)),
) -> None:
"""
Remove permission from an user.
"""
permission = await rbac_repo.get_permission(permission_id)
if not permission:
raise ControllerNotFoundError(f"Permission '{permission_id}' not found")
user = await rbac_repo.remove_permission_from_user(user_id, permission)
if not user:
raise ControllerNotFoundError(f"User '{user_id}' not found")

View File

@ -213,8 +213,8 @@ class ApplianceManager:
except ValidationError as e: except ValidationError as e:
raise ControllerError(message=f"Could not validate template data: {e}") raise ControllerError(message=f"Could not validate template data: {e}")
template = await TemplatesService(templates_repo).create_template(template_create) template = await TemplatesService(templates_repo).create_template(template_create)
template_id = template.get("template_id") #template_id = template.get("template_id")
await rbac_repo.add_permission_to_user_with_path(current_user.user_id, f"/templates/{template_id}/*") #await rbac_repo.add_permission_to_user_with_path(current_user.user_id, f"/templates/{template_id}/*")
log.info(f"Template '{template.get('name')}' has been created") log.info(f"Template '{template.get('name')}' has been created")
async def _appliance_to_template(self, appliance: Appliance, version: str = None) -> dict: async def _appliance_to_template(self, appliance: Appliance, version: str = None) -> dict:

View File

@ -16,11 +16,10 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from .base import Base from .base import Base
from .acl import ACL from .acl import ACE
from .resources import Resource
from .users import User, UserGroup from .users import User, UserGroup
from .roles import Role from .roles import Role
from .permissions import Permission from .privileges import Privilege
from .computes import Compute from .computes import Compute
from .images import Image from .images import Image
from .templates import ( from .templates import (

View File

@ -15,7 +15,7 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from sqlalchemy import Column, Boolean, ForeignKey from sqlalchemy import Column, String, Boolean, ForeignKey, CheckConstraint
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
from .base import BaseTable, generate_uuid, GUID from .base import BaseTable, generate_uuid, GUID
@ -25,17 +25,22 @@ import logging
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class ACL(BaseTable): class ACE(BaseTable):
__tablename__ = "acl" __tablename__ = "acl"
acl_id = Column(GUID, primary_key=True, default=generate_uuid) ace_id = Column(GUID, primary_key=True, default=generate_uuid)
path = Column(String)
propagate = Column(Boolean, default=True)
allowed = Column(Boolean, default=True) allowed = Column(Boolean, default=True)
type: str = Column(String)
user_id = Column(GUID, ForeignKey('users.user_id', ondelete="CASCADE")) user_id = Column(GUID, ForeignKey('users.user_id', ondelete="CASCADE"))
user = relationship("User", back_populates="acl_entries") user = relationship("User", back_populates="acl_entries")
group_id = Column(GUID, ForeignKey('user_groups.user_group_id', ondelete="CASCADE")) group_id = Column(GUID, ForeignKey('user_groups.user_group_id', ondelete="CASCADE"))
group = relationship("UserGroup", back_populates="acl_entries") group = relationship("UserGroup", back_populates="acl_entries")
resource_id = Column(GUID, ForeignKey('resources.resource_id', ondelete="CASCADE"))
resource = relationship("Resource", back_populates="acl_entries")
role_id = Column(GUID, ForeignKey('roles.role_id', ondelete="CASCADE")) role_id = Column(GUID, ForeignKey('roles.role_id', ondelete="CASCADE"))
role = relationship("Role", back_populates="acl_entries") role = relationship("Role", back_populates="acl_entries")
__table_args__ = (
CheckConstraint("(user_id IS NOT NULL AND type = 'user') OR (group_id IS NOT NULL AND type = 'group')"),
)

View File

@ -1,129 +0,0 @@
#!/usr/bin/env python
#
# Copyright (C) 2021 GNS3 Technologies Inc.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from sqlalchemy import Table, Column, String, ForeignKey, event
from sqlalchemy.orm import relationship
from .base import Base, BaseTable, generate_uuid, GUID, ListType
import logging
log = logging.getLogger(__name__)
permission_role_map = Table(
"permission_role_map",
Base.metadata,
Column("permission_id", GUID, ForeignKey("permissions.permission_id", ondelete="CASCADE")),
Column("role_id", GUID, ForeignKey("roles.role_id", ondelete="CASCADE"))
)
class Permission(BaseTable):
__tablename__ = "permissions"
permission_id = Column(GUID, primary_key=True, default=generate_uuid)
description = Column(String)
methods = Column(ListType)
path = Column(String)
action = Column(String)
user_id = Column(GUID, ForeignKey('users.user_id', ondelete="CASCADE"))
user = relationship("User", back_populates="permissions")
roles = relationship("Role", secondary=permission_role_map, back_populates="permissions")
@event.listens_for(Permission.__table__, 'after_create')
def create_default_roles(target, connection, **kw):
default_permissions = [
{
"description": "Allow access to all endpoints",
"methods": ["GET", "POST", "PUT", "DELETE"],
"path": "/",
"action": "ALLOW"
},
{
"description": "Allow to receive controller notifications",
"methods": ["GET"],
"path": "/notifications",
"action": "ALLOW"
},
{
"description": "Allow to create and list projects",
"methods": ["GET", "POST"],
"path": "/projects",
"action": "ALLOW"
},
{
"description": "Allow to create and list templates",
"methods": ["GET", "POST"],
"path": "/templates",
"action": "ALLOW"
},
{
"description": "Allow to list computes",
"methods": ["GET"],
"path": "/computes/*",
"action": "ALLOW"
},
{
"description": "Allow access to all symbol endpoints",
"methods": ["GET", "POST"],
"path": "/symbols/*",
"action": "ALLOW"
},
]
stmt = target.insert().values(default_permissions)
connection.execute(stmt)
connection.commit()
log.debug("The default permissions have been created in the database")
@event.listens_for(permission_role_map, 'after_create')
def add_permissions_to_role(target, connection, **kw):
from .roles import Role
roles_table = Role.__table__
stmt = roles_table.select().where(roles_table.c.name == "Administrator")
result = connection.execute(stmt)
role_id = result.first().role_id
permissions_table = Permission.__table__
stmt = permissions_table.select().where(permissions_table.c.path == "/")
result = connection.execute(stmt)
permission_id = result.first().permission_id
# add root path to the "Administrator" role
stmt = target.insert().values(permission_id=permission_id, role_id=role_id)
connection.execute(stmt)
stmt = roles_table.select().where(roles_table.c.name == "User")
result = connection.execute(stmt)
role_id = result.first().role_id
# add minimum required paths to the "User" role
for path in ("/notifications", "/projects", "/templates", "/computes/*", "/symbols/*"):
stmt = permissions_table.select().where(permissions_table.c.path == path)
result = connection.execute(stmt)
permission_id = result.first().permission_id
stmt = target.insert().values(permission_id=permission_id, role_id=role_id)
connection.execute(stmt)
connection.commit()

View File

@ -0,0 +1,258 @@
#!/usr/bin/env python
#
# Copyright (C) 2023 GNS3 Technologies Inc.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from sqlalchemy import Table, Column, String, ForeignKey, event
from sqlalchemy.orm import relationship
from .base import Base, BaseTable, generate_uuid, GUID
import logging
log = logging.getLogger(__name__)
privilege_role_map = Table(
"privilege_role_map",
Base.metadata,
Column("privilege_id", GUID, ForeignKey("privileges.privilege_id", ondelete="CASCADE")),
Column("role_id", GUID, ForeignKey("roles.role_id", ondelete="CASCADE"))
)
class Privilege(BaseTable):
__tablename__ = "privileges"
privilege_id = Column(GUID, primary_key=True, default=generate_uuid)
name = Column(String)
description = Column(String)
roles = relationship("Role", secondary=privilege_role_map, back_populates="privileges")
@event.listens_for(Privilege.__table__, 'after_create')
def create_default_roles(target, connection, **kw):
default_privileges = [
{
"description": "Create or delete a user",
"name": "User.Allocate"
},
{
"description": "View a user",
"name": "User.Audit"
},
{
"description": "Update a user",
"name": "User.Modify"
},
{
"description": "Create or delete a group",
"name": "Group.Allocate"
},
{
"description": "View a group",
"name": "Group.Audit"
},
{
"description": "Update a group",
"name": "Group.Modify"
},
{
"description": "Create or delete a template",
"name": "Template.Allocate"
},
{
"description": "View a template",
"name": "Template.Audit"
},
{
"description": "Update a template",
"name": "Template.Modify"
},
{
"description": "Create or delete a project",
"name": "Project.Allocate"
},
{
"description": "View a project",
"name": "Project.Audit"
},
{
"description": "Update a project",
"name": "Project.Modify"
},
{
"description": "Create or delete project snapshots",
"name": "Project.Snapshot"
},
{
"description": "Create or delete a node",
"name": "Node.Allocate"
},
{
"description": "View a node",
"name": "Node.Audit"
},
{
"description": "Update a node",
"name": "Node.Modify"
},
{
"description": "Console access to a node",
"name": "Node.Console"
},
{
"description": "Power management for a node",
"name": "Node.PowerMgmt"
},
{
"description": "Create or delete a link",
"name": "Link.Allocate"
},
{
"description": "View a link",
"name": "Link.Audit"
},
{
"description": "Update a link",
"name": "Link.Modify"
},
{
"description": "Capture packets on a link",
"name": "Link.Capture"
},
{
"description": "Create or delete a drawing",
"name": "Drawing.Allocate"
},
{
"description": "View a drawing",
"name": "Drawing.Audit"
},
{
"description": "Update a drawing",
"name": "Drawing.Modify"
},
{
"description": "Create or delete a symbol",
"name": "Symbol.Allocate"
},
{
"description": "View a symbol",
"name": "Symbol.Audit"
},
{
"description": "Create or delete an image",
"name": "Image.Allocate"
},
{
"description": "View an image",
"name": "Image.Audit"
},
{
"description": "Create or delete a compute",
"name": "Compute.Allocate"
},
{
"description": "View a compute",
"name": "Compute.Audit"
},
]
stmt = target.insert().values(default_privileges)
connection.execute(stmt)
connection.commit()
log.debug("The default privileges have been created in the database")
def add_privileges_to_role(target, connection, role, privileges):
from .roles import Role
roles_table = Role.__table__
privileges_table = Privilege.__table__
stmt = roles_table.select().where(roles_table.c.name == role)
result = connection.execute(stmt)
role_id = result.first().role_id
for privilege_name in privileges:
stmt = privileges_table.select().where(privileges_table.c.name == privilege_name)
result = connection.execute(stmt)
privilege_id = result.first().privilege_id
stmt = target.insert().values(privilege_id=privilege_id, role_id=role_id)
connection.execute(stmt)
@event.listens_for(privilege_role_map, 'after_create')
def add_privileges_to_default_roles(target, connection, **kw):
from .roles import Role
roles_table = Role.__table__
stmt = roles_table.select().where(roles_table.c.name == "Administrator")
result = connection.execute(stmt)
role_id = result.first().role_id
# add all privileges to the "Administrator" role
privileges_table = Privilege.__table__
stmt = privileges_table.select()
result = connection.execute(stmt)
for row in result:
privilege_id = row.privilege_id
stmt = target.insert().values(privilege_id=privilege_id, role_id=role_id)
connection.execute(stmt)
# add required privileges to the "User" role
user_privileges = (
"Project.Allocate",
"Project.Audit",
"Project.Modify",
"Project.Snapshot",
"Node.Allocate",
"Node.Audit",
"Node.Modify",
"Node.Console",
"Node.PowerMgmt",
"Link.Allocate",
"Link.Audit",
"Link.Modify",
"Link.Capture",
"Drawing.Allocate",
"Drawing.Audit",
"Drawing.Modify",
"Template.Audit",
"Symbol.Audit",
"Image.Audit",
"Compute.Audit"
)
add_privileges_to_role(target, connection, "User", user_privileges)
# add required privileges to the "Auditor" role
auditor_privileges = (
"Project.Audit",
"Node.Audit",
"Link.Audit",
"Drawing.Audit",
"Template.Audit",
"Symbol.Audit",
"Image.Audit",
"Compute.Audit"
)
add_privileges_to_role(target, connection, "Auditor", auditor_privileges)
connection.commit()
log.debug("Privileges have been added to the default roles in the database")

View File

@ -25,21 +25,18 @@ import logging
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class Resource(BaseTable): class ResourcePool(BaseTable):
__tablename__ = "resources" __tablename__ = "resource_pools"
resource_id = Column(GUID, primary_key=True, default=generate_uuid) resource_id = Column(GUID, primary_key=True)
name = Column(String, unique=True, index=True) resource_type = Column(String)
description = Column(String)
propagate = Column(Boolean, default=True) # # Create a self-referential relationship to represent a hierarchy of resources
user_id = Column(GUID, ForeignKey('users.user_id', ondelete="CASCADE")) # parent_id = Column(GUID, ForeignKey("resources.resource_id", ondelete="CASCADE"))
user = relationship("User", back_populates="resources") # children = relationship(
acl_entries = relationship("ACL") # "Resource",
parent_id = Column(GUID, ForeignKey("resources.resource_id", ondelete="CASCADE")) # remote_side=[resource_id],
children = relationship( # cascade="all, delete-orphan",
"Resource", # single_parent=True
remote_side=[resource_id], # )
cascade="all, delete-orphan",
single_parent=True
)

View File

@ -15,23 +15,16 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from sqlalchemy import Table, Column, String, Boolean, ForeignKey, event from sqlalchemy import Column, String, Boolean, event
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
from .base import Base, BaseTable, generate_uuid, GUID from .base import BaseTable, generate_uuid, GUID
from .permissions import permission_role_map from .privileges import privilege_role_map
import logging import logging
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
role_group_map = Table(
"role_group_map",
Base.metadata,
Column("role_id", GUID, ForeignKey("roles.role_id", ondelete="CASCADE")),
Column("user_group_id", GUID, ForeignKey("user_groups.user_group_id", ondelete="CASCADE"))
)
class Role(BaseTable): class Role(BaseTable):
@ -41,9 +34,8 @@ class Role(BaseTable):
name = Column(String, unique=True, index=True) name = Column(String, unique=True, index=True)
description = Column(String) description = Column(String)
is_builtin = Column(Boolean, default=False) is_builtin = Column(Boolean, default=False)
permissions = relationship("Permission", secondary=permission_role_map, back_populates="roles") privileges = relationship("Privilege", secondary=privilege_role_map, back_populates="roles")
groups = relationship("UserGroup", secondary=role_group_map, back_populates="roles") acl_entries = relationship("ACE")
acl_entries = relationship("ACL")
@event.listens_for(Role.__table__, 'after_create') @event.listens_for(Role.__table__, 'after_create')
@ -52,31 +44,11 @@ def create_default_roles(target, connection, **kw):
default_roles = [ default_roles = [
{"name": "Administrator", "description": "Administrator role", "is_builtin": True}, {"name": "Administrator", "description": "Administrator role", "is_builtin": True},
{"name": "User", "description": "User role", "is_builtin": True}, {"name": "User", "description": "User role", "is_builtin": True},
{"name": "Auditor", "description": "Role with read only access", "is_builtin": True},
{"name": "No Access", "description": "Role with no privileges (used to forbid access)", "is_builtin": True}
] ]
stmt = target.insert().values(default_roles) stmt = target.insert().values(default_roles)
connection.execute(stmt) connection.execute(stmt)
connection.commit() connection.commit()
log.debug("The default roles have been created in the database") log.debug("The default roles have been created in the database")
@event.listens_for(role_group_map, 'after_create')
def add_admin_to_group(target, connection, **kw):
from .users import UserGroup
user_groups_table = UserGroup.__table__
roles_table = Role.__table__
# Add roles to built-in user groups
groups_to_roles = {"Administrators": "Administrator", "Users": "User"}
for user_group, role in groups_to_roles.items():
stmt = user_groups_table.select().where(user_groups_table.c.name == user_group)
result = connection.execute(stmt)
user_group_id = result.first().user_group_id
stmt = roles_table.select().where(roles_table.c.name == role)
result = connection.execute(stmt)
role_id = result.first().role_id
stmt = target.insert().values(role_id=role_id, user_group_id=user_group_id)
connection.execute(stmt)
connection.commit()

View File

@ -19,7 +19,6 @@ from sqlalchemy import Table, Boolean, Column, String, DateTime, ForeignKey, eve
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
from .base import Base, BaseTable, generate_uuid, GUID from .base import Base, BaseTable, generate_uuid, GUID
from .roles import role_group_map
from gns3server.config import Config from gns3server.config import Config
from gns3server.services import auth_service from gns3server.services import auth_service
@ -49,9 +48,7 @@ class User(BaseTable):
is_active = Column(Boolean, default=True) is_active = Column(Boolean, default=True)
is_superadmin = Column(Boolean, default=False) is_superadmin = Column(Boolean, default=False)
groups = relationship("UserGroup", secondary=user_group_map, back_populates="users") groups = relationship("UserGroup", secondary=user_group_map, back_populates="users")
resources = relationship("Resource") acl_entries = relationship("ACE")
permissions = relationship("Permission")
acl_entries = relationship("ACL")
@event.listens_for(User.__table__, 'after_create') @event.listens_for(User.__table__, 'after_create')
@ -80,8 +77,7 @@ class UserGroup(BaseTable):
name = Column(String, unique=True, index=True) name = Column(String, unique=True, index=True)
is_builtin = Column(Boolean, default=False) is_builtin = Column(Boolean, default=False)
users = relationship("User", secondary=user_group_map, back_populates="groups") users = relationship("User", secondary=user_group_map, back_populates="groups")
roles = relationship("Role", secondary=role_group_map, back_populates="groups") acl_entries = relationship("ACE")
acl_entries = relationship("ACL")
@event.listens_for(UserGroup.__table__, 'after_create') @event.listens_for(UserGroup.__table__, 'after_create')
@ -96,21 +92,3 @@ def create_default_user_groups(target, connection, **kw):
connection.execute(stmt) connection.execute(stmt)
connection.commit() connection.commit()
log.debug("The default user groups have been created in the database") log.debug("The default user groups have been created in the database")
# @event.listens_for(user_group_link, 'after_create')
# def add_admin_to_group(target, connection, **kw):
#
# user_groups_table = UserGroup.__table__
# stmt = user_groups_table.select().where(user_groups_table.c.name == "Administrators")
# result = connection.execute(stmt)
# user_group_id = result.first().user_group_id
#
# users_table = User.__table__
# stmt = users_table.select().where(users_table.c.is_superadmin.is_(True))
# result = connection.execute(stmt)
# user_id = result.first().user_id
#
# stmt = target.insert().values(user_id=user_id, user_group_id=user_group_id)
# connection.execute(stmt)
# connection.commit()

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python #!/usr/bin/env python
# #
# Copyright (C) 2020 GNS3 Technologies Inc. # Copyright (C) 2023 GNS3 Technologies Inc.
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by # it under the terms of the GNU General Public License as published by
@ -24,7 +24,6 @@ from sqlalchemy.orm import selectinload
from .base import BaseRepository from .base import BaseRepository
import gns3server.db.models as models import gns3server.db.models as models
from gns3server.schemas.controller.rbac import HTTPMethods, PermissionAction
from gns3server import schemas from gns3server import schemas
import logging import logging
@ -44,7 +43,7 @@ class RbacRepository(BaseRepository):
""" """
query = select(models.Role).\ query = select(models.Role).\
options(selectinload(models.Role.permissions)).\ options(selectinload(models.Role.privileges)).\
where(models.Role.role_id == role_id) where(models.Role.role_id == role_id)
result = await self._db_session.execute(query) result = await self._db_session.execute(query)
return result.scalars().first() return result.scalars().first()
@ -55,9 +54,8 @@ class RbacRepository(BaseRepository):
""" """
query = select(models.Role).\ query = select(models.Role).\
options(selectinload(models.Role.permissions)).\ options(selectinload(models.Role.privileges)).\
where(models.Role.name == name) where(models.Role.name == name)
#query = select(models.Role).where(models.Role.name == name)
result = await self._db_session.execute(query) result = await self._db_session.execute(query)
return result.scalars().first() return result.scalars().first()
@ -66,7 +64,7 @@ class RbacRepository(BaseRepository):
Get all roles. Get all roles.
""" """
query = select(models.Role).options(selectinload(models.Role.permissions)) query = select(models.Role).options(selectinload(models.Role.privileges))
result = await self._db_session.execute(query) result = await self._db_session.execute(query)
return result.scalars().all() return result.scalars().all()
@ -81,7 +79,6 @@ class RbacRepository(BaseRepository):
) )
self._db_session.add(db_role) self._db_session.add(db_role)
await self._db_session.commit() await self._db_session.commit()
#await self._db_session.refresh(db_role)
return await self.get_role(db_role.role_id) return await self.get_role(db_role.role_id)
async def update_role( async def update_role(
@ -115,286 +112,256 @@ class RbacRepository(BaseRepository):
await self._db_session.commit() await self._db_session.commit()
return result.rowcount > 0 return result.rowcount > 0
async def add_permission_to_role( async def add_privilege_to_role(
self, self,
role_id: UUID, role_id: UUID,
permission: models.Permission privilege: models.Privilege
) -> Union[None, models.Role]: ) -> Union[None, models.Role]:
""" """
Add a permission to a role. Add a privilege to a role.
""" """
query = select(models.Role).\ query = select(models.Role).\
options(selectinload(models.Role.permissions)).\ options(selectinload(models.Role.privileges)).\
where(models.Role.role_id == role_id) where(models.Role.role_id == role_id)
result = await self._db_session.execute(query) result = await self._db_session.execute(query)
role_db = result.scalars().first() role_db = result.scalars().first()
if not role_db: if not role_db:
return None return None
role_db.permissions.append(permission) role_db.privileges.append(privilege)
await self._db_session.commit() await self._db_session.commit()
await self._db_session.refresh(role_db) await self._db_session.refresh(role_db)
return role_db return role_db
async def remove_permission_from_role( async def remove_privilege_from_role(
self, self,
role_id: UUID, role_id: UUID,
permission: models.Permission privilege: models.Privilege
) -> Union[None, models.Role]: ) -> Union[None, models.Role]:
""" """
Remove a permission from a role. Remove a privilege from a role.
""" """
query = select(models.Role).\ query = select(models.Role).\
options(selectinload(models.Role.permissions)).\ options(selectinload(models.Role.privileges)).\
where(models.Role.role_id == role_id) where(models.Role.role_id == role_id)
result = await self._db_session.execute(query) result = await self._db_session.execute(query)
role_db = result.scalars().first() role_db = result.scalars().first()
if not role_db: if not role_db:
return None return None
role_db.permissions.remove(permission) role_db.privileges.remove(privilege)
await self._db_session.commit() await self._db_session.commit()
await self._db_session.refresh(role_db) await self._db_session.refresh(role_db)
return role_db return role_db
async def get_role_permissions(self, role_id: UUID) -> List[models.Permission]: async def get_role_privileges(self, role_id: UUID) -> List[models.Privilege]:
""" """
Get all the role permissions. Get all the role privileges.
""" """
query = select(models.Permission).\ query = select(models.Privilege).\
join(models.Permission.roles).\ join(models.Privilege.roles).\
filter(models.Role.role_id == role_id) filter(models.Role.role_id == role_id)
result = await self._db_session.execute(query) result = await self._db_session.execute(query)
return result.scalars().all() return result.scalars().all()
async def get_permission(self, permission_id: UUID) -> Optional[models.Permission]: async def get_privilege(self, privilege_id: UUID) -> Optional[models.Privilege]:
""" """
Get a permission by its ID. Get a privilege by its ID.
""" """
query = select(models.Permission).where(models.Permission.permission_id == permission_id) query = select(models.Privilege).where(models.Privilege.privilege_id == privilege_id)
result = await self._db_session.execute(query) result = await self._db_session.execute(query)
return result.scalars().first() return result.scalars().first()
async def get_permission_by_path(self, path: str) -> Optional[models.Permission]: async def get_privilege_by_name(self, name: str) -> Optional[models.Privilege]:
""" """
Get a permission by its path. Get a privilege by its name.
""" """
query = select(models.Permission).where(models.Permission.path == path) query = select(models.Privilege).where(models.Privilege.name == name)
result = await self._db_session.execute(query) result = await self._db_session.execute(query)
return result.scalars().first() return result.scalars().first()
async def get_permissions(self) -> List[models.Permission]: async def get_privileges(self) -> List[models.Privilege]:
""" """
Get all permissions. Get all privileges.
""" """
query = select(models.Permission).\ query = select(models.Privilege)
order_by(models.Permission.path.desc())
result = await self._db_session.execute(query) result = await self._db_session.execute(query)
return result.scalars().all() return result.scalars().all()
async def check_permission_exists(self, permission_create: schemas.PermissionCreate) -> bool: async def get_ace(self, ace_id: UUID) -> Optional[models.ACE]:
""" """
Check if a permission exists. Get an ACE by its ID.
""" """
query = select(models.Permission).\ query = select(models.ACE).where(models.ACE.ace_id == ace_id)
where(models.Permission.methods == permission_create.methods, result = await self._db_session.execute(query)
models.Permission.path == permission_create.path, return result.scalars().first()
models.Permission.action == permission_create.action)
async def get_ace_by_path(self, path: str) -> Optional[models.ACE]:
"""
Get an ACE by its path.
"""
query = select(models.ACE).where(models.ACE.path == path)
result = await self._db_session.execute(query)
return result.scalars().first()
async def get_aces(self) -> List[models.ACE]:
"""
Get all ACEs.
"""
query = select(models.ACE)
result = await self._db_session.execute(query)
return result.scalars().all()
async def check_ace_exists(self, path: str) -> bool:
"""
Check if an ACE exists.
"""
query = select(models.ACE).\
where(models.ACE.path == path)
result = await self._db_session.execute(query) result = await self._db_session.execute(query)
return result.scalars().first() is not None return result.scalars().first() is not None
async def create_permission(self, permission_create: schemas.PermissionCreate) -> models.Permission: async def create_ace(self, ace_create: schemas.ACECreate) -> models.ACE:
""" """
Create a new permission. Create a new ACE
""" """
db_permission = models.Permission( create_values = ace_create.model_dump(exclude_unset=True)
description=permission_create.description, db_ace = models.ACE(**create_values)
methods=permission_create.methods, self._db_session.add(db_ace)
path=permission_create.path,
action=permission_create.action,
)
self._db_session.add(db_permission)
await self._db_session.commit() await self._db_session.commit()
await self._db_session.refresh(db_permission) await self._db_session.refresh(db_ace)
return db_permission return db_ace
async def update_permission( async def update_ace(
self, self,
permission_id: UUID, ace_id: UUID,
permission_update: schemas.PermissionUpdate ace_update: schemas.ACEUpdate
) -> Optional[models.Permission]: ) -> Optional[models.ACE]:
""" """
Update a permission. Update an ACE
""" """
update_values = permission_update.model_dump(exclude_unset=True) update_values = ace_update.model_dump(exclude_unset=True)
query = update(models.Permission).\ query = update(models.ACE).\
where(models.Permission.permission_id == permission_id).\ where(models.ACE.ace_id == ace_id).\
values(update_values) values(update_values)
await self._db_session.execute(query) await self._db_session.execute(query)
await self._db_session.commit() await self._db_session.commit()
permission_db = await self.get_permission(permission_id) ace_db = await self.get_ace(ace_id)
if permission_db: if ace_db:
await self._db_session.refresh(permission_db) # force refresh of updated_at value await self._db_session.refresh(ace_db) # force refresh of updated_at value
return permission_db return ace_db
async def delete_permission(self, permission_id: UUID) -> bool: async def delete_ace(self, ace_id: UUID) -> bool:
""" """
Delete a permission. Delete an ACE
""" """
query = delete(models.Permission).where(models.Permission.permission_id == permission_id) query = delete(models.ACE).where(models.ACE.ace_id == ace_id)
result = await self._db_session.execute(query) result = await self._db_session.execute(query)
await self._db_session.commit() await self._db_session.commit()
return result.rowcount > 0 return result.rowcount > 0
async def prune_permissions(self) -> int: # async def prune_permissions(self) -> int:
# """
# Prune orphaned permissions.
# """
#
# query = select(models.Permission).\
# filter((~models.Permission.roles.any()) & (models.Permission.user_id == null()))
# result = await self._db_session.execute(query)
# permissions = result.scalars().all()
# permissions_deleted = 0
# for permission in permissions:
# if await self.delete_permission(permission.permission_id):
# permissions_deleted += 1
# log.info(f"{permissions_deleted} orphaned permissions have been deleted")
# return permissions_deleted
#
# def _match_permission(
# self,
# permissions: List[models.Permission],
# method: str,
# path: str
# ) -> Union[None, models.Permission]:
# """
# Match the methods and path with a permission.
# """
#
# for permission in permissions:
# log.debug(f"RBAC: checking permission {permission.methods} {permission.path} {permission.action}")
# if method not in permission.methods:
# continue
# if permission.path.endswith("/*") and path.startswith(permission.path[:-2]):
# return permission
# elif permission.path == path:
# return permission
async def delete_all_ace_starting_with_path(self, path: str) -> None:
""" """
Prune orphaned permissions. Delete all ACEs starting with path.
""" """
query = select(models.Permission).\ query = delete(models.ACE).\
filter((~models.Permission.roles.any()) & (models.Permission.user_id == null())) where(models.ACE.path.startswith(path)).\
result = await self._db_session.execute(query)
permissions = result.scalars().all()
permissions_deleted = 0
for permission in permissions:
if await self.delete_permission(permission.permission_id):
permissions_deleted += 1
log.info(f"{permissions_deleted} orphaned permissions have been deleted")
return permissions_deleted
def _match_permission(
self,
permissions: List[models.Permission],
method: str,
path: str
) -> Union[None, models.Permission]:
"""
Match the methods and path with a permission.
"""
for permission in permissions:
log.debug(f"RBAC: checking permission {permission.methods} {permission.path} {permission.action}")
if method not in permission.methods:
continue
if permission.path.endswith("/*") and path.startswith(permission.path[:-2]):
return permission
elif permission.path == path:
return permission
async def get_user_permissions(self, user_id: UUID):
"""
Get all permissions from an user.
"""
query = select(models.Permission).\
join(models.User.permissions).\
filter(models.User.user_id == user_id).\
order_by(models.Permission.path.desc())
result = await self._db_session.execute(query)
return result.scalars().all()
async def add_permission_to_user(
self,
user_id: UUID,
permission: models.Permission
) -> Union[None, models.User]:
"""
Add a permission to an user.
"""
query = select(models.User).\
options(selectinload(models.User.permissions)).\
where(models.User.user_id == user_id)
result = await self._db_session.execute(query)
user_db = result.scalars().first()
if not user_db:
return None
user_db.permissions.append(permission)
await self._db_session.commit()
await self._db_session.refresh(user_db)
return user_db
async def remove_permission_from_user(
self,
user_id: UUID,
permission: models.Permission
) -> Union[None, models.User]:
"""
Remove a permission from a role.
"""
query = select(models.User).\
options(selectinload(models.User.permissions)).\
where(models.User.user_id == user_id)
result = await self._db_session.execute(query)
user_db = result.scalars().first()
if not user_db:
return None
user_db.permissions.remove(permission)
await self._db_session.commit()
await self._db_session.refresh(user_db)
return user_db
async def add_permission_to_user_with_path(self, user_id: UUID, path: str) -> Union[None, models.User]:
"""
Add a permission to an user.
"""
# Create a new permission with full rights on path
new_permission = schemas.PermissionCreate(
description=f"Allow access to {path}",
methods=[HTTPMethods.get, HTTPMethods.head, HTTPMethods.post, HTTPMethods.put, HTTPMethods.delete],
path=path,
action=PermissionAction.allow
)
permission_db = await self.create_permission(new_permission)
# Add the permission to the user
query = select(models.User).\
options(selectinload(models.User.permissions)).\
where(models.User.user_id == user_id)
result = await self._db_session.execute(query)
user_db = result.scalars().first()
if not user_db:
return None
user_db.permissions.append(permission_db)
await self._db_session.commit()
await self._db_session.refresh(user_db)
return user_db
async def delete_all_permissions_with_path(self, path: str) -> None:
"""
Delete all permissions with path.
"""
query = delete(models.Permission).\
where(models.Permission.path.startswith(path)).\
execution_options(synchronize_session=False) execution_options(synchronize_session=False)
result = await self._db_session.execute(query) result = await self._db_session.execute(query)
log.debug(f"{result.rowcount} permission(s) have been deleted") log.debug(f"{result.rowcount} ACE(s) have been deleted")
async def check_user_is_authorized(self, user_id: UUID, method: str, path: str) -> bool: async def check_user_has_privilege(self, user_id: UUID, path: str, privilege_name: str) -> bool:
# query = select(models.Privilege.name).\
# join(models.Privilege.roles).\
# join(models.Role.acl_entries).\
# join(models.ACE.user).\
# filter(models.Privilege.name == privilege). \
# filter(models.User.user_id == user_id).\
# filter(models.ACE.path == path).\
# distinct()
#query = select(models.ACE.path)
#result = await self._db_session.execute(query)
#res = result.scalars().all()
#print("ACL TABLE ==>", res)
#for ace in res:
# print(ace)
query = select(models.Privilege.name, models.ACE.path, models.ACE.propagate).\
join(models.Privilege.roles).\
join(models.Role.acl_entries).\
join(models.ACE.user).\
filter(models.User.user_id == user_id).\
filter(models.Privilege.name == privilege_name).\
filter(models.ACE.path == path).\
order_by(models.ACE.path.desc())
result = await self._db_session.execute(query)
privileges = result.all()
#print(privileges)
for privilege, privilege_path, propagate in privileges:
if privilege_path == path:
return True
return False
async def check_user_is_authorized(self, user_id: UUID, path: str) -> bool:
""" """
Check if an user is authorized to access a resource. Check if a user is authorized to access a resource.
""" """
return True
query = select(models.Permission).\ query = select(models.Permission).\
join(models.Permission.roles).\ join(models.Permission.roles).\
join(models.Role.groups).\ join(models.Role.groups).\

View File

@ -287,60 +287,3 @@ class UsersRepository(BaseRepository):
result = await self._db_session.execute(query) result = await self._db_session.execute(query)
return result.scalars().all() return result.scalars().all()
async def add_role_to_user_group(
self,
user_group_id: UUID,
role: models.Role
) -> Union[None, models.UserGroup]:
"""
Add a role to a user group.
"""
query = select(models.UserGroup).\
options(selectinload(models.UserGroup.roles)).\
where(models.UserGroup.user_group_id == user_group_id)
result = await self._db_session.execute(query)
user_group_db = result.scalars().first()
if not user_group_db:
return None
user_group_db.roles.append(role)
await self._db_session.commit()
await self._db_session.refresh(user_group_db)
return user_group_db
async def remove_role_from_user_group(
self,
user_group_id: UUID,
role: models.Role
) -> Union[None, models.UserGroup]:
"""
Remove a role from a user group.
"""
query = select(models.UserGroup).\
options(selectinload(models.UserGroup.roles)).\
where(models.UserGroup.user_group_id == user_group_id)
result = await self._db_session.execute(query)
user_group_db = result.scalars().first()
if not user_group_db:
return None
user_group_db.roles.remove(role)
await self._db_session.commit()
await self._db_session.refresh(user_group_db)
return user_group_db
async def get_user_group_roles(self, user_group_id: UUID) -> List[models.Role]:
"""
Get all roles from a user group.
"""
query = select(models.Role). \
options(selectinload(models.Role.permissions)). \
join(models.UserGroup.roles). \
filter(models.UserGroup.user_group_id == user_group_id)
result = await self._db_session.execute(query)
return result.scalars().all()

View File

@ -30,7 +30,7 @@ from .controller.gns3vm import GNS3VM
from .controller.nodes import NodeCreate, NodeUpdate, NodeDuplicate, NodeCapture, Node from .controller.nodes import NodeCreate, NodeUpdate, NodeDuplicate, NodeCapture, Node
from .controller.projects import ProjectCreate, ProjectUpdate, ProjectDuplicate, Project, ProjectFile, ProjectCompression from .controller.projects import ProjectCreate, ProjectUpdate, ProjectDuplicate, Project, ProjectFile, ProjectCompression
from .controller.users import UserCreate, UserUpdate, LoggedInUserUpdate, User, Credentials, UserGroupCreate, UserGroupUpdate, UserGroup from .controller.users import UserCreate, UserUpdate, LoggedInUserUpdate, User, Credentials, UserGroupCreate, UserGroupUpdate, UserGroup
from .controller.rbac import RoleCreate, RoleUpdate, Role, PermissionCreate, PermissionUpdate, Permission from .controller.rbac import RoleCreate, RoleUpdate, Role, Privilege, ACECreate, ACEUpdate, ACE
from .controller.tokens import Token from .controller.tokens import Token
from .controller.snapshots import SnapshotCreate, Snapshot from .controller.snapshots import SnapshotCreate, Snapshot
from .controller.iou_license import IOULicense from .controller.iou_license import IOULicense

View File

@ -15,71 +15,68 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from typing import Optional, List from typing import Optional, List
from pydantic import field_validator, ConfigDict, BaseModel from pydantic import ConfigDict, BaseModel, Field
from uuid import UUID from uuid import UUID
from enum import Enum from enum import Enum
from .base import DateTimeModelMixin from .base import DateTimeModelMixin
class HTTPMethods(str, Enum): class PrivilegeBase(BaseModel):
""" """
HTTP method type. Common privilege properties.
""" """
get = "GET" name: str
head = "HEAD"
post = "POST"
patch = "PATCH"
put = "PUT"
delete = "DELETE"
class PermissionAction(str, Enum):
"""
Action to perform when permission is matched.
"""
allow = "ALLOW"
deny = "DENY"
class PermissionBase(BaseModel):
"""
Common permission properties.
"""
methods: List[HTTPMethods]
path: str
action: PermissionAction
description: Optional[str] = None description: Optional[str] = None
class Privilege(DateTimeModelMixin, PrivilegeBase):
privilege_id: UUID
model_config = ConfigDict(from_attributes=True)
class ACEType(str, Enum):
user = "user"
group = "group"
class ACEBase(BaseModel):
"""
Common ACE properties.
"""
path: str
propagate: Optional[bool] = True
allowed: Optional[bool] = True
type: ACEType = Field(..., description="Type of the ACE")
user_id: Optional[UUID] = None
group_id: Optional[UUID] = None
role_id: UUID
model_config = ConfigDict(use_enum_values=True) model_config = ConfigDict(use_enum_values=True)
@field_validator("action", mode="before")
@classmethod
def action_uppercase(cls, v):
return v.upper()
class ACECreate(ACEBase):
class PermissionCreate(PermissionBase):
""" """
Properties to create a permission. Properties to create an ACE.
""" """
pass pass
class PermissionUpdate(PermissionBase): class ACEUpdate(ACEBase):
""" """
Properties to update a role. Properties to update an ACE.
""" """
pass pass
class Permission(DateTimeModelMixin, PermissionBase): class ACE(DateTimeModelMixin, ACEBase):
permission_id: UUID ace_id: UUID
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)
@ -112,5 +109,5 @@ class Role(DateTimeModelMixin, RoleBase):
role_id: UUID role_id: UUID
is_builtin: bool is_builtin: bool
permissions: List[Permission] privileges: List[Privilege]
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)

View File

@ -52,7 +52,7 @@ class UserUpdate(UserBase):
class LoggedInUserUpdate(BaseModel): class LoggedInUserUpdate(BaseModel):
""" """
Properties to update a logged in user. Properties to update a logged-in user.
""" """
password: Optional[SecretStr] = Field(None, min_length=6, max_length=100) password: Optional[SecretStr] = Field(None, min_length=6, max_length=100)

View File

@ -0,0 +1,214 @@
#!/usr/bin/env python
#
# Copyright (C) 2023 GNS3 Technologies Inc.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import pytest
import pytest_asyncio
import uuid
from fastapi import FastAPI, status
from httpx import AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession
from gns3server.db.repositories.users import UsersRepository
from gns3server.db.repositories.rbac import RbacRepository
from gns3server.controller import Controller
from gns3server.controller.project import Project
from gns3server.schemas.controller.users import User
from gns3server.schemas.controller.rbac import ACECreate
pytestmark = pytest.mark.asyncio
class TestACLRoutes:
# @pytest_asyncio.fixture
# async def project(
# self,
# app: FastAPI,
# authorized_client: AsyncClient,
# test_user: User,
# db_session: AsyncSession,
# controller: Controller
# ) -> Project:
#
# # add an ACE to allow user to create a project
# user_id = test_user.user_id
# rbac_repo = RbacRepository(db_session)
# role_in_db = await rbac_repo.get_role_by_name("User")
# role_id = role_in_db.role_id
# ace = ACECreate(
# path="/projects",
# type="user",
# user_id=user_id,
# role_id=role_id
# )
# await rbac_repo.create_ace(ace)
# project_uuid = str(uuid.uuid4())
# params = {"name": "test", "project_id": project_uuid}
# response = await authorized_client.post(app.url_path_for("create_project"), json=params)
# assert response.status_code == status.HTTP_201_CREATED
# return controller.get_project(project_uuid)
#@pytest_asyncio.fixture
# async def project(
# self,
# app: FastAPI,
# client: AsyncClient,
# controller: Controller
# ) -> Project:
#
# project_uuid = str(uuid.uuid4())
# params = {"name": "test", "project_id": project_uuid}
# response = await client.post(app.url_path_for("create_project"), json=params)
# assert response.status_code == status.HTTP_201_CREATED
# return controller.get_project(project_uuid)
@pytest_asyncio.fixture
async def group_id(self, db_session: AsyncSession) -> str:
users_repo = UsersRepository(db_session)
group_in_db = await users_repo.get_user_group_by_name("Users")
group_id = str(group_in_db.user_group_id)
return group_id
@pytest_asyncio.fixture
async def role_id(self, db_session: AsyncSession) -> str:
rbac_repo = RbacRepository(db_session)
role_in_db = await rbac_repo.get_role_by_name("User")
role_id = str(role_in_db.role_id)
return role_id
async def test_create_ace(
self,
app: FastAPI,
authorized_client: AsyncClient,
db_session: AsyncSession,
test_user: User,
role_id: str
) -> None:
# add an ACE on /projects to allow user to create a project
path = f"/projects"
new_ace = {
"path": path,
"type": "user",
"user_id": str(test_user.user_id),
"role_id": role_id
}
response = await authorized_client.post(app.url_path_for("create_ace"), json=new_ace)
assert response.status_code == status.HTTP_201_CREATED
rbac_repo = RbacRepository(db_session)
assert await rbac_repo.check_user_has_privilege(test_user.user_id, path, "Project.Allocate") is True
response = await authorized_client.post(app.url_path_for("create_project"), json={"name": "test"})
assert response.status_code == status.HTTP_201_CREATED
async def test_create_ace_not_existing_endpoint(
self,
app: FastAPI,
client: AsyncClient,
group_id: str,
role_id: str
) -> None:
new_ace = {
"path": "/projects/invalid",
"type": "group",
"group_id": group_id,
"role_id": role_id
}
response = await client.post(app.url_path_for("create_ace"), json=new_ace)
assert response.status_code == status.HTTP_400_BAD_REQUEST
# async def test_create_ace_not_existing_resource(
# self,
# app: FastAPI,
# client: AsyncClient,
# group_id: str,
# role_id: str
# ) -> None:
#
# new_ace = {
# "path": f"/projects/{str(uuid.uuid4())}",
# "group_id": group_id,
# "role_id": role_id
# }
# response = await client.post(app.url_path_for("create_ace"), json=new_ace)
# assert response.status_code == status.HTTP_403_FORBIDDEN
async def test_get_ace(self, app: FastAPI, client: AsyncClient, db_session: AsyncSession) -> None:
rbac_repo = RbacRepository(db_session)
ace_in_db = await rbac_repo.get_ace_by_path(f"/projects")
response = await client.get(app.url_path_for("get_ace", ace_id=ace_in_db.ace_id))
assert response.status_code == status.HTTP_200_OK
assert response.json()["ace_id"] == str(ace_in_db.ace_id)
async def test_list_aces(self, app: FastAPI, client: AsyncClient) -> None:
response = await client.get(app.url_path_for("get_aces"))
assert response.status_code == status.HTTP_200_OK
assert len(response.json()) == 1
async def test_update_ace(
self, app: FastAPI,
client: AsyncClient,
db_session: AsyncSession,
test_user: User,
role_id: str
) -> None:
rbac_repo = RbacRepository(db_session)
ace_in_db = await rbac_repo.get_ace_by_path(f"/projects")
update_ace = {
"path": f"/appliances",
"type": "user",
"user_id": str(test_user.user_id),
"role_id": role_id
}
response = await client.put(
app.url_path_for("update_ace", ace_id=ace_in_db.ace_id),
json=update_ace
)
assert response.status_code == status.HTTP_200_OK
updated_ace_in_db = await rbac_repo.get_ace(ace_in_db.ace_id)
assert updated_ace_in_db.path == f"/appliances"
async def test_delete_ace(
self,
app: FastAPI,
client: AsyncClient,
db_session: AsyncSession,
) -> None:
rbac_repo = RbacRepository(db_session)
ace_in_db = await rbac_repo.get_ace_by_path(f"/appliances")
response = await client.delete(app.url_path_for("delete_ace", ace_id=ace_in_db.ace_id))
assert response.status_code == status.HTTP_204_NO_CONTENT
# async def test_prune_permissions(self, app: FastAPI, client: AsyncClient, db_session: AsyncSession) -> None:
#
# response = await client.post(app.url_path_for("prune_permissions"))
# assert response.status_code == status.HTTP_204_NO_CONTENT
#
# rbac_repo = RbacRepository(db_session)
# permissions_in_db = await rbac_repo.get_permissions()
# assert len(permissions_in_db) == 10 # 6 default permissions + 4 custom permissions

View File

@ -16,17 +16,13 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import pytest import pytest
import pytest_asyncio
from fastapi import FastAPI, status from fastapi import FastAPI, status
from httpx import AsyncClient from httpx import AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from gns3server.db.repositories.users import UsersRepository from gns3server.db.repositories.users import UsersRepository
from gns3server.db.repositories.rbac import RbacRepository
from gns3server.schemas.controller.users import User from gns3server.schemas.controller.users import User
from gns3server.schemas.controller.rbac import Role
from gns3server import schemas
pytestmark = pytest.mark.asyncio pytestmark = pytest.mark.asyncio
@ -170,84 +166,3 @@ class TestGroupMembersRoutes:
assert response.status_code == status.HTTP_204_NO_CONTENT assert response.status_code == status.HTTP_204_NO_CONTENT
members = await user_repo.get_user_group_members(group_in_db.user_group_id) members = await user_repo.get_user_group_members(group_in_db.user_group_id)
assert len(members) == 0 assert len(members) == 0
@pytest_asyncio.fixture
async def test_role(db_session: AsyncSession) -> Role:
new_role = schemas.RoleCreate(
name="TestRole",
description="This is my test role"
)
rbac_repo = RbacRepository(db_session)
existing_role = await rbac_repo.get_role_by_name(new_role.name)
if existing_role:
return existing_role
return await rbac_repo.create_role(new_role)
class TestGroupRolesRoutes:
async def test_add_role_to_group(
self,
app: FastAPI,
client: AsyncClient,
test_role: Role,
db_session: AsyncSession
) -> None:
user_repo = UsersRepository(db_session)
group_in_db = await user_repo.get_user_group_by_name("Users")
response = await client.put(
app.url_path_for(
"add_role_to_group",
user_group_id=group_in_db.user_group_id,
role_id=str(test_role.role_id)
)
)
assert response.status_code == status.HTTP_204_NO_CONTENT
roles = await user_repo.get_user_group_roles(group_in_db.user_group_id)
assert len(roles) == 2 # 1 default role + 1 custom role
for role in roles:
if not role.is_builtin:
assert role.name == test_role.name
async def test_get_user_group_roles(
self,
app: FastAPI,
client: AsyncClient,
db_session: AsyncSession
) -> None:
user_repo = UsersRepository(db_session)
group_in_db = await user_repo.get_user_group_by_name("Users")
response = await client.get(
app.url_path_for(
"get_user_group_roles",
user_group_id=group_in_db.user_group_id)
)
assert response.status_code == status.HTTP_200_OK
assert len(response.json()) == 2 # 1 default role + 1 custom role
async def test_remove_role_from_group(
self,
app: FastAPI,
client: AsyncClient,
test_role: Role,
db_session: AsyncSession
) -> None:
user_repo = UsersRepository(db_session)
group_in_db = await user_repo.get_user_group_by_name("Users")
response = await client.delete(
app.url_path_for(
"remove_role_from_group",
user_group_id=group_in_db.user_group_id,
role_id=test_role.role_id
),
)
assert response.status_code == status.HTTP_204_NO_CONTENT
roles = await user_repo.get_user_group_roles(group_in_db.user_group_id)
assert len(roles) == 1 # 1 default role
assert roles[0].name != test_role.name

View File

@ -1,136 +0,0 @@
#!/usr/bin/env python
#
# Copyright (C) 2021 GNS3 Technologies Inc.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import pytest
import pytest_asyncio
import uuid
from fastapi import FastAPI, status
from httpx import AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession
from gns3server.db.repositories.rbac import RbacRepository
from gns3server.controller import Controller
from gns3server.controller.project import Project
pytestmark = pytest.mark.asyncio
class TestPermissionRoutes:
@pytest_asyncio.fixture
async def project(self, app: FastAPI, client: AsyncClient, controller: Controller) -> Project:
project_uuid = str(uuid.uuid4())
params = {"name": "test", "project_id": project_uuid}
await client.post(app.url_path_for("create_project"), json=params)
return controller.get_project(project_uuid)
async def test_create_permission(self, app: FastAPI, client: AsyncClient, project: Project) -> None:
new_permission = {
"methods": ["GET"],
"path": f"/projects/{project.id}",
"action": "ALLOW"
}
response = await client.post(app.url_path_for("create_permission"), json=new_permission)
assert response.status_code == status.HTTP_201_CREATED
async def test_create_wildcard_permission(self, app: FastAPI, client: AsyncClient, project: Project) -> None:
new_permission = {
"methods": ["POST"],
"path": f"/projects/{project.id}/*",
"action": "ALLOW"
}
response = await client.post(app.url_path_for("create_permission"), json=new_permission)
assert response.status_code == status.HTTP_201_CREATED
async def test_create_permission_not_existing_endpoint(self, app: FastAPI, client: AsyncClient) -> None:
new_permission = {
"methods": ["GET"],
"path": "/projects/invalid",
"action": "ALLOW"
}
response = await client.post(app.url_path_for("create_permission"), json=new_permission)
assert response.status_code == status.HTTP_400_BAD_REQUEST
async def test_create_permission_not_existing_object(self, app: FastAPI, client: AsyncClient) -> None:
new_permission = {
"methods": ["GET"],
"path": f"/projects/{str(uuid.uuid4())}/*",
"action": "ALLOW"
}
response = await client.post(app.url_path_for("create_permission"), json=new_permission)
assert response.status_code == status.HTTP_403_FORBIDDEN
async def test_get_permission(self, app: FastAPI, client: AsyncClient, db_session: AsyncSession, project: Project) -> None:
rbac_repo = RbacRepository(db_session)
permission_in_db = await rbac_repo.get_permission_by_path(f"/projects/{project.id}/*")
response = await client.get(app.url_path_for("get_permission", permission_id=permission_in_db.permission_id))
assert response.status_code == status.HTTP_200_OK
assert response.json()["permission_id"] == str(permission_in_db.permission_id)
async def test_list_permissions(self, app: FastAPI, client: AsyncClient) -> None:
response = await client.get(app.url_path_for("get_permissions"))
assert response.status_code == status.HTTP_200_OK
assert len(response.json()) == 11 # 6 default permissions + 5 custom permissions
async def test_update_permission(self, app: FastAPI, client: AsyncClient, db_session: AsyncSession, project: Project) -> None:
rbac_repo = RbacRepository(db_session)
permission_in_db = await rbac_repo.get_permission_by_path(f"/projects/{project.id}/*")
update_permission = {
"methods": ["GET"],
"path": f"/projects/{project.id}/*",
"action": "ALLOW"
}
response = await client.put(
app.url_path_for("update_permission", permission_id=permission_in_db.permission_id),
json=update_permission
)
assert response.status_code == status.HTTP_200_OK
updated_permission_in_db = await rbac_repo.get_permission(permission_in_db.permission_id)
assert updated_permission_in_db.path == f"/projects/{project.id}/*"
async def test_delete_permission(
self,
app: FastAPI,
client: AsyncClient,
db_session: AsyncSession,
project: Project,
) -> None:
rbac_repo = RbacRepository(db_session)
permission_in_db = await rbac_repo.get_permission_by_path(f"/projects/{project.id}/*")
response = await client.delete(app.url_path_for("delete_permission", permission_id=permission_in_db.permission_id))
assert response.status_code == status.HTTP_204_NO_CONTENT
async def test_prune_permissions(self, app: FastAPI, client: AsyncClient, db_session: AsyncSession) -> None:
response = await client.post(app.url_path_for("prune_permissions"))
assert response.status_code == status.HTTP_204_NO_CONTENT
rbac_repo = RbacRepository(db_session)
permissions_in_db = await rbac_repo.get_permissions()
assert len(permissions_in_db) == 10 # 6 default permissions + 4 custom permissions

View File

@ -16,15 +16,12 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import pytest import pytest
import pytest_asyncio
from fastapi import FastAPI, status from fastapi import FastAPI, status
from httpx import AsyncClient from httpx import AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from gns3server.db.repositories.rbac import RbacRepository from gns3server.db.repositories.rbac import RbacRepository
from gns3server.schemas.controller.rbac import Permission, HTTPMethods, PermissionAction
from gns3server import schemas
pytestmark = pytest.mark.asyncio pytestmark = pytest.mark.asyncio
@ -49,7 +46,7 @@ class TestRolesRoutes:
response = await client.get(app.url_path_for("get_roles")) response = await client.get(app.url_path_for("get_roles"))
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert len(response.json()) == 3 # 2 default roles + role1 assert len(response.json()) == 5 # 4 default roles + role1
async def test_update_role(self, app: FastAPI, client: AsyncClient, db_session: AsyncSession) -> None: async def test_update_role(self, app: FastAPI, client: AsyncClient, db_session: AsyncSession) -> None:
@ -106,46 +103,31 @@ class TestRolesRoutes:
assert response.status_code == status.HTTP_403_FORBIDDEN assert response.status_code == status.HTTP_403_FORBIDDEN
@pytest_asyncio.fixture class TestRolesPrivilegesRoutes:
async def test_permission(db_session: AsyncSession) -> Permission:
new_permission = schemas.PermissionCreate( async def test_add_privilege_to_role(
methods=[HTTPMethods.get],
path="/statistics",
action=PermissionAction.allow
)
rbac_repo = RbacRepository(db_session)
existing_permission = await rbac_repo.get_permission_by_path("/statistics")
if existing_permission:
return existing_permission
return await rbac_repo.create_permission(new_permission)
class TestRolesPermissionsRoutes:
async def test_add_permission_to_role(
self, self,
app: FastAPI, app: FastAPI,
client: AsyncClient, client: AsyncClient,
test_permission: Permission,
db_session: AsyncSession db_session: AsyncSession
) -> None: ) -> None:
rbac_repo = RbacRepository(db_session) rbac_repo = RbacRepository(db_session)
role_in_db = await rbac_repo.get_role_by_name("User") role_in_db = await rbac_repo.get_role_by_name("User")
privilege = await rbac_repo.get_privilege_by_name("Template.Allocate")
response = await client.put( response = await client.put(
app.url_path_for( app.url_path_for(
"add_permission_to_role", "add_privilege_to_role",
role_id=role_in_db.role_id, role_id=role_in_db.role_id,
permission_id=str(test_permission.permission_id) privilege_id=str(privilege.privilege_id)
) )
) )
assert response.status_code == status.HTTP_204_NO_CONTENT assert response.status_code == status.HTTP_204_NO_CONTENT
permissions = await rbac_repo.get_role_permissions(role_in_db.role_id) privileges = await rbac_repo.get_role_privileges(role_in_db.role_id)
assert len(permissions) == 6 # 5 default permissions + 1 custom permission assert len(privileges) == 21 # 20 default privileges + 1 custom privilege
async def test_get_role_permissions( async def test_get_role_privileges(
self, self,
app: FastAPI, app: FastAPI,
client: AsyncClient, client: AsyncClient,
@ -157,30 +139,30 @@ class TestRolesPermissionsRoutes:
response = await client.get( response = await client.get(
app.url_path_for( app.url_path_for(
"get_role_permissions", "get_role_privileges",
role_id=role_in_db.role_id) role_id=role_in_db.role_id)
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert len(response.json()) == 6 # 5 default permissions + 1 custom permission assert len(response.json()) == 21 # 20 default privileges + 1 custom privilege
async def test_remove_role_from_group( async def test_remove_privilege_from_role(
self, self,
app: FastAPI, app: FastAPI,
client: AsyncClient, client: AsyncClient,
test_permission: Permission,
db_session: AsyncSession db_session: AsyncSession
) -> None: ) -> None:
rbac_repo = RbacRepository(db_session) rbac_repo = RbacRepository(db_session)
role_in_db = await rbac_repo.get_role_by_name("User") role_in_db = await rbac_repo.get_role_by_name("User")
privilege = await rbac_repo.get_privilege_by_name("Template.Allocate")
response = await client.delete( response = await client.delete(
app.url_path_for( app.url_path_for(
"remove_permission_from_role", "remove_privilege_from_role",
role_id=role_in_db.role_id, role_id=role_in_db.role_id,
permission_id=str(test_permission.permission_id) privilege_id=str(privilege.privilege_id)
), ),
) )
assert response.status_code == status.HTTP_204_NO_CONTENT assert response.status_code == status.HTTP_204_NO_CONTENT
permissions = await rbac_repo.get_role_permissions(role_in_db.role_id) privileges = await rbac_repo.get_role_privileges(role_in_db.role_id)
assert len(permissions) == 5 # 5 default permissions assert len(privileges) == 20 # 20 default privileges

View File

@ -16,7 +16,6 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import pytest import pytest
import pytest_asyncio
from typing import Optional from typing import Optional
from fastapi import FastAPI, HTTPException, status from fastapi import FastAPI, HTTPException, status
@ -26,12 +25,9 @@ from jose import jwt
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from gns3server.db.repositories.users import UsersRepository from gns3server.db.repositories.users import UsersRepository
from gns3server.db.repositories.rbac import RbacRepository
from gns3server.schemas.controller.rbac import Permission, HTTPMethods, PermissionAction
from gns3server.services import auth_service from gns3server.services import auth_service
from gns3server.config import Config from gns3server.config import Config
from gns3server.schemas.controller.users import User from gns3server.schemas.controller.users import User
from gns3server import schemas
import gns3server.db.models as models import gns3server.db.models as models
pytestmark = pytest.mark.asyncio pytestmark = pytest.mark.asyncio
@ -352,7 +348,7 @@ class TestUserMe:
assert user.email == test_user.email assert user.email == test_user.email
assert user.user_id == test_user.user_id assert user.user_id == test_user.user_id
# logged in users can only change their email, full name and password # logged-in users can only change their email, full name and password
@pytest.mark.parametrize( @pytest.mark.parametrize(
"attr, value, status_code", "attr, value, status_code",
( (
@ -426,92 +422,3 @@ class TestSuperAdmin:
response = await unauthorized_client.post(app.url_path_for("login"), data=login_data) response = await unauthorized_client.post(app.url_path_for("login"), data=login_data)
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
# async def test_super_admin_belongs_to_admin_group(
# self,
# app: FastAPI,
# client: AsyncClient,
# db_session: AsyncSession
# ) -> None:
#
# user_repo = UsersRepository(db_session)
# admin_in_db = await user_repo.get_user_by_username("admin")
# response = await client.get(app.url_path_for("get_user_memberships", user_id=admin_in_db.user_id))
# assert response.status_code == status.HTTP_200_OK
# assert len(response.json()) == 1
@pytest_asyncio.fixture
async def test_permission(db_session: AsyncSession) -> Permission:
new_permission = schemas.PermissionCreate(
methods=[HTTPMethods.get],
path="/statistics",
action=PermissionAction.allow
)
rbac_repo = RbacRepository(db_session)
existing_permission = await rbac_repo.get_permission_by_path("/statistics")
if existing_permission:
return existing_permission
return await rbac_repo.create_permission(new_permission)
class TestUserPermissionsRoutes:
async def test_add_permission_to_user(
self,
app: FastAPI,
client: AsyncClient,
test_user: User,
test_permission: Permission,
db_session: AsyncSession
) -> None:
response = await client.put(
app.url_path_for(
"add_permission_to_user",
user_id=str(test_user.user_id),
permission_id=str(test_permission.permission_id)
)
)
assert response.status_code == status.HTTP_204_NO_CONTENT
rbac_repo = RbacRepository(db_session)
permissions = await rbac_repo.get_user_permissions(test_user.user_id)
assert len(permissions) == 1
assert permissions[0].permission_id == test_permission.permission_id
async def test_get_user_permissions(
self,
app: FastAPI,
client: AsyncClient,
test_user: User,
db_session: AsyncSession
) -> None:
response = await client.get(
app.url_path_for(
"get_user_permissions",
user_id=str(test_user.user_id))
)
assert response.status_code == status.HTTP_200_OK
assert len(response.json()) == 1
async def test_remove_permission_from_user(
self,
app: FastAPI,
client: AsyncClient,
test_user: User,
test_permission: Permission,
db_session: AsyncSession
) -> None:
response = await client.delete(
app.url_path_for(
"remove_permission_from_user",
user_id=str(test_user.user_id),
permission_id=str(test_permission.permission_id)
),
)
assert response.status_code == status.HTTP_204_NO_CONTENT
rbac_repo = RbacRepository(db_session)
permissions = await rbac_repo.get_user_permissions(test_user.user_id)
assert len(permissions) == 0

View File

@ -115,7 +115,7 @@ async def test_user(db_session: AsyncSession) -> User:
return existing_user return existing_user
user = await user_repo.create_user(new_user) user = await user_repo.create_user(new_user)
# add new user to "Users group # add new user to the "Users" group
group = await user_repo.get_user_group_by_name("Users") group = await user_repo.get_user_group_by_name("Users")
await user_repo.add_member_to_user_group(group.user_group_id, user) await user_repo.add_member_to_user_group(group.user_group_id, user)
return user return user

View File

@ -27,177 +27,177 @@ from gns3server.db.models import User
pytestmark = pytest.mark.asyncio pytestmark = pytest.mark.asyncio
class TestPermissions: # class TestPermissions:
#
@pytest.mark.parametrize( # @pytest.mark.parametrize(
"method, path, result", # "method, path, result",
( # (
("GET", "/users", False), # ("GET", "/users", False),
("GET", "/projects", True), # ("GET", "/projects", True),
("GET", "/projects/e451ad73-2519-4f83-87fe-a8e821792d44", False), # ("GET", "/projects/e451ad73-2519-4f83-87fe-a8e821792d44", False),
("POST", "/projects", True), # ("POST", "/projects", True),
("GET", "/templates", True), # ("GET", "/templates", True),
("GET", "/templates/62e92cf1-244a-4486-8dae-b95439b54da9", False), # ("GET", "/templates/62e92cf1-244a-4486-8dae-b95439b54da9", False),
("POST", "/templates", True), # ("POST", "/templates", True),
("GET", "/computes", True), # ("GET", "/computes", True),
("GET", "/computes/local", True), # ("GET", "/computes/local", True),
("GET", "/symbols", True), # ("GET", "/symbols", True),
("GET", "/symbols/default_symbols", True), # ("GET", "/symbols/default_symbols", True),
), # ),
) # )
async def test_default_permissions_user_group( # async def test_default_permissions_user_group(
self, # self,
app: FastAPI, # app: FastAPI,
authorized_client: AsyncClient, # authorized_client: AsyncClient,
test_user: User, # test_user: User,
db_session: AsyncSession, # db_session: AsyncSession,
method: str, # method: str,
path: str, # path: str,
result: bool # result: bool
) -> None: # ) -> None:
#
rbac_repo = RbacRepository(db_session) # rbac_repo = RbacRepository(db_session)
authorized = await rbac_repo.check_user_is_authorized(test_user.user_id, method, path) # authorized = await rbac_repo.check_user_is_authorized(test_user.user_id, method, path)
assert authorized == result # assert authorized == result
#
#
class TestProjectsWithRbac: # class TestProjectsWithRbac:
#
async def test_admin_create_project(self, app: FastAPI, client: AsyncClient): # async def test_admin_create_project(self, app: FastAPI, client: AsyncClient):
#
params = {"name": "Admin project"} # params = {"name": "Admin project"}
response = await client.post(app.url_path_for("create_project"), json=params) # response = await client.post(app.url_path_for("create_project"), json=params)
assert response.status_code == status.HTTP_201_CREATED # assert response.status_code == status.HTTP_201_CREATED
#
async def test_user_only_access_own_projects( # async def test_user_only_access_own_projects(
self, # self,
app: FastAPI, # app: FastAPI,
authorized_client: AsyncClient, # authorized_client: AsyncClient,
test_user: User, # test_user: User,
db_session: AsyncSession # db_session: AsyncSession
) -> None: # ) -> None:
#
params = {"name": "User project"} # params = {"name": "User project"}
response = await authorized_client.post(app.url_path_for("create_project"), json=params) # response = await authorized_client.post(app.url_path_for("create_project"), json=params)
assert response.status_code == status.HTTP_201_CREATED # assert response.status_code == status.HTTP_201_CREATED
project_id = response.json()["project_id"] # project_id = response.json()["project_id"]
#
rbac_repo = RbacRepository(db_session) # rbac_repo = RbacRepository(db_session)
permissions_in_db = await rbac_repo.get_user_permissions(test_user.user_id) # permissions_in_db = await rbac_repo.get_user_permissions(test_user.user_id)
assert len(permissions_in_db) == 1 # assert len(permissions_in_db) == 1
assert permissions_in_db[0].path == f"/projects/{project_id}/*" # assert permissions_in_db[0].path == f"/projects/{project_id}/*"
#
response = await authorized_client.get(app.url_path_for("get_projects")) # response = await authorized_client.get(app.url_path_for("get_projects"))
assert response.status_code == status.HTTP_200_OK # assert response.status_code == status.HTTP_200_OK
projects = response.json() # projects = response.json()
assert len(projects) == 1 # assert len(projects) == 1
#
async def test_admin_access_all_projects(self, app: FastAPI, client: AsyncClient): # async def test_admin_access_all_projects(self, app: FastAPI, client: AsyncClient):
#
response = await client.get(app.url_path_for("get_projects")) # response = await client.get(app.url_path_for("get_projects"))
assert response.status_code == status.HTTP_200_OK # assert response.status_code == status.HTTP_200_OK
projects = response.json() # projects = response.json()
assert len(projects) == 2 # assert len(projects) == 2
#
async def test_admin_user_give_permission_on_project( # async def test_admin_user_give_permission_on_project(
self, # self,
app: FastAPI, # app: FastAPI,
client: AsyncClient, # client: AsyncClient,
test_user: User # test_user: User
): # ):
#
response = await client.get(app.url_path_for("get_projects")) # response = await client.get(app.url_path_for("get_projects"))
assert response.status_code == status.HTTP_200_OK # assert response.status_code == status.HTTP_200_OK
projects = response.json() # projects = response.json()
project_id = None # project_id = None
for project in projects: # for project in projects:
if project["name"] == "Admin project": # if project["name"] == "Admin project":
project_id = project["project_id"] # project_id = project["project_id"]
break # break
#
new_permission = { # new_permission = {
"methods": ["GET"], # "methods": ["GET"],
"path": f"/projects/{project_id}", # "path": f"/projects/{project_id}",
"action": "ALLOW" # "action": "ALLOW"
} # }
response = await client.post(app.url_path_for("create_permission"), json=new_permission) # response = await client.post(app.url_path_for("create_permission"), json=new_permission)
assert response.status_code == status.HTTP_201_CREATED # assert response.status_code == status.HTTP_201_CREATED
permission_id = response.json()["permission_id"] # permission_id = response.json()["permission_id"]
#
response = await client.put( # response = await client.put(
app.url_path_for( # app.url_path_for(
"add_permission_to_user", # "add_permission_to_user",
user_id=test_user.user_id, # user_id=test_user.user_id,
permission_id=permission_id # permission_id=permission_id
) # )
) # )
assert response.status_code == status.HTTP_204_NO_CONTENT # assert response.status_code == status.HTTP_204_NO_CONTENT
#
async def test_user_access_admin_project( # async def test_user_access_admin_project(
self, # self,
app: FastAPI, # app: FastAPI,
authorized_client: AsyncClient, # authorized_client: AsyncClient,
test_user: User, # test_user: User,
db_session: AsyncSession # db_session: AsyncSession
) -> None: # ) -> None:
#
response = await authorized_client.get(app.url_path_for("get_projects")) # response = await authorized_client.get(app.url_path_for("get_projects"))
assert response.status_code == status.HTTP_200_OK # assert response.status_code == status.HTTP_200_OK
projects = response.json() # projects = response.json()
assert len(projects) == 2 # assert len(projects) == 2
#
#
class TestTemplatesWithRbac: # class TestTemplatesWithRbac:
#
async def test_admin_create_template(self, app: FastAPI, client: AsyncClient): # async def test_admin_create_template(self, app: FastAPI, client: AsyncClient):
#
new_template = {"base_script_file": "vpcs_base_config.txt", # new_template = {"base_script_file": "vpcs_base_config.txt",
"category": "guest", # "category": "guest",
"console_auto_start": False, # "console_auto_start": False,
"console_type": "telnet", # "console_type": "telnet",
"default_name_format": "PC{0}", # "default_name_format": "PC{0}",
"name": "ADMIN_VPCS_TEMPLATE", # "name": "ADMIN_VPCS_TEMPLATE",
"compute_id": "local", # "compute_id": "local",
"symbol": ":/symbols/vpcs_guest.svg", # "symbol": ":/symbols/vpcs_guest.svg",
"template_type": "vpcs"} # "template_type": "vpcs"}
#
response = await client.post(app.url_path_for("create_template"), json=new_template) # response = await client.post(app.url_path_for("create_template"), json=new_template)
assert response.status_code == status.HTTP_201_CREATED # assert response.status_code == status.HTTP_201_CREATED
#
async def test_user_only_access_own_templates( # async def test_user_only_access_own_templates(
self, app: FastAPI, # self, app: FastAPI,
authorized_client: AsyncClient, # authorized_client: AsyncClient,
test_user: User, # test_user: User,
db_session: AsyncSession # db_session: AsyncSession
) -> None: # ) -> None:
#
new_template = {"base_script_file": "vpcs_base_config.txt", # new_template = {"base_script_file": "vpcs_base_config.txt",
"category": "guest", # "category": "guest",
"console_auto_start": False, # "console_auto_start": False,
"console_type": "telnet", # "console_type": "telnet",
"default_name_format": "PC{0}", # "default_name_format": "PC{0}",
"name": "USER_VPCS_TEMPLATE", # "name": "USER_VPCS_TEMPLATE",
"compute_id": "local", # "compute_id": "local",
"symbol": ":/symbols/vpcs_guest.svg", # "symbol": ":/symbols/vpcs_guest.svg",
"template_type": "vpcs"} # "template_type": "vpcs"}
#
response = await authorized_client.post(app.url_path_for("create_template"), json=new_template) # response = await authorized_client.post(app.url_path_for("create_template"), json=new_template)
assert response.status_code == status.HTTP_201_CREATED # assert response.status_code == status.HTTP_201_CREATED
template_id = response.json()["template_id"] # template_id = response.json()["template_id"]
#
rbac_repo = RbacRepository(db_session) # rbac_repo = RbacRepository(db_session)
permissions_in_db = await rbac_repo.get_user_permissions(test_user.user_id) # permissions_in_db = await rbac_repo.get_user_permissions(test_user.user_id)
assert len(permissions_in_db) == 1 # assert len(permissions_in_db) == 1
assert permissions_in_db[0].path == f"/templates/{template_id}/*" # assert permissions_in_db[0].path == f"/templates/{template_id}/*"
#
response = await authorized_client.get(app.url_path_for("get_templates")) # response = await authorized_client.get(app.url_path_for("get_templates"))
assert response.status_code == status.HTTP_200_OK # assert response.status_code == status.HTTP_200_OK
templates = [template for template in response.json() if template["builtin"] is False] # templates = [template for template in response.json() if template["builtin"] is False]
assert len(templates) == 1 # assert len(templates) == 1
#
async def test_admin_access_all_templates(self, app: FastAPI, client: AsyncClient): # async def test_admin_access_all_templates(self, app: FastAPI, client: AsyncClient):
#
response = await client.get(app.url_path_for("get_templates")) # response = await client.get(app.url_path_for("get_templates"))
assert response.status_code == status.HTTP_200_OK # assert response.status_code == status.HTTP_200_OK
templates = [template for template in response.json() if template["builtin"] is False] # templates = [template for template in response.json() if template["builtin"] is False]
assert len(templates) == 2 # assert len(templates) == 2