1
0
mirror of https://github.com/GNS3/gns3-server synced 2025-01-16 11:00:58 +00:00

User authentication with tests.

This commit is contained in:
grossmj 2020-12-07 16:52:36 +10:30
parent bf7cf862af
commit d47dcb0d6f
14 changed files with 452 additions and 180 deletions

View File

@ -4,6 +4,5 @@ pytest==6.1.2
flake8==3.8.4 flake8==3.8.4
pytest-timeout==1.4.2 pytest-timeout==1.4.2
pytest-asyncio==0.14.0 pytest-asyncio==0.14.0
asgi-lifespan==1.0.1
requests==2.24.0 requests==2.24.0
httpx==0.16.1 httpx==0.16.1

View File

@ -23,12 +23,15 @@ from gns3server import schemas
from gns3server.db.repositories.users import UsersRepository from gns3server.db.repositories.users import UsersRepository
from gns3server.services import auth_service from gns3server.services import auth_service
from .database import get_repository
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/v3/users/login") # FIXME: URL prefix oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/v3/users/login") # FIXME: URL prefix
async def get_user_from_token(token: str = Depends(oauth2_scheme), async def get_user_from_token(
user_repo: UsersRepository = Depends()) -> schemas.User: token: str = Depends(oauth2_scheme),
user_repo: UsersRepository = Depends(get_repository(UsersRepository))
) -> schemas.User:
username = auth_service.get_username_from_token(token) username = auth_service.get_username_from_token(token)
user = await user_repo.get_user_by_username(username) user = await user_repo.get_user_by_username(username)

View File

@ -1,4 +1,4 @@
#!/usr/bin/env python # -*- coding: utf-8 -*-
# #
# Copyright (C) 2020 GNS3 Technologies Inc. # Copyright (C) 2020 GNS3 Technologies Inc.
# #
@ -15,12 +15,24 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import os from typing import Callable, Type
from fastapi import Depends, Request
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.asyncio import create_async_engine from gns3server.db.repositories.base import BaseRepository
from sqlalchemy.orm import declarative_base
SQLALCHEMY_DATABASE_URL = os.environ.get("DATABASE_URI", "sqlite:///./sql_app.db")
engine = create_async_engine(SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}) async def get_db_session(request: Request) -> AsyncSession:
Base = declarative_base()
session = AsyncSession(request.app.state._db_engine)
try:
yield session
finally:
await session.close()
def get_repository(repo: Type[BaseRepository]) -> Callable:
def get_repo(db_session: AsyncSession = Depends(get_db_session)) -> Type[BaseRepository]:
return repo(db_session)
return get_repo

View File

@ -20,7 +20,7 @@ API routes for users.
""" """
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from fastapi.security import OAuth2PasswordRequestForm
from uuid import UUID from uuid import UUID
from typing import List from typing import List
@ -30,6 +30,7 @@ from gns3server.db.repositories.users import UsersRepository
from gns3server.services import auth_service from gns3server.services import auth_service
from .dependencies.authentication import get_current_active_user from .dependencies.authentication import get_current_active_user
from .dependencies.database import get_repository
import logging import logging
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -38,7 +39,7 @@ router = APIRouter()
@router.get("", response_model=List[schemas.User]) @router.get("", response_model=List[schemas.User])
async def get_users(user_repo: UsersRepository = Depends()) -> List[schemas.User]: async def get_users(user_repo: UsersRepository = Depends(get_repository(UsersRepository))) -> List[schemas.User]:
""" """
Get all users. Get all users.
""" """
@ -48,7 +49,10 @@ async def get_users(user_repo: UsersRepository = Depends()) -> List[schemas.User
@router.post("", response_model=schemas.User, status_code=status.HTTP_201_CREATED) @router.post("", response_model=schemas.User, status_code=status.HTTP_201_CREATED)
async def create_user(new_user: schemas.UserCreate, user_repo: UsersRepository = Depends()) -> schemas.User: async def create_user(
new_user: schemas.UserCreate,
user_repo: UsersRepository = Depends(get_repository(UsersRepository))
) -> schemas.User:
""" """
Create a new user. Create a new user.
""" """
@ -63,7 +67,10 @@ async def create_user(new_user: schemas.UserCreate, user_repo: UsersRepository =
@router.get("/{user_id}", response_model=schemas.User) @router.get("/{user_id}", response_model=schemas.User)
async def get_user(user_id: UUID, user_repo: UsersRepository = Depends()) -> schemas.User: async def get_user(
user_id: UUID,
user_repo: UsersRepository = Depends(get_repository(UsersRepository))
) -> schemas.User:
""" """
Get an user. Get an user.
""" """
@ -75,9 +82,11 @@ async def get_user(user_id: UUID, user_repo: UsersRepository = Depends()) -> sch
@router.put("/{user_id}", response_model=schemas.User) @router.put("/{user_id}", response_model=schemas.User)
async def update_user(user_id: UUID, async def update_user(
update_user: schemas.UserUpdate, user_id: UUID,
user_repo: UsersRepository = Depends()) -> schemas.User: update_user: schemas.UserUpdate,
user_repo: UsersRepository = Depends(get_repository(UsersRepository))
) -> schemas.User:
""" """
Update an user. Update an user.
""" """
@ -89,7 +98,7 @@ async def update_user(user_id: UUID,
@router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT) @router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_user(user_id: UUID, user_repo: UsersRepository = Depends()): async def delete_user(user_id: UUID, user_repo: UsersRepository = Depends(get_repository(UsersRepository))):
""" """
Delete an user. Delete an user.
""" """
@ -100,8 +109,10 @@ async def delete_user(user_id: UUID, user_repo: UsersRepository = Depends()):
@router.post("/login", response_model=schemas.Token) @router.post("/login", response_model=schemas.Token)
async def login(user_repo: UsersRepository = Depends(), async def login(
form_data: OAuth2PasswordRequestForm = Depends()) -> schemas.Token: user_repo: UsersRepository = Depends(get_repository(UsersRepository)),
form_data: OAuth2PasswordRequestForm = Depends()
) -> schemas.Token:
""" """
User login. User login.
""" """

View File

@ -45,34 +45,43 @@ import logging
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
app = FastAPI(title="GNS3 controller API", def get_application() -> FastAPI:
description="This page describes the public controller API for GNS3",
version="v3")
origins = [ application = FastAPI(
"http://127.0.0.1", title="GNS3 controller API",
"http://localhost", description="This page describes the public controller API for GNS3",
"http://127.0.0.1:8080", version="v3"
"http://localhost:8080", )
"http://127.0.0.1:3080",
"http://localhost:3080",
"http://gns3.github.io",
"https://gns3.github.io"
]
app.add_middleware( origins = [
CORSMiddleware, "http://127.0.0.1",
allow_origins=origins, "http://localhost",
allow_credentials=True, "http://127.0.0.1:8080",
allow_methods=["*"], "http://localhost:8080",
allow_headers=["*"], "http://127.0.0.1:3080",
) "http://localhost:3080",
"http://gns3.github.io",
"https://gns3.github.io"
]
app.add_event_handler("startup", tasks.create_startup_handler(app)) application.add_middleware(
app.add_event_handler("shutdown", tasks.create_shutdown_handler(app)) CORSMiddleware,
app.include_router(index.router, tags=["Index"]) allow_origins=origins,
app.include_router(controller.router, prefix="/v3") allow_credentials=True,
app.mount("/v3/compute", compute_api) allow_methods=["*"],
allow_headers=["*"],
)
application.add_event_handler("startup", tasks.create_startup_handler(application))
application.add_event_handler("shutdown", tasks.create_shutdown_handler(application))
application.include_router(index.router, tags=["Index"])
application.include_router(controller.router, prefix="/v3")
application.mount("/v3/compute", compute_api)
return application
app = get_application()
@app.exception_handler(ControllerError) @app.exception_handler(ControllerError)

View File

@ -55,7 +55,7 @@ def create_startup_handler(app: FastAPI) -> Callable:
loop.set_debug(True) loop.set_debug(True)
# connect to the database # connect to the database
await connect_to_db() await connect_to_db(app)
await Controller.instance().start() await Controller.instance().start()
# Because with a large image collection # Because with a large image collection

View File

@ -21,7 +21,10 @@ from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, DateTime, f
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
from sqlalchemy.types import TypeDecorator, CHAR from sqlalchemy.types import TypeDecorator, CHAR
from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.dialects.postgresql import UUID
from .database import Base
from sqlalchemy.orm import declarative_base
Base = declarative_base()
class GUID(TypeDecorator): class GUID(TypeDecorator):
@ -68,11 +71,14 @@ class BaseTable(Base):
onupdate=func.current_timestamp()) onupdate=func.current_timestamp())
def generate_uuid():
return str(uuid.uuid4())
class User(BaseTable): class User(BaseTable):
__tablename__ = "users" __tablename__ = "users"
user_id = Column(GUID, primary_key=True, default=str(uuid.uuid4())) user_id = Column(GUID, primary_key=True, default=generate_uuid)
username = Column(String, unique=True, index=True) username = Column(String, unique=True, index=True)
email = Column(String, unique=True, index=True) email = Column(String, unique=True, index=True)
full_name = Column(String) full_name = Column(String)

View File

@ -16,14 +16,10 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from ..database import engine
class BaseRepository: class BaseRepository:
async def db(self): def __init__(self, db_session: AsyncSession) -> None:
session = AsyncSession(engine)
try: self._db_session = db_session
yield session
finally:
await session.close()

View File

@ -20,7 +20,6 @@ from typing import Optional, List
from sqlalchemy import select, update, delete from sqlalchemy import select, update, delete
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from ..database import engine
from .base import BaseRepository from .base import BaseRepository
import gns3server.db.models as models import gns3server.db.models as models
@ -30,73 +29,66 @@ from gns3server.services import auth_service
class UsersRepository(BaseRepository): class UsersRepository(BaseRepository):
def __init__(self) -> None: def __init__(self, db_session: AsyncSession) -> None:
super().__init__() super().__init__(db_session)
self._auth_service = auth_service self._auth_service = auth_service
async def get_user(self, user_id: UUID) -> Optional[models.User]: async def get_user(self, user_id: UUID) -> Optional[models.User]:
async with AsyncSession(engine) as session: result = await self._db_session.execute(select(models.User).where(models.User.user_id == user_id))
result = await session.execute(select(models.User).where(models.User.user_id == user_id)) return result.scalars().first()
return result.scalars().first()
async def get_user_by_username(self, username: str) -> Optional[models.User]: async def get_user_by_username(self, username: str) -> Optional[models.User]:
async with AsyncSession(engine) as session: result = await self._db_session.execute(select(models.User).where(models.User.username == username))
result = await session.execute(select(models.User).where(models.User.username == username)) return result.scalars().first()
return result.scalars().first()
async def get_user_by_email(self, email: str) -> Optional[models.User]: async def get_user_by_email(self, email: str) -> Optional[models.User]:
async with AsyncSession(engine) as session: result = await self._db_session.execute(select(models.User).where(models.User.email == email))
result = await session.execute(select(models.User).where(models.User.email == email)) return result.scalars().first()
return result.scalars().first()
async def get_users(self) -> List[models.User]: async def get_users(self) -> List[models.User]:
async with AsyncSession(engine) as session: result = await self._db_session.execute(select(models.User))
result = await session.execute(select(models.User)) return result.scalars().all()
return result.scalars().all()
async def create_user(self, user: schemas.UserCreate) -> models.User: async def create_user(self, user: schemas.UserCreate) -> models.User:
async with AsyncSession(engine) as session: hashed_password = self._auth_service.hash_password(user.password)
hashed_password = self._auth_service.hash_password(user.password) db_user = models.User(
db_user = models.User(username=user.username, username=user.username,
email=user.email, email=user.email,
full_name=user.full_name, full_name=user.full_name,
hashed_password=hashed_password) hashed_password=hashed_password
session.add(db_user) )
await session.commit() self._db_session.add(db_user)
await session.refresh(db_user) await self._db_session.commit()
return db_user await self._db_session.refresh(db_user)
return db_user
async def update_user(self, user_id: UUID, user_update: schemas.UserUpdate) -> Optional[models.User]: async def update_user(self, user_id: UUID, user_update: schemas.UserUpdate) -> Optional[models.User]:
async with AsyncSession(engine) as session: update_values = user_update.dict(exclude_unset=True)
password = update_values.pop("password", None)
if password:
update_values["hashed_password"] = self._auth_service.hash_password(password=password)
update_values = user_update.dict(exclude_unset=True) query = update(models.User) \
password = update_values.pop("password", None) .where(models.User.user_id == user_id) \
if password: .values(update_values)
update_values["hashed_password"] = self._auth_service.hash_password(password=password)
print(update_values) await self._db_session.execute(query)
query = update(models.User) \ await self._db_session.commit()
.where(models.User.user_id == user_id) \ return await self.get_user(user_id)
.values(update_values)
await session.execute(query)
await session.commit()
return await self.get_user(user_id)
async def delete_user(self, user_id: UUID) -> bool: async def delete_user(self, user_id: UUID) -> bool:
async with AsyncSession(engine) as session: query = delete(models.User).where(models.User.user_id == user_id)
query = delete(models.User).where(models.User.user_id == user_id) result = await self._db_session.execute(query)
result = await session.execute(query) await self._db_session.commit()
await session.commit() return result.rowcount > 0
return result.rowcount > 0
#except: #except:
# await session.rollback() # await session.rollback()

View File

@ -15,20 +15,26 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from sqlalchemy.exc import SQLAlchemyError import os
from fastapi import FastAPI
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import create_async_engine
from .database import engine
from .models import Base from .models import Base
import logging import logging
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
async def connect_to_db() -> None: async def connect_to_db(app: FastAPI) -> None:
db_url = os.environ.get("GNS3_DATABASE_URI", "sqlite:///./sql_app.db")
engine = create_async_engine(db_url, connect_args={"check_same_thread": False}, future=True)
try: try:
async with engine.begin() as conn: async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all) await conn.run_sync(Base.metadata.create_all)
log.info("Successfully connected to the database") log.info("Successfully connected to the database")
app.state._db_engine = engine
except SQLAlchemyError as e: except SQLAlchemyError as e:
log.error(f"Error while connecting to the database: {e}") log.error(f"Error while connecting to the database: {e}")

View File

@ -35,10 +35,6 @@ ACCESS_TOKEN_EXPIRE_MINUTES = 30
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
#class AuthException(BaseException):
# pass
class AuthService: class AuthService:
def hash_password(self, password: str) -> str: def hash_password(self, password: str) -> str:
@ -49,14 +45,19 @@ class AuthService:
return pwd_context.verify(password, hashed_password) return pwd_context.verify(password, hashed_password)
def create_access_token(self, username): def create_access_token(
self,
username,
secret_key: str = SECRET_KEY,
expires_in: int = ACCESS_TOKEN_EXPIRE_MINUTES
) -> str:
expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) expire = datetime.utcnow() + timedelta(minutes=expires_in)
to_encode = {"sub": username, "exp": expire} to_encode = {"sub": username, "exp": expire}
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=ALGORITHM)
return encoded_jwt return encoded_jwt
def get_username_from_token(self, token: str) -> Optional[str]: def get_username_from_token(self, token: str, secret_key: str = SECRET_KEY) -> Optional[str]:
credentials_exception = HTTPException( credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@ -64,7 +65,7 @@ class AuthService:
headers={"WWW-Authenticate": "Bearer"}, headers={"WWW-Authenticate": "Bearer"},
) )
try: try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) payload = jwt.decode(token, secret_key, algorithms=[ALGORITHM])
username: str = payload.get("sub") username: str = payload.get("sub")
if username is None: if username is None:
raise credentials_exception raise credentials_exception

View File

@ -1,5 +1,5 @@
uvicorn==0.11.8 # force version to 0.11.8 because of https://github.com/encode/uvicorn/issues/841 uvicorn==0.11.8 # force version to 0.11.8 because of https://github.com/encode/uvicorn/issues/841
fastapi==0.61.2 fastapi==0.62.0
websockets==8.1 websockets==8.1
python-multipart==0.0.5 python-multipart==0.0.5
aiohttp==3.7.2 aiohttp==3.7.2

View File

@ -17,46 +17,237 @@
import pytest import pytest
from fastapi import FastAPI, status from typing import Optional, Union
from fastapi.encoders import jsonable_encoder from fastapi import FastAPI, HTTPException, status
from starlette.datastructures import Secret
from httpx import AsyncClient from httpx import AsyncClient
from jose import jwt
from sqlalchemy.ext.asyncio import AsyncSession
from gns3server.db.repositories.users import UsersRepository from gns3server.db.repositories.users import UsersRepository
from gns3server.services import auth_service
from gns3server.services.authentication import SECRET_KEY, ALGORITHM
from gns3server.schemas.users import User from gns3server.schemas.users import User
pytestmark = pytest.mark.asyncio pytestmark = pytest.mark.asyncio
# async def test_route_exist(app: FastAPI, client: AsyncClient) -> None: class TestUserRoutes:
#
# params = {"username": "test_username", "email": "user@email.com", "password": "test_password"} async def test_route_exist(self, app: FastAPI, client: AsyncClient) -> None:
# response = await client.post(app.url_path_for("create_user"), json=params)
# assert response.status_code != status.HTTP_404_NOT_FOUND new_user = {"username": "test_user1", "email": "user1@email.com", "password": "test_password"}
# response = await client.post(app.url_path_for("create_user"), json=new_user)
# assert response.status_code != status.HTTP_404_NOT_FOUND
# async def test_users_can_register_successfully(app: FastAPI, client: AsyncClient) -> None:
# async def test_users_can_register_successfully(
# user_repo = UsersRepository() self,
# params = {"username": "test_username2", "email": "user2@email.com", "password": "test_password2"} app: FastAPI,
# client: AsyncClient,
# # make sure the user doesn't exist in the database db_session: AsyncSession
# user_in_db = await user_repo.get_user_by_username(params["username"]) ) -> None:
# assert user_in_db is None
# user_repo = UsersRepository(db_session)
# # register the user params = {"username": "test_user2", "email": "user2@email.com", "password": "test_password"}
# res = await client.post(app.url_path_for("create_user"), json=params)
# assert res.status_code == status.HTTP_201_CREATED # make sure the user doesn't exist in the database
# user_in_db = await user_repo.get_user_by_username(params["username"])
# # make sure the user does exists in the database now assert user_in_db is None
# user_in_db = await user_repo.get_user_by_username(params["username"])
# assert user_in_db is not None # register the user
# assert user_in_db.email == params["email"] res = await client.post(app.url_path_for("create_user"), json=params)
# assert user_in_db.username == params["username"] assert res.status_code == status.HTTP_201_CREATED
#
# # check that the user returned in the response is equal to the user in the database # make sure the user does exists in the database now
# created_user = User(**res.json()).json() user_in_db = await user_repo.get_user_by_username(params["username"])
# print(created_user) assert user_in_db is not None
# #print(user_in_db.__dict__) assert user_in_db.email == params["email"]
# test = jsonable_encoder(user_in_db.__dict__, exclude={"_sa_instance_state", "hashed_password"}) assert user_in_db.username == params["username"]
# print(test)
# assert created_user == test # check that the user returned in the response is equal to the user in the database
created_user = User(**res.json()).json()
assert created_user == User.from_orm(user_in_db).json()
@pytest.mark.parametrize(
"attr, value, status_code",
(
("email", "user2@email.com", status.HTTP_400_BAD_REQUEST),
("username", "test_user2", status.HTTP_400_BAD_REQUEST),
("email", "invalid_email@one@two.io", status.HTTP_422_UNPROCESSABLE_ENTITY),
("password", "short", status.HTTP_422_UNPROCESSABLE_ENTITY),
("username", "user2@#$%^<>", status.HTTP_422_UNPROCESSABLE_ENTITY),
("username", "ab", status.HTTP_422_UNPROCESSABLE_ENTITY),
)
)
async def test_user_registration_fails_when_credentials_are_taken(
self,
app: FastAPI,
client: AsyncClient,
attr: str,
value: str,
status_code: int,
) -> None:
new_user = {"email": "not_taken@email.com", "username": "not_taken_username", "password": "test_password"}
new_user[attr] = value
res = await client.post(app.url_path_for("create_user"), json=new_user)
assert res.status_code == status_code
async def test_users_saved_password_is_hashed(
self,
app: FastAPI,
client: AsyncClient,
db_session: AsyncSession
) -> None:
user_repo = UsersRepository(db_session)
new_user = {"username": "test_user3", "email": "user3@email.com", "password": "test_password"}
# send post request to create user and ensure it is successful
res = await client.post(app.url_path_for("create_user"), json=new_user)
assert res.status_code == status.HTTP_201_CREATED
# ensure that the users password is hashed in the db
# and that we can verify it using our auth service
user_in_db = await user_repo.get_user_by_username(new_user["username"])
assert user_in_db is not None
assert user_in_db.hashed_password != new_user["password"]
assert auth_service.verify_password(new_user["password"], user_in_db.hashed_password)
class TestAuthTokens:
async def test_can_create_token_successfully(
self,
app: FastAPI,
client: AsyncClient,
test_user: User
) -> None:
token = auth_service.create_access_token(test_user.username)
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username = payload.get("sub")
assert username == test_user.username
async def test_token_missing_user_is_invalid(self, app: FastAPI, client: AsyncClient) -> None:
token = auth_service.create_access_token(None)
with pytest.raises(jwt.JWTError):
jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
async def test_can_retrieve_username_from_token(
self,
app: FastAPI,
client: AsyncClient,
test_user: User
) -> None:
token = auth_service.create_access_token(test_user.username)
username = auth_service.get_username_from_token(token)
assert username == test_user.username
@pytest.mark.parametrize(
"secret, wrong_token",
(
(SECRET_KEY, "asdf"), # use wrong token
(SECRET_KEY, ""), # use wrong token
("ABC123", "use correct token"), # use wrong secret
),
)
async def test_error_when_token_or_secret_is_wrong(
self,
app: FastAPI,
client: AsyncClient,
test_user: User,
secret: Union[Secret, str],
wrong_token: Optional[str],
) -> None:
token = auth_service.create_access_token(test_user.username)
if wrong_token == "use correct token":
wrong_token = token
with pytest.raises(HTTPException):
auth_service.get_username_from_token(wrong_token, secret_key=str(secret))
class TestUserLogin:
async def test_user_can_login_successfully_and_receives_valid_token(
self,
app: FastAPI,
client: AsyncClient,
test_user: User,
) -> None:
client.headers["content-type"] = "application/x-www-form-urlencoded"
login_data = {
"username": test_user.username,
"password": "user1_password",
}
res = await client.post(app.url_path_for("login"), data=login_data)
assert res.status_code == status.HTTP_200_OK
# check that token exists in response and has user encoded within it
token = res.json().get("access_token")
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
assert "sub" in payload
username = payload.get("sub")
assert username == test_user.username
# check that token is proper type
assert "token_type" in res.json()
assert res.json().get("token_type") == "bearer"
@pytest.mark.parametrize(
"username, password, status_code",
(
("wrong_username", "user1_password", status.HTTP_401_UNAUTHORIZED),
("user1", "wrong_password", status.HTTP_401_UNAUTHORIZED),
("user1", None, status.HTTP_401_UNAUTHORIZED),
),
)
async def test_user_with_wrong_creds_doesnt_receive_token(
self,
app: FastAPI,
client: AsyncClient,
test_user: User,
username: str,
password: str,
status_code: int,
) -> None:
client.headers["content-type"] = "application/x-www-form-urlencoded"
login_data = {
"username": username,
"password": password,
}
res = await client.post(app.url_path_for("login"), data=login_data)
assert res.status_code == status_code
assert "access_token" not in res.json()
class TestUserMe:
async def test_authenticated_user_can_retrieve_own_data(
self,
app: FastAPI,
authorized_client: AsyncClient,
test_user: User,
) -> None:
res = await authorized_client.get(app.url_path_for("get_current_active_user"))
assert res.status_code == status.HTTP_200_OK
user = User(**res.json())
assert user.username == test_user.username
assert user.email == test_user.email
assert user.user_id == test_user.user_id
async def test_user_cannot_access_own_data_if_not_authenticated(
self, app: FastAPI,
client: AsyncClient,
test_user: User,
) -> None:
res = await client.get(app.url_path_for("get_current_active_user"))
assert res.status_code == status.HTTP_401_UNAUTHORIZED

View File

@ -7,7 +7,6 @@ import os
from fastapi import FastAPI from fastapi import FastAPI
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from asgi_lifespan import LifespanManager
from httpx import AsyncClient from httpx import AsyncClient
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from pathlib import Path from pathlib import Path
@ -17,25 +16,16 @@ from gns3server.config import Config
from gns3server.compute import MODULES from gns3server.compute import MODULES
from gns3server.compute.port_manager import PortManager from gns3server.compute.port_manager import PortManager
from gns3server.compute.project_manager import ProjectManager from gns3server.compute.project_manager import ProjectManager
from gns3server.db.database import Base from gns3server.db.models import Base, User
from gns3server.db.repositories.users import UsersRepository
from gns3server.api.routes.controller.dependencies.database import get_db_session
from gns3server.schemas.users import UserCreate
from gns3server.services import auth_service
sys._called_from_test = True sys._called_from_test = True
sys.original_platform = sys.platform sys.original_platform = sys.platform
SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db"
engine = create_async_engine(
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
)
async def start_db():
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
await conn.run_sync(Base.metadata.create_all)
if sys.platform.startswith("win") and sys.version_info < (3, 8): if sys.platform.startswith("win") and sys.version_info < (3, 8):
@pytest.yield_fixture(scope="session") @pytest.yield_fixture(scope="session")
def event_loop(request): def event_loop(request):
@ -49,34 +39,64 @@ if sys.platform.startswith("win") and sys.version_info < (3, 8):
yield loop yield loop
asyncio.set_event_loop(None) asyncio.set_event_loop(None)
# https://github.com/pytest-dev/pytest-asyncio/issues/68
# this event_loop is used by pytest-asyncio, and redefining it
# is currently the only way of changing the scope of this fixture
@pytest.yield_fixture(scope="session")
def event_loop(request):
# @pytest.mark.asyncio loop = asyncio.get_event_loop_policy().new_event_loop()
# @pytest.fixture(scope="session", autouse=True) yield loop
# async def database_connection() -> None: loop.close()
#
# from gns3server.db.tasks import connect_to_db
# os.environ["DATABASE_URI"] = "sqlite:///./sql_app_test.db"
# await connect_to_db()
# yield
@pytest.fixture#(scope="session") @pytest.fixture(scope="session")
async def app() -> FastAPI: async def app(db_engine) -> FastAPI:
from gns3server.api.server import app as gns3_app async with db_engine.begin() as conn:
gns3_app.add_event_handler("startup", start_db()) await conn.run_sync(Base.metadata.drop_all)
return gns3_app await conn.run_sync(Base.metadata.create_all)
from gns3server.api.server import app as gns3app
yield gns3app
# Grab a reference to our database when needed @pytest.fixture(scope="session")
#@pytest.fixture def db_engine():
#def db(app: FastAPI) -> Database:
# return app.state._db db_url = os.getenv("GNS3_TEST_DATABASE_URI", "sqlite:///:memory:") # "sqlite:///./sql_test_app.db"
engine = create_async_engine(db_url, connect_args={"check_same_thread": False}, future=True)
yield engine
engine.sync_engine.dispose()
@pytest.fixture(scope="class")
async def db_session(app: FastAPI, db_engine):
# recreate database tables for each class
# preferred and faster way would be to rollback the session/transaction
# but it doesn't work for some reason
async with db_engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
await conn.run_sync(Base.metadata.create_all)
session = AsyncSession(db_engine)
try:
yield session
finally:
await session.close()
@pytest.fixture @pytest.fixture
async def client(app: FastAPI) -> AsyncClient: async def client(app: FastAPI, db_session: AsyncSession) -> AsyncClient:
async def _get_test_db():
try:
yield db_session
finally:
pass
app.dependency_overrides[get_db_session] = _get_test_db
#async with LifespanManager(app):
async with AsyncClient( async with AsyncClient(
app=app, app=app,
base_url="http://test-api", base_url="http://test-api",
@ -85,6 +105,32 @@ async def client(app: FastAPI) -> AsyncClient:
yield client yield client
@pytest.fixture
async def test_user(db_session: AsyncSession) -> User:
new_user = UserCreate(
username="user1",
email="user1@email.com",
password="user1_password",
)
user_repo = UsersRepository(db_session)
existing_user = await user_repo.get_user_by_username(new_user.username)
if existing_user:
return existing_user
return await user_repo.create_user(new_user)
@pytest.fixture
def authorized_client(client: AsyncClient, test_user: User) -> AsyncClient:
access_token = auth_service.create_access_token(test_user.username)
client.headers = {
**client.headers,
"Authorization": f"Bearer {access_token}",
}
return client
@pytest.fixture @pytest.fixture
def controller_config_path(tmpdir): def controller_config_path(tmpdir):