mirror of
https://github.com/GNS3/gns3-server
synced 2025-01-13 09:30:54 +00:00
Generate JWT secret key if none is configured in the config file.
Change location of the database.
This commit is contained in:
parent
509e762cda
commit
bde706d19a
@ -4,11 +4,16 @@ host = 0.0.0.0
|
||||
; HTTP port for controlling the servers
|
||||
port = 3080
|
||||
|
||||
; Option to enable SSL encryption
|
||||
; Options to enable SSL encryption
|
||||
ssl = False
|
||||
certfile = /home/gns3/.config/GNS3/ssl/server.cert
|
||||
certkey = /home/gns3/.config/GNS3/ssl/server.key
|
||||
|
||||
; Options for JWT tokens (user authentication)
|
||||
jwt_secret_key = efd08eccec3bd0a1be2e086670e5efa90969c68d07e072d7354a76cea5e33d4e
|
||||
jwt_algorithm = HS256
|
||||
jwt_access_token_expire_minutes = 1440
|
||||
|
||||
; Path where devices images are stored
|
||||
images_path = /home/gns3/GNS3/images
|
||||
|
||||
|
@ -25,7 +25,12 @@ from uuid import UUID
|
||||
from typing import List
|
||||
|
||||
from gns3server import schemas
|
||||
from gns3server.controller.controller_error import ControllerBadRequestError, ControllerNotFoundError
|
||||
from gns3server.controller.controller_error import (
|
||||
ControllerBadRequestError,
|
||||
ControllerNotFoundError,
|
||||
ControllerUnauthorizedError
|
||||
)
|
||||
|
||||
from gns3server.db.repositories.users import UsersRepository
|
||||
from gns3server.services import auth_service
|
||||
|
||||
@ -98,11 +103,18 @@ async def update_user(
|
||||
|
||||
|
||||
@router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_user(user_id: UUID, user_repo: UsersRepository = Depends(get_repository(UsersRepository))):
|
||||
async def delete_user(
|
||||
user_id: UUID,
|
||||
user_repo: UsersRepository = Depends(get_repository(UsersRepository)),
|
||||
current_user: schemas.User = Depends(get_current_active_user)
|
||||
) -> None:
|
||||
"""
|
||||
Delete an user.
|
||||
"""
|
||||
|
||||
if current_user.is_superuser:
|
||||
raise ControllerUnauthorizedError("The super user cannot be deleted")
|
||||
|
||||
success = await user_repo.delete_user(user_id)
|
||||
if not success:
|
||||
raise ControllerNotFoundError(f"User '{user_id}' not found")
|
||||
|
@ -182,9 +182,21 @@ class Config:
|
||||
controller_config_filename = "gns3_controller.conf"
|
||||
return os.path.join(self.config_dir, controller_config_filename)
|
||||
|
||||
@property
|
||||
def server_config(self):
|
||||
|
||||
if sys.platform.startswith("win"):
|
||||
server_config_filename = "gns3_server.ini"
|
||||
else:
|
||||
server_config_filename = "gns3_server.conf"
|
||||
return os.path.join(self.config_dir, server_config_filename)
|
||||
|
||||
def clear(self):
|
||||
"""Restart with a clean config"""
|
||||
self._config = configparser.RawConfigParser()
|
||||
"""
|
||||
Restart with a clean config
|
||||
"""
|
||||
|
||||
self._config = configparser.ConfigParser(interpolation=None)
|
||||
# Override config from command line even if we modify the config file and live reload it.
|
||||
self._override_config = {}
|
||||
|
||||
@ -231,6 +243,18 @@ class Config:
|
||||
log.info("Load configuration file {}".format(file))
|
||||
self._watched_files[file] = os.stat(file).st_mtime
|
||||
|
||||
def write_config(self):
|
||||
"""
|
||||
Write the server configuration file.
|
||||
"""
|
||||
|
||||
try:
|
||||
os.makedirs(os.path.dirname(self.server_config), exist_ok=True)
|
||||
with open(self.server_config, 'w+') as fd:
|
||||
self._config.write(fd)
|
||||
except OSError as e:
|
||||
log.error("Cannot write server configuration file '{}': {}".format(self.server_config, e))
|
||||
|
||||
def get_default_section(self):
|
||||
"""
|
||||
Get the default configuration section.
|
||||
|
@ -74,6 +74,7 @@ class BaseTable(Base):
|
||||
def generate_uuid():
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
class User(BaseTable):
|
||||
|
||||
__tablename__ = "users"
|
||||
|
@ -22,6 +22,7 @@ from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
|
||||
from .models import Base
|
||||
from gns3server.config import Config
|
||||
|
||||
import logging
|
||||
log = logging.getLogger(__name__)
|
||||
@ -29,12 +30,13 @@ log = logging.getLogger(__name__)
|
||||
|
||||
async def connect_to_db(app: FastAPI) -> None:
|
||||
|
||||
db_url = os.environ.get("GNS3_DATABASE_URI", "sqlite:///./sql_app.db")
|
||||
db_path = os.path.join(Config.instance().config_dir, "gns3_controller.db")
|
||||
db_url = os.environ.get("GNS3_DATABASE_URI", f"sqlite:///{db_path}")
|
||||
engine = create_async_engine(db_url, connect_args={"check_same_thread": False}, future=True)
|
||||
try:
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
log.info("Successfully connected to the database")
|
||||
log.info(f"Successfully connected to database '{db_url}'")
|
||||
app.state._db_engine = engine
|
||||
except SQLAlchemyError as e:
|
||||
log.error(f"Error while connecting to the database: {e}")
|
||||
log.error(f"Error while connecting to database '{db_url}: {e}")
|
||||
|
@ -31,6 +31,7 @@ import asyncio
|
||||
import signal
|
||||
import functools
|
||||
import uvicorn
|
||||
import secrets
|
||||
|
||||
from gns3server.controller import Controller
|
||||
from gns3server.compute.port_manager import PortManager
|
||||
@ -122,7 +123,7 @@ def parse_arguments(argv):
|
||||
config = Config.instance().get_section_config("Server")
|
||||
defaults = {
|
||||
"host": config.get("host", "0.0.0.0"),
|
||||
"port": config.get("port", 3080),
|
||||
"port": config.getint("port", 3080),
|
||||
"ssl": config.getboolean("ssl", False),
|
||||
"certfile": config.get("certfile", ""),
|
||||
"certkey": config.get("certkey", ""),
|
||||
@ -132,8 +133,8 @@ def parse_arguments(argv):
|
||||
"quiet": config.getboolean("quiet", False),
|
||||
"debug": config.getboolean("debug", False),
|
||||
"logfile": config.getboolean("logfile", ""),
|
||||
"logmaxsize": config.get("logmaxsize", 10000000), # default is 10MB
|
||||
"logbackupcount": config.get("logbackupcount", 10),
|
||||
"logmaxsize": config.getint("logmaxsize", 10000000), # default is 10MB
|
||||
"logbackupcount": config.getint("logbackupcount", 10),
|
||||
"logcompression": config.getboolean("logcompression", False)
|
||||
}
|
||||
|
||||
@ -145,6 +146,13 @@ def set_config(args):
|
||||
|
||||
config = Config.instance()
|
||||
server_config = config.get_section_config("Server")
|
||||
jwt_secret_key = server_config.get("jwt_secret_key", None)
|
||||
if not jwt_secret_key:
|
||||
log.info("No JWT secret key configured, generating one...")
|
||||
if not config._config.has_section("Server"):
|
||||
config._config.add_section("Server")
|
||||
config._config.set("Server", "jwt_secret_key", secrets.token_hex(32))
|
||||
config.write_config()
|
||||
server_config["local"] = str(args.local)
|
||||
server_config["allow_remote_console"] = str(args.allow)
|
||||
server_config["host"] = args.host
|
||||
|
@ -16,7 +16,6 @@
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
|
||||
import bcrypt
|
||||
from jose import JWTError, jwt
|
||||
from datetime import datetime, timedelta
|
||||
from passlib.context import CryptContext
|
||||
@ -24,19 +23,22 @@ from passlib.context import CryptContext
|
||||
from typing import Optional
|
||||
from fastapi import HTTPException, status
|
||||
from gns3server.schemas.tokens import TokenData
|
||||
from gns3server.controller.controller_error import ControllerError
|
||||
from gns3server.config import Config
|
||||
from pydantic import ValidationError
|
||||
|
||||
# FIXME: temporary variables to move to config
|
||||
SECRET_KEY = "09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e7"
|
||||
ALGORITHM = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
||||
|
||||
import logging
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
|
||||
class AuthService:
|
||||
|
||||
def __init__(self):
|
||||
|
||||
self._server_config = Config.instance().get_section_config("Server")
|
||||
|
||||
def hash_password(self, password: str) -> str:
|
||||
|
||||
return pwd_context.hash(password)
|
||||
@ -45,19 +47,40 @@ class AuthService:
|
||||
|
||||
return pwd_context.verify(password, hashed_password)
|
||||
|
||||
def get_secret_key(self):
|
||||
"""
|
||||
Should only be used by tests.
|
||||
"""
|
||||
|
||||
return self._server_config.get("jwt_secret_key", None)
|
||||
|
||||
def get_algorithm(self):
|
||||
"""
|
||||
Should only be used by tests.
|
||||
"""
|
||||
|
||||
return self._server_config.get("jwt_algorithm", None)
|
||||
|
||||
def create_access_token(
|
||||
self,
|
||||
username,
|
||||
secret_key: str = SECRET_KEY,
|
||||
expires_in: int = ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
secret_key: str = None,
|
||||
expires_in: int = 0
|
||||
) -> str:
|
||||
|
||||
if not expires_in:
|
||||
expires_in = self._server_config.getint("jwt_access_token_expire_minutes", 1440)
|
||||
expire = datetime.utcnow() + timedelta(minutes=expires_in)
|
||||
to_encode = {"sub": username, "exp": expire}
|
||||
encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=ALGORITHM)
|
||||
if secret_key is None:
|
||||
secret_key = self._server_config.get("jwt_secret_key", None)
|
||||
if secret_key is None:
|
||||
raise ControllerError("No JWT secret key has been configured")
|
||||
algorithm = self._server_config.get("jwt_algorithm", "HS256")
|
||||
encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=algorithm)
|
||||
return encoded_jwt
|
||||
|
||||
def get_username_from_token(self, token: str, secret_key: str = SECRET_KEY) -> Optional[str]:
|
||||
def get_username_from_token(self, token: str, secret_key: str = None) -> Optional[str]:
|
||||
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
@ -65,7 +88,12 @@ class AuthService:
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
try:
|
||||
payload = jwt.decode(token, secret_key, algorithms=[ALGORITHM])
|
||||
if secret_key is None:
|
||||
secret_key = self._server_config.get("jwt_secret_key", None)
|
||||
if secret_key is None:
|
||||
raise ControllerError("No JWT secret key has been configured")
|
||||
algorithm = self._server_config.get("jwt_algorithm", "HS256")
|
||||
payload = jwt.decode(token, secret_key, algorithms=[algorithm])
|
||||
username: str = payload.get("sub")
|
||||
if username is None:
|
||||
raise credentials_exception
|
||||
|
@ -17,16 +17,15 @@
|
||||
|
||||
import pytest
|
||||
|
||||
from typing import Optional, Union
|
||||
from typing import Optional
|
||||
from fastapi import FastAPI, HTTPException, status
|
||||
from starlette.datastructures import Secret
|
||||
from httpx import AsyncClient
|
||||
from jose import jwt
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from gns3server.db.repositories.users import UsersRepository
|
||||
from gns3server.services import auth_service
|
||||
from gns3server.services.authentication import SECRET_KEY, ALGORITHM
|
||||
from gns3server.config import Config
|
||||
from gns3server.schemas.users import User
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
@ -36,7 +35,7 @@ class TestUserRoutes:
|
||||
|
||||
async def test_route_exist(self, app: FastAPI, client: AsyncClient) -> None:
|
||||
|
||||
new_user = {"username": "test_user1", "email": "user1@email.com", "password": "test_password"}
|
||||
new_user = {"username": "user1", "email": "user1@email.com", "password": "test_password"}
|
||||
response = await client.post(app.url_path_for("create_user"), json=new_user)
|
||||
assert response.status_code != status.HTTP_404_NOT_FOUND
|
||||
|
||||
@ -48,7 +47,7 @@ class TestUserRoutes:
|
||||
) -> None:
|
||||
|
||||
user_repo = UsersRepository(db_session)
|
||||
params = {"username": "test_user2", "email": "user2@email.com", "password": "test_password"}
|
||||
params = {"username": "user2", "email": "user2@email.com", "password": "test_password"}
|
||||
|
||||
# make sure the user doesn't exist in the database
|
||||
user_in_db = await user_repo.get_user_by_username(params["username"])
|
||||
@ -72,7 +71,7 @@ class TestUserRoutes:
|
||||
"attr, value, status_code",
|
||||
(
|
||||
("email", "user2@email.com", status.HTTP_400_BAD_REQUEST),
|
||||
("username", "test_user2", status.HTTP_400_BAD_REQUEST),
|
||||
("username", "user2", status.HTTP_400_BAD_REQUEST),
|
||||
("email", "invalid_email@one@two.io", status.HTTP_422_UNPROCESSABLE_ENTITY),
|
||||
("password", "short", status.HTTP_422_UNPROCESSABLE_ENTITY),
|
||||
("username", "user2@#$%^<>", status.HTTP_422_UNPROCESSABLE_ENTITY),
|
||||
@ -101,7 +100,7 @@ class TestUserRoutes:
|
||||
) -> None:
|
||||
|
||||
user_repo = UsersRepository(db_session)
|
||||
new_user = {"username": "test_user3", "email": "user3@email.com", "password": "test_password"}
|
||||
new_user = {"username": "user3", "email": "user3@email.com", "password": "test_password"}
|
||||
|
||||
# send post request to create user and ensure it is successful
|
||||
res = await client.post(app.url_path_for("create_user"), json=new_user)
|
||||
@ -114,6 +113,12 @@ class TestUserRoutes:
|
||||
assert user_in_db.hashed_password != new_user["password"]
|
||||
assert auth_service.verify_password(new_user["password"], user_in_db.hashed_password)
|
||||
|
||||
async def test_get_users(self, app: FastAPI, client: AsyncClient) -> None:
|
||||
|
||||
response = await client.get(app.url_path_for("get_users"))
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert len(response.json()) == 3 # user1, user2 and user3 should exist
|
||||
|
||||
|
||||
class TestAuthTokens:
|
||||
|
||||
@ -124,16 +129,18 @@ class TestAuthTokens:
|
||||
test_user: User
|
||||
) -> None:
|
||||
|
||||
secret_key = auth_service._server_config.get("jwt_secret_key")
|
||||
token = auth_service.create_access_token(test_user.username)
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
payload = jwt.decode(token, secret_key, algorithms=["HS256"])
|
||||
username = payload.get("sub")
|
||||
assert username == test_user.username
|
||||
|
||||
async def test_token_missing_user_is_invalid(self, app: FastAPI, client: AsyncClient) -> None:
|
||||
async def test_token_missing_user_is_invalid(self, app: FastAPI, client: AsyncClient, config: Config) -> None:
|
||||
|
||||
secret_key = auth_service._server_config.get("jwt_secret_key")
|
||||
token = auth_service.create_access_token(None)
|
||||
with pytest.raises(jwt.JWTError):
|
||||
jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
jwt.decode(token, secret_key, algorithms=["HS256"])
|
||||
|
||||
async def test_can_retrieve_username_from_token(
|
||||
self,
|
||||
@ -148,10 +155,10 @@ class TestAuthTokens:
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"secret, wrong_token",
|
||||
"wrong_secret, wrong_token",
|
||||
(
|
||||
(SECRET_KEY, "asdf"), # use wrong token
|
||||
(SECRET_KEY, ""), # use wrong token
|
||||
("use correct secret", "asdf"), # use wrong token
|
||||
("use correct secret", ""), # use wrong token
|
||||
("ABC123", "use correct token"), # use wrong secret
|
||||
),
|
||||
)
|
||||
@ -160,15 +167,17 @@ class TestAuthTokens:
|
||||
app: FastAPI,
|
||||
client: AsyncClient,
|
||||
test_user: User,
|
||||
secret: Union[Secret, str],
|
||||
wrong_secret: str,
|
||||
wrong_token: Optional[str],
|
||||
) -> None:
|
||||
|
||||
token = auth_service.create_access_token(test_user.username)
|
||||
if wrong_secret == "use correct secret":
|
||||
wrong_secret = auth_service._server_config.get("jwt_secret_key")
|
||||
if wrong_token == "use correct token":
|
||||
wrong_token = token
|
||||
with pytest.raises(HTTPException):
|
||||
auth_service.get_username_from_token(wrong_token, secret_key=str(secret))
|
||||
auth_service.get_username_from_token(wrong_token, secret_key=wrong_secret)
|
||||
|
||||
|
||||
class TestUserLogin:
|
||||
@ -189,8 +198,9 @@ class TestUserLogin:
|
||||
assert res.status_code == status.HTTP_200_OK
|
||||
|
||||
# check that token exists in response and has user encoded within it
|
||||
secret_key = auth_service._server_config.get("jwt_secret_key")
|
||||
token = res.json().get("access_token")
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
payload = jwt.decode(token, secret_key, algorithms=["HS256"])
|
||||
assert "sub" in payload
|
||||
username = payload.get("sub")
|
||||
assert username == test_user.username
|
||||
|
Loading…
Reference in New Issue
Block a user