1
0
mirror of https://github.com/GNS3/gns3-server synced 2024-11-12 19:38:57 +00:00

User authentication with tests.

This commit is contained in:
grossmj 2020-12-07 16:52:36 +10:30
parent bf7cf862af
commit d47dcb0d6f
14 changed files with 452 additions and 180 deletions

View File

@ -4,6 +4,5 @@ pytest==6.1.2
flake8==3.8.4
pytest-timeout==1.4.2
pytest-asyncio==0.14.0
asgi-lifespan==1.0.1
requests==2.24.0
httpx==0.16.1

View File

@ -23,12 +23,15 @@ from gns3server import schemas
from gns3server.db.repositories.users import UsersRepository
from gns3server.services import auth_service
from .database import get_repository
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/v3/users/login") # FIXME: URL prefix
async def get_user_from_token(token: str = Depends(oauth2_scheme),
user_repo: UsersRepository = Depends()) -> schemas.User:
async def get_user_from_token(
token: str = Depends(oauth2_scheme),
user_repo: UsersRepository = Depends(get_repository(UsersRepository))
) -> schemas.User:
username = auth_service.get_username_from_token(token)
user = await user_repo.get_user_by_username(username)

View File

@ -1,4 +1,4 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2020 GNS3 Technologies Inc.
#
@ -15,12 +15,24 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import os
from typing import Callable, Type
from fastapi import Depends, Request
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.orm import declarative_base
from gns3server.db.repositories.base import BaseRepository
SQLALCHEMY_DATABASE_URL = os.environ.get("DATABASE_URI", "sqlite:///./sql_app.db")
engine = create_async_engine(SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False})
Base = declarative_base()
async def get_db_session(request: Request) -> AsyncSession:
session = AsyncSession(request.app.state._db_engine)
try:
yield session
finally:
await session.close()
def get_repository(repo: Type[BaseRepository]) -> Callable:
def get_repo(db_session: AsyncSession = Depends(get_db_session)) -> Type[BaseRepository]:
return repo(db_session)
return get_repo

View File

@ -20,7 +20,7 @@ API routes for users.
"""
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from fastapi.security import OAuth2PasswordRequestForm
from uuid import UUID
from typing import List
@ -30,6 +30,7 @@ from gns3server.db.repositories.users import UsersRepository
from gns3server.services import auth_service
from .dependencies.authentication import get_current_active_user
from .dependencies.database import get_repository
import logging
log = logging.getLogger(__name__)
@ -38,7 +39,7 @@ router = APIRouter()
@router.get("", response_model=List[schemas.User])
async def get_users(user_repo: UsersRepository = Depends()) -> List[schemas.User]:
async def get_users(user_repo: UsersRepository = Depends(get_repository(UsersRepository))) -> List[schemas.User]:
"""
Get all users.
"""
@ -48,7 +49,10 @@ async def get_users(user_repo: UsersRepository = Depends()) -> List[schemas.User
@router.post("", response_model=schemas.User, status_code=status.HTTP_201_CREATED)
async def create_user(new_user: schemas.UserCreate, user_repo: UsersRepository = Depends()) -> schemas.User:
async def create_user(
new_user: schemas.UserCreate,
user_repo: UsersRepository = Depends(get_repository(UsersRepository))
) -> schemas.User:
"""
Create a new user.
"""
@ -63,7 +67,10 @@ async def create_user(new_user: schemas.UserCreate, user_repo: UsersRepository =
@router.get("/{user_id}", response_model=schemas.User)
async def get_user(user_id: UUID, user_repo: UsersRepository = Depends()) -> schemas.User:
async def get_user(
user_id: UUID,
user_repo: UsersRepository = Depends(get_repository(UsersRepository))
) -> schemas.User:
"""
Get an user.
"""
@ -75,9 +82,11 @@ async def get_user(user_id: UUID, user_repo: UsersRepository = Depends()) -> sch
@router.put("/{user_id}", response_model=schemas.User)
async def update_user(user_id: UUID,
update_user: schemas.UserUpdate,
user_repo: UsersRepository = Depends()) -> schemas.User:
async def update_user(
user_id: UUID,
update_user: schemas.UserUpdate,
user_repo: UsersRepository = Depends(get_repository(UsersRepository))
) -> schemas.User:
"""
Update an user.
"""
@ -89,7 +98,7 @@ async def update_user(user_id: UUID,
@router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_user(user_id: UUID, user_repo: UsersRepository = Depends()):
async def delete_user(user_id: UUID, user_repo: UsersRepository = Depends(get_repository(UsersRepository))):
"""
Delete an user.
"""
@ -100,8 +109,10 @@ async def delete_user(user_id: UUID, user_repo: UsersRepository = Depends()):
@router.post("/login", response_model=schemas.Token)
async def login(user_repo: UsersRepository = Depends(),
form_data: OAuth2PasswordRequestForm = Depends()) -> schemas.Token:
async def login(
user_repo: UsersRepository = Depends(get_repository(UsersRepository)),
form_data: OAuth2PasswordRequestForm = Depends()
) -> schemas.Token:
"""
User login.
"""

View File

@ -45,34 +45,43 @@ import logging
log = logging.getLogger(__name__)
app = FastAPI(title="GNS3 controller API",
description="This page describes the public controller API for GNS3",
version="v3")
def get_application() -> FastAPI:
origins = [
"http://127.0.0.1",
"http://localhost",
"http://127.0.0.1:8080",
"http://localhost:8080",
"http://127.0.0.1:3080",
"http://localhost:3080",
"http://gns3.github.io",
"https://gns3.github.io"
]
application = FastAPI(
title="GNS3 controller API",
description="This page describes the public controller API for GNS3",
version="v3"
)
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
origins = [
"http://127.0.0.1",
"http://localhost",
"http://127.0.0.1:8080",
"http://localhost:8080",
"http://127.0.0.1:3080",
"http://localhost:3080",
"http://gns3.github.io",
"https://gns3.github.io"
]
app.add_event_handler("startup", tasks.create_startup_handler(app))
app.add_event_handler("shutdown", tasks.create_shutdown_handler(app))
app.include_router(index.router, tags=["Index"])
app.include_router(controller.router, prefix="/v3")
app.mount("/v3/compute", compute_api)
application.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
application.add_event_handler("startup", tasks.create_startup_handler(application))
application.add_event_handler("shutdown", tasks.create_shutdown_handler(application))
application.include_router(index.router, tags=["Index"])
application.include_router(controller.router, prefix="/v3")
application.mount("/v3/compute", compute_api)
return application
app = get_application()
@app.exception_handler(ControllerError)

View File

@ -55,7 +55,7 @@ def create_startup_handler(app: FastAPI) -> Callable:
loop.set_debug(True)
# connect to the database
await connect_to_db()
await connect_to_db(app)
await Controller.instance().start()
# Because with a large image collection

View File

@ -21,7 +21,10 @@ from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, DateTime, f
from sqlalchemy.orm import relationship
from sqlalchemy.types import TypeDecorator, CHAR
from sqlalchemy.dialects.postgresql import UUID
from .database import Base
from sqlalchemy.orm import declarative_base
Base = declarative_base()
class GUID(TypeDecorator):
@ -68,11 +71,14 @@ class BaseTable(Base):
onupdate=func.current_timestamp())
def generate_uuid():
return str(uuid.uuid4())
class User(BaseTable):
__tablename__ = "users"
user_id = Column(GUID, primary_key=True, default=str(uuid.uuid4()))
user_id = Column(GUID, primary_key=True, default=generate_uuid)
username = Column(String, unique=True, index=True)
email = Column(String, unique=True, index=True)
full_name = Column(String)

View File

@ -16,14 +16,10 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from sqlalchemy.ext.asyncio import AsyncSession
from ..database import engine
class BaseRepository:
async def db(self):
session = AsyncSession(engine)
try:
yield session
finally:
await session.close()
def __init__(self, db_session: AsyncSession) -> None:
self._db_session = db_session

View File

@ -20,7 +20,6 @@ from typing import Optional, List
from sqlalchemy import select, update, delete
from sqlalchemy.ext.asyncio import AsyncSession
from ..database import engine
from .base import BaseRepository
import gns3server.db.models as models
@ -30,73 +29,66 @@ from gns3server.services import auth_service
class UsersRepository(BaseRepository):
def __init__(self) -> None:
def __init__(self, db_session: AsyncSession) -> None:
super().__init__()
super().__init__(db_session)
self._auth_service = auth_service
async def get_user(self, user_id: UUID) -> Optional[models.User]:
async with AsyncSession(engine) as session:
result = await session.execute(select(models.User).where(models.User.user_id == user_id))
return result.scalars().first()
result = await self._db_session.execute(select(models.User).where(models.User.user_id == user_id))
return result.scalars().first()
async def get_user_by_username(self, username: str) -> Optional[models.User]:
async with AsyncSession(engine) as session:
result = await session.execute(select(models.User).where(models.User.username == username))
return result.scalars().first()
result = await self._db_session.execute(select(models.User).where(models.User.username == username))
return result.scalars().first()
async def get_user_by_email(self, email: str) -> Optional[models.User]:
async with AsyncSession(engine) as session:
result = await session.execute(select(models.User).where(models.User.email == email))
return result.scalars().first()
result = await self._db_session.execute(select(models.User).where(models.User.email == email))
return result.scalars().first()
async def get_users(self) -> List[models.User]:
async with AsyncSession(engine) as session:
result = await session.execute(select(models.User))
return result.scalars().all()
result = await self._db_session.execute(select(models.User))
return result.scalars().all()
async def create_user(self, user: schemas.UserCreate) -> models.User:
async with AsyncSession(engine) as session:
hashed_password = self._auth_service.hash_password(user.password)
db_user = models.User(username=user.username,
email=user.email,
full_name=user.full_name,
hashed_password=hashed_password)
session.add(db_user)
await session.commit()
await session.refresh(db_user)
return db_user
hashed_password = self._auth_service.hash_password(user.password)
db_user = models.User(
username=user.username,
email=user.email,
full_name=user.full_name,
hashed_password=hashed_password
)
self._db_session.add(db_user)
await self._db_session.commit()
await self._db_session.refresh(db_user)
return db_user
async def update_user(self, user_id: UUID, user_update: schemas.UserUpdate) -> Optional[models.User]:
async with AsyncSession(engine) as session:
update_values = user_update.dict(exclude_unset=True)
password = update_values.pop("password", None)
if password:
update_values["hashed_password"] = self._auth_service.hash_password(password=password)
update_values = user_update.dict(exclude_unset=True)
password = update_values.pop("password", None)
if password:
update_values["hashed_password"] = self._auth_service.hash_password(password=password)
query = update(models.User) \
.where(models.User.user_id == user_id) \
.values(update_values)
print(update_values)
query = update(models.User) \
.where(models.User.user_id == user_id) \
.values(update_values)
await session.execute(query)
await session.commit()
return await self.get_user(user_id)
await self._db_session.execute(query)
await self._db_session.commit()
return await self.get_user(user_id)
async def delete_user(self, user_id: UUID) -> bool:
async with AsyncSession(engine) as session:
query = delete(models.User).where(models.User.user_id == user_id)
result = await session.execute(query)
await session.commit()
return result.rowcount > 0
query = delete(models.User).where(models.User.user_id == user_id)
result = await self._db_session.execute(query)
await self._db_session.commit()
return result.rowcount > 0
#except:
# await session.rollback()

View File

@ -15,20 +15,26 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from sqlalchemy.exc import SQLAlchemyError
import os
from fastapi import FastAPI
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import create_async_engine
from .database import engine
from .models import Base
import logging
log = logging.getLogger(__name__)
async def connect_to_db() -> None:
async def connect_to_db(app: FastAPI) -> None:
db_url = os.environ.get("GNS3_DATABASE_URI", "sqlite:///./sql_app.db")
engine = create_async_engine(db_url, connect_args={"check_same_thread": False}, future=True)
try:
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
log.info("Successfully connected to the database")
app.state._db_engine = engine
except SQLAlchemyError as e:
log.error(f"Error while connecting to the database: {e}")

View File

@ -35,10 +35,6 @@ ACCESS_TOKEN_EXPIRE_MINUTES = 30
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
#class AuthException(BaseException):
# pass
class AuthService:
def hash_password(self, password: str) -> str:
@ -49,14 +45,19 @@ class AuthService:
return pwd_context.verify(password, hashed_password)
def create_access_token(self, username):
def create_access_token(
self,
username,
secret_key: str = SECRET_KEY,
expires_in: int = ACCESS_TOKEN_EXPIRE_MINUTES
) -> str:
expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
expire = datetime.utcnow() + timedelta(minutes=expires_in)
to_encode = {"sub": username, "exp": expire}
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=ALGORITHM)
return encoded_jwt
def get_username_from_token(self, token: str) -> Optional[str]:
def get_username_from_token(self, token: str, secret_key: str = SECRET_KEY) -> Optional[str]:
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -64,7 +65,7 @@ class AuthService:
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
payload = jwt.decode(token, secret_key, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise credentials_exception

View File

@ -1,5 +1,5 @@
uvicorn==0.11.8 # force version to 0.11.8 because of https://github.com/encode/uvicorn/issues/841
fastapi==0.61.2
fastapi==0.62.0
websockets==8.1
python-multipart==0.0.5
aiohttp==3.7.2

View File

@ -17,46 +17,237 @@
import pytest
from fastapi import FastAPI, status
from fastapi.encoders import jsonable_encoder
from typing import Optional, Union
from fastapi import FastAPI, HTTPException, status
from starlette.datastructures import Secret
from httpx import AsyncClient
from jose import jwt
from sqlalchemy.ext.asyncio import AsyncSession
from gns3server.db.repositories.users import UsersRepository
from gns3server.services import auth_service
from gns3server.services.authentication import SECRET_KEY, ALGORITHM
from gns3server.schemas.users import User
pytestmark = pytest.mark.asyncio
# async def test_route_exist(app: FastAPI, client: AsyncClient) -> None:
#
# params = {"username": "test_username", "email": "user@email.com", "password": "test_password"}
# response = await client.post(app.url_path_for("create_user"), json=params)
# assert response.status_code != status.HTTP_404_NOT_FOUND
#
#
# async def test_users_can_register_successfully(app: FastAPI, client: AsyncClient) -> None:
#
# user_repo = UsersRepository()
# params = {"username": "test_username2", "email": "user2@email.com", "password": "test_password2"}
#
# # make sure the user doesn't exist in the database
# user_in_db = await user_repo.get_user_by_username(params["username"])
# assert user_in_db is None
#
# # register the user
# res = await client.post(app.url_path_for("create_user"), json=params)
# assert res.status_code == status.HTTP_201_CREATED
#
# # make sure the user does exists in the database now
# user_in_db = await user_repo.get_user_by_username(params["username"])
# assert user_in_db is not None
# assert user_in_db.email == params["email"]
# assert user_in_db.username == params["username"]
#
# # check that the user returned in the response is equal to the user in the database
# created_user = User(**res.json()).json()
# print(created_user)
# #print(user_in_db.__dict__)
# test = jsonable_encoder(user_in_db.__dict__, exclude={"_sa_instance_state", "hashed_password"})
# print(test)
# assert created_user == test
class TestUserRoutes:
async def test_route_exist(self, app: FastAPI, client: AsyncClient) -> None:
new_user = {"username": "test_user1", "email": "user1@email.com", "password": "test_password"}
response = await client.post(app.url_path_for("create_user"), json=new_user)
assert response.status_code != status.HTTP_404_NOT_FOUND
async def test_users_can_register_successfully(
self,
app: FastAPI,
client: AsyncClient,
db_session: AsyncSession
) -> None:
user_repo = UsersRepository(db_session)
params = {"username": "test_user2", "email": "user2@email.com", "password": "test_password"}
# make sure the user doesn't exist in the database
user_in_db = await user_repo.get_user_by_username(params["username"])
assert user_in_db is None
# register the user
res = await client.post(app.url_path_for("create_user"), json=params)
assert res.status_code == status.HTTP_201_CREATED
# make sure the user does exists in the database now
user_in_db = await user_repo.get_user_by_username(params["username"])
assert user_in_db is not None
assert user_in_db.email == params["email"]
assert user_in_db.username == params["username"]
# check that the user returned in the response is equal to the user in the database
created_user = User(**res.json()).json()
assert created_user == User.from_orm(user_in_db).json()
@pytest.mark.parametrize(
"attr, value, status_code",
(
("email", "user2@email.com", status.HTTP_400_BAD_REQUEST),
("username", "test_user2", status.HTTP_400_BAD_REQUEST),
("email", "invalid_email@one@two.io", status.HTTP_422_UNPROCESSABLE_ENTITY),
("password", "short", status.HTTP_422_UNPROCESSABLE_ENTITY),
("username", "user2@#$%^<>", status.HTTP_422_UNPROCESSABLE_ENTITY),
("username", "ab", status.HTTP_422_UNPROCESSABLE_ENTITY),
)
)
async def test_user_registration_fails_when_credentials_are_taken(
self,
app: FastAPI,
client: AsyncClient,
attr: str,
value: str,
status_code: int,
) -> None:
new_user = {"email": "not_taken@email.com", "username": "not_taken_username", "password": "test_password"}
new_user[attr] = value
res = await client.post(app.url_path_for("create_user"), json=new_user)
assert res.status_code == status_code
async def test_users_saved_password_is_hashed(
self,
app: FastAPI,
client: AsyncClient,
db_session: AsyncSession
) -> None:
user_repo = UsersRepository(db_session)
new_user = {"username": "test_user3", "email": "user3@email.com", "password": "test_password"}
# send post request to create user and ensure it is successful
res = await client.post(app.url_path_for("create_user"), json=new_user)
assert res.status_code == status.HTTP_201_CREATED
# ensure that the users password is hashed in the db
# and that we can verify it using our auth service
user_in_db = await user_repo.get_user_by_username(new_user["username"])
assert user_in_db is not None
assert user_in_db.hashed_password != new_user["password"]
assert auth_service.verify_password(new_user["password"], user_in_db.hashed_password)
class TestAuthTokens:
async def test_can_create_token_successfully(
self,
app: FastAPI,
client: AsyncClient,
test_user: User
) -> None:
token = auth_service.create_access_token(test_user.username)
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username = payload.get("sub")
assert username == test_user.username
async def test_token_missing_user_is_invalid(self, app: FastAPI, client: AsyncClient) -> None:
token = auth_service.create_access_token(None)
with pytest.raises(jwt.JWTError):
jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
async def test_can_retrieve_username_from_token(
self,
app: FastAPI,
client: AsyncClient,
test_user: User
) -> None:
token = auth_service.create_access_token(test_user.username)
username = auth_service.get_username_from_token(token)
assert username == test_user.username
@pytest.mark.parametrize(
"secret, wrong_token",
(
(SECRET_KEY, "asdf"), # use wrong token
(SECRET_KEY, ""), # use wrong token
("ABC123", "use correct token"), # use wrong secret
),
)
async def test_error_when_token_or_secret_is_wrong(
self,
app: FastAPI,
client: AsyncClient,
test_user: User,
secret: Union[Secret, str],
wrong_token: Optional[str],
) -> None:
token = auth_service.create_access_token(test_user.username)
if wrong_token == "use correct token":
wrong_token = token
with pytest.raises(HTTPException):
auth_service.get_username_from_token(wrong_token, secret_key=str(secret))
class TestUserLogin:
async def test_user_can_login_successfully_and_receives_valid_token(
self,
app: FastAPI,
client: AsyncClient,
test_user: User,
) -> None:
client.headers["content-type"] = "application/x-www-form-urlencoded"
login_data = {
"username": test_user.username,
"password": "user1_password",
}
res = await client.post(app.url_path_for("login"), data=login_data)
assert res.status_code == status.HTTP_200_OK
# check that token exists in response and has user encoded within it
token = res.json().get("access_token")
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
assert "sub" in payload
username = payload.get("sub")
assert username == test_user.username
# check that token is proper type
assert "token_type" in res.json()
assert res.json().get("token_type") == "bearer"
@pytest.mark.parametrize(
"username, password, status_code",
(
("wrong_username", "user1_password", status.HTTP_401_UNAUTHORIZED),
("user1", "wrong_password", status.HTTP_401_UNAUTHORIZED),
("user1", None, status.HTTP_401_UNAUTHORIZED),
),
)
async def test_user_with_wrong_creds_doesnt_receive_token(
self,
app: FastAPI,
client: AsyncClient,
test_user: User,
username: str,
password: str,
status_code: int,
) -> None:
client.headers["content-type"] = "application/x-www-form-urlencoded"
login_data = {
"username": username,
"password": password,
}
res = await client.post(app.url_path_for("login"), data=login_data)
assert res.status_code == status_code
assert "access_token" not in res.json()
class TestUserMe:
async def test_authenticated_user_can_retrieve_own_data(
self,
app: FastAPI,
authorized_client: AsyncClient,
test_user: User,
) -> None:
res = await authorized_client.get(app.url_path_for("get_current_active_user"))
assert res.status_code == status.HTTP_200_OK
user = User(**res.json())
assert user.username == test_user.username
assert user.email == test_user.email
assert user.user_id == test_user.user_id
async def test_user_cannot_access_own_data_if_not_authenticated(
self, app: FastAPI,
client: AsyncClient,
test_user: User,
) -> None:
res = await client.get(app.url_path_for("get_current_active_user"))
assert res.status_code == status.HTTP_401_UNAUTHORIZED

View File

@ -7,7 +7,6 @@ import os
from fastapi import FastAPI
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from asgi_lifespan import LifespanManager
from httpx import AsyncClient
from unittest.mock import MagicMock, patch
from pathlib import Path
@ -17,25 +16,16 @@ from gns3server.config import Config
from gns3server.compute import MODULES
from gns3server.compute.port_manager import PortManager
from gns3server.compute.project_manager import ProjectManager
from gns3server.db.database import Base
from gns3server.db.models import Base, User
from gns3server.db.repositories.users import UsersRepository
from gns3server.api.routes.controller.dependencies.database import get_db_session
from gns3server.schemas.users import UserCreate
from gns3server.services import auth_service
sys._called_from_test = True
sys.original_platform = sys.platform
SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db"
engine = create_async_engine(
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
)
async def start_db():
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
await conn.run_sync(Base.metadata.create_all)
if sys.platform.startswith("win") and sys.version_info < (3, 8):
@pytest.yield_fixture(scope="session")
def event_loop(request):
@ -49,34 +39,64 @@ if sys.platform.startswith("win") and sys.version_info < (3, 8):
yield loop
asyncio.set_event_loop(None)
# https://github.com/pytest-dev/pytest-asyncio/issues/68
# this event_loop is used by pytest-asyncio, and redefining it
# is currently the only way of changing the scope of this fixture
@pytest.yield_fixture(scope="session")
def event_loop(request):
# @pytest.mark.asyncio
# @pytest.fixture(scope="session", autouse=True)
# async def database_connection() -> None:
#
# from gns3server.db.tasks import connect_to_db
# os.environ["DATABASE_URI"] = "sqlite:///./sql_app_test.db"
# await connect_to_db()
# yield
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
@pytest.fixture#(scope="session")
async def app() -> FastAPI:
@pytest.fixture(scope="session")
async def app(db_engine) -> FastAPI:
from gns3server.api.server import app as gns3_app
gns3_app.add_event_handler("startup", start_db())
return gns3_app
async with db_engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
await conn.run_sync(Base.metadata.create_all)
from gns3server.api.server import app as gns3app
yield gns3app
# Grab a reference to our database when needed
#@pytest.fixture
#def db(app: FastAPI) -> Database:
# return app.state._db
@pytest.fixture(scope="session")
def db_engine():
db_url = os.getenv("GNS3_TEST_DATABASE_URI", "sqlite:///:memory:") # "sqlite:///./sql_test_app.db"
engine = create_async_engine(db_url, connect_args={"check_same_thread": False}, future=True)
yield engine
engine.sync_engine.dispose()
@pytest.fixture(scope="class")
async def db_session(app: FastAPI, db_engine):
# recreate database tables for each class
# preferred and faster way would be to rollback the session/transaction
# but it doesn't work for some reason
async with db_engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
await conn.run_sync(Base.metadata.create_all)
session = AsyncSession(db_engine)
try:
yield session
finally:
await session.close()
@pytest.fixture
async def client(app: FastAPI) -> AsyncClient:
async def client(app: FastAPI, db_session: AsyncSession) -> AsyncClient:
async def _get_test_db():
try:
yield db_session
finally:
pass
app.dependency_overrides[get_db_session] = _get_test_db
#async with LifespanManager(app):
async with AsyncClient(
app=app,
base_url="http://test-api",
@ -85,6 +105,32 @@ async def client(app: FastAPI) -> AsyncClient:
yield client
@pytest.fixture
async def test_user(db_session: AsyncSession) -> User:
new_user = UserCreate(
username="user1",
email="user1@email.com",
password="user1_password",
)
user_repo = UsersRepository(db_session)
existing_user = await user_repo.get_user_by_username(new_user.username)
if existing_user:
return existing_user
return await user_repo.create_user(new_user)
@pytest.fixture
def authorized_client(client: AsyncClient, test_user: User) -> AsyncClient:
access_token = auth_service.create_access_token(test_user.username)
client.headers = {
**client.headers,
"Authorization": f"Bearer {access_token}",
}
return client
@pytest.fixture
def controller_config_path(tmpdir):