mirror of
https://github.com/GNS3/gns3-server
synced 2025-01-13 17:40:54 +00:00
Generate JWT secret key if none is configured in the config file.
Change location of the database.
This commit is contained in:
parent
509e762cda
commit
bde706d19a
@ -4,10 +4,15 @@ host = 0.0.0.0
|
|||||||
; HTTP port for controlling the servers
|
; HTTP port for controlling the servers
|
||||||
port = 3080
|
port = 3080
|
||||||
|
|
||||||
; Option to enable SSL encryption
|
; Options to enable SSL encryption
|
||||||
ssl = False
|
ssl = False
|
||||||
certfile=/home/gns3/.config/GNS3/ssl/server.cert
|
certfile = /home/gns3/.config/GNS3/ssl/server.cert
|
||||||
certkey=/home/gns3/.config/GNS3/ssl/server.key
|
certkey = /home/gns3/.config/GNS3/ssl/server.key
|
||||||
|
|
||||||
|
; Options for JWT tokens (user authentication)
|
||||||
|
jwt_secret_key = efd08eccec3bd0a1be2e086670e5efa90969c68d07e072d7354a76cea5e33d4e
|
||||||
|
jwt_algorithm = HS256
|
||||||
|
jwt_access_token_expire_minutes = 1440
|
||||||
|
|
||||||
; Path where devices images are stored
|
; Path where devices images are stored
|
||||||
images_path = /home/gns3/GNS3/images
|
images_path = /home/gns3/GNS3/images
|
||||||
|
@ -25,7 +25,12 @@ from uuid import UUID
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from gns3server import schemas
|
from gns3server import schemas
|
||||||
from gns3server.controller.controller_error import ControllerBadRequestError, ControllerNotFoundError
|
from gns3server.controller.controller_error import (
|
||||||
|
ControllerBadRequestError,
|
||||||
|
ControllerNotFoundError,
|
||||||
|
ControllerUnauthorizedError
|
||||||
|
)
|
||||||
|
|
||||||
from gns3server.db.repositories.users import UsersRepository
|
from gns3server.db.repositories.users import UsersRepository
|
||||||
from gns3server.services import auth_service
|
from gns3server.services import auth_service
|
||||||
|
|
||||||
@ -98,11 +103,18 @@ async def update_user(
|
|||||||
|
|
||||||
|
|
||||||
@router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
|
@router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
async def delete_user(user_id: UUID, user_repo: UsersRepository = Depends(get_repository(UsersRepository))):
|
async def delete_user(
|
||||||
|
user_id: UUID,
|
||||||
|
user_repo: UsersRepository = Depends(get_repository(UsersRepository)),
|
||||||
|
current_user: schemas.User = Depends(get_current_active_user)
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Delete an user.
|
Delete an user.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if current_user.is_superuser:
|
||||||
|
raise ControllerUnauthorizedError("The super user cannot be deleted")
|
||||||
|
|
||||||
success = await user_repo.delete_user(user_id)
|
success = await user_repo.delete_user(user_id)
|
||||||
if not success:
|
if not success:
|
||||||
raise ControllerNotFoundError(f"User '{user_id}' not found")
|
raise ControllerNotFoundError(f"User '{user_id}' not found")
|
||||||
|
@ -182,9 +182,21 @@ class Config:
|
|||||||
controller_config_filename = "gns3_controller.conf"
|
controller_config_filename = "gns3_controller.conf"
|
||||||
return os.path.join(self.config_dir, controller_config_filename)
|
return os.path.join(self.config_dir, controller_config_filename)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def server_config(self):
|
||||||
|
|
||||||
|
if sys.platform.startswith("win"):
|
||||||
|
server_config_filename = "gns3_server.ini"
|
||||||
|
else:
|
||||||
|
server_config_filename = "gns3_server.conf"
|
||||||
|
return os.path.join(self.config_dir, server_config_filename)
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
"""Restart with a clean config"""
|
"""
|
||||||
self._config = configparser.RawConfigParser()
|
Restart with a clean config
|
||||||
|
"""
|
||||||
|
|
||||||
|
self._config = configparser.ConfigParser(interpolation=None)
|
||||||
# Override config from command line even if we modify the config file and live reload it.
|
# Override config from command line even if we modify the config file and live reload it.
|
||||||
self._override_config = {}
|
self._override_config = {}
|
||||||
|
|
||||||
@ -231,6 +243,18 @@ class Config:
|
|||||||
log.info("Load configuration file {}".format(file))
|
log.info("Load configuration file {}".format(file))
|
||||||
self._watched_files[file] = os.stat(file).st_mtime
|
self._watched_files[file] = os.stat(file).st_mtime
|
||||||
|
|
||||||
|
def write_config(self):
|
||||||
|
"""
|
||||||
|
Write the server configuration file.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
os.makedirs(os.path.dirname(self.server_config), exist_ok=True)
|
||||||
|
with open(self.server_config, 'w+') as fd:
|
||||||
|
self._config.write(fd)
|
||||||
|
except OSError as e:
|
||||||
|
log.error("Cannot write server configuration file '{}': {}".format(self.server_config, e))
|
||||||
|
|
||||||
def get_default_section(self):
|
def get_default_section(self):
|
||||||
"""
|
"""
|
||||||
Get the default configuration section.
|
Get the default configuration section.
|
||||||
|
@ -74,6 +74,7 @@ class BaseTable(Base):
|
|||||||
def generate_uuid():
|
def generate_uuid():
|
||||||
return str(uuid.uuid4())
|
return str(uuid.uuid4())
|
||||||
|
|
||||||
|
|
||||||
class User(BaseTable):
|
class User(BaseTable):
|
||||||
|
|
||||||
__tablename__ = "users"
|
__tablename__ = "users"
|
||||||
|
@ -22,6 +22,7 @@ from sqlalchemy.exc import SQLAlchemyError
|
|||||||
from sqlalchemy.ext.asyncio import create_async_engine
|
from sqlalchemy.ext.asyncio import create_async_engine
|
||||||
|
|
||||||
from .models import Base
|
from .models import Base
|
||||||
|
from gns3server.config import Config
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
@ -29,12 +30,13 @@ log = logging.getLogger(__name__)
|
|||||||
|
|
||||||
async def connect_to_db(app: FastAPI) -> None:
|
async def connect_to_db(app: FastAPI) -> None:
|
||||||
|
|
||||||
db_url = os.environ.get("GNS3_DATABASE_URI", "sqlite:///./sql_app.db")
|
db_path = os.path.join(Config.instance().config_dir, "gns3_controller.db")
|
||||||
|
db_url = os.environ.get("GNS3_DATABASE_URI", f"sqlite:///{db_path}")
|
||||||
engine = create_async_engine(db_url, connect_args={"check_same_thread": False}, future=True)
|
engine = create_async_engine(db_url, connect_args={"check_same_thread": False}, future=True)
|
||||||
try:
|
try:
|
||||||
async with engine.begin() as conn:
|
async with engine.begin() as conn:
|
||||||
await conn.run_sync(Base.metadata.create_all)
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
log.info("Successfully connected to the database")
|
log.info(f"Successfully connected to database '{db_url}'")
|
||||||
app.state._db_engine = engine
|
app.state._db_engine = engine
|
||||||
except SQLAlchemyError as e:
|
except SQLAlchemyError as e:
|
||||||
log.error(f"Error while connecting to the database: {e}")
|
log.error(f"Error while connecting to database '{db_url}: {e}")
|
||||||
|
@ -31,6 +31,7 @@ import asyncio
|
|||||||
import signal
|
import signal
|
||||||
import functools
|
import functools
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
import secrets
|
||||||
|
|
||||||
from gns3server.controller import Controller
|
from gns3server.controller import Controller
|
||||||
from gns3server.compute.port_manager import PortManager
|
from gns3server.compute.port_manager import PortManager
|
||||||
@ -122,7 +123,7 @@ def parse_arguments(argv):
|
|||||||
config = Config.instance().get_section_config("Server")
|
config = Config.instance().get_section_config("Server")
|
||||||
defaults = {
|
defaults = {
|
||||||
"host": config.get("host", "0.0.0.0"),
|
"host": config.get("host", "0.0.0.0"),
|
||||||
"port": config.get("port", 3080),
|
"port": config.getint("port", 3080),
|
||||||
"ssl": config.getboolean("ssl", False),
|
"ssl": config.getboolean("ssl", False),
|
||||||
"certfile": config.get("certfile", ""),
|
"certfile": config.get("certfile", ""),
|
||||||
"certkey": config.get("certkey", ""),
|
"certkey": config.get("certkey", ""),
|
||||||
@ -132,8 +133,8 @@ def parse_arguments(argv):
|
|||||||
"quiet": config.getboolean("quiet", False),
|
"quiet": config.getboolean("quiet", False),
|
||||||
"debug": config.getboolean("debug", False),
|
"debug": config.getboolean("debug", False),
|
||||||
"logfile": config.getboolean("logfile", ""),
|
"logfile": config.getboolean("logfile", ""),
|
||||||
"logmaxsize": config.get("logmaxsize", 10000000), # default is 10MB
|
"logmaxsize": config.getint("logmaxsize", 10000000), # default is 10MB
|
||||||
"logbackupcount": config.get("logbackupcount", 10),
|
"logbackupcount": config.getint("logbackupcount", 10),
|
||||||
"logcompression": config.getboolean("logcompression", False)
|
"logcompression": config.getboolean("logcompression", False)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -145,6 +146,13 @@ def set_config(args):
|
|||||||
|
|
||||||
config = Config.instance()
|
config = Config.instance()
|
||||||
server_config = config.get_section_config("Server")
|
server_config = config.get_section_config("Server")
|
||||||
|
jwt_secret_key = server_config.get("jwt_secret_key", None)
|
||||||
|
if not jwt_secret_key:
|
||||||
|
log.info("No JWT secret key configured, generating one...")
|
||||||
|
if not config._config.has_section("Server"):
|
||||||
|
config._config.add_section("Server")
|
||||||
|
config._config.set("Server", "jwt_secret_key", secrets.token_hex(32))
|
||||||
|
config.write_config()
|
||||||
server_config["local"] = str(args.local)
|
server_config["local"] = str(args.local)
|
||||||
server_config["allow_remote_console"] = str(args.allow)
|
server_config["allow_remote_console"] = str(args.allow)
|
||||||
server_config["host"] = args.host
|
server_config["host"] = args.host
|
||||||
|
@ -16,7 +16,6 @@
|
|||||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
|
||||||
import bcrypt
|
|
||||||
from jose import JWTError, jwt
|
from jose import JWTError, jwt
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from passlib.context import CryptContext
|
from passlib.context import CryptContext
|
||||||
@ -24,19 +23,22 @@ from passlib.context import CryptContext
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
from fastapi import HTTPException, status
|
from fastapi import HTTPException, status
|
||||||
from gns3server.schemas.tokens import TokenData
|
from gns3server.schemas.tokens import TokenData
|
||||||
|
from gns3server.controller.controller_error import ControllerError
|
||||||
|
from gns3server.config import Config
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
# FIXME: temporary variables to move to config
|
import logging
|
||||||
SECRET_KEY = "09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e7"
|
log = logging.getLogger(__name__)
|
||||||
ALGORITHM = "HS256"
|
|
||||||
ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
|
||||||
|
|
||||||
|
|
||||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||||
|
|
||||||
|
|
||||||
class AuthService:
|
class AuthService:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
|
||||||
|
self._server_config = Config.instance().get_section_config("Server")
|
||||||
|
|
||||||
def hash_password(self, password: str) -> str:
|
def hash_password(self, password: str) -> str:
|
||||||
|
|
||||||
return pwd_context.hash(password)
|
return pwd_context.hash(password)
|
||||||
@ -45,19 +47,40 @@ class AuthService:
|
|||||||
|
|
||||||
return pwd_context.verify(password, hashed_password)
|
return pwd_context.verify(password, hashed_password)
|
||||||
|
|
||||||
|
def get_secret_key(self):
|
||||||
|
"""
|
||||||
|
Should only be used by tests.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return self._server_config.get("jwt_secret_key", None)
|
||||||
|
|
||||||
|
def get_algorithm(self):
|
||||||
|
"""
|
||||||
|
Should only be used by tests.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return self._server_config.get("jwt_algorithm", None)
|
||||||
|
|
||||||
def create_access_token(
|
def create_access_token(
|
||||||
self,
|
self,
|
||||||
username,
|
username,
|
||||||
secret_key: str = SECRET_KEY,
|
secret_key: str = None,
|
||||||
expires_in: int = ACCESS_TOKEN_EXPIRE_MINUTES
|
expires_in: int = 0
|
||||||
) -> str:
|
) -> str:
|
||||||
|
|
||||||
|
if not expires_in:
|
||||||
|
expires_in = self._server_config.getint("jwt_access_token_expire_minutes", 1440)
|
||||||
expire = datetime.utcnow() + timedelta(minutes=expires_in)
|
expire = datetime.utcnow() + timedelta(minutes=expires_in)
|
||||||
to_encode = {"sub": username, "exp": expire}
|
to_encode = {"sub": username, "exp": expire}
|
||||||
encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=ALGORITHM)
|
if secret_key is None:
|
||||||
|
secret_key = self._server_config.get("jwt_secret_key", None)
|
||||||
|
if secret_key is None:
|
||||||
|
raise ControllerError("No JWT secret key has been configured")
|
||||||
|
algorithm = self._server_config.get("jwt_algorithm", "HS256")
|
||||||
|
encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=algorithm)
|
||||||
return encoded_jwt
|
return encoded_jwt
|
||||||
|
|
||||||
def get_username_from_token(self, token: str, secret_key: str = SECRET_KEY) -> Optional[str]:
|
def get_username_from_token(self, token: str, secret_key: str = None) -> Optional[str]:
|
||||||
|
|
||||||
credentials_exception = HTTPException(
|
credentials_exception = HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
@ -65,7 +88,12 @@ class AuthService:
|
|||||||
headers={"WWW-Authenticate": "Bearer"},
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
payload = jwt.decode(token, secret_key, algorithms=[ALGORITHM])
|
if secret_key is None:
|
||||||
|
secret_key = self._server_config.get("jwt_secret_key", None)
|
||||||
|
if secret_key is None:
|
||||||
|
raise ControllerError("No JWT secret key has been configured")
|
||||||
|
algorithm = self._server_config.get("jwt_algorithm", "HS256")
|
||||||
|
payload = jwt.decode(token, secret_key, algorithms=[algorithm])
|
||||||
username: str = payload.get("sub")
|
username: str = payload.get("sub")
|
||||||
if username is None:
|
if username is None:
|
||||||
raise credentials_exception
|
raise credentials_exception
|
||||||
|
@ -17,16 +17,15 @@
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from typing import Optional, Union
|
from typing import Optional
|
||||||
from fastapi import FastAPI, HTTPException, status
|
from fastapi import FastAPI, HTTPException, status
|
||||||
from starlette.datastructures import Secret
|
|
||||||
from httpx import AsyncClient
|
from httpx import AsyncClient
|
||||||
from jose import jwt
|
from jose import jwt
|
||||||
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from gns3server.db.repositories.users import UsersRepository
|
from gns3server.db.repositories.users import UsersRepository
|
||||||
from gns3server.services import auth_service
|
from gns3server.services import auth_service
|
||||||
from gns3server.services.authentication import SECRET_KEY, ALGORITHM
|
from gns3server.config import Config
|
||||||
from gns3server.schemas.users import User
|
from gns3server.schemas.users import User
|
||||||
|
|
||||||
pytestmark = pytest.mark.asyncio
|
pytestmark = pytest.mark.asyncio
|
||||||
@ -36,7 +35,7 @@ class TestUserRoutes:
|
|||||||
|
|
||||||
async def test_route_exist(self, app: FastAPI, client: AsyncClient) -> None:
|
async def test_route_exist(self, app: FastAPI, client: AsyncClient) -> None:
|
||||||
|
|
||||||
new_user = {"username": "test_user1", "email": "user1@email.com", "password": "test_password"}
|
new_user = {"username": "user1", "email": "user1@email.com", "password": "test_password"}
|
||||||
response = await client.post(app.url_path_for("create_user"), json=new_user)
|
response = await client.post(app.url_path_for("create_user"), json=new_user)
|
||||||
assert response.status_code != status.HTTP_404_NOT_FOUND
|
assert response.status_code != status.HTTP_404_NOT_FOUND
|
||||||
|
|
||||||
@ -48,7 +47,7 @@ class TestUserRoutes:
|
|||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
user_repo = UsersRepository(db_session)
|
user_repo = UsersRepository(db_session)
|
||||||
params = {"username": "test_user2", "email": "user2@email.com", "password": "test_password"}
|
params = {"username": "user2", "email": "user2@email.com", "password": "test_password"}
|
||||||
|
|
||||||
# make sure the user doesn't exist in the database
|
# make sure the user doesn't exist in the database
|
||||||
user_in_db = await user_repo.get_user_by_username(params["username"])
|
user_in_db = await user_repo.get_user_by_username(params["username"])
|
||||||
@ -72,7 +71,7 @@ class TestUserRoutes:
|
|||||||
"attr, value, status_code",
|
"attr, value, status_code",
|
||||||
(
|
(
|
||||||
("email", "user2@email.com", status.HTTP_400_BAD_REQUEST),
|
("email", "user2@email.com", status.HTTP_400_BAD_REQUEST),
|
||||||
("username", "test_user2", status.HTTP_400_BAD_REQUEST),
|
("username", "user2", status.HTTP_400_BAD_REQUEST),
|
||||||
("email", "invalid_email@one@two.io", status.HTTP_422_UNPROCESSABLE_ENTITY),
|
("email", "invalid_email@one@two.io", status.HTTP_422_UNPROCESSABLE_ENTITY),
|
||||||
("password", "short", status.HTTP_422_UNPROCESSABLE_ENTITY),
|
("password", "short", status.HTTP_422_UNPROCESSABLE_ENTITY),
|
||||||
("username", "user2@#$%^<>", status.HTTP_422_UNPROCESSABLE_ENTITY),
|
("username", "user2@#$%^<>", status.HTTP_422_UNPROCESSABLE_ENTITY),
|
||||||
@ -101,7 +100,7 @@ class TestUserRoutes:
|
|||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
user_repo = UsersRepository(db_session)
|
user_repo = UsersRepository(db_session)
|
||||||
new_user = {"username": "test_user3", "email": "user3@email.com", "password": "test_password"}
|
new_user = {"username": "user3", "email": "user3@email.com", "password": "test_password"}
|
||||||
|
|
||||||
# send post request to create user and ensure it is successful
|
# send post request to create user and ensure it is successful
|
||||||
res = await client.post(app.url_path_for("create_user"), json=new_user)
|
res = await client.post(app.url_path_for("create_user"), json=new_user)
|
||||||
@ -114,6 +113,12 @@ class TestUserRoutes:
|
|||||||
assert user_in_db.hashed_password != new_user["password"]
|
assert user_in_db.hashed_password != new_user["password"]
|
||||||
assert auth_service.verify_password(new_user["password"], user_in_db.hashed_password)
|
assert auth_service.verify_password(new_user["password"], user_in_db.hashed_password)
|
||||||
|
|
||||||
|
async def test_get_users(self, app: FastAPI, client: AsyncClient) -> None:
|
||||||
|
|
||||||
|
response = await client.get(app.url_path_for("get_users"))
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
assert len(response.json()) == 3 # user1, user2 and user3 should exist
|
||||||
|
|
||||||
|
|
||||||
class TestAuthTokens:
|
class TestAuthTokens:
|
||||||
|
|
||||||
@ -124,16 +129,18 @@ class TestAuthTokens:
|
|||||||
test_user: User
|
test_user: User
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
|
secret_key = auth_service._server_config.get("jwt_secret_key")
|
||||||
token = auth_service.create_access_token(test_user.username)
|
token = auth_service.create_access_token(test_user.username)
|
||||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
payload = jwt.decode(token, secret_key, algorithms=["HS256"])
|
||||||
username = payload.get("sub")
|
username = payload.get("sub")
|
||||||
assert username == test_user.username
|
assert username == test_user.username
|
||||||
|
|
||||||
async def test_token_missing_user_is_invalid(self, app: FastAPI, client: AsyncClient) -> None:
|
async def test_token_missing_user_is_invalid(self, app: FastAPI, client: AsyncClient, config: Config) -> None:
|
||||||
|
|
||||||
|
secret_key = auth_service._server_config.get("jwt_secret_key")
|
||||||
token = auth_service.create_access_token(None)
|
token = auth_service.create_access_token(None)
|
||||||
with pytest.raises(jwt.JWTError):
|
with pytest.raises(jwt.JWTError):
|
||||||
jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
jwt.decode(token, secret_key, algorithms=["HS256"])
|
||||||
|
|
||||||
async def test_can_retrieve_username_from_token(
|
async def test_can_retrieve_username_from_token(
|
||||||
self,
|
self,
|
||||||
@ -148,10 +155,10 @@ class TestAuthTokens:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"secret, wrong_token",
|
"wrong_secret, wrong_token",
|
||||||
(
|
(
|
||||||
(SECRET_KEY, "asdf"), # use wrong token
|
("use correct secret", "asdf"), # use wrong token
|
||||||
(SECRET_KEY, ""), # use wrong token
|
("use correct secret", ""), # use wrong token
|
||||||
("ABC123", "use correct token"), # use wrong secret
|
("ABC123", "use correct token"), # use wrong secret
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@ -160,15 +167,17 @@ class TestAuthTokens:
|
|||||||
app: FastAPI,
|
app: FastAPI,
|
||||||
client: AsyncClient,
|
client: AsyncClient,
|
||||||
test_user: User,
|
test_user: User,
|
||||||
secret: Union[Secret, str],
|
wrong_secret: str,
|
||||||
wrong_token: Optional[str],
|
wrong_token: Optional[str],
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
token = auth_service.create_access_token(test_user.username)
|
token = auth_service.create_access_token(test_user.username)
|
||||||
|
if wrong_secret == "use correct secret":
|
||||||
|
wrong_secret = auth_service._server_config.get("jwt_secret_key")
|
||||||
if wrong_token == "use correct token":
|
if wrong_token == "use correct token":
|
||||||
wrong_token = token
|
wrong_token = token
|
||||||
with pytest.raises(HTTPException):
|
with pytest.raises(HTTPException):
|
||||||
auth_service.get_username_from_token(wrong_token, secret_key=str(secret))
|
auth_service.get_username_from_token(wrong_token, secret_key=wrong_secret)
|
||||||
|
|
||||||
|
|
||||||
class TestUserLogin:
|
class TestUserLogin:
|
||||||
@ -189,8 +198,9 @@ class TestUserLogin:
|
|||||||
assert res.status_code == status.HTTP_200_OK
|
assert res.status_code == status.HTTP_200_OK
|
||||||
|
|
||||||
# check that token exists in response and has user encoded within it
|
# check that token exists in response and has user encoded within it
|
||||||
|
secret_key = auth_service._server_config.get("jwt_secret_key")
|
||||||
token = res.json().get("access_token")
|
token = res.json().get("access_token")
|
||||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
payload = jwt.decode(token, secret_key, algorithms=["HS256"])
|
||||||
assert "sub" in payload
|
assert "sub" in payload
|
||||||
username = payload.get("sub")
|
username = payload.get("sub")
|
||||||
assert username == test_user.username
|
assert username == test_user.username
|
||||||
|
Loading…
Reference in New Issue
Block a user