mirror of
https://github.com/GNS3/gns3-server
synced 2025-01-16 11:00:58 +00:00
User authentication with tests.
This commit is contained in:
parent
bf7cf862af
commit
d47dcb0d6f
@ -4,6 +4,5 @@ pytest==6.1.2
|
|||||||
flake8==3.8.4
|
flake8==3.8.4
|
||||||
pytest-timeout==1.4.2
|
pytest-timeout==1.4.2
|
||||||
pytest-asyncio==0.14.0
|
pytest-asyncio==0.14.0
|
||||||
asgi-lifespan==1.0.1
|
|
||||||
requests==2.24.0
|
requests==2.24.0
|
||||||
httpx==0.16.1
|
httpx==0.16.1
|
||||||
|
@ -23,12 +23,15 @@ from gns3server import schemas
|
|||||||
from gns3server.db.repositories.users import UsersRepository
|
from gns3server.db.repositories.users import UsersRepository
|
||||||
from gns3server.services import auth_service
|
from gns3server.services import auth_service
|
||||||
|
|
||||||
|
from .database import get_repository
|
||||||
|
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/v3/users/login") # FIXME: URL prefix
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/v3/users/login") # FIXME: URL prefix
|
||||||
|
|
||||||
|
|
||||||
async def get_user_from_token(token: str = Depends(oauth2_scheme),
|
async def get_user_from_token(
|
||||||
user_repo: UsersRepository = Depends()) -> schemas.User:
|
token: str = Depends(oauth2_scheme),
|
||||||
|
user_repo: UsersRepository = Depends(get_repository(UsersRepository))
|
||||||
|
) -> schemas.User:
|
||||||
|
|
||||||
username = auth_service.get_username_from_token(token)
|
username = auth_service.get_username_from_token(token)
|
||||||
user = await user_repo.get_user_by_username(username)
|
user = await user_repo.get_user_by_username(username)
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
#!/usr/bin/env python
|
# -*- coding: utf-8 -*-
|
||||||
#
|
#
|
||||||
# Copyright (C) 2020 GNS3 Technologies Inc.
|
# Copyright (C) 2020 GNS3 Technologies Inc.
|
||||||
#
|
#
|
||||||
@ -15,12 +15,24 @@
|
|||||||
# You should have received a copy of the GNU General Public License
|
# You should have received a copy of the GNU General Public License
|
||||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
import os
|
from typing import Callable, Type
|
||||||
|
from fastapi import Depends, Request
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from sqlalchemy.ext.asyncio import create_async_engine
|
from gns3server.db.repositories.base import BaseRepository
|
||||||
from sqlalchemy.orm import declarative_base
|
|
||||||
|
|
||||||
SQLALCHEMY_DATABASE_URL = os.environ.get("DATABASE_URI", "sqlite:///./sql_app.db")
|
|
||||||
|
|
||||||
engine = create_async_engine(SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False})
|
async def get_db_session(request: Request) -> AsyncSession:
|
||||||
Base = declarative_base()
|
|
||||||
|
session = AsyncSession(request.app.state._db_engine)
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
finally:
|
||||||
|
await session.close()
|
||||||
|
|
||||||
|
|
||||||
|
def get_repository(repo: Type[BaseRepository]) -> Callable:
|
||||||
|
|
||||||
|
def get_repo(db_session: AsyncSession = Depends(get_db_session)) -> Type[BaseRepository]:
|
||||||
|
return repo(db_session)
|
||||||
|
return get_repo
|
@ -20,7 +20,7 @@ API routes for users.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
@ -30,6 +30,7 @@ from gns3server.db.repositories.users import UsersRepository
|
|||||||
from gns3server.services import auth_service
|
from gns3server.services import auth_service
|
||||||
|
|
||||||
from .dependencies.authentication import get_current_active_user
|
from .dependencies.authentication import get_current_active_user
|
||||||
|
from .dependencies.database import get_repository
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
@ -38,7 +39,7 @@ router = APIRouter()
|
|||||||
|
|
||||||
|
|
||||||
@router.get("", response_model=List[schemas.User])
|
@router.get("", response_model=List[schemas.User])
|
||||||
async def get_users(user_repo: UsersRepository = Depends()) -> List[schemas.User]:
|
async def get_users(user_repo: UsersRepository = Depends(get_repository(UsersRepository))) -> List[schemas.User]:
|
||||||
"""
|
"""
|
||||||
Get all users.
|
Get all users.
|
||||||
"""
|
"""
|
||||||
@ -48,7 +49,10 @@ async def get_users(user_repo: UsersRepository = Depends()) -> List[schemas.User
|
|||||||
|
|
||||||
|
|
||||||
@router.post("", response_model=schemas.User, status_code=status.HTTP_201_CREATED)
|
@router.post("", response_model=schemas.User, status_code=status.HTTP_201_CREATED)
|
||||||
async def create_user(new_user: schemas.UserCreate, user_repo: UsersRepository = Depends()) -> schemas.User:
|
async def create_user(
|
||||||
|
new_user: schemas.UserCreate,
|
||||||
|
user_repo: UsersRepository = Depends(get_repository(UsersRepository))
|
||||||
|
) -> schemas.User:
|
||||||
"""
|
"""
|
||||||
Create a new user.
|
Create a new user.
|
||||||
"""
|
"""
|
||||||
@ -63,7 +67,10 @@ async def create_user(new_user: schemas.UserCreate, user_repo: UsersRepository =
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/{user_id}", response_model=schemas.User)
|
@router.get("/{user_id}", response_model=schemas.User)
|
||||||
async def get_user(user_id: UUID, user_repo: UsersRepository = Depends()) -> schemas.User:
|
async def get_user(
|
||||||
|
user_id: UUID,
|
||||||
|
user_repo: UsersRepository = Depends(get_repository(UsersRepository))
|
||||||
|
) -> schemas.User:
|
||||||
"""
|
"""
|
||||||
Get an user.
|
Get an user.
|
||||||
"""
|
"""
|
||||||
@ -75,9 +82,11 @@ async def get_user(user_id: UUID, user_repo: UsersRepository = Depends()) -> sch
|
|||||||
|
|
||||||
|
|
||||||
@router.put("/{user_id}", response_model=schemas.User)
|
@router.put("/{user_id}", response_model=schemas.User)
|
||||||
async def update_user(user_id: UUID,
|
async def update_user(
|
||||||
update_user: schemas.UserUpdate,
|
user_id: UUID,
|
||||||
user_repo: UsersRepository = Depends()) -> schemas.User:
|
update_user: schemas.UserUpdate,
|
||||||
|
user_repo: UsersRepository = Depends(get_repository(UsersRepository))
|
||||||
|
) -> schemas.User:
|
||||||
"""
|
"""
|
||||||
Update an user.
|
Update an user.
|
||||||
"""
|
"""
|
||||||
@ -89,7 +98,7 @@ async def update_user(user_id: UUID,
|
|||||||
|
|
||||||
|
|
||||||
@router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
|
@router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
async def delete_user(user_id: UUID, user_repo: UsersRepository = Depends()):
|
async def delete_user(user_id: UUID, user_repo: UsersRepository = Depends(get_repository(UsersRepository))):
|
||||||
"""
|
"""
|
||||||
Delete an user.
|
Delete an user.
|
||||||
"""
|
"""
|
||||||
@ -100,8 +109,10 @@ async def delete_user(user_id: UUID, user_repo: UsersRepository = Depends()):
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/login", response_model=schemas.Token)
|
@router.post("/login", response_model=schemas.Token)
|
||||||
async def login(user_repo: UsersRepository = Depends(),
|
async def login(
|
||||||
form_data: OAuth2PasswordRequestForm = Depends()) -> schemas.Token:
|
user_repo: UsersRepository = Depends(get_repository(UsersRepository)),
|
||||||
|
form_data: OAuth2PasswordRequestForm = Depends()
|
||||||
|
) -> schemas.Token:
|
||||||
"""
|
"""
|
||||||
User login.
|
User login.
|
||||||
"""
|
"""
|
||||||
|
@ -45,34 +45,43 @@ import logging
|
|||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(title="GNS3 controller API",
|
def get_application() -> FastAPI:
|
||||||
description="This page describes the public controller API for GNS3",
|
|
||||||
version="v3")
|
|
||||||
|
|
||||||
origins = [
|
application = FastAPI(
|
||||||
"http://127.0.0.1",
|
title="GNS3 controller API",
|
||||||
"http://localhost",
|
description="This page describes the public controller API for GNS3",
|
||||||
"http://127.0.0.1:8080",
|
version="v3"
|
||||||
"http://localhost:8080",
|
)
|
||||||
"http://127.0.0.1:3080",
|
|
||||||
"http://localhost:3080",
|
|
||||||
"http://gns3.github.io",
|
|
||||||
"https://gns3.github.io"
|
|
||||||
]
|
|
||||||
|
|
||||||
app.add_middleware(
|
origins = [
|
||||||
CORSMiddleware,
|
"http://127.0.0.1",
|
||||||
allow_origins=origins,
|
"http://localhost",
|
||||||
allow_credentials=True,
|
"http://127.0.0.1:8080",
|
||||||
allow_methods=["*"],
|
"http://localhost:8080",
|
||||||
allow_headers=["*"],
|
"http://127.0.0.1:3080",
|
||||||
)
|
"http://localhost:3080",
|
||||||
|
"http://gns3.github.io",
|
||||||
|
"https://gns3.github.io"
|
||||||
|
]
|
||||||
|
|
||||||
app.add_event_handler("startup", tasks.create_startup_handler(app))
|
application.add_middleware(
|
||||||
app.add_event_handler("shutdown", tasks.create_shutdown_handler(app))
|
CORSMiddleware,
|
||||||
app.include_router(index.router, tags=["Index"])
|
allow_origins=origins,
|
||||||
app.include_router(controller.router, prefix="/v3")
|
allow_credentials=True,
|
||||||
app.mount("/v3/compute", compute_api)
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
application.add_event_handler("startup", tasks.create_startup_handler(application))
|
||||||
|
application.add_event_handler("shutdown", tasks.create_shutdown_handler(application))
|
||||||
|
application.include_router(index.router, tags=["Index"])
|
||||||
|
application.include_router(controller.router, prefix="/v3")
|
||||||
|
application.mount("/v3/compute", compute_api)
|
||||||
|
|
||||||
|
return application
|
||||||
|
|
||||||
|
|
||||||
|
app = get_application()
|
||||||
|
|
||||||
|
|
||||||
@app.exception_handler(ControllerError)
|
@app.exception_handler(ControllerError)
|
||||||
|
@ -55,7 +55,7 @@ def create_startup_handler(app: FastAPI) -> Callable:
|
|||||||
loop.set_debug(True)
|
loop.set_debug(True)
|
||||||
|
|
||||||
# connect to the database
|
# connect to the database
|
||||||
await connect_to_db()
|
await connect_to_db(app)
|
||||||
|
|
||||||
await Controller.instance().start()
|
await Controller.instance().start()
|
||||||
# Because with a large image collection
|
# Because with a large image collection
|
||||||
|
@ -21,7 +21,10 @@ from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, DateTime, f
|
|||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
from sqlalchemy.types import TypeDecorator, CHAR
|
from sqlalchemy.types import TypeDecorator, CHAR
|
||||||
from sqlalchemy.dialects.postgresql import UUID
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
from .database import Base
|
|
||||||
|
from sqlalchemy.orm import declarative_base
|
||||||
|
|
||||||
|
Base = declarative_base()
|
||||||
|
|
||||||
|
|
||||||
class GUID(TypeDecorator):
|
class GUID(TypeDecorator):
|
||||||
@ -68,11 +71,14 @@ class BaseTable(Base):
|
|||||||
onupdate=func.current_timestamp())
|
onupdate=func.current_timestamp())
|
||||||
|
|
||||||
|
|
||||||
|
def generate_uuid():
|
||||||
|
return str(uuid.uuid4())
|
||||||
|
|
||||||
class User(BaseTable):
|
class User(BaseTable):
|
||||||
|
|
||||||
__tablename__ = "users"
|
__tablename__ = "users"
|
||||||
|
|
||||||
user_id = Column(GUID, primary_key=True, default=str(uuid.uuid4()))
|
user_id = Column(GUID, primary_key=True, default=generate_uuid)
|
||||||
username = Column(String, unique=True, index=True)
|
username = Column(String, unique=True, index=True)
|
||||||
email = Column(String, unique=True, index=True)
|
email = Column(String, unique=True, index=True)
|
||||||
full_name = Column(String)
|
full_name = Column(String)
|
||||||
|
@ -16,14 +16,10 @@
|
|||||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from ..database import engine
|
|
||||||
|
|
||||||
|
|
||||||
class BaseRepository:
|
class BaseRepository:
|
||||||
|
|
||||||
async def db(self):
|
def __init__(self, db_session: AsyncSession) -> None:
|
||||||
session = AsyncSession(engine)
|
|
||||||
try:
|
self._db_session = db_session
|
||||||
yield session
|
|
||||||
finally:
|
|
||||||
await session.close()
|
|
||||||
|
@ -20,7 +20,6 @@ from typing import Optional, List
|
|||||||
from sqlalchemy import select, update, delete
|
from sqlalchemy import select, update, delete
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from ..database import engine
|
|
||||||
from .base import BaseRepository
|
from .base import BaseRepository
|
||||||
|
|
||||||
import gns3server.db.models as models
|
import gns3server.db.models as models
|
||||||
@ -30,73 +29,66 @@ from gns3server.services import auth_service
|
|||||||
|
|
||||||
class UsersRepository(BaseRepository):
|
class UsersRepository(BaseRepository):
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self, db_session: AsyncSession) -> None:
|
||||||
|
|
||||||
super().__init__()
|
super().__init__(db_session)
|
||||||
self._auth_service = auth_service
|
self._auth_service = auth_service
|
||||||
|
|
||||||
async def get_user(self, user_id: UUID) -> Optional[models.User]:
|
async def get_user(self, user_id: UUID) -> Optional[models.User]:
|
||||||
|
|
||||||
async with AsyncSession(engine) as session:
|
result = await self._db_session.execute(select(models.User).where(models.User.user_id == user_id))
|
||||||
result = await session.execute(select(models.User).where(models.User.user_id == user_id))
|
return result.scalars().first()
|
||||||
return result.scalars().first()
|
|
||||||
|
|
||||||
async def get_user_by_username(self, username: str) -> Optional[models.User]:
|
async def get_user_by_username(self, username: str) -> Optional[models.User]:
|
||||||
|
|
||||||
async with AsyncSession(engine) as session:
|
result = await self._db_session.execute(select(models.User).where(models.User.username == username))
|
||||||
result = await session.execute(select(models.User).where(models.User.username == username))
|
return result.scalars().first()
|
||||||
return result.scalars().first()
|
|
||||||
|
|
||||||
async def get_user_by_email(self, email: str) -> Optional[models.User]:
|
async def get_user_by_email(self, email: str) -> Optional[models.User]:
|
||||||
|
|
||||||
async with AsyncSession(engine) as session:
|
result = await self._db_session.execute(select(models.User).where(models.User.email == email))
|
||||||
result = await session.execute(select(models.User).where(models.User.email == email))
|
return result.scalars().first()
|
||||||
return result.scalars().first()
|
|
||||||
|
|
||||||
async def get_users(self) -> List[models.User]:
|
async def get_users(self) -> List[models.User]:
|
||||||
|
|
||||||
async with AsyncSession(engine) as session:
|
result = await self._db_session.execute(select(models.User))
|
||||||
result = await session.execute(select(models.User))
|
return result.scalars().all()
|
||||||
return result.scalars().all()
|
|
||||||
|
|
||||||
async def create_user(self, user: schemas.UserCreate) -> models.User:
|
async def create_user(self, user: schemas.UserCreate) -> models.User:
|
||||||
|
|
||||||
async with AsyncSession(engine) as session:
|
hashed_password = self._auth_service.hash_password(user.password)
|
||||||
hashed_password = self._auth_service.hash_password(user.password)
|
db_user = models.User(
|
||||||
db_user = models.User(username=user.username,
|
username=user.username,
|
||||||
email=user.email,
|
email=user.email,
|
||||||
full_name=user.full_name,
|
full_name=user.full_name,
|
||||||
hashed_password=hashed_password)
|
hashed_password=hashed_password
|
||||||
session.add(db_user)
|
)
|
||||||
await session.commit()
|
self._db_session.add(db_user)
|
||||||
await session.refresh(db_user)
|
await self._db_session.commit()
|
||||||
return db_user
|
await self._db_session.refresh(db_user)
|
||||||
|
return db_user
|
||||||
|
|
||||||
async def update_user(self, user_id: UUID, user_update: schemas.UserUpdate) -> Optional[models.User]:
|
async def update_user(self, user_id: UUID, user_update: schemas.UserUpdate) -> Optional[models.User]:
|
||||||
|
|
||||||
async with AsyncSession(engine) as session:
|
update_values = user_update.dict(exclude_unset=True)
|
||||||
|
password = update_values.pop("password", None)
|
||||||
|
if password:
|
||||||
|
update_values["hashed_password"] = self._auth_service.hash_password(password=password)
|
||||||
|
|
||||||
update_values = user_update.dict(exclude_unset=True)
|
query = update(models.User) \
|
||||||
password = update_values.pop("password", None)
|
.where(models.User.user_id == user_id) \
|
||||||
if password:
|
.values(update_values)
|
||||||
update_values["hashed_password"] = self._auth_service.hash_password(password=password)
|
|
||||||
|
|
||||||
print(update_values)
|
await self._db_session.execute(query)
|
||||||
query = update(models.User) \
|
await self._db_session.commit()
|
||||||
.where(models.User.user_id == user_id) \
|
return await self.get_user(user_id)
|
||||||
.values(update_values)
|
|
||||||
|
|
||||||
await session.execute(query)
|
|
||||||
await session.commit()
|
|
||||||
return await self.get_user(user_id)
|
|
||||||
|
|
||||||
async def delete_user(self, user_id: UUID) -> bool:
|
async def delete_user(self, user_id: UUID) -> bool:
|
||||||
|
|
||||||
async with AsyncSession(engine) as session:
|
query = delete(models.User).where(models.User.user_id == user_id)
|
||||||
query = delete(models.User).where(models.User.user_id == user_id)
|
result = await self._db_session.execute(query)
|
||||||
result = await session.execute(query)
|
await self._db_session.commit()
|
||||||
await session.commit()
|
return result.rowcount > 0
|
||||||
return result.rowcount > 0
|
|
||||||
#except:
|
#except:
|
||||||
# await session.rollback()
|
# await session.rollback()
|
||||||
|
|
||||||
|
@ -15,20 +15,26 @@
|
|||||||
# You should have received a copy of the GNU General Public License
|
# You should have received a copy of the GNU General Public License
|
||||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
from sqlalchemy.exc import SQLAlchemyError
|
import os
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
|
from sqlalchemy.ext.asyncio import create_async_engine
|
||||||
|
|
||||||
from .database import engine
|
|
||||||
from .models import Base
|
from .models import Base
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def connect_to_db() -> None:
|
async def connect_to_db(app: FastAPI) -> None:
|
||||||
|
|
||||||
|
db_url = os.environ.get("GNS3_DATABASE_URI", "sqlite:///./sql_app.db")
|
||||||
|
engine = create_async_engine(db_url, connect_args={"check_same_thread": False}, future=True)
|
||||||
try:
|
try:
|
||||||
async with engine.begin() as conn:
|
async with engine.begin() as conn:
|
||||||
await conn.run_sync(Base.metadata.create_all)
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
log.info("Successfully connected to the database")
|
log.info("Successfully connected to the database")
|
||||||
|
app.state._db_engine = engine
|
||||||
except SQLAlchemyError as e:
|
except SQLAlchemyError as e:
|
||||||
log.error(f"Error while connecting to the database: {e}")
|
log.error(f"Error while connecting to the database: {e}")
|
||||||
|
@ -35,10 +35,6 @@ ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
|||||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||||
|
|
||||||
|
|
||||||
#class AuthException(BaseException):
|
|
||||||
# pass
|
|
||||||
|
|
||||||
|
|
||||||
class AuthService:
|
class AuthService:
|
||||||
|
|
||||||
def hash_password(self, password: str) -> str:
|
def hash_password(self, password: str) -> str:
|
||||||
@ -49,14 +45,19 @@ class AuthService:
|
|||||||
|
|
||||||
return pwd_context.verify(password, hashed_password)
|
return pwd_context.verify(password, hashed_password)
|
||||||
|
|
||||||
def create_access_token(self, username):
|
def create_access_token(
|
||||||
|
self,
|
||||||
|
username,
|
||||||
|
secret_key: str = SECRET_KEY,
|
||||||
|
expires_in: int = ACCESS_TOKEN_EXPIRE_MINUTES
|
||||||
|
) -> str:
|
||||||
|
|
||||||
expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
expire = datetime.utcnow() + timedelta(minutes=expires_in)
|
||||||
to_encode = {"sub": username, "exp": expire}
|
to_encode = {"sub": username, "exp": expire}
|
||||||
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=ALGORITHM)
|
||||||
return encoded_jwt
|
return encoded_jwt
|
||||||
|
|
||||||
def get_username_from_token(self, token: str) -> Optional[str]:
|
def get_username_from_token(self, token: str, secret_key: str = SECRET_KEY) -> Optional[str]:
|
||||||
|
|
||||||
credentials_exception = HTTPException(
|
credentials_exception = HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
@ -64,7 +65,7 @@ class AuthService:
|
|||||||
headers={"WWW-Authenticate": "Bearer"},
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
payload = jwt.decode(token, secret_key, algorithms=[ALGORITHM])
|
||||||
username: str = payload.get("sub")
|
username: str = payload.get("sub")
|
||||||
if username is None:
|
if username is None:
|
||||||
raise credentials_exception
|
raise credentials_exception
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
uvicorn==0.11.8 # force version to 0.11.8 because of https://github.com/encode/uvicorn/issues/841
|
uvicorn==0.11.8 # force version to 0.11.8 because of https://github.com/encode/uvicorn/issues/841
|
||||||
fastapi==0.61.2
|
fastapi==0.62.0
|
||||||
websockets==8.1
|
websockets==8.1
|
||||||
python-multipart==0.0.5
|
python-multipart==0.0.5
|
||||||
aiohttp==3.7.2
|
aiohttp==3.7.2
|
||||||
|
@ -17,46 +17,237 @@
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from fastapi import FastAPI, status
|
from typing import Optional, Union
|
||||||
from fastapi.encoders import jsonable_encoder
|
from fastapi import FastAPI, HTTPException, status
|
||||||
|
from starlette.datastructures import Secret
|
||||||
from httpx import AsyncClient
|
from httpx import AsyncClient
|
||||||
|
from jose import jwt
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from gns3server.db.repositories.users import UsersRepository
|
from gns3server.db.repositories.users import UsersRepository
|
||||||
|
from gns3server.services import auth_service
|
||||||
|
from gns3server.services.authentication import SECRET_KEY, ALGORITHM
|
||||||
from gns3server.schemas.users import User
|
from gns3server.schemas.users import User
|
||||||
|
|
||||||
pytestmark = pytest.mark.asyncio
|
pytestmark = pytest.mark.asyncio
|
||||||
|
|
||||||
|
|
||||||
# async def test_route_exist(app: FastAPI, client: AsyncClient) -> None:
|
class TestUserRoutes:
|
||||||
#
|
|
||||||
# params = {"username": "test_username", "email": "user@email.com", "password": "test_password"}
|
async def test_route_exist(self, app: FastAPI, client: AsyncClient) -> None:
|
||||||
# response = await client.post(app.url_path_for("create_user"), json=params)
|
|
||||||
# assert response.status_code != status.HTTP_404_NOT_FOUND
|
new_user = {"username": "test_user1", "email": "user1@email.com", "password": "test_password"}
|
||||||
#
|
response = await client.post(app.url_path_for("create_user"), json=new_user)
|
||||||
#
|
assert response.status_code != status.HTTP_404_NOT_FOUND
|
||||||
# async def test_users_can_register_successfully(app: FastAPI, client: AsyncClient) -> None:
|
|
||||||
#
|
async def test_users_can_register_successfully(
|
||||||
# user_repo = UsersRepository()
|
self,
|
||||||
# params = {"username": "test_username2", "email": "user2@email.com", "password": "test_password2"}
|
app: FastAPI,
|
||||||
#
|
client: AsyncClient,
|
||||||
# # make sure the user doesn't exist in the database
|
db_session: AsyncSession
|
||||||
# user_in_db = await user_repo.get_user_by_username(params["username"])
|
) -> None:
|
||||||
# assert user_in_db is None
|
|
||||||
#
|
user_repo = UsersRepository(db_session)
|
||||||
# # register the user
|
params = {"username": "test_user2", "email": "user2@email.com", "password": "test_password"}
|
||||||
# res = await client.post(app.url_path_for("create_user"), json=params)
|
|
||||||
# assert res.status_code == status.HTTP_201_CREATED
|
# make sure the user doesn't exist in the database
|
||||||
#
|
user_in_db = await user_repo.get_user_by_username(params["username"])
|
||||||
# # make sure the user does exists in the database now
|
assert user_in_db is None
|
||||||
# user_in_db = await user_repo.get_user_by_username(params["username"])
|
|
||||||
# assert user_in_db is not None
|
# register the user
|
||||||
# assert user_in_db.email == params["email"]
|
res = await client.post(app.url_path_for("create_user"), json=params)
|
||||||
# assert user_in_db.username == params["username"]
|
assert res.status_code == status.HTTP_201_CREATED
|
||||||
#
|
|
||||||
# # check that the user returned in the response is equal to the user in the database
|
# make sure the user does exists in the database now
|
||||||
# created_user = User(**res.json()).json()
|
user_in_db = await user_repo.get_user_by_username(params["username"])
|
||||||
# print(created_user)
|
assert user_in_db is not None
|
||||||
# #print(user_in_db.__dict__)
|
assert user_in_db.email == params["email"]
|
||||||
# test = jsonable_encoder(user_in_db.__dict__, exclude={"_sa_instance_state", "hashed_password"})
|
assert user_in_db.username == params["username"]
|
||||||
# print(test)
|
|
||||||
# assert created_user == test
|
# check that the user returned in the response is equal to the user in the database
|
||||||
|
created_user = User(**res.json()).json()
|
||||||
|
assert created_user == User.from_orm(user_in_db).json()
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"attr, value, status_code",
|
||||||
|
(
|
||||||
|
("email", "user2@email.com", status.HTTP_400_BAD_REQUEST),
|
||||||
|
("username", "test_user2", status.HTTP_400_BAD_REQUEST),
|
||||||
|
("email", "invalid_email@one@two.io", status.HTTP_422_UNPROCESSABLE_ENTITY),
|
||||||
|
("password", "short", status.HTTP_422_UNPROCESSABLE_ENTITY),
|
||||||
|
("username", "user2@#$%^<>", status.HTTP_422_UNPROCESSABLE_ENTITY),
|
||||||
|
("username", "ab", status.HTTP_422_UNPROCESSABLE_ENTITY),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
async def test_user_registration_fails_when_credentials_are_taken(
|
||||||
|
self,
|
||||||
|
app: FastAPI,
|
||||||
|
client: AsyncClient,
|
||||||
|
attr: str,
|
||||||
|
value: str,
|
||||||
|
status_code: int,
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
new_user = {"email": "not_taken@email.com", "username": "not_taken_username", "password": "test_password"}
|
||||||
|
new_user[attr] = value
|
||||||
|
res = await client.post(app.url_path_for("create_user"), json=new_user)
|
||||||
|
assert res.status_code == status_code
|
||||||
|
|
||||||
|
async def test_users_saved_password_is_hashed(
|
||||||
|
self,
|
||||||
|
app: FastAPI,
|
||||||
|
client: AsyncClient,
|
||||||
|
db_session: AsyncSession
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
user_repo = UsersRepository(db_session)
|
||||||
|
new_user = {"username": "test_user3", "email": "user3@email.com", "password": "test_password"}
|
||||||
|
|
||||||
|
# send post request to create user and ensure it is successful
|
||||||
|
res = await client.post(app.url_path_for("create_user"), json=new_user)
|
||||||
|
assert res.status_code == status.HTTP_201_CREATED
|
||||||
|
|
||||||
|
# ensure that the users password is hashed in the db
|
||||||
|
# and that we can verify it using our auth service
|
||||||
|
user_in_db = await user_repo.get_user_by_username(new_user["username"])
|
||||||
|
assert user_in_db is not None
|
||||||
|
assert user_in_db.hashed_password != new_user["password"]
|
||||||
|
assert auth_service.verify_password(new_user["password"], user_in_db.hashed_password)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuthTokens:
|
||||||
|
|
||||||
|
async def test_can_create_token_successfully(
|
||||||
|
self,
|
||||||
|
app: FastAPI,
|
||||||
|
client: AsyncClient,
|
||||||
|
test_user: User
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
token = auth_service.create_access_token(test_user.username)
|
||||||
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||||
|
username = payload.get("sub")
|
||||||
|
assert username == test_user.username
|
||||||
|
|
||||||
|
async def test_token_missing_user_is_invalid(self, app: FastAPI, client: AsyncClient) -> None:
|
||||||
|
|
||||||
|
token = auth_service.create_access_token(None)
|
||||||
|
with pytest.raises(jwt.JWTError):
|
||||||
|
jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||||
|
|
||||||
|
async def test_can_retrieve_username_from_token(
|
||||||
|
self,
|
||||||
|
app: FastAPI,
|
||||||
|
client: AsyncClient,
|
||||||
|
test_user: User
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
token = auth_service.create_access_token(test_user.username)
|
||||||
|
username = auth_service.get_username_from_token(token)
|
||||||
|
assert username == test_user.username
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"secret, wrong_token",
|
||||||
|
(
|
||||||
|
(SECRET_KEY, "asdf"), # use wrong token
|
||||||
|
(SECRET_KEY, ""), # use wrong token
|
||||||
|
("ABC123", "use correct token"), # use wrong secret
|
||||||
|
),
|
||||||
|
)
|
||||||
|
async def test_error_when_token_or_secret_is_wrong(
|
||||||
|
self,
|
||||||
|
app: FastAPI,
|
||||||
|
client: AsyncClient,
|
||||||
|
test_user: User,
|
||||||
|
secret: Union[Secret, str],
|
||||||
|
wrong_token: Optional[str],
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
token = auth_service.create_access_token(test_user.username)
|
||||||
|
if wrong_token == "use correct token":
|
||||||
|
wrong_token = token
|
||||||
|
with pytest.raises(HTTPException):
|
||||||
|
auth_service.get_username_from_token(wrong_token, secret_key=str(secret))
|
||||||
|
|
||||||
|
|
||||||
|
class TestUserLogin:
|
||||||
|
|
||||||
|
async def test_user_can_login_successfully_and_receives_valid_token(
|
||||||
|
self,
|
||||||
|
app: FastAPI,
|
||||||
|
client: AsyncClient,
|
||||||
|
test_user: User,
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
client.headers["content-type"] = "application/x-www-form-urlencoded"
|
||||||
|
login_data = {
|
||||||
|
"username": test_user.username,
|
||||||
|
"password": "user1_password",
|
||||||
|
}
|
||||||
|
res = await client.post(app.url_path_for("login"), data=login_data)
|
||||||
|
assert res.status_code == status.HTTP_200_OK
|
||||||
|
|
||||||
|
# check that token exists in response and has user encoded within it
|
||||||
|
token = res.json().get("access_token")
|
||||||
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||||
|
assert "sub" in payload
|
||||||
|
username = payload.get("sub")
|
||||||
|
assert username == test_user.username
|
||||||
|
|
||||||
|
# check that token is proper type
|
||||||
|
assert "token_type" in res.json()
|
||||||
|
assert res.json().get("token_type") == "bearer"
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"username, password, status_code",
|
||||||
|
(
|
||||||
|
("wrong_username", "user1_password", status.HTTP_401_UNAUTHORIZED),
|
||||||
|
("user1", "wrong_password", status.HTTP_401_UNAUTHORIZED),
|
||||||
|
("user1", None, status.HTTP_401_UNAUTHORIZED),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
async def test_user_with_wrong_creds_doesnt_receive_token(
|
||||||
|
self,
|
||||||
|
app: FastAPI,
|
||||||
|
client: AsyncClient,
|
||||||
|
test_user: User,
|
||||||
|
username: str,
|
||||||
|
password: str,
|
||||||
|
status_code: int,
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
client.headers["content-type"] = "application/x-www-form-urlencoded"
|
||||||
|
login_data = {
|
||||||
|
"username": username,
|
||||||
|
"password": password,
|
||||||
|
}
|
||||||
|
res = await client.post(app.url_path_for("login"), data=login_data)
|
||||||
|
assert res.status_code == status_code
|
||||||
|
assert "access_token" not in res.json()
|
||||||
|
|
||||||
|
|
||||||
|
class TestUserMe:
|
||||||
|
|
||||||
|
async def test_authenticated_user_can_retrieve_own_data(
|
||||||
|
self,
|
||||||
|
app: FastAPI,
|
||||||
|
authorized_client: AsyncClient,
|
||||||
|
test_user: User,
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
res = await authorized_client.get(app.url_path_for("get_current_active_user"))
|
||||||
|
assert res.status_code == status.HTTP_200_OK
|
||||||
|
user = User(**res.json())
|
||||||
|
assert user.username == test_user.username
|
||||||
|
assert user.email == test_user.email
|
||||||
|
assert user.user_id == test_user.user_id
|
||||||
|
|
||||||
|
async def test_user_cannot_access_own_data_if_not_authenticated(
|
||||||
|
self, app: FastAPI,
|
||||||
|
client: AsyncClient,
|
||||||
|
test_user: User,
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
res = await client.get(app.url_path_for("get_current_active_user"))
|
||||||
|
assert res.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
@ -7,7 +7,6 @@ import os
|
|||||||
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||||
from asgi_lifespan import LifespanManager
|
|
||||||
from httpx import AsyncClient
|
from httpx import AsyncClient
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -17,25 +16,16 @@ from gns3server.config import Config
|
|||||||
from gns3server.compute import MODULES
|
from gns3server.compute import MODULES
|
||||||
from gns3server.compute.port_manager import PortManager
|
from gns3server.compute.port_manager import PortManager
|
||||||
from gns3server.compute.project_manager import ProjectManager
|
from gns3server.compute.project_manager import ProjectManager
|
||||||
from gns3server.db.database import Base
|
from gns3server.db.models import Base, User
|
||||||
|
from gns3server.db.repositories.users import UsersRepository
|
||||||
|
from gns3server.api.routes.controller.dependencies.database import get_db_session
|
||||||
|
from gns3server.schemas.users import UserCreate
|
||||||
|
from gns3server.services import auth_service
|
||||||
|
|
||||||
sys._called_from_test = True
|
sys._called_from_test = True
|
||||||
sys.original_platform = sys.platform
|
sys.original_platform = sys.platform
|
||||||
|
|
||||||
|
|
||||||
SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db"
|
|
||||||
|
|
||||||
engine = create_async_engine(
|
|
||||||
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def start_db():
|
|
||||||
async with engine.begin() as conn:
|
|
||||||
await conn.run_sync(Base.metadata.drop_all)
|
|
||||||
await conn.run_sync(Base.metadata.create_all)
|
|
||||||
|
|
||||||
|
|
||||||
if sys.platform.startswith("win") and sys.version_info < (3, 8):
|
if sys.platform.startswith("win") and sys.version_info < (3, 8):
|
||||||
@pytest.yield_fixture(scope="session")
|
@pytest.yield_fixture(scope="session")
|
||||||
def event_loop(request):
|
def event_loop(request):
|
||||||
@ -49,34 +39,64 @@ if sys.platform.startswith("win") and sys.version_info < (3, 8):
|
|||||||
yield loop
|
yield loop
|
||||||
asyncio.set_event_loop(None)
|
asyncio.set_event_loop(None)
|
||||||
|
|
||||||
|
# https://github.com/pytest-dev/pytest-asyncio/issues/68
|
||||||
|
# this event_loop is used by pytest-asyncio, and redefining it
|
||||||
|
# is currently the only way of changing the scope of this fixture
|
||||||
|
@pytest.yield_fixture(scope="session")
|
||||||
|
def event_loop(request):
|
||||||
|
|
||||||
# @pytest.mark.asyncio
|
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||||
# @pytest.fixture(scope="session", autouse=True)
|
yield loop
|
||||||
# async def database_connection() -> None:
|
loop.close()
|
||||||
#
|
|
||||||
# from gns3server.db.tasks import connect_to_db
|
|
||||||
# os.environ["DATABASE_URI"] = "sqlite:///./sql_app_test.db"
|
|
||||||
# await connect_to_db()
|
|
||||||
# yield
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture#(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
async def app() -> FastAPI:
|
async def app(db_engine) -> FastAPI:
|
||||||
|
|
||||||
from gns3server.api.server import app as gns3_app
|
async with db_engine.begin() as conn:
|
||||||
gns3_app.add_event_handler("startup", start_db())
|
await conn.run_sync(Base.metadata.drop_all)
|
||||||
return gns3_app
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
from gns3server.api.server import app as gns3app
|
||||||
|
yield gns3app
|
||||||
|
|
||||||
|
|
||||||
# Grab a reference to our database when needed
|
@pytest.fixture(scope="session")
|
||||||
#@pytest.fixture
|
def db_engine():
|
||||||
#def db(app: FastAPI) -> Database:
|
|
||||||
# return app.state._db
|
db_url = os.getenv("GNS3_TEST_DATABASE_URI", "sqlite:///:memory:") # "sqlite:///./sql_test_app.db"
|
||||||
|
engine = create_async_engine(db_url, connect_args={"check_same_thread": False}, future=True)
|
||||||
|
yield engine
|
||||||
|
engine.sync_engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="class")
|
||||||
|
async def db_session(app: FastAPI, db_engine):
|
||||||
|
|
||||||
|
# recreate database tables for each class
|
||||||
|
# preferred and faster way would be to rollback the session/transaction
|
||||||
|
# but it doesn't work for some reason
|
||||||
|
async with db_engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.drop_all)
|
||||||
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
|
||||||
|
session = AsyncSession(db_engine)
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
finally:
|
||||||
|
await session.close()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def client(app: FastAPI) -> AsyncClient:
|
async def client(app: FastAPI, db_session: AsyncSession) -> AsyncClient:
|
||||||
|
|
||||||
|
async def _get_test_db():
|
||||||
|
try:
|
||||||
|
yield db_session
|
||||||
|
finally:
|
||||||
|
pass
|
||||||
|
|
||||||
|
app.dependency_overrides[get_db_session] = _get_test_db
|
||||||
|
|
||||||
#async with LifespanManager(app):
|
|
||||||
async with AsyncClient(
|
async with AsyncClient(
|
||||||
app=app,
|
app=app,
|
||||||
base_url="http://test-api",
|
base_url="http://test-api",
|
||||||
@ -85,6 +105,32 @@ async def client(app: FastAPI) -> AsyncClient:
|
|||||||
yield client
|
yield client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def test_user(db_session: AsyncSession) -> User:
|
||||||
|
|
||||||
|
new_user = UserCreate(
|
||||||
|
username="user1",
|
||||||
|
email="user1@email.com",
|
||||||
|
password="user1_password",
|
||||||
|
)
|
||||||
|
user_repo = UsersRepository(db_session)
|
||||||
|
existing_user = await user_repo.get_user_by_username(new_user.username)
|
||||||
|
if existing_user:
|
||||||
|
return existing_user
|
||||||
|
return await user_repo.create_user(new_user)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def authorized_client(client: AsyncClient, test_user: User) -> AsyncClient:
|
||||||
|
|
||||||
|
access_token = auth_service.create_access_token(test_user.username)
|
||||||
|
client.headers = {
|
||||||
|
**client.headers,
|
||||||
|
"Authorization": f"Bearer {access_token}",
|
||||||
|
}
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def controller_config_path(tmpdir):
|
def controller_config_path(tmpdir):
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user