1
0
mirror of https://github.com/GNS3/gns3-server synced 2024-11-24 17:28:08 +00:00

Basic functional RBAC support.

This commit is contained in:
grossmj 2021-05-27 17:28:44 +09:30
parent 6d4da98b8e
commit fbc47598d9
19 changed files with 527 additions and 92 deletions

View File

@ -55,7 +55,7 @@ router.include_router(
)
router.include_router(
roles.router,
permissions.router,
dependencies=[Depends(get_current_active_user)],
prefix="/permissions",
tags=["Permissions"]

View File

@ -15,13 +15,13 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from fastapi import Depends, HTTPException, status
from fastapi import Request, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from gns3server import schemas
from gns3server.db.repositories.users import UsersRepository
from gns3server.db.repositories.rbac import RbacRepository
from gns3server.services import auth_service
from .database import get_repository
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/v3/users/login")
@ -42,7 +42,11 @@ async def get_user_from_token(
return user
async def get_current_active_user(current_user: schemas.User = Depends(get_user_from_token)) -> schemas.User:
async def get_current_active_user(
request: Request,
current_user: schemas.User = Depends(get_user_from_token),
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository))
) -> schemas.User:
# Super admin is always authorized
if current_user.is_superadmin:
@ -54,4 +58,19 @@ async def get_current_active_user(current_user: schemas.User = Depends(get_user_
detail="Not an active user",
headers={"WWW-Authenticate": "Bearer"},
)
# remove the prefix (e.g. "/v3") from URL path
if request.url.path.startswith("/v3"):
path = request.url.path[len("/v3"):]
else:
path = request.url.path
authorized = await rbac_repo.check_user_is_authorized(current_user.user_id, request.method, path)
if not authorized:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"User is not authorized '{current_user.user_id}' on {request.method} '{path}'",
headers={"WWW-Authenticate": "Bearer"},
)
return current_user

View File

@ -99,7 +99,7 @@ async def update_user_group(
if not user_group:
raise ControllerNotFoundError(f"User group '{user_group_id}' not found")
if not user_group.is_updatable:
if user_group.builtin:
raise ControllerForbiddenError(f"User group '{user_group_id}' cannot be updated")
return await users_repo.update_user_group(user_group_id, user_group_update)
@ -121,7 +121,7 @@ async def delete_user_group(
if not user_group:
raise ControllerNotFoundError(f"User group '{user_group_id}' not found")
if not user_group.is_updatable:
if user_group.builtin:
raise ControllerForbiddenError(f"User group '{user_group_id}' cannot be deleted")
success = await users_repo.delete_user_group(user_group_id)

View File

@ -60,8 +60,9 @@ async def create_permission(
Create a new permission.
"""
# if await rbac_repo.get_role_by_path(role_create.name):
# raise ControllerBadRequestError(f"Role '{role_create.name}' already exists")
if await rbac_repo.check_permission_exists(permission_create):
raise ControllerBadRequestError(f"Permission '{permission_create.methods} {permission_create.path} "
f"{permission_create.action}' already exists")
return await rbac_repo.create_permission(permission_create)
@ -95,9 +96,6 @@ async def update_permission(
if not permission:
raise ControllerNotFoundError(f"Permission '{permission_id}' not found")
#if not user_group.is_updatable:
# raise ControllerForbiddenError(f"User group '{user_group_id}' cannot be updated")
return await rbac_repo.update_permission(permission_id, permission_update)
@ -114,9 +112,6 @@ async def delete_permission(
if not permission:
raise ControllerNotFoundError(f"Permission '{permission_id}' not found")
#if not user_group.is_updatable:
# raise ControllerForbiddenError(f"User group '{user_group_id}' cannot be deleted")
success = await rbac_repo.delete_permission(permission_id)
if not success:
raise ControllerNotFoundError(f"Permission '{permission_id}' could not be deleted")

View File

@ -47,6 +47,10 @@ from gns3server.controller.export_project import export_project as export_contro
from gns3server.utils.asyncio import aiozipstream
from gns3server.utils.path import is_safe_path
from gns3server.config import Config
from gns3server.db.repositories.rbac import RbacRepository
from .dependencies.authentication import get_current_active_user
from .dependencies.database import get_repository
responses = {404: {"model": schemas.ErrorMessage, "description": "Could not find project"}}
@ -82,13 +86,18 @@ def get_projects() -> List[schemas.Project]:
response_model_exclude_unset=True,
responses={409: {"model": schemas.ErrorMessage, "description": "Could not create project"}},
)
async def create_project(project_data: schemas.ProjectCreate) -> schemas.Project:
async def create_project(
project_data: schemas.ProjectCreate,
current_user: schemas.User = Depends(get_current_active_user),
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository))
) -> schemas.Project:
"""
Create a new project.
"""
controller = Controller.instance()
project = await controller.add_project(**jsonable_encoder(project_data, exclude_unset=True))
await rbac_repo.add_permission_to_user(current_user.user_id, f"/projects/{project.id}/*")
return project.asdict()
@ -115,7 +124,10 @@ async def update_project(
@router.delete("/{project_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_project(project: Project = Depends(dep_project)) -> None:
async def delete_project(
project: Project = Depends(dep_project),
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository))
) -> None:
"""
Delete a project.
"""
@ -123,6 +135,7 @@ async def delete_project(project: Project = Depends(dep_project)) -> None:
controller = Controller.instance()
await project.delete()
controller.remove_project(project)
await rbac_repo.delete_all_permissions_matching_path(f"/projects/{project.id}")
@router.get("/{project_id}/stats")

View File

@ -95,8 +95,8 @@ async def update_role(
if not role:
raise ControllerNotFoundError(f"Role '{role_id}' not found")
#if not user_group.is_updatable:
# raise ControllerForbiddenError(f"User group '{user_group_id}' cannot be updated")
if role.builtin:
raise ControllerForbiddenError(f"Role '{role_id}' cannot be updated")
return await rbac_repo.update_role(role_id, role_update)
@ -114,8 +114,8 @@ async def delete_role(
if not role:
raise ControllerNotFoundError(f"Role '{role_id}' not found")
#if not user_group.is_updatable:
# raise ControllerForbiddenError(f"User group '{user_group_id}' cannot be deleted")
if role.builtin:
raise ControllerForbiddenError(f"Role '{role_id}' cannot be deleted")
success = await rbac_repo.delete_role(role_id)
if not success:

View File

@ -88,6 +88,24 @@ async def authenticate(
return token
@router.get("/me", response_model=schemas.User)
async def get_logged_in_user(current_user: schemas.User = Depends(get_current_active_user)) -> schemas.User:
"""
Get the current active user.
"""
return current_user
@router.get("/me", response_model=schemas.User)
async def get_logged_in_user(current_user: schemas.User = Depends(get_current_active_user)) -> schemas.User:
"""
Get the current active user.
"""
return current_user
@router.get("", response_model=List[schemas.User], dependencies=[Depends(get_current_active_user)])
async def get_users(
users_repo: UsersRepository = Depends(get_repository(UsersRepository))
@ -178,15 +196,6 @@ async def delete_user(
raise ControllerNotFoundError(f"User '{user_id}' could not be deleted")
@router.get("/me/", response_model=schemas.User)
async def get_current_active_user(current_user: schemas.User = Depends(get_current_active_user)) -> schemas.User:
"""
Get the current active user.
"""
return current_user
@router.get(
"/{user_id}/groups",
dependencies=[Depends(get_current_active_user)],

View File

@ -15,7 +15,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from sqlalchemy import Table, Column, String, ForeignKey, Boolean
from sqlalchemy import Table, Column, String, ForeignKey, event
from sqlalchemy.orm import relationship
from .base import Base, BaseTable, generate_uuid, GUID, ListType
@ -39,7 +39,78 @@ class Permission(BaseTable):
__tablename__ = "permissions"
permission_id = Column(GUID, primary_key=True, default=generate_uuid)
description = Column(String)
methods = Column(ListType)
path = Column(String)
action = Column(String)
user_id = Column(GUID, ForeignKey('users.user_id', ondelete="CASCADE"))
roles = relationship("Role", secondary=permission_role_link, back_populates="permissions")
@event.listens_for(Permission.__table__, 'after_create')
def create_default_roles(target, connection, **kw):
default_permissions = [
{
"description": "Allow access to all endpoints",
"methods": ["GET", "HEAD", "POST", "PUT", "DELETE", "PATCH"],
"path": "/",
"action": "ALLOW"
},
{
"description": "Allow access to the logged in user",
"methods": ["GET"],
"path": "/users/me",
"action": "ALLOW"
},
{
"description": "Allow to create a project or list projects",
"methods": ["GET", "POST"],
"path": "/projects",
"action": "ALLOW"
},
{
"description": "Allow to access to all symbol endpoints",
"methods": ["GET", "HEAD", "POST", "PUT", "DELETE", "PATCH"],
"path": "/symbols",
"action": "ALLOW"
},
]
stmt = target.insert().values(default_permissions)
connection.execute(stmt)
connection.commit()
log.debug("The default permissions have been created in the database")
@event.listens_for(permission_role_link, 'after_create')
def add_permissions_to_role(target, connection, **kw):
from .roles import Role
roles_table = Role.__table__
stmt = roles_table.select().where(roles_table.c.name == "Administrator")
result = connection.execute(stmt)
role_id = result.first().role_id
permissions_table = Permission.__table__
stmt = permissions_table.select().where(permissions_table.c.path == "/")
result = connection.execute(stmt)
permission_id = result.first().permission_id
# add root path to the "Administrator" role
stmt = target.insert().values(permission_id=permission_id, role_id=role_id)
connection.execute(stmt)
stmt = roles_table.select().where(roles_table.c.name == "User")
result = connection.execute(stmt)
role_id = result.first().role_id
# add minimum required paths to the "User" role
for path in ("/projects", "/symbols", "/users/me"):
stmt = permissions_table.select().where(permissions_table.c.path == path)
result = connection.execute(stmt)
permission_id = result.first().permission_id
stmt = target.insert().values(permission_id=permission_id, role_id=role_id)
connection.execute(stmt)
connection.commit()

View File

@ -40,7 +40,7 @@ class Role(BaseTable):
role_id = Column(GUID, primary_key=True, default=generate_uuid)
name = Column(String)
description = Column(String)
is_updatable = Column(Boolean, default=True)
builtin = Column(Boolean, default=False)
permissions = relationship("Permission", secondary=permission_role_link, back_populates="roles")
groups = relationship("UserGroup", secondary=role_group_link, back_populates="roles")
@ -49,11 +49,33 @@ class Role(BaseTable):
def create_default_roles(target, connection, **kw):
default_roles = [
{"name": "Administrator", "description": "Administrator role", "is_updatable": False},
{"name": "User", "description": "User role", "is_updatable": False},
{"name": "Administrator", "description": "Administrator role", "builtin": True},
{"name": "User", "description": "User role", "builtin": True},
]
stmt = target.insert().values(default_roles)
connection.execute(stmt)
connection.commit()
log.info("The default roles have been created in the database")
log.debug("The default roles have been created in the database")
@event.listens_for(role_group_link, 'after_create')
def add_admin_to_group(target, connection, **kw):
from .users import UserGroup
user_groups_table = UserGroup.__table__
roles_table = Role.__table__
# Add roles to built-in user groups
groups_to_roles = {"Administrators": "Administrator", "Users": "User"}
for user_group, role in groups_to_roles.items():
stmt = user_groups_table.select().where(user_groups_table.c.name == user_group)
result = connection.execute(stmt)
user_group_id = result.first().user_group_id
stmt = roles_table.select().where(roles_table.c.name == role)
result = connection.execute(stmt)
role_id = result.first().role_id
stmt = target.insert().values(role_id=role_id, user_group_id=user_group_id)
connection.execute(stmt)
connection.commit()

View File

@ -48,7 +48,6 @@ class User(BaseTable):
is_active = Column(Boolean, default=True)
is_superadmin = Column(Boolean, default=False)
groups = relationship("UserGroup", secondary=user_group_link, back_populates="users")
permission_id = Column(GUID, ForeignKey('permissions.permission_id', ondelete="CASCADE"))
permissions = relationship("Permission")
@ -67,7 +66,7 @@ def create_default_super_admin(target, connection, **kw):
)
connection.execute(stmt)
connection.commit()
log.info("The default super admin account has been created in the database")
log.debug("The default super admin account has been created in the database")
class UserGroup(BaseTable):
@ -76,7 +75,7 @@ class UserGroup(BaseTable):
user_group_id = Column(GUID, primary_key=True, default=generate_uuid)
name = Column(String, unique=True, index=True)
is_updatable = Column(Boolean, default=True)
builtin = Column(Boolean, default=False)
users = relationship("User", secondary=user_group_link, back_populates="groups")
roles = relationship("Role", secondary=role_group_link, back_populates="groups")
@ -85,30 +84,29 @@ class UserGroup(BaseTable):
def create_default_user_groups(target, connection, **kw):
default_groups = [
{"name": "Administrators", "is_updatable": False},
{"name": "Editors", "is_updatable": False},
{"name": "Users", "is_updatable": False}
{"name": "Administrators", "builtin": True},
{"name": "Users", "builtin": True}
]
stmt = target.insert().values(default_groups)
connection.execute(stmt)
connection.commit()
log.info("The default user groups have been created in the database")
log.debug("The default user groups have been created in the database")
@event.listens_for(user_group_link, 'after_create')
def add_admin_to_group(target, connection, **kw):
user_groups_table = UserGroup.__table__
stmt = user_groups_table.select().where(user_groups_table.c.name == "Administrators")
result = connection.execute(stmt)
user_group_id = result.first().user_group_id
users_table = User.__table__
stmt = users_table.select().where(users_table.c.is_superadmin.is_(True))
result = connection.execute(stmt)
user_id = result.first().user_id
stmt = target.insert().values(user_id=user_id, user_group_id=user_group_id)
connection.execute(stmt)
connection.commit()
# @event.listens_for(user_group_link, 'after_create')
# def add_admin_to_group(target, connection, **kw):
#
# user_groups_table = UserGroup.__table__
# stmt = user_groups_table.select().where(user_groups_table.c.name == "Administrators")
# result = connection.execute(stmt)
# user_group_id = result.first().user_group_id
#
# users_table = User.__table__
# stmt = users_table.select().where(users_table.c.is_superadmin.is_(True))
# result = connection.execute(stmt)
# user_id = result.first().user_id
#
# stmt = target.insert().values(user_id=user_id, user_group_id=user_group_id)
# connection.execute(stmt)
# connection.commit()

View File

@ -24,6 +24,7 @@ from sqlalchemy.orm import selectinload
from .base import BaseRepository
import gns3server.db.models as models
from gns3server.schemas.controller.rbac import HTTPMethods, PermissionAction
from gns3server import schemas
import logging
@ -38,6 +39,9 @@ class RbacRepository(BaseRepository):
super().__init__(db_session)
async def get_role(self, role_id: UUID) -> Optional[models.Role]:
"""
Get a role by its ID.
"""
query = select(models.Role).\
options(selectinload(models.Role.permissions)).\
@ -46,6 +50,9 @@ class RbacRepository(BaseRepository):
return result.scalars().first()
async def get_role_by_name(self, name: str) -> Optional[models.Role]:
"""
Get a role by its name.
"""
query = select(models.Role).\
options(selectinload(models.Role.permissions)).\
@ -55,12 +62,18 @@ class RbacRepository(BaseRepository):
return result.scalars().first()
async def get_roles(self) -> List[models.Role]:
"""
Get all roles.
"""
query = select(models.Role).options(selectinload(models.Role.permissions))
result = await self._db_session.execute(query)
return result.scalars().all()
async def create_role(self, role_create: schemas.RoleCreate) -> models.Role:
"""
Create a new role.
"""
db_role = models.Role(
name=role_create.name,
@ -76,6 +89,9 @@ class RbacRepository(BaseRepository):
role_id: UUID,
role_update: schemas.RoleUpdate
) -> Optional[models.Role]:
"""
Update a role.
"""
update_values = role_update.dict(exclude_unset=True)
query = update(models.Role).where(models.Role.role_id == role_id).values(update_values)
@ -85,6 +101,9 @@ class RbacRepository(BaseRepository):
return await self.get_role(role_id)
async def delete_role(self, role_id: UUID) -> bool:
"""
Delete a role.
"""
query = delete(models.Role).where(models.Role.role_id == role_id)
result = await self._db_session.execute(query)
@ -96,6 +115,9 @@ class RbacRepository(BaseRepository):
role_id: UUID,
permission: models.Permission
) -> Union[None, models.Role]:
"""
Add a permission to a role.
"""
query = select(models.Role).\
options(selectinload(models.Role.permissions)).\
@ -115,6 +137,9 @@ class RbacRepository(BaseRepository):
role_id: UUID,
permission: models.Permission
) -> Union[None, models.Role]:
"""
Remove a permission from a role.
"""
query = select(models.Role).\
options(selectinload(models.Role.permissions)).\
@ -130,6 +155,9 @@ class RbacRepository(BaseRepository):
return role_db
async def get_role_permissions(self, role_id: UUID) -> List[models.Permission]:
"""
Get all the role permissions.
"""
query = select(models.Permission).\
join(models.Permission.roles).\
@ -139,30 +167,48 @@ class RbacRepository(BaseRepository):
return result.scalars().all()
async def get_permission(self, permission_id: UUID) -> Optional[models.Permission]:
"""
Get a permission by its ID.
"""
query = select(models.Permission).where(models.Permission.permission_id == permission_id)
result = await self._db_session.execute(query)
return result.scalars().first()
async def get_permission_by_path(self, path: str) -> Optional[models.Permission]:
"""
Get a permission by its path.
"""
query = select(models.Permission).where(models.Permission.path == path)
result = await self._db_session.execute(query)
return result.scalars().first()
async def get_permissions(self) -> List[models.Permission]:
"""
Get all permissions.
"""
query = select(models.Permission)
result = await self._db_session.execute(query)
return result.scalars().all()
async def create_permission(self, permission_create: schemas.PermissionCreate) -> models.Permission:
async def check_permission_exists(self, permission_create: schemas.PermissionCreate) -> bool:
"""
Check if a permission exists.
"""
create_values = permission_create.dict(exclude_unset=True)
# action = create_values.pop("action", "deny")
# is_allowed = False
# if action == "allow":
# is_allowed = True
query = select(models.Permission).\
where(models.Permission.methods == permission_create.methods,
models.Permission.path == permission_create.path,
models.Permission.action == permission_create.action)
result = await self._db_session.execute(query)
return result.scalars().first() is not None
async def create_permission(self, permission_create: schemas.PermissionCreate) -> models.Permission:
"""
Create a new permission.
"""
db_permission = models.Permission(
methods=permission_create.methods,
@ -170,7 +216,6 @@ class RbacRepository(BaseRepository):
action=permission_create.action,
)
self._db_session.add(db_permission)
await self._db_session.commit()
await self._db_session.refresh(db_permission)
return db_permission
@ -180,6 +225,9 @@ class RbacRepository(BaseRepository):
permission_id: UUID,
permission_update: schemas.PermissionUpdate
) -> Optional[models.Permission]:
"""
Update a permission.
"""
update_values = permission_update.dict(exclude_unset=True)
query = update(models.Permission).where(models.Permission.permission_id == permission_id).values(update_values)
@ -189,8 +237,110 @@ class RbacRepository(BaseRepository):
return await self.get_permission(permission_id)
async def delete_permission(self, permission_id: UUID) -> bool:
"""
Delete a permission.
"""
query = delete(models.Permission).where(models.Permission.permission_id == permission_id)
result = await self._db_session.execute(query)
await self._db_session.commit()
return result.rowcount > 0
def _match_permission(
self,
permissions: List[models.Permission],
method: str,
path: str
) -> Union[None, models.Permission]:
"""
Match the methods and path with a permission.
"""
for permission in permissions:
log.debug(f"RBAC: checking permission {permission.methods} {permission.path} {permission.action}")
if method not in permission.methods:
continue
if permission.path.endswith("*") and path.startswith(permission.path[:-1]):
return permission
elif permission.path == path:
return permission
async def check_user_is_authorized(self, user_id: UUID, method: str, path: str) -> bool:
"""
Check if an user is authorized to access a resource.
"""
query = select(models.Permission).\
join(models.Permission.roles). \
join(models.Role.groups). \
join(models.UserGroup.users). \
filter(models.User.user_id == user_id).\
order_by(models.Permission.path)
result = await self._db_session.execute(query)
permissions = result.scalars().all()
log.debug(f"RBAC: checking authorization for '{user_id}' on {method} '{path}'")
matched_permission = self._match_permission(permissions, method, path)
if matched_permission:
log.debug(f"RBAC: matched role permission {matched_permission.methods} "
f"{matched_permission.path} {matched_permission.action}")
if matched_permission.action == "DENY":
return False
return True
log.debug(f"RBAC: could not find a role permission, checking user permissions...")
query = select(models.Permission).\
join(models.User.permissions). \
filter(models.User.user_id == user_id).\
order_by(models.Permission.path)
result = await self._db_session.execute(query)
permissions = result.scalars().all()
matched_permission = self._match_permission(permissions, method, path)
if matched_permission:
log.debug(f"RBAC: matched user permission {matched_permission.methods} "
f"{matched_permission.path} {matched_permission.action}")
if matched_permission.action == "DENY":
return False
return True
return False
async def add_permission_to_user(self, user_id: UUID, path: str) -> Union[None, models.User]:
"""
Add a permission to an user.
"""
# Create a new permission with full rights
new_permission = schemas.PermissionCreate(
methods=[HTTPMethods.get, HTTPMethods.head, HTTPMethods.post, HTTPMethods.put, HTTPMethods.delete],
path=path,
action=PermissionAction.allow
)
permission_db = await self.create_permission(new_permission)
# Add the permission to the user
query = select(models.User).\
options(selectinload(models.User.permissions)).\
where(models.User.user_id == user_id)
result = await self._db_session.execute(query)
user_db = result.scalars().first()
if not user_db:
return None
user_db.permissions.append(permission_db)
await self._db_session.commit()
await self._db_session.refresh(user_db)
return user_db
async def delete_all_permissions_matching_path(self, path: str) -> None:
"""
Delete all permissions matching with path.
"""
query = delete(models.Permission).\
where(models.Permission.path.startswith(path)).\
execution_options(synchronize_session=False)
result = await self._db_session.execute(query)
log.debug(f"{result.rowcount} permission(s) have been deleted")

View File

@ -33,40 +33,59 @@ log = logging.getLogger(__name__)
class UsersRepository(BaseRepository):
def __init__(self, db_session: AsyncSession) -> None:
super().__init__(db_session)
self._auth_service = auth_service
async def get_user(self, user_id: UUID) -> Optional[models.User]:
"""
Get an user by its ID.
"""
query = select(models.User).where(models.User.user_id == user_id)
result = await self._db_session.execute(query)
return result.scalars().first()
async def get_user_by_username(self, username: str) -> Optional[models.User]:
"""
Get an user by its name.
"""
query = select(models.User).where(models.User.username == username)
result = await self._db_session.execute(query)
return result.scalars().first()
async def get_user_by_email(self, email: str) -> Optional[models.User]:
"""
Get an user by its email.
"""
query = select(models.User).where(models.User.email == email)
result = await self._db_session.execute(query)
return result.scalars().first()
async def get_users(self) -> List[models.User]:
"""
Get all users.
"""
query = select(models.User)
result = await self._db_session.execute(query)
return result.scalars().all()
async def create_user(self, user: schemas.UserCreate) -> models.User:
"""
Create a new user.
"""
hashed_password = self._auth_service.hash_password(user.password.get_secret_value())
db_user = models.User(
username=user.username, email=user.email, full_name=user.full_name, hashed_password=hashed_password
username=user.username,
email=user.email,
full_name=user.full_name,
hashed_password=hashed_password
)
self._db_session.add(db_user)
await self._db_session.commit()
@ -74,6 +93,9 @@ class UsersRepository(BaseRepository):
return db_user
async def update_user(self, user_id: UUID, user_update: schemas.UserUpdate) -> Optional[models.User]:
"""
Update an user.
"""
update_values = user_update.dict(exclude_unset=True)
password = update_values.pop("password", None)
@ -87,6 +109,9 @@ class UsersRepository(BaseRepository):
return await self.get_user(user_id)
async def delete_user(self, user_id: UUID) -> bool:
"""
Delete an user.
"""
query = delete(models.User).where(models.User.user_id == user_id)
result = await self._db_session.execute(query)
@ -94,6 +119,9 @@ class UsersRepository(BaseRepository):
return result.rowcount > 0
async def authenticate_user(self, username: str, password: str) -> Optional[models.User]:
"""
Authenticate an user.
"""
user = await self.get_user_by_username(username)
if not user:
@ -110,6 +138,9 @@ class UsersRepository(BaseRepository):
return user
async def get_user_memberships(self, user_id: UUID) -> List[models.UserGroup]:
"""
Get all user memberships (user groups).
"""
query = select(models.UserGroup).\
join(models.UserGroup.users).\
@ -119,24 +150,36 @@ class UsersRepository(BaseRepository):
return result.scalars().all()
async def get_user_group(self, user_group_id: UUID) -> Optional[models.UserGroup]:
"""
Get an user group by its ID.
"""
query = select(models.UserGroup).where(models.UserGroup.user_group_id == user_group_id)
result = await self._db_session.execute(query)
return result.scalars().first()
async def get_user_group_by_name(self, name: str) -> Optional[models.UserGroup]:
"""
Get an user group by its name.
"""
query = select(models.UserGroup).where(models.UserGroup.name == name)
result = await self._db_session.execute(query)
return result.scalars().first()
async def get_user_groups(self) -> List[models.UserGroup]:
"""
Get all user groups.
"""
query = select(models.UserGroup)
result = await self._db_session.execute(query)
return result.scalars().all()
async def create_user_group(self, user_group: schemas.UserGroupCreate) -> models.UserGroup:
"""
Create a new user group.
"""
db_user_group = models.UserGroup(name=user_group.name)
self._db_session.add(db_user_group)
@ -149,6 +192,9 @@ class UsersRepository(BaseRepository):
user_group_id: UUID,
user_group_update: schemas.UserGroupUpdate
) -> Optional[models.UserGroup]:
"""
Update an user group.
"""
update_values = user_group_update.dict(exclude_unset=True)
query = update(models.UserGroup).where(models.UserGroup.user_group_id == user_group_id).values(update_values)
@ -158,6 +204,9 @@ class UsersRepository(BaseRepository):
return await self.get_user_group(user_group_id)
async def delete_user_group(self, user_group_id: UUID) -> bool:
"""
Delete an user group.
"""
query = delete(models.UserGroup).where(models.UserGroup.user_group_id == user_group_id)
result = await self._db_session.execute(query)
@ -169,6 +218,9 @@ class UsersRepository(BaseRepository):
user_group_id: UUID,
user: models.User
) -> Union[None, models.UserGroup]:
"""
Add a member to an user group.
"""
query = select(models.UserGroup).\
options(selectinload(models.UserGroup.users)).\
@ -188,6 +240,9 @@ class UsersRepository(BaseRepository):
user_group_id: UUID,
user: models.User
) -> Union[None, models.UserGroup]:
"""
Remove a member from an user group.
"""
query = select(models.UserGroup).\
options(selectinload(models.UserGroup.users)).\
@ -203,6 +258,9 @@ class UsersRepository(BaseRepository):
return user_group_db
async def get_user_group_members(self, user_group_id: UUID) -> List[models.User]:
"""
Get all members from an user group.
"""
query = select(models.User).\
join(models.User.groups).\
@ -216,6 +274,9 @@ class UsersRepository(BaseRepository):
user_group_id: UUID,
role: models.Role
) -> Union[None, models.UserGroup]:
"""
Add a role to an user group.
"""
query = select(models.UserGroup).\
options(selectinload(models.UserGroup.roles)).\
@ -235,6 +296,9 @@ class UsersRepository(BaseRepository):
user_group_id: UUID,
role: models.Role
) -> Union[None, models.UserGroup]:
"""
Remove a role from an user group.
"""
query = select(models.UserGroup).\
options(selectinload(models.UserGroup.roles)).\
@ -250,6 +314,9 @@ class UsersRepository(BaseRepository):
return user_group_db
async def get_user_group_roles(self, user_group_id: UUID) -> List[models.Role]:
"""
Get all roles from an user group.
"""
query = select(models.Role). \
options(selectinload(models.Role.permissions)). \

View File

@ -114,6 +114,7 @@ class RoleUpdate(RoleBase):
class Role(DateTimeModelMixin, RoleBase):
role_id: UUID
builtin: bool
permissions: List[Permission]
class Config:

View File

@ -85,7 +85,7 @@ class UserGroupUpdate(UserGroupBase):
class UserGroup(DateTimeModelMixin, UserGroupBase):
user_group_id: UUID
is_updatable: bool
builtin: bool
class Config:
orm_mode = True

View File

@ -50,7 +50,7 @@ class TestGroupRoutes:
response = await client.get(app.url_path_for("get_user_groups"))
assert response.status_code == status.HTTP_200_OK
assert len(response.json()) == 4 # 3 default groups + group1
assert len(response.json()) == 3 # 2 default groups + group1
async def test_update_group(self, app: FastAPI, client: AsyncClient, db_session: AsyncSession) -> None:
@ -206,8 +206,10 @@ class TestGroupRolesRoutes:
)
assert response.status_code == status.HTTP_204_NO_CONTENT
roles = await user_repo.get_user_group_roles(group_in_db.user_group_id)
assert len(roles) == 1
assert roles[0].name == test_role.name
assert len(roles) == 2 # 1 default role + 1 custom role
for role in roles:
if not role.builtin:
assert role.name == test_role.name
async def test_get_user_group_roles(
self,
@ -224,7 +226,7 @@ class TestGroupRolesRoutes:
user_group_id=group_in_db.user_group_id)
)
assert response.status_code == status.HTTP_200_OK
assert len(response.json()) == 1
assert len(response.json()) == 2 # 1 default role + 1 custom role
async def test_remove_role_from_group(
self,
@ -246,4 +248,5 @@ class TestGroupRolesRoutes:
)
assert response.status_code == status.HTTP_204_NO_CONTENT
roles = await user_repo.get_user_group_roles(group_in_db.user_group_id)
assert len(roles) == 0
assert len(roles) == 1 # 1 default role
assert roles[0].name != test_role.name

View File

@ -0,0 +1,83 @@
#!/usr/bin/env python
#
# Copyright (C) 2021 GNS3 Technologies Inc.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import pytest
from fastapi import FastAPI, status
from httpx import AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession
from gns3server.db.repositories.rbac import RbacRepository
pytestmark = pytest.mark.asyncio
class TestPermissionRoutes:
async def test_create_permission(self, app: FastAPI, client: AsyncClient) -> None:
new_permission = {
"methods": ["GET"],
"path": "/templates",
"action": "ALLOW"
}
response = await client.post(app.url_path_for("create_permission"), json=new_permission)
assert response.status_code == status.HTTP_201_CREATED
async def test_get_permission(self, app: FastAPI, client: AsyncClient, db_session: AsyncSession) -> None:
rbac_repo = RbacRepository(db_session)
permission_in_db = await rbac_repo.get_permission_by_path("/templates")
response = await client.get(app.url_path_for("get_permission", permission_id=permission_in_db.permission_id))
assert response.status_code == status.HTTP_200_OK
assert response.json()["permission_id"] == str(permission_in_db.permission_id)
async def test_list_permissions(self, app: FastAPI, client: AsyncClient) -> None:
response = await client.get(app.url_path_for("get_permissions"))
assert response.status_code == status.HTTP_200_OK
assert len(response.json()) == 5 # 4 default permissions + 1 custom permission
async def test_update_permission(self, app: FastAPI, client: AsyncClient, db_session: AsyncSession) -> None:
rbac_repo = RbacRepository(db_session)
permission_in_db = await rbac_repo.get_permission_by_path("/templates")
update_permission = {
"methods": ["GET"],
"path": "/appliances",
"action": "ALLOW"
}
response = await client.put(
app.url_path_for("update_permission", permission_id=permission_in_db.permission_id),
json=update_permission
)
assert response.status_code == status.HTTP_200_OK
updated_permission_in_db = await rbac_repo.get_permission(permission_in_db.permission_id)
assert updated_permission_in_db.path == "/appliances"
async def test_delete_permission(
self,
app: FastAPI,
client: AsyncClient,
db_session: AsyncSession
) -> None:
rbac_repo = RbacRepository(db_session)
permission_in_db = await rbac_repo.get_permission_by_path("/appliances")
response = await client.delete(app.url_path_for("delete_permission", permission_id=permission_in_db.permission_id))
assert response.status_code == status.HTTP_204_NO_CONTENT

View File

@ -110,11 +110,11 @@ async def test_permission(db_session: AsyncSession) -> Permission:
new_permission = schemas.PermissionCreate(
methods=[HTTPMethods.get, HTTPMethods.post],
path="/projects",
path="/templates",
action=PermissionAction.allow
)
rbac_repo = RbacRepository(db_session)
existing_permission = await rbac_repo.get_permission_by_path("/projects")
existing_permission = await rbac_repo.get_permission_by_path("/templates")
if existing_permission:
return existing_permission
return await rbac_repo.create_permission(new_permission)
@ -142,8 +142,7 @@ class TestRolesPermissionsRoutes:
)
assert response.status_code == status.HTTP_204_NO_CONTENT
permissions = await rbac_repo.get_role_permissions(role_in_db.role_id)
assert len(permissions) == 1
assert permissions[0].path == test_permission.path
assert len(permissions) == 4 # 3 default + 1 custom permissions
async def test_get_role_permissions(
self,
@ -161,7 +160,7 @@ class TestRolesPermissionsRoutes:
role_id=role_in_db.role_id)
)
assert response.status_code == status.HTTP_200_OK
assert len(response.json()) == 1
assert len(response.json()) == 4 # 3 default + 1 custom permissions
async def test_remove_role_from_group(
self,
@ -183,4 +182,4 @@ class TestRolesPermissionsRoutes:
)
assert response.status_code == status.HTTP_204_NO_CONTENT
permissions = await rbac_repo.get_role_permissions(role_in_db.role_id)
assert len(permissions) == 0
assert len(permissions) == 3 # 3 default permissions

View File

@ -266,7 +266,7 @@ class TestUserMe:
test_user: User,
) -> None:
response = await authorized_client.get(app.url_path_for("get_current_active_user"))
response = await authorized_client.get(app.url_path_for("get_logged_in_user"))
assert response.status_code == status.HTTP_200_OK
user = User(**response.json())
assert user.username == test_user.username
@ -279,7 +279,7 @@ class TestUserMe:
test_user: User,
) -> None:
response = await unauthorized_client.get(app.url_path_for("get_current_active_user"))
response = await unauthorized_client.get(app.url_path_for("get_logged_in_user"))
assert response.status_code == status.HTTP_401_UNAUTHORIZED
@ -329,15 +329,15 @@ class TestSuperAdmin:
response = await unauthorized_client.post(app.url_path_for("login"), data=login_data)
assert response.status_code == status.HTTP_200_OK
async def test_super_admin_belongs_to_admin_group(
self,
app: FastAPI,
client: AsyncClient,
db_session: AsyncSession
) -> None:
user_repo = UsersRepository(db_session)
admin_in_db = await user_repo.get_user_by_username("admin")
response = await client.get(app.url_path_for("get_user_memberships", user_id=admin_in_db.user_id))
assert response.status_code == status.HTTP_200_OK
assert len(response.json()) == 1
# async def test_super_admin_belongs_to_admin_group(
# self,
# app: FastAPI,
# client: AsyncClient,
# db_session: AsyncSession
# ) -> None:
#
# user_repo = UsersRepository(db_session)
# admin_in_db = await user_repo.get_user_by_username("admin")
# response = await client.get(app.url_path_for("get_user_memberships", user_id=admin_in_db.user_id))
# assert response.status_code == status.HTTP_200_OK
# assert len(response.json()) == 1

View File

@ -124,7 +124,12 @@ async def test_user(db_session: AsyncSession) -> User:
existing_user = await user_repo.get_user_by_username(new_user.username)
if existing_user:
return existing_user
return await user_repo.create_user(new_user)
user = await user_repo.create_user(new_user)
# add new user to "Users group
group = await user_repo.get_user_group_by_name("Users")
await user_repo.add_member_to_user_group(group.user_group_id, user)
return user
@pytest.fixture