1
0
mirror of https://github.com/GNS3/gns3-server synced 2024-11-28 19:28:07 +00:00

Basic functional RBAC support.

This commit is contained in:
grossmj 2021-05-27 17:28:44 +09:30
parent 6d4da98b8e
commit fbc47598d9
19 changed files with 527 additions and 92 deletions

View File

@ -55,7 +55,7 @@ router.include_router(
) )
router.include_router( router.include_router(
roles.router, permissions.router,
dependencies=[Depends(get_current_active_user)], dependencies=[Depends(get_current_active_user)],
prefix="/permissions", prefix="/permissions",
tags=["Permissions"] tags=["Permissions"]

View File

@ -15,13 +15,13 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from fastapi import Depends, HTTPException, status from fastapi import Request, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from gns3server import schemas from gns3server import schemas
from gns3server.db.repositories.users import UsersRepository from gns3server.db.repositories.users import UsersRepository
from gns3server.db.repositories.rbac import RbacRepository
from gns3server.services import auth_service from gns3server.services import auth_service
from .database import get_repository from .database import get_repository
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/v3/users/login") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/v3/users/login")
@ -42,7 +42,11 @@ async def get_user_from_token(
return user return user
async def get_current_active_user(current_user: schemas.User = Depends(get_user_from_token)) -> schemas.User: async def get_current_active_user(
request: Request,
current_user: schemas.User = Depends(get_user_from_token),
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository))
) -> schemas.User:
# Super admin is always authorized # Super admin is always authorized
if current_user.is_superadmin: if current_user.is_superadmin:
@ -54,4 +58,19 @@ async def get_current_active_user(current_user: schemas.User = Depends(get_user_
detail="Not an active user", detail="Not an active user",
headers={"WWW-Authenticate": "Bearer"}, headers={"WWW-Authenticate": "Bearer"},
) )
# remove the prefix (e.g. "/v3") from URL path
if request.url.path.startswith("/v3"):
path = request.url.path[len("/v3"):]
else:
path = request.url.path
authorized = await rbac_repo.check_user_is_authorized(current_user.user_id, request.method, path)
if not authorized:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"User is not authorized '{current_user.user_id}' on {request.method} '{path}'",
headers={"WWW-Authenticate": "Bearer"},
)
return current_user return current_user

View File

@ -99,7 +99,7 @@ async def update_user_group(
if not user_group: if not user_group:
raise ControllerNotFoundError(f"User group '{user_group_id}' not found") raise ControllerNotFoundError(f"User group '{user_group_id}' not found")
if not user_group.is_updatable: if user_group.builtin:
raise ControllerForbiddenError(f"User group '{user_group_id}' cannot be updated") raise ControllerForbiddenError(f"User group '{user_group_id}' cannot be updated")
return await users_repo.update_user_group(user_group_id, user_group_update) return await users_repo.update_user_group(user_group_id, user_group_update)
@ -121,7 +121,7 @@ async def delete_user_group(
if not user_group: if not user_group:
raise ControllerNotFoundError(f"User group '{user_group_id}' not found") raise ControllerNotFoundError(f"User group '{user_group_id}' not found")
if not user_group.is_updatable: if user_group.builtin:
raise ControllerForbiddenError(f"User group '{user_group_id}' cannot be deleted") raise ControllerForbiddenError(f"User group '{user_group_id}' cannot be deleted")
success = await users_repo.delete_user_group(user_group_id) success = await users_repo.delete_user_group(user_group_id)

View File

@ -60,8 +60,9 @@ async def create_permission(
Create a new permission. Create a new permission.
""" """
# if await rbac_repo.get_role_by_path(role_create.name): if await rbac_repo.check_permission_exists(permission_create):
# raise ControllerBadRequestError(f"Role '{role_create.name}' already exists") raise ControllerBadRequestError(f"Permission '{permission_create.methods} {permission_create.path} "
f"{permission_create.action}' already exists")
return await rbac_repo.create_permission(permission_create) return await rbac_repo.create_permission(permission_create)
@ -95,9 +96,6 @@ async def update_permission(
if not permission: if not permission:
raise ControllerNotFoundError(f"Permission '{permission_id}' not found") raise ControllerNotFoundError(f"Permission '{permission_id}' not found")
#if not user_group.is_updatable:
# raise ControllerForbiddenError(f"User group '{user_group_id}' cannot be updated")
return await rbac_repo.update_permission(permission_id, permission_update) return await rbac_repo.update_permission(permission_id, permission_update)
@ -114,9 +112,6 @@ async def delete_permission(
if not permission: if not permission:
raise ControllerNotFoundError(f"Permission '{permission_id}' not found") raise ControllerNotFoundError(f"Permission '{permission_id}' not found")
#if not user_group.is_updatable:
# raise ControllerForbiddenError(f"User group '{user_group_id}' cannot be deleted")
success = await rbac_repo.delete_permission(permission_id) success = await rbac_repo.delete_permission(permission_id)
if not success: if not success:
raise ControllerNotFoundError(f"Permission '{permission_id}' could not be deleted") raise ControllerNotFoundError(f"Permission '{permission_id}' could not be deleted")

View File

@ -47,6 +47,10 @@ from gns3server.controller.export_project import export_project as export_contro
from gns3server.utils.asyncio import aiozipstream from gns3server.utils.asyncio import aiozipstream
from gns3server.utils.path import is_safe_path from gns3server.utils.path import is_safe_path
from gns3server.config import Config from gns3server.config import Config
from gns3server.db.repositories.rbac import RbacRepository
from .dependencies.authentication import get_current_active_user
from .dependencies.database import get_repository
responses = {404: {"model": schemas.ErrorMessage, "description": "Could not find project"}} responses = {404: {"model": schemas.ErrorMessage, "description": "Could not find project"}}
@ -82,13 +86,18 @@ def get_projects() -> List[schemas.Project]:
response_model_exclude_unset=True, response_model_exclude_unset=True,
responses={409: {"model": schemas.ErrorMessage, "description": "Could not create project"}}, responses={409: {"model": schemas.ErrorMessage, "description": "Could not create project"}},
) )
async def create_project(project_data: schemas.ProjectCreate) -> schemas.Project: async def create_project(
project_data: schemas.ProjectCreate,
current_user: schemas.User = Depends(get_current_active_user),
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository))
) -> schemas.Project:
""" """
Create a new project. Create a new project.
""" """
controller = Controller.instance() controller = Controller.instance()
project = await controller.add_project(**jsonable_encoder(project_data, exclude_unset=True)) project = await controller.add_project(**jsonable_encoder(project_data, exclude_unset=True))
await rbac_repo.add_permission_to_user(current_user.user_id, f"/projects/{project.id}/*")
return project.asdict() return project.asdict()
@ -115,7 +124,10 @@ async def update_project(
@router.delete("/{project_id}", status_code=status.HTTP_204_NO_CONTENT) @router.delete("/{project_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_project(project: Project = Depends(dep_project)) -> None: async def delete_project(
project: Project = Depends(dep_project),
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository))
) -> None:
""" """
Delete a project. Delete a project.
""" """
@ -123,6 +135,7 @@ async def delete_project(project: Project = Depends(dep_project)) -> None:
controller = Controller.instance() controller = Controller.instance()
await project.delete() await project.delete()
controller.remove_project(project) controller.remove_project(project)
await rbac_repo.delete_all_permissions_matching_path(f"/projects/{project.id}")
@router.get("/{project_id}/stats") @router.get("/{project_id}/stats")

View File

@ -95,8 +95,8 @@ async def update_role(
if not role: if not role:
raise ControllerNotFoundError(f"Role '{role_id}' not found") raise ControllerNotFoundError(f"Role '{role_id}' not found")
#if not user_group.is_updatable: if role.builtin:
# raise ControllerForbiddenError(f"User group '{user_group_id}' cannot be updated") raise ControllerForbiddenError(f"Role '{role_id}' cannot be updated")
return await rbac_repo.update_role(role_id, role_update) return await rbac_repo.update_role(role_id, role_update)
@ -114,8 +114,8 @@ async def delete_role(
if not role: if not role:
raise ControllerNotFoundError(f"Role '{role_id}' not found") raise ControllerNotFoundError(f"Role '{role_id}' not found")
#if not user_group.is_updatable: if role.builtin:
# raise ControllerForbiddenError(f"User group '{user_group_id}' cannot be deleted") raise ControllerForbiddenError(f"Role '{role_id}' cannot be deleted")
success = await rbac_repo.delete_role(role_id) success = await rbac_repo.delete_role(role_id)
if not success: if not success:

View File

@ -88,6 +88,24 @@ async def authenticate(
return token return token
@router.get("/me", response_model=schemas.User)
async def get_logged_in_user(current_user: schemas.User = Depends(get_current_active_user)) -> schemas.User:
"""
Get the current active user.
"""
return current_user
@router.get("/me", response_model=schemas.User)
async def get_logged_in_user(current_user: schemas.User = Depends(get_current_active_user)) -> schemas.User:
"""
Get the current active user.
"""
return current_user
@router.get("", response_model=List[schemas.User], dependencies=[Depends(get_current_active_user)]) @router.get("", response_model=List[schemas.User], dependencies=[Depends(get_current_active_user)])
async def get_users( async def get_users(
users_repo: UsersRepository = Depends(get_repository(UsersRepository)) users_repo: UsersRepository = Depends(get_repository(UsersRepository))
@ -178,15 +196,6 @@ async def delete_user(
raise ControllerNotFoundError(f"User '{user_id}' could not be deleted") raise ControllerNotFoundError(f"User '{user_id}' could not be deleted")
@router.get("/me/", response_model=schemas.User)
async def get_current_active_user(current_user: schemas.User = Depends(get_current_active_user)) -> schemas.User:
"""
Get the current active user.
"""
return current_user
@router.get( @router.get(
"/{user_id}/groups", "/{user_id}/groups",
dependencies=[Depends(get_current_active_user)], dependencies=[Depends(get_current_active_user)],

View File

@ -15,7 +15,7 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from sqlalchemy import Table, Column, String, ForeignKey, Boolean from sqlalchemy import Table, Column, String, ForeignKey, event
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
from .base import Base, BaseTable, generate_uuid, GUID, ListType from .base import Base, BaseTable, generate_uuid, GUID, ListType
@ -39,7 +39,78 @@ class Permission(BaseTable):
__tablename__ = "permissions" __tablename__ = "permissions"
permission_id = Column(GUID, primary_key=True, default=generate_uuid) permission_id = Column(GUID, primary_key=True, default=generate_uuid)
description = Column(String)
methods = Column(ListType) methods = Column(ListType)
path = Column(String) path = Column(String)
action = Column(String) action = Column(String)
user_id = Column(GUID, ForeignKey('users.user_id', ondelete="CASCADE"))
roles = relationship("Role", secondary=permission_role_link, back_populates="permissions") roles = relationship("Role", secondary=permission_role_link, back_populates="permissions")
@event.listens_for(Permission.__table__, 'after_create')
def create_default_roles(target, connection, **kw):
default_permissions = [
{
"description": "Allow access to all endpoints",
"methods": ["GET", "HEAD", "POST", "PUT", "DELETE", "PATCH"],
"path": "/",
"action": "ALLOW"
},
{
"description": "Allow access to the logged in user",
"methods": ["GET"],
"path": "/users/me",
"action": "ALLOW"
},
{
"description": "Allow to create a project or list projects",
"methods": ["GET", "POST"],
"path": "/projects",
"action": "ALLOW"
},
{
"description": "Allow to access to all symbol endpoints",
"methods": ["GET", "HEAD", "POST", "PUT", "DELETE", "PATCH"],
"path": "/symbols",
"action": "ALLOW"
},
]
stmt = target.insert().values(default_permissions)
connection.execute(stmt)
connection.commit()
log.debug("The default permissions have been created in the database")
@event.listens_for(permission_role_link, 'after_create')
def add_permissions_to_role(target, connection, **kw):
from .roles import Role
roles_table = Role.__table__
stmt = roles_table.select().where(roles_table.c.name == "Administrator")
result = connection.execute(stmt)
role_id = result.first().role_id
permissions_table = Permission.__table__
stmt = permissions_table.select().where(permissions_table.c.path == "/")
result = connection.execute(stmt)
permission_id = result.first().permission_id
# add root path to the "Administrator" role
stmt = target.insert().values(permission_id=permission_id, role_id=role_id)
connection.execute(stmt)
stmt = roles_table.select().where(roles_table.c.name == "User")
result = connection.execute(stmt)
role_id = result.first().role_id
# add minimum required paths to the "User" role
for path in ("/projects", "/symbols", "/users/me"):
stmt = permissions_table.select().where(permissions_table.c.path == path)
result = connection.execute(stmt)
permission_id = result.first().permission_id
stmt = target.insert().values(permission_id=permission_id, role_id=role_id)
connection.execute(stmt)
connection.commit()

View File

@ -40,7 +40,7 @@ class Role(BaseTable):
role_id = Column(GUID, primary_key=True, default=generate_uuid) role_id = Column(GUID, primary_key=True, default=generate_uuid)
name = Column(String) name = Column(String)
description = Column(String) description = Column(String)
is_updatable = Column(Boolean, default=True) builtin = Column(Boolean, default=False)
permissions = relationship("Permission", secondary=permission_role_link, back_populates="roles") permissions = relationship("Permission", secondary=permission_role_link, back_populates="roles")
groups = relationship("UserGroup", secondary=role_group_link, back_populates="roles") groups = relationship("UserGroup", secondary=role_group_link, back_populates="roles")
@ -49,11 +49,33 @@ class Role(BaseTable):
def create_default_roles(target, connection, **kw): def create_default_roles(target, connection, **kw):
default_roles = [ default_roles = [
{"name": "Administrator", "description": "Administrator role", "is_updatable": False}, {"name": "Administrator", "description": "Administrator role", "builtin": True},
{"name": "User", "description": "User role", "is_updatable": False}, {"name": "User", "description": "User role", "builtin": True},
] ]
stmt = target.insert().values(default_roles) stmt = target.insert().values(default_roles)
connection.execute(stmt) connection.execute(stmt)
connection.commit() connection.commit()
log.info("The default roles have been created in the database") log.debug("The default roles have been created in the database")
@event.listens_for(role_group_link, 'after_create')
def add_admin_to_group(target, connection, **kw):
from .users import UserGroup
user_groups_table = UserGroup.__table__
roles_table = Role.__table__
# Add roles to built-in user groups
groups_to_roles = {"Administrators": "Administrator", "Users": "User"}
for user_group, role in groups_to_roles.items():
stmt = user_groups_table.select().where(user_groups_table.c.name == user_group)
result = connection.execute(stmt)
user_group_id = result.first().user_group_id
stmt = roles_table.select().where(roles_table.c.name == role)
result = connection.execute(stmt)
role_id = result.first().role_id
stmt = target.insert().values(role_id=role_id, user_group_id=user_group_id)
connection.execute(stmt)
connection.commit()

View File

@ -48,7 +48,6 @@ class User(BaseTable):
is_active = Column(Boolean, default=True) is_active = Column(Boolean, default=True)
is_superadmin = Column(Boolean, default=False) is_superadmin = Column(Boolean, default=False)
groups = relationship("UserGroup", secondary=user_group_link, back_populates="users") groups = relationship("UserGroup", secondary=user_group_link, back_populates="users")
permission_id = Column(GUID, ForeignKey('permissions.permission_id', ondelete="CASCADE"))
permissions = relationship("Permission") permissions = relationship("Permission")
@ -67,7 +66,7 @@ def create_default_super_admin(target, connection, **kw):
) )
connection.execute(stmt) connection.execute(stmt)
connection.commit() connection.commit()
log.info("The default super admin account has been created in the database") log.debug("The default super admin account has been created in the database")
class UserGroup(BaseTable): class UserGroup(BaseTable):
@ -76,7 +75,7 @@ class UserGroup(BaseTable):
user_group_id = Column(GUID, primary_key=True, default=generate_uuid) user_group_id = Column(GUID, primary_key=True, default=generate_uuid)
name = Column(String, unique=True, index=True) name = Column(String, unique=True, index=True)
is_updatable = Column(Boolean, default=True) builtin = Column(Boolean, default=False)
users = relationship("User", secondary=user_group_link, back_populates="groups") users = relationship("User", secondary=user_group_link, back_populates="groups")
roles = relationship("Role", secondary=role_group_link, back_populates="groups") roles = relationship("Role", secondary=role_group_link, back_populates="groups")
@ -85,30 +84,29 @@ class UserGroup(BaseTable):
def create_default_user_groups(target, connection, **kw): def create_default_user_groups(target, connection, **kw):
default_groups = [ default_groups = [
{"name": "Administrators", "is_updatable": False}, {"name": "Administrators", "builtin": True},
{"name": "Editors", "is_updatable": False}, {"name": "Users", "builtin": True}
{"name": "Users", "is_updatable": False}
] ]
stmt = target.insert().values(default_groups) stmt = target.insert().values(default_groups)
connection.execute(stmt) connection.execute(stmt)
connection.commit() connection.commit()
log.info("The default user groups have been created in the database") log.debug("The default user groups have been created in the database")
@event.listens_for(user_group_link, 'after_create') # @event.listens_for(user_group_link, 'after_create')
def add_admin_to_group(target, connection, **kw): # def add_admin_to_group(target, connection, **kw):
#
user_groups_table = UserGroup.__table__ # user_groups_table = UserGroup.__table__
stmt = user_groups_table.select().where(user_groups_table.c.name == "Administrators") # stmt = user_groups_table.select().where(user_groups_table.c.name == "Administrators")
result = connection.execute(stmt) # result = connection.execute(stmt)
user_group_id = result.first().user_group_id # user_group_id = result.first().user_group_id
#
users_table = User.__table__ # users_table = User.__table__
stmt = users_table.select().where(users_table.c.is_superadmin.is_(True)) # stmt = users_table.select().where(users_table.c.is_superadmin.is_(True))
result = connection.execute(stmt) # result = connection.execute(stmt)
user_id = result.first().user_id # user_id = result.first().user_id
#
stmt = target.insert().values(user_id=user_id, user_group_id=user_group_id) # stmt = target.insert().values(user_id=user_id, user_group_id=user_group_id)
connection.execute(stmt) # connection.execute(stmt)
connection.commit() # connection.commit()

View File

@ -24,6 +24,7 @@ from sqlalchemy.orm import selectinload
from .base import BaseRepository from .base import BaseRepository
import gns3server.db.models as models import gns3server.db.models as models
from gns3server.schemas.controller.rbac import HTTPMethods, PermissionAction
from gns3server import schemas from gns3server import schemas
import logging import logging
@ -38,6 +39,9 @@ class RbacRepository(BaseRepository):
super().__init__(db_session) super().__init__(db_session)
async def get_role(self, role_id: UUID) -> Optional[models.Role]: async def get_role(self, role_id: UUID) -> Optional[models.Role]:
"""
Get a role by its ID.
"""
query = select(models.Role).\ query = select(models.Role).\
options(selectinload(models.Role.permissions)).\ options(selectinload(models.Role.permissions)).\
@ -46,6 +50,9 @@ class RbacRepository(BaseRepository):
return result.scalars().first() return result.scalars().first()
async def get_role_by_name(self, name: str) -> Optional[models.Role]: async def get_role_by_name(self, name: str) -> Optional[models.Role]:
"""
Get a role by its name.
"""
query = select(models.Role).\ query = select(models.Role).\
options(selectinload(models.Role.permissions)).\ options(selectinload(models.Role.permissions)).\
@ -55,12 +62,18 @@ class RbacRepository(BaseRepository):
return result.scalars().first() return result.scalars().first()
async def get_roles(self) -> List[models.Role]: async def get_roles(self) -> List[models.Role]:
"""
Get all roles.
"""
query = select(models.Role).options(selectinload(models.Role.permissions)) query = select(models.Role).options(selectinload(models.Role.permissions))
result = await self._db_session.execute(query) result = await self._db_session.execute(query)
return result.scalars().all() return result.scalars().all()
async def create_role(self, role_create: schemas.RoleCreate) -> models.Role: async def create_role(self, role_create: schemas.RoleCreate) -> models.Role:
"""
Create a new role.
"""
db_role = models.Role( db_role = models.Role(
name=role_create.name, name=role_create.name,
@ -76,6 +89,9 @@ class RbacRepository(BaseRepository):
role_id: UUID, role_id: UUID,
role_update: schemas.RoleUpdate role_update: schemas.RoleUpdate
) -> Optional[models.Role]: ) -> Optional[models.Role]:
"""
Update a role.
"""
update_values = role_update.dict(exclude_unset=True) update_values = role_update.dict(exclude_unset=True)
query = update(models.Role).where(models.Role.role_id == role_id).values(update_values) query = update(models.Role).where(models.Role.role_id == role_id).values(update_values)
@ -85,6 +101,9 @@ class RbacRepository(BaseRepository):
return await self.get_role(role_id) return await self.get_role(role_id)
async def delete_role(self, role_id: UUID) -> bool: async def delete_role(self, role_id: UUID) -> bool:
"""
Delete a role.
"""
query = delete(models.Role).where(models.Role.role_id == role_id) query = delete(models.Role).where(models.Role.role_id == role_id)
result = await self._db_session.execute(query) result = await self._db_session.execute(query)
@ -96,6 +115,9 @@ class RbacRepository(BaseRepository):
role_id: UUID, role_id: UUID,
permission: models.Permission permission: models.Permission
) -> Union[None, models.Role]: ) -> Union[None, models.Role]:
"""
Add a permission to a role.
"""
query = select(models.Role).\ query = select(models.Role).\
options(selectinload(models.Role.permissions)).\ options(selectinload(models.Role.permissions)).\
@ -115,6 +137,9 @@ class RbacRepository(BaseRepository):
role_id: UUID, role_id: UUID,
permission: models.Permission permission: models.Permission
) -> Union[None, models.Role]: ) -> Union[None, models.Role]:
"""
Remove a permission from a role.
"""
query = select(models.Role).\ query = select(models.Role).\
options(selectinload(models.Role.permissions)).\ options(selectinload(models.Role.permissions)).\
@ -130,6 +155,9 @@ class RbacRepository(BaseRepository):
return role_db return role_db
async def get_role_permissions(self, role_id: UUID) -> List[models.Permission]: async def get_role_permissions(self, role_id: UUID) -> List[models.Permission]:
"""
Get all the role permissions.
"""
query = select(models.Permission).\ query = select(models.Permission).\
join(models.Permission.roles).\ join(models.Permission.roles).\
@ -139,30 +167,48 @@ class RbacRepository(BaseRepository):
return result.scalars().all() return result.scalars().all()
async def get_permission(self, permission_id: UUID) -> Optional[models.Permission]: async def get_permission(self, permission_id: UUID) -> Optional[models.Permission]:
"""
Get a permission by its ID.
"""
query = select(models.Permission).where(models.Permission.permission_id == permission_id) query = select(models.Permission).where(models.Permission.permission_id == permission_id)
result = await self._db_session.execute(query) result = await self._db_session.execute(query)
return result.scalars().first() return result.scalars().first()
async def get_permission_by_path(self, path: str) -> Optional[models.Permission]: async def get_permission_by_path(self, path: str) -> Optional[models.Permission]:
"""
Get a permission by its path.
"""
query = select(models.Permission).where(models.Permission.path == path) query = select(models.Permission).where(models.Permission.path == path)
result = await self._db_session.execute(query) result = await self._db_session.execute(query)
return result.scalars().first() return result.scalars().first()
async def get_permissions(self) -> List[models.Permission]: async def get_permissions(self) -> List[models.Permission]:
"""
Get all permissions.
"""
query = select(models.Permission) query = select(models.Permission)
result = await self._db_session.execute(query) result = await self._db_session.execute(query)
return result.scalars().all() return result.scalars().all()
async def create_permission(self, permission_create: schemas.PermissionCreate) -> models.Permission: async def check_permission_exists(self, permission_create: schemas.PermissionCreate) -> bool:
"""
Check if a permission exists.
"""
create_values = permission_create.dict(exclude_unset=True) query = select(models.Permission).\
# action = create_values.pop("action", "deny") where(models.Permission.methods == permission_create.methods,
# is_allowed = False models.Permission.path == permission_create.path,
# if action == "allow": models.Permission.action == permission_create.action)
# is_allowed = True result = await self._db_session.execute(query)
return result.scalars().first() is not None
async def create_permission(self, permission_create: schemas.PermissionCreate) -> models.Permission:
"""
Create a new permission.
"""
db_permission = models.Permission( db_permission = models.Permission(
methods=permission_create.methods, methods=permission_create.methods,
@ -170,7 +216,6 @@ class RbacRepository(BaseRepository):
action=permission_create.action, action=permission_create.action,
) )
self._db_session.add(db_permission) self._db_session.add(db_permission)
await self._db_session.commit() await self._db_session.commit()
await self._db_session.refresh(db_permission) await self._db_session.refresh(db_permission)
return db_permission return db_permission
@ -180,6 +225,9 @@ class RbacRepository(BaseRepository):
permission_id: UUID, permission_id: UUID,
permission_update: schemas.PermissionUpdate permission_update: schemas.PermissionUpdate
) -> Optional[models.Permission]: ) -> Optional[models.Permission]:
"""
Update a permission.
"""
update_values = permission_update.dict(exclude_unset=True) update_values = permission_update.dict(exclude_unset=True)
query = update(models.Permission).where(models.Permission.permission_id == permission_id).values(update_values) query = update(models.Permission).where(models.Permission.permission_id == permission_id).values(update_values)
@ -189,8 +237,110 @@ class RbacRepository(BaseRepository):
return await self.get_permission(permission_id) return await self.get_permission(permission_id)
async def delete_permission(self, permission_id: UUID) -> bool: async def delete_permission(self, permission_id: UUID) -> bool:
"""
Delete a permission.
"""
query = delete(models.Permission).where(models.Permission.permission_id == permission_id) query = delete(models.Permission).where(models.Permission.permission_id == permission_id)
result = await self._db_session.execute(query) result = await self._db_session.execute(query)
await self._db_session.commit() await self._db_session.commit()
return result.rowcount > 0 return result.rowcount > 0
def _match_permission(
self,
permissions: List[models.Permission],
method: str,
path: str
) -> Union[None, models.Permission]:
"""
Match the methods and path with a permission.
"""
for permission in permissions:
log.debug(f"RBAC: checking permission {permission.methods} {permission.path} {permission.action}")
if method not in permission.methods:
continue
if permission.path.endswith("*") and path.startswith(permission.path[:-1]):
return permission
elif permission.path == path:
return permission
async def check_user_is_authorized(self, user_id: UUID, method: str, path: str) -> bool:
"""
Check if an user is authorized to access a resource.
"""
query = select(models.Permission).\
join(models.Permission.roles). \
join(models.Role.groups). \
join(models.UserGroup.users). \
filter(models.User.user_id == user_id).\
order_by(models.Permission.path)
result = await self._db_session.execute(query)
permissions = result.scalars().all()
log.debug(f"RBAC: checking authorization for '{user_id}' on {method} '{path}'")
matched_permission = self._match_permission(permissions, method, path)
if matched_permission:
log.debug(f"RBAC: matched role permission {matched_permission.methods} "
f"{matched_permission.path} {matched_permission.action}")
if matched_permission.action == "DENY":
return False
return True
log.debug(f"RBAC: could not find a role permission, checking user permissions...")
query = select(models.Permission).\
join(models.User.permissions). \
filter(models.User.user_id == user_id).\
order_by(models.Permission.path)
result = await self._db_session.execute(query)
permissions = result.scalars().all()
matched_permission = self._match_permission(permissions, method, path)
if matched_permission:
log.debug(f"RBAC: matched user permission {matched_permission.methods} "
f"{matched_permission.path} {matched_permission.action}")
if matched_permission.action == "DENY":
return False
return True
return False
async def add_permission_to_user(self, user_id: UUID, path: str) -> Union[None, models.User]:
"""
Add a permission to an user.
"""
# Create a new permission with full rights
new_permission = schemas.PermissionCreate(
methods=[HTTPMethods.get, HTTPMethods.head, HTTPMethods.post, HTTPMethods.put, HTTPMethods.delete],
path=path,
action=PermissionAction.allow
)
permission_db = await self.create_permission(new_permission)
# Add the permission to the user
query = select(models.User).\
options(selectinload(models.User.permissions)).\
where(models.User.user_id == user_id)
result = await self._db_session.execute(query)
user_db = result.scalars().first()
if not user_db:
return None
user_db.permissions.append(permission_db)
await self._db_session.commit()
await self._db_session.refresh(user_db)
return user_db
async def delete_all_permissions_matching_path(self, path: str) -> None:
"""
Delete all permissions matching with path.
"""
query = delete(models.Permission).\
where(models.Permission.path.startswith(path)).\
execution_options(synchronize_session=False)
result = await self._db_session.execute(query)
log.debug(f"{result.rowcount} permission(s) have been deleted")

View File

@ -33,40 +33,59 @@ log = logging.getLogger(__name__)
class UsersRepository(BaseRepository): class UsersRepository(BaseRepository):
def __init__(self, db_session: AsyncSession) -> None: def __init__(self, db_session: AsyncSession) -> None:
super().__init__(db_session) super().__init__(db_session)
self._auth_service = auth_service self._auth_service = auth_service
async def get_user(self, user_id: UUID) -> Optional[models.User]: async def get_user(self, user_id: UUID) -> Optional[models.User]:
"""
Get an user by its ID.
"""
query = select(models.User).where(models.User.user_id == user_id) query = select(models.User).where(models.User.user_id == user_id)
result = await self._db_session.execute(query) result = await self._db_session.execute(query)
return result.scalars().first() return result.scalars().first()
async def get_user_by_username(self, username: str) -> Optional[models.User]: async def get_user_by_username(self, username: str) -> Optional[models.User]:
"""
Get an user by its name.
"""
query = select(models.User).where(models.User.username == username) query = select(models.User).where(models.User.username == username)
result = await self._db_session.execute(query) result = await self._db_session.execute(query)
return result.scalars().first() return result.scalars().first()
async def get_user_by_email(self, email: str) -> Optional[models.User]: async def get_user_by_email(self, email: str) -> Optional[models.User]:
"""
Get an user by its email.
"""
query = select(models.User).where(models.User.email == email) query = select(models.User).where(models.User.email == email)
result = await self._db_session.execute(query) result = await self._db_session.execute(query)
return result.scalars().first() return result.scalars().first()
async def get_users(self) -> List[models.User]: async def get_users(self) -> List[models.User]:
"""
Get all users.
"""
query = select(models.User) query = select(models.User)
result = await self._db_session.execute(query) result = await self._db_session.execute(query)
return result.scalars().all() return result.scalars().all()
async def create_user(self, user: schemas.UserCreate) -> models.User: async def create_user(self, user: schemas.UserCreate) -> models.User:
"""
Create a new user.
"""
hashed_password = self._auth_service.hash_password(user.password.get_secret_value()) hashed_password = self._auth_service.hash_password(user.password.get_secret_value())
db_user = models.User( db_user = models.User(
username=user.username, email=user.email, full_name=user.full_name, hashed_password=hashed_password username=user.username,
email=user.email,
full_name=user.full_name,
hashed_password=hashed_password
) )
self._db_session.add(db_user) self._db_session.add(db_user)
await self._db_session.commit() await self._db_session.commit()
@ -74,6 +93,9 @@ class UsersRepository(BaseRepository):
return db_user return db_user
async def update_user(self, user_id: UUID, user_update: schemas.UserUpdate) -> Optional[models.User]: async def update_user(self, user_id: UUID, user_update: schemas.UserUpdate) -> Optional[models.User]:
"""
Update an user.
"""
update_values = user_update.dict(exclude_unset=True) update_values = user_update.dict(exclude_unset=True)
password = update_values.pop("password", None) password = update_values.pop("password", None)
@ -87,6 +109,9 @@ class UsersRepository(BaseRepository):
return await self.get_user(user_id) return await self.get_user(user_id)
async def delete_user(self, user_id: UUID) -> bool: async def delete_user(self, user_id: UUID) -> bool:
"""
Delete an user.
"""
query = delete(models.User).where(models.User.user_id == user_id) query = delete(models.User).where(models.User.user_id == user_id)
result = await self._db_session.execute(query) result = await self._db_session.execute(query)
@ -94,6 +119,9 @@ class UsersRepository(BaseRepository):
return result.rowcount > 0 return result.rowcount > 0
async def authenticate_user(self, username: str, password: str) -> Optional[models.User]: async def authenticate_user(self, username: str, password: str) -> Optional[models.User]:
"""
Authenticate an user.
"""
user = await self.get_user_by_username(username) user = await self.get_user_by_username(username)
if not user: if not user:
@ -110,6 +138,9 @@ class UsersRepository(BaseRepository):
return user return user
async def get_user_memberships(self, user_id: UUID) -> List[models.UserGroup]: async def get_user_memberships(self, user_id: UUID) -> List[models.UserGroup]:
"""
Get all user memberships (user groups).
"""
query = select(models.UserGroup).\ query = select(models.UserGroup).\
join(models.UserGroup.users).\ join(models.UserGroup.users).\
@ -119,24 +150,36 @@ class UsersRepository(BaseRepository):
return result.scalars().all() return result.scalars().all()
async def get_user_group(self, user_group_id: UUID) -> Optional[models.UserGroup]: async def get_user_group(self, user_group_id: UUID) -> Optional[models.UserGroup]:
"""
Get an user group by its ID.
"""
query = select(models.UserGroup).where(models.UserGroup.user_group_id == user_group_id) query = select(models.UserGroup).where(models.UserGroup.user_group_id == user_group_id)
result = await self._db_session.execute(query) result = await self._db_session.execute(query)
return result.scalars().first() return result.scalars().first()
async def get_user_group_by_name(self, name: str) -> Optional[models.UserGroup]: async def get_user_group_by_name(self, name: str) -> Optional[models.UserGroup]:
"""
Get an user group by its name.
"""
query = select(models.UserGroup).where(models.UserGroup.name == name) query = select(models.UserGroup).where(models.UserGroup.name == name)
result = await self._db_session.execute(query) result = await self._db_session.execute(query)
return result.scalars().first() return result.scalars().first()
async def get_user_groups(self) -> List[models.UserGroup]: async def get_user_groups(self) -> List[models.UserGroup]:
"""
Get all user groups.
"""
query = select(models.UserGroup) query = select(models.UserGroup)
result = await self._db_session.execute(query) result = await self._db_session.execute(query)
return result.scalars().all() return result.scalars().all()
async def create_user_group(self, user_group: schemas.UserGroupCreate) -> models.UserGroup: async def create_user_group(self, user_group: schemas.UserGroupCreate) -> models.UserGroup:
"""
Create a new user group.
"""
db_user_group = models.UserGroup(name=user_group.name) db_user_group = models.UserGroup(name=user_group.name)
self._db_session.add(db_user_group) self._db_session.add(db_user_group)
@ -149,6 +192,9 @@ class UsersRepository(BaseRepository):
user_group_id: UUID, user_group_id: UUID,
user_group_update: schemas.UserGroupUpdate user_group_update: schemas.UserGroupUpdate
) -> Optional[models.UserGroup]: ) -> Optional[models.UserGroup]:
"""
Update an user group.
"""
update_values = user_group_update.dict(exclude_unset=True) update_values = user_group_update.dict(exclude_unset=True)
query = update(models.UserGroup).where(models.UserGroup.user_group_id == user_group_id).values(update_values) query = update(models.UserGroup).where(models.UserGroup.user_group_id == user_group_id).values(update_values)
@ -158,6 +204,9 @@ class UsersRepository(BaseRepository):
return await self.get_user_group(user_group_id) return await self.get_user_group(user_group_id)
async def delete_user_group(self, user_group_id: UUID) -> bool: async def delete_user_group(self, user_group_id: UUID) -> bool:
"""
Delete an user group.
"""
query = delete(models.UserGroup).where(models.UserGroup.user_group_id == user_group_id) query = delete(models.UserGroup).where(models.UserGroup.user_group_id == user_group_id)
result = await self._db_session.execute(query) result = await self._db_session.execute(query)
@ -169,6 +218,9 @@ class UsersRepository(BaseRepository):
user_group_id: UUID, user_group_id: UUID,
user: models.User user: models.User
) -> Union[None, models.UserGroup]: ) -> Union[None, models.UserGroup]:
"""
Add a member to an user group.
"""
query = select(models.UserGroup).\ query = select(models.UserGroup).\
options(selectinload(models.UserGroup.users)).\ options(selectinload(models.UserGroup.users)).\
@ -188,6 +240,9 @@ class UsersRepository(BaseRepository):
user_group_id: UUID, user_group_id: UUID,
user: models.User user: models.User
) -> Union[None, models.UserGroup]: ) -> Union[None, models.UserGroup]:
"""
Remove a member from an user group.
"""
query = select(models.UserGroup).\ query = select(models.UserGroup).\
options(selectinload(models.UserGroup.users)).\ options(selectinload(models.UserGroup.users)).\
@ -203,6 +258,9 @@ class UsersRepository(BaseRepository):
return user_group_db return user_group_db
async def get_user_group_members(self, user_group_id: UUID) -> List[models.User]: async def get_user_group_members(self, user_group_id: UUID) -> List[models.User]:
"""
Get all members from an user group.
"""
query = select(models.User).\ query = select(models.User).\
join(models.User.groups).\ join(models.User.groups).\
@ -216,6 +274,9 @@ class UsersRepository(BaseRepository):
user_group_id: UUID, user_group_id: UUID,
role: models.Role role: models.Role
) -> Union[None, models.UserGroup]: ) -> Union[None, models.UserGroup]:
"""
Add a role to an user group.
"""
query = select(models.UserGroup).\ query = select(models.UserGroup).\
options(selectinload(models.UserGroup.roles)).\ options(selectinload(models.UserGroup.roles)).\
@ -235,6 +296,9 @@ class UsersRepository(BaseRepository):
user_group_id: UUID, user_group_id: UUID,
role: models.Role role: models.Role
) -> Union[None, models.UserGroup]: ) -> Union[None, models.UserGroup]:
"""
Remove a role from an user group.
"""
query = select(models.UserGroup).\ query = select(models.UserGroup).\
options(selectinload(models.UserGroup.roles)).\ options(selectinload(models.UserGroup.roles)).\
@ -250,6 +314,9 @@ class UsersRepository(BaseRepository):
return user_group_db return user_group_db
async def get_user_group_roles(self, user_group_id: UUID) -> List[models.Role]: async def get_user_group_roles(self, user_group_id: UUID) -> List[models.Role]:
"""
Get all roles from an user group.
"""
query = select(models.Role). \ query = select(models.Role). \
options(selectinload(models.Role.permissions)). \ options(selectinload(models.Role.permissions)). \

View File

@ -114,6 +114,7 @@ class RoleUpdate(RoleBase):
class Role(DateTimeModelMixin, RoleBase): class Role(DateTimeModelMixin, RoleBase):
role_id: UUID role_id: UUID
builtin: bool
permissions: List[Permission] permissions: List[Permission]
class Config: class Config:

View File

@ -85,7 +85,7 @@ class UserGroupUpdate(UserGroupBase):
class UserGroup(DateTimeModelMixin, UserGroupBase): class UserGroup(DateTimeModelMixin, UserGroupBase):
user_group_id: UUID user_group_id: UUID
is_updatable: bool builtin: bool
class Config: class Config:
orm_mode = True orm_mode = True

View File

@ -50,7 +50,7 @@ class TestGroupRoutes:
response = await client.get(app.url_path_for("get_user_groups")) response = await client.get(app.url_path_for("get_user_groups"))
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert len(response.json()) == 4 # 3 default groups + group1 assert len(response.json()) == 3 # 2 default groups + group1
async def test_update_group(self, app: FastAPI, client: AsyncClient, db_session: AsyncSession) -> None: async def test_update_group(self, app: FastAPI, client: AsyncClient, db_session: AsyncSession) -> None:
@ -206,8 +206,10 @@ class TestGroupRolesRoutes:
) )
assert response.status_code == status.HTTP_204_NO_CONTENT assert response.status_code == status.HTTP_204_NO_CONTENT
roles = await user_repo.get_user_group_roles(group_in_db.user_group_id) roles = await user_repo.get_user_group_roles(group_in_db.user_group_id)
assert len(roles) == 1 assert len(roles) == 2 # 1 default role + 1 custom role
assert roles[0].name == test_role.name for role in roles:
if not role.builtin:
assert role.name == test_role.name
async def test_get_user_group_roles( async def test_get_user_group_roles(
self, self,
@ -224,7 +226,7 @@ class TestGroupRolesRoutes:
user_group_id=group_in_db.user_group_id) user_group_id=group_in_db.user_group_id)
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert len(response.json()) == 1 assert len(response.json()) == 2 # 1 default role + 1 custom role
async def test_remove_role_from_group( async def test_remove_role_from_group(
self, self,
@ -246,4 +248,5 @@ class TestGroupRolesRoutes:
) )
assert response.status_code == status.HTTP_204_NO_CONTENT assert response.status_code == status.HTTP_204_NO_CONTENT
roles = await user_repo.get_user_group_roles(group_in_db.user_group_id) roles = await user_repo.get_user_group_roles(group_in_db.user_group_id)
assert len(roles) == 0 assert len(roles) == 1 # 1 default role
assert roles[0].name != test_role.name

View File

@ -0,0 +1,83 @@
#!/usr/bin/env python
#
# Copyright (C) 2021 GNS3 Technologies Inc.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import pytest
from fastapi import FastAPI, status
from httpx import AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession
from gns3server.db.repositories.rbac import RbacRepository
pytestmark = pytest.mark.asyncio
class TestPermissionRoutes:
async def test_create_permission(self, app: FastAPI, client: AsyncClient) -> None:
new_permission = {
"methods": ["GET"],
"path": "/templates",
"action": "ALLOW"
}
response = await client.post(app.url_path_for("create_permission"), json=new_permission)
assert response.status_code == status.HTTP_201_CREATED
async def test_get_permission(self, app: FastAPI, client: AsyncClient, db_session: AsyncSession) -> None:
rbac_repo = RbacRepository(db_session)
permission_in_db = await rbac_repo.get_permission_by_path("/templates")
response = await client.get(app.url_path_for("get_permission", permission_id=permission_in_db.permission_id))
assert response.status_code == status.HTTP_200_OK
assert response.json()["permission_id"] == str(permission_in_db.permission_id)
async def test_list_permissions(self, app: FastAPI, client: AsyncClient) -> None:
response = await client.get(app.url_path_for("get_permissions"))
assert response.status_code == status.HTTP_200_OK
assert len(response.json()) == 5 # 4 default permissions + 1 custom permission
async def test_update_permission(self, app: FastAPI, client: AsyncClient, db_session: AsyncSession) -> None:
rbac_repo = RbacRepository(db_session)
permission_in_db = await rbac_repo.get_permission_by_path("/templates")
update_permission = {
"methods": ["GET"],
"path": "/appliances",
"action": "ALLOW"
}
response = await client.put(
app.url_path_for("update_permission", permission_id=permission_in_db.permission_id),
json=update_permission
)
assert response.status_code == status.HTTP_200_OK
updated_permission_in_db = await rbac_repo.get_permission(permission_in_db.permission_id)
assert updated_permission_in_db.path == "/appliances"
async def test_delete_permission(
self,
app: FastAPI,
client: AsyncClient,
db_session: AsyncSession
) -> None:
rbac_repo = RbacRepository(db_session)
permission_in_db = await rbac_repo.get_permission_by_path("/appliances")
response = await client.delete(app.url_path_for("delete_permission", permission_id=permission_in_db.permission_id))
assert response.status_code == status.HTTP_204_NO_CONTENT

View File

@ -110,11 +110,11 @@ async def test_permission(db_session: AsyncSession) -> Permission:
new_permission = schemas.PermissionCreate( new_permission = schemas.PermissionCreate(
methods=[HTTPMethods.get, HTTPMethods.post], methods=[HTTPMethods.get, HTTPMethods.post],
path="/projects", path="/templates",
action=PermissionAction.allow action=PermissionAction.allow
) )
rbac_repo = RbacRepository(db_session) rbac_repo = RbacRepository(db_session)
existing_permission = await rbac_repo.get_permission_by_path("/projects") existing_permission = await rbac_repo.get_permission_by_path("/templates")
if existing_permission: if existing_permission:
return existing_permission return existing_permission
return await rbac_repo.create_permission(new_permission) return await rbac_repo.create_permission(new_permission)
@ -142,8 +142,7 @@ class TestRolesPermissionsRoutes:
) )
assert response.status_code == status.HTTP_204_NO_CONTENT assert response.status_code == status.HTTP_204_NO_CONTENT
permissions = await rbac_repo.get_role_permissions(role_in_db.role_id) permissions = await rbac_repo.get_role_permissions(role_in_db.role_id)
assert len(permissions) == 1 assert len(permissions) == 4 # 3 default + 1 custom permissions
assert permissions[0].path == test_permission.path
async def test_get_role_permissions( async def test_get_role_permissions(
self, self,
@ -161,7 +160,7 @@ class TestRolesPermissionsRoutes:
role_id=role_in_db.role_id) role_id=role_in_db.role_id)
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert len(response.json()) == 1 assert len(response.json()) == 4 # 3 default + 1 custom permissions
async def test_remove_role_from_group( async def test_remove_role_from_group(
self, self,
@ -183,4 +182,4 @@ class TestRolesPermissionsRoutes:
) )
assert response.status_code == status.HTTP_204_NO_CONTENT assert response.status_code == status.HTTP_204_NO_CONTENT
permissions = await rbac_repo.get_role_permissions(role_in_db.role_id) permissions = await rbac_repo.get_role_permissions(role_in_db.role_id)
assert len(permissions) == 0 assert len(permissions) == 3 # 3 default permissions

View File

@ -266,7 +266,7 @@ class TestUserMe:
test_user: User, test_user: User,
) -> None: ) -> None:
response = await authorized_client.get(app.url_path_for("get_current_active_user")) response = await authorized_client.get(app.url_path_for("get_logged_in_user"))
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
user = User(**response.json()) user = User(**response.json())
assert user.username == test_user.username assert user.username == test_user.username
@ -279,7 +279,7 @@ class TestUserMe:
test_user: User, test_user: User,
) -> None: ) -> None:
response = await unauthorized_client.get(app.url_path_for("get_current_active_user")) response = await unauthorized_client.get(app.url_path_for("get_logged_in_user"))
assert response.status_code == status.HTTP_401_UNAUTHORIZED assert response.status_code == status.HTTP_401_UNAUTHORIZED
@ -329,15 +329,15 @@ class TestSuperAdmin:
response = await unauthorized_client.post(app.url_path_for("login"), data=login_data) response = await unauthorized_client.post(app.url_path_for("login"), data=login_data)
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
async def test_super_admin_belongs_to_admin_group( # async def test_super_admin_belongs_to_admin_group(
self, # self,
app: FastAPI, # app: FastAPI,
client: AsyncClient, # client: AsyncClient,
db_session: AsyncSession # db_session: AsyncSession
) -> None: # ) -> None:
#
user_repo = UsersRepository(db_session) # user_repo = UsersRepository(db_session)
admin_in_db = await user_repo.get_user_by_username("admin") # admin_in_db = await user_repo.get_user_by_username("admin")
response = await client.get(app.url_path_for("get_user_memberships", user_id=admin_in_db.user_id)) # response = await client.get(app.url_path_for("get_user_memberships", user_id=admin_in_db.user_id))
assert response.status_code == status.HTTP_200_OK # assert response.status_code == status.HTTP_200_OK
assert len(response.json()) == 1 # assert len(response.json()) == 1

View File

@ -124,7 +124,12 @@ async def test_user(db_session: AsyncSession) -> User:
existing_user = await user_repo.get_user_by_username(new_user.username) existing_user = await user_repo.get_user_by_username(new_user.username)
if existing_user: if existing_user:
return existing_user return existing_user
return await user_repo.create_user(new_user) user = await user_repo.create_user(new_user)
# add new user to "Users group
group = await user_repo.get_user_group_by_name("Users")
await user_repo.add_member_to_user_group(group.user_group_id, user)
return user
@pytest.fixture @pytest.fixture