mirror of
https://github.com/GNS3/gns3-server
synced 2025-01-16 11:00:58 +00:00
User authentication with tests.
This commit is contained in:
parent
bf7cf862af
commit
d47dcb0d6f
@ -4,6 +4,5 @@ pytest==6.1.2
|
||||
flake8==3.8.4
|
||||
pytest-timeout==1.4.2
|
||||
pytest-asyncio==0.14.0
|
||||
asgi-lifespan==1.0.1
|
||||
requests==2.24.0
|
||||
httpx==0.16.1
|
||||
|
@ -23,12 +23,15 @@ from gns3server import schemas
|
||||
from gns3server.db.repositories.users import UsersRepository
|
||||
from gns3server.services import auth_service
|
||||
|
||||
from .database import get_repository
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/v3/users/login") # FIXME: URL prefix
|
||||
|
||||
|
||||
async def get_user_from_token(token: str = Depends(oauth2_scheme),
|
||||
user_repo: UsersRepository = Depends()) -> schemas.User:
|
||||
async def get_user_from_token(
|
||||
token: str = Depends(oauth2_scheme),
|
||||
user_repo: UsersRepository = Depends(get_repository(UsersRepository))
|
||||
) -> schemas.User:
|
||||
|
||||
username = auth_service.get_username_from_token(token)
|
||||
user = await user_repo.get_user_by_username(username)
|
||||
|
@ -1,4 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright (C) 2020 GNS3 Technologies Inc.
|
||||
#
|
||||
@ -15,12 +15,24 @@
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import os
|
||||
from typing import Callable, Type
|
||||
from fastapi import Depends, Request
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.orm import declarative_base
|
||||
from gns3server.db.repositories.base import BaseRepository
|
||||
|
||||
SQLALCHEMY_DATABASE_URL = os.environ.get("DATABASE_URI", "sqlite:///./sql_app.db")
|
||||
|
||||
engine = create_async_engine(SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False})
|
||||
Base = declarative_base()
|
||||
async def get_db_session(request: Request) -> AsyncSession:
|
||||
|
||||
session = AsyncSession(request.app.state._db_engine)
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
def get_repository(repo: Type[BaseRepository]) -> Callable:
|
||||
|
||||
def get_repo(db_session: AsyncSession = Depends(get_db_session)) -> Type[BaseRepository]:
|
||||
return repo(db_session)
|
||||
return get_repo
|
@ -20,7 +20,7 @@ API routes for users.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from uuid import UUID
|
||||
from typing import List
|
||||
|
||||
@ -30,6 +30,7 @@ from gns3server.db.repositories.users import UsersRepository
|
||||
from gns3server.services import auth_service
|
||||
|
||||
from .dependencies.authentication import get_current_active_user
|
||||
from .dependencies.database import get_repository
|
||||
|
||||
import logging
|
||||
log = logging.getLogger(__name__)
|
||||
@ -38,7 +39,7 @@ router = APIRouter()
|
||||
|
||||
|
||||
@router.get("", response_model=List[schemas.User])
|
||||
async def get_users(user_repo: UsersRepository = Depends()) -> List[schemas.User]:
|
||||
async def get_users(user_repo: UsersRepository = Depends(get_repository(UsersRepository))) -> List[schemas.User]:
|
||||
"""
|
||||
Get all users.
|
||||
"""
|
||||
@ -48,7 +49,10 @@ async def get_users(user_repo: UsersRepository = Depends()) -> List[schemas.User
|
||||
|
||||
|
||||
@router.post("", response_model=schemas.User, status_code=status.HTTP_201_CREATED)
|
||||
async def create_user(new_user: schemas.UserCreate, user_repo: UsersRepository = Depends()) -> schemas.User:
|
||||
async def create_user(
|
||||
new_user: schemas.UserCreate,
|
||||
user_repo: UsersRepository = Depends(get_repository(UsersRepository))
|
||||
) -> schemas.User:
|
||||
"""
|
||||
Create a new user.
|
||||
"""
|
||||
@ -63,7 +67,10 @@ async def create_user(new_user: schemas.UserCreate, user_repo: UsersRepository =
|
||||
|
||||
|
||||
@router.get("/{user_id}", response_model=schemas.User)
|
||||
async def get_user(user_id: UUID, user_repo: UsersRepository = Depends()) -> schemas.User:
|
||||
async def get_user(
|
||||
user_id: UUID,
|
||||
user_repo: UsersRepository = Depends(get_repository(UsersRepository))
|
||||
) -> schemas.User:
|
||||
"""
|
||||
Get an user.
|
||||
"""
|
||||
@ -75,9 +82,11 @@ async def get_user(user_id: UUID, user_repo: UsersRepository = Depends()) -> sch
|
||||
|
||||
|
||||
@router.put("/{user_id}", response_model=schemas.User)
|
||||
async def update_user(user_id: UUID,
|
||||
update_user: schemas.UserUpdate,
|
||||
user_repo: UsersRepository = Depends()) -> schemas.User:
|
||||
async def update_user(
|
||||
user_id: UUID,
|
||||
update_user: schemas.UserUpdate,
|
||||
user_repo: UsersRepository = Depends(get_repository(UsersRepository))
|
||||
) -> schemas.User:
|
||||
"""
|
||||
Update an user.
|
||||
"""
|
||||
@ -89,7 +98,7 @@ async def update_user(user_id: UUID,
|
||||
|
||||
|
||||
@router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_user(user_id: UUID, user_repo: UsersRepository = Depends()):
|
||||
async def delete_user(user_id: UUID, user_repo: UsersRepository = Depends(get_repository(UsersRepository))):
|
||||
"""
|
||||
Delete an user.
|
||||
"""
|
||||
@ -100,8 +109,10 @@ async def delete_user(user_id: UUID, user_repo: UsersRepository = Depends()):
|
||||
|
||||
|
||||
@router.post("/login", response_model=schemas.Token)
|
||||
async def login(user_repo: UsersRepository = Depends(),
|
||||
form_data: OAuth2PasswordRequestForm = Depends()) -> schemas.Token:
|
||||
async def login(
|
||||
user_repo: UsersRepository = Depends(get_repository(UsersRepository)),
|
||||
form_data: OAuth2PasswordRequestForm = Depends()
|
||||
) -> schemas.Token:
|
||||
"""
|
||||
User login.
|
||||
"""
|
||||
|
@ -45,34 +45,43 @@ import logging
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
app = FastAPI(title="GNS3 controller API",
|
||||
description="This page describes the public controller API for GNS3",
|
||||
version="v3")
|
||||
def get_application() -> FastAPI:
|
||||
|
||||
origins = [
|
||||
"http://127.0.0.1",
|
||||
"http://localhost",
|
||||
"http://127.0.0.1:8080",
|
||||
"http://localhost:8080",
|
||||
"http://127.0.0.1:3080",
|
||||
"http://localhost:3080",
|
||||
"http://gns3.github.io",
|
||||
"https://gns3.github.io"
|
||||
]
|
||||
application = FastAPI(
|
||||
title="GNS3 controller API",
|
||||
description="This page describes the public controller API for GNS3",
|
||||
version="v3"
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
origins = [
|
||||
"http://127.0.0.1",
|
||||
"http://localhost",
|
||||
"http://127.0.0.1:8080",
|
||||
"http://localhost:8080",
|
||||
"http://127.0.0.1:3080",
|
||||
"http://localhost:3080",
|
||||
"http://gns3.github.io",
|
||||
"https://gns3.github.io"
|
||||
]
|
||||
|
||||
app.add_event_handler("startup", tasks.create_startup_handler(app))
|
||||
app.add_event_handler("shutdown", tasks.create_shutdown_handler(app))
|
||||
app.include_router(index.router, tags=["Index"])
|
||||
app.include_router(controller.router, prefix="/v3")
|
||||
app.mount("/v3/compute", compute_api)
|
||||
application.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
application.add_event_handler("startup", tasks.create_startup_handler(application))
|
||||
application.add_event_handler("shutdown", tasks.create_shutdown_handler(application))
|
||||
application.include_router(index.router, tags=["Index"])
|
||||
application.include_router(controller.router, prefix="/v3")
|
||||
application.mount("/v3/compute", compute_api)
|
||||
|
||||
return application
|
||||
|
||||
|
||||
app = get_application()
|
||||
|
||||
|
||||
@app.exception_handler(ControllerError)
|
||||
|
@ -55,7 +55,7 @@ def create_startup_handler(app: FastAPI) -> Callable:
|
||||
loop.set_debug(True)
|
||||
|
||||
# connect to the database
|
||||
await connect_to_db()
|
||||
await connect_to_db(app)
|
||||
|
||||
await Controller.instance().start()
|
||||
# Because with a large image collection
|
||||
|
@ -21,7 +21,10 @@ from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, DateTime, f
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.types import TypeDecorator, CHAR
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from .database import Base
|
||||
|
||||
from sqlalchemy.orm import declarative_base
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class GUID(TypeDecorator):
|
||||
@ -68,11 +71,14 @@ class BaseTable(Base):
|
||||
onupdate=func.current_timestamp())
|
||||
|
||||
|
||||
def generate_uuid():
|
||||
return str(uuid.uuid4())
|
||||
|
||||
class User(BaseTable):
|
||||
|
||||
__tablename__ = "users"
|
||||
|
||||
user_id = Column(GUID, primary_key=True, default=str(uuid.uuid4()))
|
||||
user_id = Column(GUID, primary_key=True, default=generate_uuid)
|
||||
username = Column(String, unique=True, index=True)
|
||||
email = Column(String, unique=True, index=True)
|
||||
full_name = Column(String)
|
||||
|
@ -16,14 +16,10 @@
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from ..database import engine
|
||||
|
||||
|
||||
class BaseRepository:
|
||||
|
||||
async def db(self):
|
||||
session = AsyncSession(engine)
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
await session.close()
|
||||
def __init__(self, db_session: AsyncSession) -> None:
|
||||
|
||||
self._db_session = db_session
|
||||
|
@ -20,7 +20,6 @@ from typing import Optional, List
|
||||
from sqlalchemy import select, update, delete
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ..database import engine
|
||||
from .base import BaseRepository
|
||||
|
||||
import gns3server.db.models as models
|
||||
@ -30,73 +29,66 @@ from gns3server.services import auth_service
|
||||
|
||||
class UsersRepository(BaseRepository):
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, db_session: AsyncSession) -> None:
|
||||
|
||||
super().__init__()
|
||||
super().__init__(db_session)
|
||||
self._auth_service = auth_service
|
||||
|
||||
async def get_user(self, user_id: UUID) -> Optional[models.User]:
|
||||
|
||||
async with AsyncSession(engine) as session:
|
||||
result = await session.execute(select(models.User).where(models.User.user_id == user_id))
|
||||
return result.scalars().first()
|
||||
result = await self._db_session.execute(select(models.User).where(models.User.user_id == user_id))
|
||||
return result.scalars().first()
|
||||
|
||||
async def get_user_by_username(self, username: str) -> Optional[models.User]:
|
||||
|
||||
async with AsyncSession(engine) as session:
|
||||
result = await session.execute(select(models.User).where(models.User.username == username))
|
||||
return result.scalars().first()
|
||||
result = await self._db_session.execute(select(models.User).where(models.User.username == username))
|
||||
return result.scalars().first()
|
||||
|
||||
async def get_user_by_email(self, email: str) -> Optional[models.User]:
|
||||
|
||||
async with AsyncSession(engine) as session:
|
||||
result = await session.execute(select(models.User).where(models.User.email == email))
|
||||
return result.scalars().first()
|
||||
result = await self._db_session.execute(select(models.User).where(models.User.email == email))
|
||||
return result.scalars().first()
|
||||
|
||||
async def get_users(self) -> List[models.User]:
|
||||
|
||||
async with AsyncSession(engine) as session:
|
||||
result = await session.execute(select(models.User))
|
||||
return result.scalars().all()
|
||||
result = await self._db_session.execute(select(models.User))
|
||||
return result.scalars().all()
|
||||
|
||||
async def create_user(self, user: schemas.UserCreate) -> models.User:
|
||||
|
||||
async with AsyncSession(engine) as session:
|
||||
hashed_password = self._auth_service.hash_password(user.password)
|
||||
db_user = models.User(username=user.username,
|
||||
email=user.email,
|
||||
full_name=user.full_name,
|
||||
hashed_password=hashed_password)
|
||||
session.add(db_user)
|
||||
await session.commit()
|
||||
await session.refresh(db_user)
|
||||
return db_user
|
||||
hashed_password = self._auth_service.hash_password(user.password)
|
||||
db_user = models.User(
|
||||
username=user.username,
|
||||
email=user.email,
|
||||
full_name=user.full_name,
|
||||
hashed_password=hashed_password
|
||||
)
|
||||
self._db_session.add(db_user)
|
||||
await self._db_session.commit()
|
||||
await self._db_session.refresh(db_user)
|
||||
return db_user
|
||||
|
||||
async def update_user(self, user_id: UUID, user_update: schemas.UserUpdate) -> Optional[models.User]:
|
||||
|
||||
async with AsyncSession(engine) as session:
|
||||
update_values = user_update.dict(exclude_unset=True)
|
||||
password = update_values.pop("password", None)
|
||||
if password:
|
||||
update_values["hashed_password"] = self._auth_service.hash_password(password=password)
|
||||
|
||||
update_values = user_update.dict(exclude_unset=True)
|
||||
password = update_values.pop("password", None)
|
||||
if password:
|
||||
update_values["hashed_password"] = self._auth_service.hash_password(password=password)
|
||||
query = update(models.User) \
|
||||
.where(models.User.user_id == user_id) \
|
||||
.values(update_values)
|
||||
|
||||
print(update_values)
|
||||
query = update(models.User) \
|
||||
.where(models.User.user_id == user_id) \
|
||||
.values(update_values)
|
||||
|
||||
await session.execute(query)
|
||||
await session.commit()
|
||||
return await self.get_user(user_id)
|
||||
await self._db_session.execute(query)
|
||||
await self._db_session.commit()
|
||||
return await self.get_user(user_id)
|
||||
|
||||
async def delete_user(self, user_id: UUID) -> bool:
|
||||
|
||||
async with AsyncSession(engine) as session:
|
||||
query = delete(models.User).where(models.User.user_id == user_id)
|
||||
result = await session.execute(query)
|
||||
await session.commit()
|
||||
return result.rowcount > 0
|
||||
query = delete(models.User).where(models.User.user_id == user_id)
|
||||
result = await self._db_session.execute(query)
|
||||
await self._db_session.commit()
|
||||
return result.rowcount > 0
|
||||
#except:
|
||||
# await session.rollback()
|
||||
|
||||
|
@ -15,20 +15,26 @@
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
import os
|
||||
|
||||
from fastapi import FastAPI
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
|
||||
from .database import engine
|
||||
from .models import Base
|
||||
|
||||
import logging
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def connect_to_db() -> None:
|
||||
async def connect_to_db(app: FastAPI) -> None:
|
||||
|
||||
db_url = os.environ.get("GNS3_DATABASE_URI", "sqlite:///./sql_app.db")
|
||||
engine = create_async_engine(db_url, connect_args={"check_same_thread": False}, future=True)
|
||||
try:
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
log.info("Successfully connected to the database")
|
||||
app.state._db_engine = engine
|
||||
except SQLAlchemyError as e:
|
||||
log.error(f"Error while connecting to the database: {e}")
|
||||
|
@ -35,10 +35,6 @@ ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
|
||||
#class AuthException(BaseException):
|
||||
# pass
|
||||
|
||||
|
||||
class AuthService:
|
||||
|
||||
def hash_password(self, password: str) -> str:
|
||||
@ -49,14 +45,19 @@ class AuthService:
|
||||
|
||||
return pwd_context.verify(password, hashed_password)
|
||||
|
||||
def create_access_token(self, username):
|
||||
def create_access_token(
|
||||
self,
|
||||
username,
|
||||
secret_key: str = SECRET_KEY,
|
||||
expires_in: int = ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
) -> str:
|
||||
|
||||
expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
expire = datetime.utcnow() + timedelta(minutes=expires_in)
|
||||
to_encode = {"sub": username, "exp": expire}
|
||||
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||
encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
def get_username_from_token(self, token: str) -> Optional[str]:
|
||||
def get_username_from_token(self, token: str, secret_key: str = SECRET_KEY) -> Optional[str]:
|
||||
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
@ -64,7 +65,7 @@ class AuthService:
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
payload = jwt.decode(token, secret_key, algorithms=[ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
if username is None:
|
||||
raise credentials_exception
|
||||
|
@ -1,5 +1,5 @@
|
||||
uvicorn==0.11.8 # force version to 0.11.8 because of https://github.com/encode/uvicorn/issues/841
|
||||
fastapi==0.61.2
|
||||
fastapi==0.62.0
|
||||
websockets==8.1
|
||||
python-multipart==0.0.5
|
||||
aiohttp==3.7.2
|
||||
|
@ -17,46 +17,237 @@
|
||||
|
||||
import pytest
|
||||
|
||||
from fastapi import FastAPI, status
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from typing import Optional, Union
|
||||
from fastapi import FastAPI, HTTPException, status
|
||||
from starlette.datastructures import Secret
|
||||
from httpx import AsyncClient
|
||||
from jose import jwt
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from gns3server.db.repositories.users import UsersRepository
|
||||
from gns3server.services import auth_service
|
||||
from gns3server.services.authentication import SECRET_KEY, ALGORITHM
|
||||
from gns3server.schemas.users import User
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
# async def test_route_exist(app: FastAPI, client: AsyncClient) -> None:
|
||||
#
|
||||
# params = {"username": "test_username", "email": "user@email.com", "password": "test_password"}
|
||||
# response = await client.post(app.url_path_for("create_user"), json=params)
|
||||
# assert response.status_code != status.HTTP_404_NOT_FOUND
|
||||
#
|
||||
#
|
||||
# async def test_users_can_register_successfully(app: FastAPI, client: AsyncClient) -> None:
|
||||
#
|
||||
# user_repo = UsersRepository()
|
||||
# params = {"username": "test_username2", "email": "user2@email.com", "password": "test_password2"}
|
||||
#
|
||||
# # make sure the user doesn't exist in the database
|
||||
# user_in_db = await user_repo.get_user_by_username(params["username"])
|
||||
# assert user_in_db is None
|
||||
#
|
||||
# # register the user
|
||||
# res = await client.post(app.url_path_for("create_user"), json=params)
|
||||
# assert res.status_code == status.HTTP_201_CREATED
|
||||
#
|
||||
# # make sure the user does exists in the database now
|
||||
# user_in_db = await user_repo.get_user_by_username(params["username"])
|
||||
# assert user_in_db is not None
|
||||
# assert user_in_db.email == params["email"]
|
||||
# assert user_in_db.username == params["username"]
|
||||
#
|
||||
# # check that the user returned in the response is equal to the user in the database
|
||||
# created_user = User(**res.json()).json()
|
||||
# print(created_user)
|
||||
# #print(user_in_db.__dict__)
|
||||
# test = jsonable_encoder(user_in_db.__dict__, exclude={"_sa_instance_state", "hashed_password"})
|
||||
# print(test)
|
||||
# assert created_user == test
|
||||
class TestUserRoutes:
|
||||
|
||||
async def test_route_exist(self, app: FastAPI, client: AsyncClient) -> None:
|
||||
|
||||
new_user = {"username": "test_user1", "email": "user1@email.com", "password": "test_password"}
|
||||
response = await client.post(app.url_path_for("create_user"), json=new_user)
|
||||
assert response.status_code != status.HTTP_404_NOT_FOUND
|
||||
|
||||
async def test_users_can_register_successfully(
|
||||
self,
|
||||
app: FastAPI,
|
||||
client: AsyncClient,
|
||||
db_session: AsyncSession
|
||||
) -> None:
|
||||
|
||||
user_repo = UsersRepository(db_session)
|
||||
params = {"username": "test_user2", "email": "user2@email.com", "password": "test_password"}
|
||||
|
||||
# make sure the user doesn't exist in the database
|
||||
user_in_db = await user_repo.get_user_by_username(params["username"])
|
||||
assert user_in_db is None
|
||||
|
||||
# register the user
|
||||
res = await client.post(app.url_path_for("create_user"), json=params)
|
||||
assert res.status_code == status.HTTP_201_CREATED
|
||||
|
||||
# make sure the user does exists in the database now
|
||||
user_in_db = await user_repo.get_user_by_username(params["username"])
|
||||
assert user_in_db is not None
|
||||
assert user_in_db.email == params["email"]
|
||||
assert user_in_db.username == params["username"]
|
||||
|
||||
# check that the user returned in the response is equal to the user in the database
|
||||
created_user = User(**res.json()).json()
|
||||
assert created_user == User.from_orm(user_in_db).json()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"attr, value, status_code",
|
||||
(
|
||||
("email", "user2@email.com", status.HTTP_400_BAD_REQUEST),
|
||||
("username", "test_user2", status.HTTP_400_BAD_REQUEST),
|
||||
("email", "invalid_email@one@two.io", status.HTTP_422_UNPROCESSABLE_ENTITY),
|
||||
("password", "short", status.HTTP_422_UNPROCESSABLE_ENTITY),
|
||||
("username", "user2@#$%^<>", status.HTTP_422_UNPROCESSABLE_ENTITY),
|
||||
("username", "ab", status.HTTP_422_UNPROCESSABLE_ENTITY),
|
||||
)
|
||||
)
|
||||
async def test_user_registration_fails_when_credentials_are_taken(
|
||||
self,
|
||||
app: FastAPI,
|
||||
client: AsyncClient,
|
||||
attr: str,
|
||||
value: str,
|
||||
status_code: int,
|
||||
) -> None:
|
||||
|
||||
new_user = {"email": "not_taken@email.com", "username": "not_taken_username", "password": "test_password"}
|
||||
new_user[attr] = value
|
||||
res = await client.post(app.url_path_for("create_user"), json=new_user)
|
||||
assert res.status_code == status_code
|
||||
|
||||
async def test_users_saved_password_is_hashed(
|
||||
self,
|
||||
app: FastAPI,
|
||||
client: AsyncClient,
|
||||
db_session: AsyncSession
|
||||
) -> None:
|
||||
|
||||
user_repo = UsersRepository(db_session)
|
||||
new_user = {"username": "test_user3", "email": "user3@email.com", "password": "test_password"}
|
||||
|
||||
# send post request to create user and ensure it is successful
|
||||
res = await client.post(app.url_path_for("create_user"), json=new_user)
|
||||
assert res.status_code == status.HTTP_201_CREATED
|
||||
|
||||
# ensure that the users password is hashed in the db
|
||||
# and that we can verify it using our auth service
|
||||
user_in_db = await user_repo.get_user_by_username(new_user["username"])
|
||||
assert user_in_db is not None
|
||||
assert user_in_db.hashed_password != new_user["password"]
|
||||
assert auth_service.verify_password(new_user["password"], user_in_db.hashed_password)
|
||||
|
||||
|
||||
class TestAuthTokens:
|
||||
|
||||
async def test_can_create_token_successfully(
|
||||
self,
|
||||
app: FastAPI,
|
||||
client: AsyncClient,
|
||||
test_user: User
|
||||
) -> None:
|
||||
|
||||
token = auth_service.create_access_token(test_user.username)
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
username = payload.get("sub")
|
||||
assert username == test_user.username
|
||||
|
||||
async def test_token_missing_user_is_invalid(self, app: FastAPI, client: AsyncClient) -> None:
|
||||
|
||||
token = auth_service.create_access_token(None)
|
||||
with pytest.raises(jwt.JWTError):
|
||||
jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
|
||||
async def test_can_retrieve_username_from_token(
|
||||
self,
|
||||
app: FastAPI,
|
||||
client: AsyncClient,
|
||||
test_user: User
|
||||
) -> None:
|
||||
|
||||
token = auth_service.create_access_token(test_user.username)
|
||||
username = auth_service.get_username_from_token(token)
|
||||
assert username == test_user.username
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"secret, wrong_token",
|
||||
(
|
||||
(SECRET_KEY, "asdf"), # use wrong token
|
||||
(SECRET_KEY, ""), # use wrong token
|
||||
("ABC123", "use correct token"), # use wrong secret
|
||||
),
|
||||
)
|
||||
async def test_error_when_token_or_secret_is_wrong(
|
||||
self,
|
||||
app: FastAPI,
|
||||
client: AsyncClient,
|
||||
test_user: User,
|
||||
secret: Union[Secret, str],
|
||||
wrong_token: Optional[str],
|
||||
) -> None:
|
||||
|
||||
token = auth_service.create_access_token(test_user.username)
|
||||
if wrong_token == "use correct token":
|
||||
wrong_token = token
|
||||
with pytest.raises(HTTPException):
|
||||
auth_service.get_username_from_token(wrong_token, secret_key=str(secret))
|
||||
|
||||
|
||||
class TestUserLogin:
|
||||
|
||||
async def test_user_can_login_successfully_and_receives_valid_token(
|
||||
self,
|
||||
app: FastAPI,
|
||||
client: AsyncClient,
|
||||
test_user: User,
|
||||
) -> None:
|
||||
|
||||
client.headers["content-type"] = "application/x-www-form-urlencoded"
|
||||
login_data = {
|
||||
"username": test_user.username,
|
||||
"password": "user1_password",
|
||||
}
|
||||
res = await client.post(app.url_path_for("login"), data=login_data)
|
||||
assert res.status_code == status.HTTP_200_OK
|
||||
|
||||
# check that token exists in response and has user encoded within it
|
||||
token = res.json().get("access_token")
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
assert "sub" in payload
|
||||
username = payload.get("sub")
|
||||
assert username == test_user.username
|
||||
|
||||
# check that token is proper type
|
||||
assert "token_type" in res.json()
|
||||
assert res.json().get("token_type") == "bearer"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"username, password, status_code",
|
||||
(
|
||||
("wrong_username", "user1_password", status.HTTP_401_UNAUTHORIZED),
|
||||
("user1", "wrong_password", status.HTTP_401_UNAUTHORIZED),
|
||||
("user1", None, status.HTTP_401_UNAUTHORIZED),
|
||||
),
|
||||
)
|
||||
async def test_user_with_wrong_creds_doesnt_receive_token(
|
||||
self,
|
||||
app: FastAPI,
|
||||
client: AsyncClient,
|
||||
test_user: User,
|
||||
username: str,
|
||||
password: str,
|
||||
status_code: int,
|
||||
) -> None:
|
||||
|
||||
client.headers["content-type"] = "application/x-www-form-urlencoded"
|
||||
login_data = {
|
||||
"username": username,
|
||||
"password": password,
|
||||
}
|
||||
res = await client.post(app.url_path_for("login"), data=login_data)
|
||||
assert res.status_code == status_code
|
||||
assert "access_token" not in res.json()
|
||||
|
||||
|
||||
class TestUserMe:
|
||||
|
||||
async def test_authenticated_user_can_retrieve_own_data(
|
||||
self,
|
||||
app: FastAPI,
|
||||
authorized_client: AsyncClient,
|
||||
test_user: User,
|
||||
) -> None:
|
||||
|
||||
res = await authorized_client.get(app.url_path_for("get_current_active_user"))
|
||||
assert res.status_code == status.HTTP_200_OK
|
||||
user = User(**res.json())
|
||||
assert user.username == test_user.username
|
||||
assert user.email == test_user.email
|
||||
assert user.user_id == test_user.user_id
|
||||
|
||||
async def test_user_cannot_access_own_data_if_not_authenticated(
|
||||
self, app: FastAPI,
|
||||
client: AsyncClient,
|
||||
test_user: User,
|
||||
) -> None:
|
||||
|
||||
res = await client.get(app.url_path_for("get_current_active_user"))
|
||||
assert res.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
@ -7,7 +7,6 @@ import os
|
||||
|
||||
from fastapi import FastAPI
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from asgi_lifespan import LifespanManager
|
||||
from httpx import AsyncClient
|
||||
from unittest.mock import MagicMock, patch
|
||||
from pathlib import Path
|
||||
@ -17,25 +16,16 @@ from gns3server.config import Config
|
||||
from gns3server.compute import MODULES
|
||||
from gns3server.compute.port_manager import PortManager
|
||||
from gns3server.compute.project_manager import ProjectManager
|
||||
from gns3server.db.database import Base
|
||||
from gns3server.db.models import Base, User
|
||||
from gns3server.db.repositories.users import UsersRepository
|
||||
from gns3server.api.routes.controller.dependencies.database import get_db_session
|
||||
from gns3server.schemas.users import UserCreate
|
||||
from gns3server.services import auth_service
|
||||
|
||||
sys._called_from_test = True
|
||||
sys.original_platform = sys.platform
|
||||
|
||||
|
||||
SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db"
|
||||
|
||||
engine = create_async_engine(
|
||||
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
|
||||
)
|
||||
|
||||
|
||||
async def start_db():
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
|
||||
if sys.platform.startswith("win") and sys.version_info < (3, 8):
|
||||
@pytest.yield_fixture(scope="session")
|
||||
def event_loop(request):
|
||||
@ -49,34 +39,64 @@ if sys.platform.startswith("win") and sys.version_info < (3, 8):
|
||||
yield loop
|
||||
asyncio.set_event_loop(None)
|
||||
|
||||
# https://github.com/pytest-dev/pytest-asyncio/issues/68
|
||||
# this event_loop is used by pytest-asyncio, and redefining it
|
||||
# is currently the only way of changing the scope of this fixture
|
||||
@pytest.yield_fixture(scope="session")
|
||||
def event_loop(request):
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# @pytest.fixture(scope="session", autouse=True)
|
||||
# async def database_connection() -> None:
|
||||
#
|
||||
# from gns3server.db.tasks import connect_to_db
|
||||
# os.environ["DATABASE_URI"] = "sqlite:///./sql_app_test.db"
|
||||
# await connect_to_db()
|
||||
# yield
|
||||
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.fixture#(scope="session")
|
||||
async def app() -> FastAPI:
|
||||
@pytest.fixture(scope="session")
|
||||
async def app(db_engine) -> FastAPI:
|
||||
|
||||
from gns3server.api.server import app as gns3_app
|
||||
gns3_app.add_event_handler("startup", start_db())
|
||||
return gns3_app
|
||||
async with db_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
from gns3server.api.server import app as gns3app
|
||||
yield gns3app
|
||||
|
||||
|
||||
# Grab a reference to our database when needed
|
||||
#@pytest.fixture
|
||||
#def db(app: FastAPI) -> Database:
|
||||
# return app.state._db
|
||||
@pytest.fixture(scope="session")
|
||||
def db_engine():
|
||||
|
||||
db_url = os.getenv("GNS3_TEST_DATABASE_URI", "sqlite:///:memory:") # "sqlite:///./sql_test_app.db"
|
||||
engine = create_async_engine(db_url, connect_args={"check_same_thread": False}, future=True)
|
||||
yield engine
|
||||
engine.sync_engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
async def db_session(app: FastAPI, db_engine):
|
||||
|
||||
# recreate database tables for each class
|
||||
# preferred and faster way would be to rollback the session/transaction
|
||||
# but it doesn't work for some reason
|
||||
async with db_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
session = AsyncSession(db_engine)
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def client(app: FastAPI) -> AsyncClient:
|
||||
async def client(app: FastAPI, db_session: AsyncSession) -> AsyncClient:
|
||||
|
||||
async def _get_test_db():
|
||||
try:
|
||||
yield db_session
|
||||
finally:
|
||||
pass
|
||||
|
||||
app.dependency_overrides[get_db_session] = _get_test_db
|
||||
|
||||
#async with LifespanManager(app):
|
||||
async with AsyncClient(
|
||||
app=app,
|
||||
base_url="http://test-api",
|
||||
@ -85,6 +105,32 @@ async def client(app: FastAPI) -> AsyncClient:
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def test_user(db_session: AsyncSession) -> User:
|
||||
|
||||
new_user = UserCreate(
|
||||
username="user1",
|
||||
email="user1@email.com",
|
||||
password="user1_password",
|
||||
)
|
||||
user_repo = UsersRepository(db_session)
|
||||
existing_user = await user_repo.get_user_by_username(new_user.username)
|
||||
if existing_user:
|
||||
return existing_user
|
||||
return await user_repo.create_user(new_user)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def authorized_client(client: AsyncClient, test_user: User) -> AsyncClient:
|
||||
|
||||
access_token = auth_service.create_access_token(test_user.username)
|
||||
client.headers = {
|
||||
**client.headers,
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
}
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def controller_config_path(tmpdir):
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user