1
0
mirror of https://github.com/GNS3/gns3-server synced 2024-11-28 11:18:11 +00:00

Add user groups support.

This commit is contained in:
grossmj 2021-05-15 15:10:02 +09:30
parent 956b9056c1
commit 8810249d36
13 changed files with 606 additions and 36 deletions

View File

@ -29,6 +29,7 @@ from . import snapshots
from . import symbols from . import symbols
from . import templates from . import templates
from . import users from . import users
from . import groups
from .dependencies.authentication import get_current_active_user from .dependencies.authentication import get_current_active_user
@ -37,6 +38,13 @@ router = APIRouter()
router.include_router(controller.router, tags=["Controller"]) router.include_router(controller.router, tags=["Controller"])
router.include_router(users.router, prefix="/users", tags=["Users"]) router.include_router(users.router, prefix="/users", tags=["Users"])
router.include_router(
groups.router,
dependencies=[Depends(get_current_active_user)],
prefix="/groups",
tags=["Users groups"]
)
router.include_router( router.include_router(
appliances.router, appliances.router,
dependencies=[Depends(get_current_active_user)], dependencies=[Depends(get_current_active_user)],

View File

@ -0,0 +1,184 @@
#!/usr/bin/env python
#
# Copyright (C) 2021 GNS3 Technologies Inc.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
API routes for user groups.
"""
from fastapi import APIRouter, Depends, status
from uuid import UUID
from typing import List
from gns3server import schemas
from gns3server.controller.controller_error import (
ControllerBadRequestError,
ControllerNotFoundError,
ControllerForbiddenError,
)
from gns3server.db.repositories.users import UsersRepository
from .dependencies.database import get_repository
import logging
log = logging.getLogger(__name__)
router = APIRouter()
@router.get("", response_model=List[schemas.UserGroup])
async def get_user_groups(
users_repo: UsersRepository = Depends(get_repository(UsersRepository))
) -> List[schemas.UserGroup]:
"""
Get all user groups.
"""
return await users_repo.get_user_groups()
@router.post(
"",
response_model=schemas.UserGroup,
status_code=status.HTTP_201_CREATED
)
async def create_user_group(
user_group_create: schemas.UserGroupCreate,
users_repo: UsersRepository = Depends(get_repository(UsersRepository))
) -> schemas.UserGroup:
"""
Create a new user group.
"""
if await users_repo.get_user_group_by_name(user_group_create.name):
raise ControllerBadRequestError(f"User group '{user_group_create.name}' already exists")
return await users_repo.create_user_group(user_group_create)
@router.get("/{user_group_id}", response_model=schemas.UserGroup)
async def get_user_group(
user_group_id: UUID,
users_repo: UsersRepository = Depends(get_repository(UsersRepository)),
) -> schemas.UserGroup:
"""
Get an user group.
"""
user_group = await users_repo.get_user_group(user_group_id)
if not user_group:
raise ControllerNotFoundError(f"User group '{user_group_id}' not found")
return user_group
@router.put("/{user_group_id}", response_model=schemas.UserGroup)
async def update_user_group(
user_group_id: UUID,
user_group_update: schemas.UserGroupUpdate,
users_repo: UsersRepository = Depends(get_repository(UsersRepository))
) -> schemas.UserGroup:
"""
Update an user group.
"""
user_group = await users_repo.get_user_group(user_group_id)
if not user_group:
raise ControllerNotFoundError(f"User group '{user_group_id}' not found")
if not user_group.is_updatable:
raise ControllerForbiddenError(f"User group '{user_group_id}' cannot be updated")
return await users_repo.update_user_group(user_group_id, user_group_update)
@router.delete(
"/{user_group_id}",
status_code=status.HTTP_204_NO_CONTENT
)
async def delete_user_group(
user_group_id: UUID,
users_repo: UsersRepository = Depends(get_repository(UsersRepository)),
) -> None:
"""
Delete an user group
"""
user_group = await users_repo.get_user_group(user_group_id)
if not user_group:
raise ControllerNotFoundError(f"User group '{user_group_id}' not found")
if not user_group.is_updatable:
raise ControllerForbiddenError(f"User group '{user_group_id}' cannot be deleted")
success = await users_repo.delete_user_group(user_group_id)
if not success:
raise ControllerNotFoundError(f"User group '{user_group_id}' could not be deleted")
@router.get("/{user_group_id}/members", response_model=List[schemas.User])
async def get_user_group_members(
user_group_id: UUID,
users_repo: UsersRepository = Depends(get_repository(UsersRepository))
) -> List[schemas.User]:
"""
Get all user group members.
"""
return await users_repo.get_user_group_members(user_group_id)
@router.put(
"/{user_group_id}/members/{user_id}",
status_code=status.HTTP_204_NO_CONTENT
)
async def add_member_to_group(
user_group_id: UUID,
user_id: UUID,
users_repo: UsersRepository = Depends(get_repository(UsersRepository))
) -> None:
"""
Add member to an user group.
"""
user = await users_repo.get_user(user_id)
if not user:
raise ControllerNotFoundError(f"User '{user_id}' not found")
user_group = await users_repo.add_member_to_user_group(user_group_id, user)
if not user_group:
raise ControllerNotFoundError(f"User group '{user_group_id}' not found")
@router.delete(
"/{user_group_id}/members/{user_id}",
status_code=status.HTTP_204_NO_CONTENT
)
async def remove_member_from_group(
user_group_id: UUID,
user_id: UUID,
users_repo: UsersRepository = Depends(get_repository(UsersRepository)),
) -> None:
"""
Remove member from an user group.
"""
user = await users_repo.get_user(user_id)
if not user:
raise ControllerNotFoundError(f"User '{user_id}' not found")
user_group = await users_repo.remove_member_from_user_group(user_group_id, user)
if not user_group:
raise ControllerNotFoundError(f"User group '{user_group_id}' not found")

View File

@ -185,3 +185,15 @@ async def get_current_active_user(current_user: schemas.User = Depends(get_curre
""" """
return current_user return current_user
@router.get("/{user_id}/groups", response_model=List[schemas.UserGroup])
async def get_user_memberships(
user_id: UUID,
users_repo: UsersRepository = Depends(get_repository(UsersRepository))
) -> List[schemas.UserGroup]:
"""
Get user memberships.
"""
return await users_repo.get_user_memberships(user_id)

View File

@ -487,7 +487,7 @@ class Compute:
# Try to reconnect after 1 second if server unavailable only if not during tests (otherwise we create a ressources usage bomb) # Try to reconnect after 1 second if server unavailable only if not during tests (otherwise we create a ressources usage bomb)
from gns3server.api.server import app from gns3server.api.server import app
if not app.state.exiting and not hasattr(sys, "_called_from_test") or not sys._called_from_test: if not app.state.exiting and not hasattr(sys, "_called_from_test"):
log.info(f"Reconnecting to to compute '{self._id}' WebSocket '{ws_url}'") log.info(f"Reconnecting to to compute '{self._id}' WebSocket '{ws_url}'")
asyncio.get_event_loop().call_later(1, lambda: asyncio.ensure_future(self.connect())) asyncio.get_event_loop().call_later(1, lambda: asyncio.ensure_future(self.connect()))

View File

@ -16,7 +16,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from .base import Base from .base import Base
from .users import User from .users import User, UserGroup
from .computes import Compute from .computes import Compute
from .templates import ( from .templates import (
Template, Template,

View File

@ -76,8 +76,8 @@ class BaseTable(Base):
__abstract__ = True __abstract__ = True
created_at = Column(DateTime, default=func.current_timestamp()) created_at = Column(DateTime, server_default=func.current_timestamp())
updated_at = Column(DateTime, default=func.current_timestamp(), onupdate=func.current_timestamp()) updated_at = Column(DateTime, server_default=func.current_timestamp(), onupdate=func.current_timestamp())
def generate_uuid(): def generate_uuid():

View File

@ -15,15 +15,23 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from sqlalchemy import Boolean, Column, String, event from sqlalchemy import Table, Boolean, Column, String, ForeignKey, event
from sqlalchemy.orm import relationship
from .base import BaseTable, generate_uuid, GUID from .base import Base, BaseTable, generate_uuid, GUID
from gns3server.services import auth_service from gns3server.services import auth_service
import logging import logging
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
users_group_members = Table(
"users_group_members",
Base.metadata,
Column("user_id", GUID, ForeignKey("users.user_id", ondelete="CASCADE")),
Column("user_group_id", GUID, ForeignKey("users_group.user_group_id", ondelete="CASCADE"))
)
class User(BaseTable): class User(BaseTable):
@ -36,6 +44,7 @@ class User(BaseTable):
hashed_password = Column(String) hashed_password = Column(String)
is_active = Column(Boolean, default=True) is_active = Column(Boolean, default=True)
is_superadmin = Column(Boolean, default=False) is_superadmin = Column(Boolean, default=False)
groups = relationship("UserGroup", secondary=users_group_members, back_populates="users")
@event.listens_for(User.__table__, 'after_create') @event.listens_for(User.__table__, 'after_create')
@ -51,3 +60,46 @@ def create_default_super_admin(target, connection, **kw):
connection.execute(stmt) connection.execute(stmt)
connection.commit() connection.commit()
log.info("The default super admin account has been created in the database") log.info("The default super admin account has been created in the database")
class UserGroup(BaseTable):
__tablename__ = "users_group"
user_group_id = Column(GUID, primary_key=True, default=generate_uuid)
name = Column(String, unique=True, index=True)
is_updatable = Column(Boolean, default=True)
users = relationship("User", secondary=users_group_members, back_populates="groups")
@event.listens_for(UserGroup.__table__, 'after_create')
def create_default_user_groups(target, connection, **kw):
default_groups = [
{"name": "Administrators", "is_updatable": False},
{"name": "Editors", "is_updatable": False},
{"name": "Users", "is_updatable": False}
]
stmt = target.insert().values(default_groups)
connection.execute(stmt)
connection.commit()
log.info("The default user groups have been created in the database")
@event.listens_for(users_group_members, 'after_create')
def add_admin_to_group(target, connection, **kw):
users_group_table = UserGroup.__table__
stmt = users_group_table.select().where(users_group_table.c.name == "Administrators")
result = connection.execute(stmt)
user_group_id = result.first().user_group_id
users_table = User.__table__
stmt = users_table.select().where(users_table.c.username == "admin")
result = connection.execute(stmt)
user_id = result.first().user_id
stmt = target.insert().values(user_id=user_id, user_group_id=user_group_id)
connection.execute(stmt)
connection.commit()

View File

@ -16,9 +16,10 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from uuid import UUID from uuid import UUID
from typing import Optional, List from typing import Optional, List, Union
from sqlalchemy import select, update, delete from sqlalchemy import select, update, delete
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from .base import BaseRepository from .base import BaseRepository
@ -107,3 +108,105 @@ class UsersRepository(BaseRepository):
if not self._auth_service.verify_password(password, user.hashed_password): if not self._auth_service.verify_password(password, user.hashed_password):
return None return None
return user return user
async def get_user_memberships(self, user_id: UUID) -> List[models.UserGroup]:
query = select(models.UserGroup).\
join(models.UserGroup.users).\
filter(models.User.user_id == user_id)
result = await self._db_session.execute(query)
return result.scalars().all()
async def get_user_group(self, user_group_id: UUID) -> Optional[models.UserGroup]:
query = select(models.UserGroup).where(models.UserGroup.user_group_id == user_group_id)
result = await self._db_session.execute(query)
return result.scalars().first()
async def get_user_group_by_name(self, name: str) -> Optional[models.UserGroup]:
query = select(models.UserGroup).where(models.UserGroup.name == name)
result = await self._db_session.execute(query)
return result.scalars().first()
async def get_user_groups(self) -> List[models.UserGroup]:
query = select(models.UserGroup)
result = await self._db_session.execute(query)
return result.scalars().all()
async def create_user_group(self, user_group: schemas.UserGroupCreate) -> models.UserGroup:
db_user_group = models.UserGroup(name=user_group.name)
self._db_session.add(db_user_group)
await self._db_session.commit()
await self._db_session.refresh(db_user_group)
return db_user_group
async def update_user_group(
self,
user_group_id: UUID,
user_group_update: schemas.UserGroupUpdate
) -> Optional[models.UserGroup]:
update_values = user_group_update.dict(exclude_unset=True)
query = update(models.UserGroup).where(models.UserGroup.user_group_id == user_group_id).values(update_values)
await self._db_session.execute(query)
await self._db_session.commit()
return await self.get_user_group(user_group_id)
async def delete_user_group(self, user_group_id: UUID) -> bool:
query = delete(models.UserGroup).where(models.UserGroup.user_group_id == user_group_id)
result = await self._db_session.execute(query)
await self._db_session.commit()
return result.rowcount > 0
async def add_member_to_user_group(
self,
user_group_id: UUID,
user: models.User
) -> Union[None, models.UserGroup]:
query = select(models.UserGroup).\
options(selectinload(models.UserGroup.users)).\
where(models.UserGroup.user_group_id == user_group_id)
result = await self._db_session.execute(query)
user_group_db = result.scalars().first()
if not user_group_db:
return None
user_group_db.users.append(user)
await self._db_session.commit()
await self._db_session.refresh(user_group_db)
return user_group_db
async def remove_member_from_user_group(
self,
user_group_id: UUID,
user: models.User
) -> Union[None, models.UserGroup]:
query = select(models.UserGroup).\
options(selectinload(models.UserGroup.users)).\
where(models.UserGroup.user_group_id == user_group_id)
result = await self._db_session.execute(query)
user_group_db = result.scalars().first()
if not user_group_db:
return None
user_group_db.users.remove(user)
await self._db_session.commit()
await self._db_session.refresh(user_group_db)
return user_group_db
async def get_user_group_members(self, user_group_id: UUID) -> List[models.User]:
query = select(models.User).\
join(models.User.groups).\
filter(models.UserGroup.user_group_id == user_group_id)
result = await self._db_session.execute(query)
return result.scalars().all()

View File

@ -27,7 +27,7 @@ from .controller.drawings import Drawing
from .controller.gns3vm import GNS3VM from .controller.gns3vm import GNS3VM
from .controller.nodes import NodeCreate, NodeUpdate, NodeDuplicate, NodeCapture, Node from .controller.nodes import NodeCreate, NodeUpdate, NodeDuplicate, NodeCapture, Node
from .controller.projects import ProjectCreate, ProjectUpdate, ProjectDuplicate, Project, ProjectFile from .controller.projects import ProjectCreate, ProjectUpdate, ProjectDuplicate, Project, ProjectFile
from .controller.users import UserCreate, UserUpdate, User, Credentials from .controller.users import UserCreate, UserUpdate, User, Credentials, UserGroupCreate, UserGroupUpdate, UserGroup
from .controller.tokens import Token from .controller.tokens import Token
from .controller.snapshots import SnapshotCreate, Snapshot from .controller.snapshots import SnapshotCreate, Snapshot
from .controller.iou_license import IOULicense from .controller.iou_license import IOULicense

View File

@ -58,6 +58,39 @@ class User(DateTimeModelMixin, UserBase):
orm_mode = True orm_mode = True
class UserGroupBase(BaseModel):
"""
Common user group properties.
"""
name: Optional[str] = Field(None, min_length=3, regex="[a-zA-Z0-9_-]+$")
class UserGroupCreate(UserGroupBase):
"""
Properties to create an user group.
"""
name: Optional[str] = Field(..., min_length=3, regex="[a-zA-Z0-9_-]+$")
class UserGroupUpdate(UserGroupBase):
"""
Properties to update an user group.
"""
pass
class UserGroup(DateTimeModelMixin, UserGroupBase):
user_group_id: UUID
is_updatable: bool
class Config:
orm_mode = True
class Credentials(BaseModel): class Credentials(BaseModel):
username: str username: str

View File

@ -319,6 +319,7 @@ class Server:
access_log=access_log, access_log=access_log,
ssl_certfile=config.Server.certfile, ssl_certfile=config.Server.certfile,
ssl_keyfile=config.Server.certkey, ssl_keyfile=config.Server.certkey,
lifespan="on"
) )
# overwrite uvicorn loggers with our own logger # overwrite uvicorn loggers with our own logger

View File

@ -0,0 +1,165 @@
#!/usr/bin/env python
#
# Copyright (C) 2021 GNS3 Technologies Inc.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import pytest
from fastapi import FastAPI, status
from httpx import AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession
from gns3server.db.repositories.users import UsersRepository
from gns3server.schemas.controller.users import User
pytestmark = pytest.mark.asyncio
class TestGroupRoutes:
async def test_create_group(self, app: FastAPI, client: AsyncClient) -> None:
new_group = {"name": "group1"}
response = await client.post(app.url_path_for("create_user_group"), json=new_group)
assert response.status_code == status.HTTP_201_CREATED
async def test_get_group(self, app: FastAPI, client: AsyncClient, db_session: AsyncSession) -> None:
user_repo = UsersRepository(db_session)
group_in_db = await user_repo.get_user_group_by_name("group1")
response = await client.get(app.url_path_for("get_user_group", user_group_id=group_in_db.user_group_id))
assert response.status_code == status.HTTP_200_OK
assert response.json()["user_group_id"] == str(group_in_db.user_group_id)
async def test_list_groups(self, app: FastAPI, client: AsyncClient) -> None:
response = await client.get(app.url_path_for("get_user_groups"))
assert response.status_code == status.HTTP_200_OK
assert len(response.json()) == 4 # 3 default groups + group1
async def test_update_group(self, app: FastAPI, client: AsyncClient, db_session: AsyncSession) -> None:
user_repo = UsersRepository(db_session)
group_in_db = await user_repo.get_user_group_by_name("group1")
update_group = {"name": "group42"}
response = await client.put(
app.url_path_for("update_user_group", user_group_id=group_in_db.user_group_id),
json=update_group
)
assert response.status_code == status.HTTP_200_OK
updated_group_in_db = await user_repo.get_user_group(group_in_db.user_group_id)
assert updated_group_in_db.name == "group42"
async def test_cannot_update_admin_group(
self,
app: FastAPI,
client: AsyncClient,
db_session: AsyncSession
) -> None:
user_repo = UsersRepository(db_session)
group_in_db = await user_repo.get_user_group_by_name("Administrators")
update_group = {"name": "Hackers"}
response = await client.put(
app.url_path_for("update_user_group", user_group_id=group_in_db.user_group_id),
json=update_group
)
assert response.status_code == status.HTTP_403_FORBIDDEN
async def test_delete_group(
self,
app: FastAPI,
client: AsyncClient,
db_session: AsyncSession
) -> None:
user_repo = UsersRepository(db_session)
group_in_db = await user_repo.get_user_group_by_name("group42")
response = await client.delete(app.url_path_for("delete_user_group", user_group_id=group_in_db.user_group_id))
assert response.status_code == status.HTTP_204_NO_CONTENT
async def test_cannot_delete_admin_group(
self,
app: FastAPI,
client: AsyncClient,
db_session: AsyncSession
) -> None:
user_repo = UsersRepository(db_session)
group_in_db = await user_repo.get_user_group_by_name("Administrators")
response = await client.delete(app.url_path_for("delete_user_group", user_group_id=group_in_db.user_group_id))
assert response.status_code == status.HTTP_403_FORBIDDEN
async def test_add_member_to_group(
self,
app: FastAPI,
client: AsyncClient,
test_user: User,
db_session: AsyncSession
) -> None:
user_repo = UsersRepository(db_session)
group_in_db = await user_repo.get_user_group_by_name("Users")
response = await client.put(
app.url_path_for(
"add_member_to_group",
user_group_id=group_in_db.user_group_id,
user_id=str(test_user.user_id)
)
)
assert response.status_code == status.HTTP_204_NO_CONTENT
members = await user_repo.get_user_group_members(group_in_db.user_group_id)
assert len(members) == 1
assert members[0].username == test_user.username
async def test_get_user_group_members(
self,
app: FastAPI,
client: AsyncClient,
db_session: AsyncSession
) -> None:
user_repo = UsersRepository(db_session)
group_in_db = await user_repo.get_user_group_by_name("Users")
response = await client.get(
app.url_path_for(
"get_user_group_members",
user_group_id=group_in_db.user_group_id)
)
assert response.status_code == status.HTTP_200_OK
assert len(response.json()) == 1
async def test_remove_member_from_group(
self,
app: FastAPI,
client: AsyncClient,
test_user: User,
db_session: AsyncSession
) -> None:
user_repo = UsersRepository(db_session)
group_in_db = await user_repo.get_user_group_by_name("Users")
response = await client.delete(
app.url_path_for(
"remove_member_from_group",
user_group_id=group_in_db.user_group_id,
user_id=str(test_user.user_id)
),
)
assert response.status_code == status.HTTP_204_NO_CONTENT
members = await user_repo.get_user_group_members(group_in_db.user_group_id)
assert len(members) == 0

View File

@ -56,8 +56,8 @@ class TestUserRoutes:
assert user_in_db is None assert user_in_db is None
# register the user # register the user
res = await client.post(app.url_path_for("create_user"), json=params) response = await client.post(app.url_path_for("create_user"), json=params)
assert res.status_code == status.HTTP_201_CREATED assert response.status_code == status.HTTP_201_CREATED
# make sure the user does exists in the database now # make sure the user does exists in the database now
user_in_db = await user_repo.get_user_by_username(params["username"]) user_in_db = await user_repo.get_user_by_username(params["username"])
@ -66,7 +66,7 @@ class TestUserRoutes:
assert user_in_db.username == params["username"] assert user_in_db.username == params["username"]
# check that the user returned in the response is equal to the user in the database # check that the user returned in the response is equal to the user in the database
created_user = User(**res.json()).json() created_user = User(**response.json()).json()
assert created_user == User.from_orm(user_in_db).json() assert created_user == User.from_orm(user_in_db).json()
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -91,8 +91,8 @@ class TestUserRoutes:
new_user = {"email": "not_taken@email.com", "username": "not_taken_username", "password": "test_password"} new_user = {"email": "not_taken@email.com", "username": "not_taken_username", "password": "test_password"}
new_user[attr] = value new_user[attr] = value
res = await client.post(app.url_path_for("create_user"), json=new_user) response = await client.post(app.url_path_for("create_user"), json=new_user)
assert res.status_code == status_code assert response.status_code == status_code
async def test_users_saved_password_is_hashed( async def test_users_saved_password_is_hashed(
self, self,
@ -105,8 +105,8 @@ class TestUserRoutes:
new_user = {"username": "user3", "email": "user3@email.com", "password": "test_password"} new_user = {"username": "user3", "email": "user3@email.com", "password": "test_password"}
# send post request to create user and ensure it is successful # send post request to create user and ensure it is successful
res = await client.post(app.url_path_for("create_user"), json=new_user) response = await client.post(app.url_path_for("create_user"), json=new_user)
assert res.status_code == status.HTTP_201_CREATED assert response.status_code == status.HTTP_201_CREATED
# ensure that the users password is hashed in the db # ensure that the users password is hashed in the db
# and that we can verify it using our auth service # and that we can verify it using our auth service
@ -156,7 +156,6 @@ class TestAuthTokens:
username = auth_service.get_username_from_token(token) username = auth_service.get_username_from_token(token)
assert username == test_user.username assert username == test_user.username
@pytest.mark.parametrize( @pytest.mark.parametrize(
"wrong_secret, wrong_token", "wrong_secret, wrong_token",
( (
@ -200,19 +199,19 @@ class TestUserLogin:
"username": test_user.username, "username": test_user.username,
"password": "user1_password", "password": "user1_password",
} }
res = await unauthorized_client.post(app.url_path_for("login"), data=login_data) response = await unauthorized_client.post(app.url_path_for("login"), data=login_data)
assert res.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
# check that token exists in response and has user encoded within it # check that token exists in response and has user encoded within it
token = res.json().get("access_token") token = response.json().get("access_token")
payload = jwt.decode(token, jwt_secret, algorithms=["HS256"]) payload = jwt.decode(token, jwt_secret, algorithms=["HS256"])
assert "sub" in payload assert "sub" in payload
username = payload.get("sub") username = payload.get("sub")
assert username == test_user.username assert username == test_user.username
# check that token is proper type # check that token is proper type
assert "token_type" in res.json() assert "token_type" in response.json()
assert res.json().get("token_type") == "bearer" assert response.json().get("token_type") == "bearer"
async def test_user_can_authenticate_using_json( async def test_user_can_authenticate_using_json(
self, self,
@ -226,9 +225,9 @@ class TestUserLogin:
"username": test_user.username, "username": test_user.username,
"password": "user1_password", "password": "user1_password",
} }
res = await unauthorized_client.post(app.url_path_for("authenticate"), json=credentials) response = await unauthorized_client.post(app.url_path_for("authenticate"), json=credentials)
assert res.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert res.json().get("access_token") assert response.json().get("access_token")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"username, password, status_code", "username, password, status_code",
@ -253,9 +252,9 @@ class TestUserLogin:
"username": username, "username": username,
"password": password, "password": password,
} }
res = await unauthorized_client.post(app.url_path_for("login"), data=login_data) response = await unauthorized_client.post(app.url_path_for("login"), data=login_data)
assert res.status_code == status_code assert response.status_code == status_code
assert "access_token" not in res.json() assert "access_token" not in response.json()
class TestUserMe: class TestUserMe:
@ -267,9 +266,9 @@ class TestUserMe:
test_user: User, test_user: User,
) -> None: ) -> None:
res = await authorized_client.get(app.url_path_for("get_current_active_user")) response = await authorized_client.get(app.url_path_for("get_current_active_user"))
assert res.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
user = User(**res.json()) user = User(**response.json())
assert user.username == test_user.username assert user.username == test_user.username
assert user.email == test_user.email assert user.email == test_user.email
assert user.user_id == test_user.user_id assert user.user_id == test_user.user_id
@ -280,8 +279,8 @@ class TestUserMe:
test_user: User, test_user: User,
) -> None: ) -> None:
res = await unauthorized_client.get(app.url_path_for("get_current_active_user")) response = await unauthorized_client.get(app.url_path_for("get_current_active_user"))
assert res.status_code == status.HTTP_401_UNAUTHORIZED assert response.status_code == status.HTTP_401_UNAUTHORIZED
class TestSuperAdmin: class TestSuperAdmin:
@ -307,8 +306,8 @@ class TestSuperAdmin:
user_repo = UsersRepository(db_session) user_repo = UsersRepository(db_session)
admin_in_db = await user_repo.get_user_by_username("admin") admin_in_db = await user_repo.get_user_by_username("admin")
res = await client.delete(app.url_path_for("delete_user", user_id=admin_in_db.user_id)) response = await client.delete(app.url_path_for("delete_user", user_id=admin_in_db.user_id))
assert res.status_code == status.HTTP_403_FORBIDDEN assert response.status_code == status.HTTP_403_FORBIDDEN
async def test_admin_can_login_after_password_recovery( async def test_admin_can_login_after_password_recovery(
self, self,
@ -327,5 +326,18 @@ class TestSuperAdmin:
"username": "admin", "username": "admin",
"password": "whatever", "password": "whatever",
} }
res = await unauthorized_client.post(app.url_path_for("login"), data=login_data) response = await unauthorized_client.post(app.url_path_for("login"), data=login_data)
assert res.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
async def test_super_admin_belongs_to_admin_group(
self,
app: FastAPI,
client: AsyncClient,
db_session: AsyncSession
) -> None:
user_repo = UsersRepository(db_session)
admin_in_db = await user_repo.get_user_by_username("admin")
response = await client.get(app.url_path_for("get_user_memberships", user_id=admin_in_db.user_id))
assert response.status_code == status.HTTP_200_OK
assert len(response.json()) == 1