1
0
mirror of https://github.com/GNS3/gns3-server synced 2024-12-26 00:38:10 +00:00

Generate JWT secret key if none is configured in the config file.

Change location of the database.
This commit is contained in:
grossmj 2020-12-16 18:24:21 +10:30
parent 509e762cda
commit bde706d19a
8 changed files with 130 additions and 40 deletions

View File

@ -4,10 +4,15 @@ host = 0.0.0.0
; HTTP port for controlling the servers ; HTTP port for controlling the servers
port = 3080 port = 3080
; Option to enable SSL encryption ; Options to enable SSL encryption
ssl = False ssl = False
certfile=/home/gns3/.config/GNS3/ssl/server.cert certfile = /home/gns3/.config/GNS3/ssl/server.cert
certkey=/home/gns3/.config/GNS3/ssl/server.key certkey = /home/gns3/.config/GNS3/ssl/server.key
; Options for JWT tokens (user authentication)
jwt_secret_key = efd08eccec3bd0a1be2e086670e5efa90969c68d07e072d7354a76cea5e33d4e
jwt_algorithm = HS256
jwt_access_token_expire_minutes = 1440
; Path where devices images are stored ; Path where devices images are stored
images_path = /home/gns3/GNS3/images images_path = /home/gns3/GNS3/images

View File

@ -25,7 +25,12 @@ from uuid import UUID
from typing import List from typing import List
from gns3server import schemas from gns3server import schemas
from gns3server.controller.controller_error import ControllerBadRequestError, ControllerNotFoundError from gns3server.controller.controller_error import (
ControllerBadRequestError,
ControllerNotFoundError,
ControllerUnauthorizedError
)
from gns3server.db.repositories.users import UsersRepository from gns3server.db.repositories.users import UsersRepository
from gns3server.services import auth_service from gns3server.services import auth_service
@ -98,11 +103,18 @@ async def update_user(
@router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT) @router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_user(user_id: UUID, user_repo: UsersRepository = Depends(get_repository(UsersRepository))): async def delete_user(
user_id: UUID,
user_repo: UsersRepository = Depends(get_repository(UsersRepository)),
current_user: schemas.User = Depends(get_current_active_user)
) -> None:
""" """
Delete an user. Delete an user.
""" """
if current_user.is_superuser:
raise ControllerUnauthorizedError("The super user cannot be deleted")
success = await user_repo.delete_user(user_id) success = await user_repo.delete_user(user_id)
if not success: if not success:
raise ControllerNotFoundError(f"User '{user_id}' not found") raise ControllerNotFoundError(f"User '{user_id}' not found")

View File

@ -182,9 +182,21 @@ class Config:
controller_config_filename = "gns3_controller.conf" controller_config_filename = "gns3_controller.conf"
return os.path.join(self.config_dir, controller_config_filename) return os.path.join(self.config_dir, controller_config_filename)
@property
def server_config(self):
if sys.platform.startswith("win"):
server_config_filename = "gns3_server.ini"
else:
server_config_filename = "gns3_server.conf"
return os.path.join(self.config_dir, server_config_filename)
def clear(self): def clear(self):
"""Restart with a clean config""" """
self._config = configparser.RawConfigParser() Restart with a clean config
"""
self._config = configparser.ConfigParser(interpolation=None)
# Override config from command line even if we modify the config file and live reload it. # Override config from command line even if we modify the config file and live reload it.
self._override_config = {} self._override_config = {}
@ -231,6 +243,18 @@ class Config:
log.info("Load configuration file {}".format(file)) log.info("Load configuration file {}".format(file))
self._watched_files[file] = os.stat(file).st_mtime self._watched_files[file] = os.stat(file).st_mtime
def write_config(self):
"""
Write the server configuration file.
"""
try:
os.makedirs(os.path.dirname(self.server_config), exist_ok=True)
with open(self.server_config, 'w+') as fd:
self._config.write(fd)
except OSError as e:
log.error("Cannot write server configuration file '{}': {}".format(self.server_config, e))
def get_default_section(self): def get_default_section(self):
""" """
Get the default configuration section. Get the default configuration section.

View File

@ -74,6 +74,7 @@ class BaseTable(Base):
def generate_uuid(): def generate_uuid():
return str(uuid.uuid4()) return str(uuid.uuid4())
class User(BaseTable): class User(BaseTable):
__tablename__ = "users" __tablename__ = "users"

View File

@ -22,6 +22,7 @@ from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.ext.asyncio import create_async_engine
from .models import Base from .models import Base
from gns3server.config import Config
import logging import logging
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -29,12 +30,13 @@ log = logging.getLogger(__name__)
async def connect_to_db(app: FastAPI) -> None: async def connect_to_db(app: FastAPI) -> None:
db_url = os.environ.get("GNS3_DATABASE_URI", "sqlite:///./sql_app.db") db_path = os.path.join(Config.instance().config_dir, "gns3_controller.db")
db_url = os.environ.get("GNS3_DATABASE_URI", f"sqlite:///{db_path}")
engine = create_async_engine(db_url, connect_args={"check_same_thread": False}, future=True) engine = create_async_engine(db_url, connect_args={"check_same_thread": False}, future=True)
try: try:
async with engine.begin() as conn: async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all) await conn.run_sync(Base.metadata.create_all)
log.info("Successfully connected to the database") log.info(f"Successfully connected to database '{db_url}'")
app.state._db_engine = engine app.state._db_engine = engine
except SQLAlchemyError as e: except SQLAlchemyError as e:
log.error(f"Error while connecting to the database: {e}") log.error(f"Error while connecting to database '{db_url}: {e}")

View File

@ -31,6 +31,7 @@ import asyncio
import signal import signal
import functools import functools
import uvicorn import uvicorn
import secrets
from gns3server.controller import Controller from gns3server.controller import Controller
from gns3server.compute.port_manager import PortManager from gns3server.compute.port_manager import PortManager
@ -122,7 +123,7 @@ def parse_arguments(argv):
config = Config.instance().get_section_config("Server") config = Config.instance().get_section_config("Server")
defaults = { defaults = {
"host": config.get("host", "0.0.0.0"), "host": config.get("host", "0.0.0.0"),
"port": config.get("port", 3080), "port": config.getint("port", 3080),
"ssl": config.getboolean("ssl", False), "ssl": config.getboolean("ssl", False),
"certfile": config.get("certfile", ""), "certfile": config.get("certfile", ""),
"certkey": config.get("certkey", ""), "certkey": config.get("certkey", ""),
@ -132,8 +133,8 @@ def parse_arguments(argv):
"quiet": config.getboolean("quiet", False), "quiet": config.getboolean("quiet", False),
"debug": config.getboolean("debug", False), "debug": config.getboolean("debug", False),
"logfile": config.getboolean("logfile", ""), "logfile": config.getboolean("logfile", ""),
"logmaxsize": config.get("logmaxsize", 10000000), # default is 10MB "logmaxsize": config.getint("logmaxsize", 10000000), # default is 10MB
"logbackupcount": config.get("logbackupcount", 10), "logbackupcount": config.getint("logbackupcount", 10),
"logcompression": config.getboolean("logcompression", False) "logcompression": config.getboolean("logcompression", False)
} }
@ -145,6 +146,13 @@ def set_config(args):
config = Config.instance() config = Config.instance()
server_config = config.get_section_config("Server") server_config = config.get_section_config("Server")
jwt_secret_key = server_config.get("jwt_secret_key", None)
if not jwt_secret_key:
log.info("No JWT secret key configured, generating one...")
if not config._config.has_section("Server"):
config._config.add_section("Server")
config._config.set("Server", "jwt_secret_key", secrets.token_hex(32))
config.write_config()
server_config["local"] = str(args.local) server_config["local"] = str(args.local)
server_config["allow_remote_console"] = str(args.allow) server_config["allow_remote_console"] = str(args.allow)
server_config["host"] = args.host server_config["host"] = args.host

View File

@ -16,7 +16,6 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import bcrypt
from jose import JWTError, jwt from jose import JWTError, jwt
from datetime import datetime, timedelta from datetime import datetime, timedelta
from passlib.context import CryptContext from passlib.context import CryptContext
@ -24,19 +23,22 @@ from passlib.context import CryptContext
from typing import Optional from typing import Optional
from fastapi import HTTPException, status from fastapi import HTTPException, status
from gns3server.schemas.tokens import TokenData from gns3server.schemas.tokens import TokenData
from gns3server.controller.controller_error import ControllerError
from gns3server.config import Config
from pydantic import ValidationError from pydantic import ValidationError
# FIXME: temporary variables to move to config import logging
SECRET_KEY = "09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e7" log = logging.getLogger(__name__)
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
class AuthService: class AuthService:
def __init__(self):
self._server_config = Config.instance().get_section_config("Server")
def hash_password(self, password: str) -> str: def hash_password(self, password: str) -> str:
return pwd_context.hash(password) return pwd_context.hash(password)
@ -45,19 +47,40 @@ class AuthService:
return pwd_context.verify(password, hashed_password) return pwd_context.verify(password, hashed_password)
def get_secret_key(self):
"""
Should only be used by tests.
"""
return self._server_config.get("jwt_secret_key", None)
def get_algorithm(self):
"""
Should only be used by tests.
"""
return self._server_config.get("jwt_algorithm", None)
def create_access_token( def create_access_token(
self, self,
username, username,
secret_key: str = SECRET_KEY, secret_key: str = None,
expires_in: int = ACCESS_TOKEN_EXPIRE_MINUTES expires_in: int = 0
) -> str: ) -> str:
if not expires_in:
expires_in = self._server_config.getint("jwt_access_token_expire_minutes", 1440)
expire = datetime.utcnow() + timedelta(minutes=expires_in) expire = datetime.utcnow() + timedelta(minutes=expires_in)
to_encode = {"sub": username, "exp": expire} to_encode = {"sub": username, "exp": expire}
encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=ALGORITHM) if secret_key is None:
secret_key = self._server_config.get("jwt_secret_key", None)
if secret_key is None:
raise ControllerError("No JWT secret key has been configured")
algorithm = self._server_config.get("jwt_algorithm", "HS256")
encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=algorithm)
return encoded_jwt return encoded_jwt
def get_username_from_token(self, token: str, secret_key: str = SECRET_KEY) -> Optional[str]: def get_username_from_token(self, token: str, secret_key: str = None) -> Optional[str]:
credentials_exception = HTTPException( credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@ -65,7 +88,12 @@ class AuthService:
headers={"WWW-Authenticate": "Bearer"}, headers={"WWW-Authenticate": "Bearer"},
) )
try: try:
payload = jwt.decode(token, secret_key, algorithms=[ALGORITHM]) if secret_key is None:
secret_key = self._server_config.get("jwt_secret_key", None)
if secret_key is None:
raise ControllerError("No JWT secret key has been configured")
algorithm = self._server_config.get("jwt_algorithm", "HS256")
payload = jwt.decode(token, secret_key, algorithms=[algorithm])
username: str = payload.get("sub") username: str = payload.get("sub")
if username is None: if username is None:
raise credentials_exception raise credentials_exception

View File

@ -17,16 +17,15 @@
import pytest import pytest
from typing import Optional, Union from typing import Optional
from fastapi import FastAPI, HTTPException, status from fastapi import FastAPI, HTTPException, status
from starlette.datastructures import Secret
from httpx import AsyncClient from httpx import AsyncClient
from jose import jwt from jose import jwt
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from gns3server.db.repositories.users import UsersRepository from gns3server.db.repositories.users import UsersRepository
from gns3server.services import auth_service from gns3server.services import auth_service
from gns3server.services.authentication import SECRET_KEY, ALGORITHM from gns3server.config import Config
from gns3server.schemas.users import User from gns3server.schemas.users import User
pytestmark = pytest.mark.asyncio pytestmark = pytest.mark.asyncio
@ -36,7 +35,7 @@ class TestUserRoutes:
async def test_route_exist(self, app: FastAPI, client: AsyncClient) -> None: async def test_route_exist(self, app: FastAPI, client: AsyncClient) -> None:
new_user = {"username": "test_user1", "email": "user1@email.com", "password": "test_password"} new_user = {"username": "user1", "email": "user1@email.com", "password": "test_password"}
response = await client.post(app.url_path_for("create_user"), json=new_user) response = await client.post(app.url_path_for("create_user"), json=new_user)
assert response.status_code != status.HTTP_404_NOT_FOUND assert response.status_code != status.HTTP_404_NOT_FOUND
@ -48,7 +47,7 @@ class TestUserRoutes:
) -> None: ) -> None:
user_repo = UsersRepository(db_session) user_repo = UsersRepository(db_session)
params = {"username": "test_user2", "email": "user2@email.com", "password": "test_password"} params = {"username": "user2", "email": "user2@email.com", "password": "test_password"}
# make sure the user doesn't exist in the database # make sure the user doesn't exist in the database
user_in_db = await user_repo.get_user_by_username(params["username"]) user_in_db = await user_repo.get_user_by_username(params["username"])
@ -72,7 +71,7 @@ class TestUserRoutes:
"attr, value, status_code", "attr, value, status_code",
( (
("email", "user2@email.com", status.HTTP_400_BAD_REQUEST), ("email", "user2@email.com", status.HTTP_400_BAD_REQUEST),
("username", "test_user2", status.HTTP_400_BAD_REQUEST), ("username", "user2", status.HTTP_400_BAD_REQUEST),
("email", "invalid_email@one@two.io", status.HTTP_422_UNPROCESSABLE_ENTITY), ("email", "invalid_email@one@two.io", status.HTTP_422_UNPROCESSABLE_ENTITY),
("password", "short", status.HTTP_422_UNPROCESSABLE_ENTITY), ("password", "short", status.HTTP_422_UNPROCESSABLE_ENTITY),
("username", "user2@#$%^<>", status.HTTP_422_UNPROCESSABLE_ENTITY), ("username", "user2@#$%^<>", status.HTTP_422_UNPROCESSABLE_ENTITY),
@ -101,7 +100,7 @@ class TestUserRoutes:
) -> None: ) -> None:
user_repo = UsersRepository(db_session) user_repo = UsersRepository(db_session)
new_user = {"username": "test_user3", "email": "user3@email.com", "password": "test_password"} new_user = {"username": "user3", "email": "user3@email.com", "password": "test_password"}
# send post request to create user and ensure it is successful # send post request to create user and ensure it is successful
res = await client.post(app.url_path_for("create_user"), json=new_user) res = await client.post(app.url_path_for("create_user"), json=new_user)
@ -114,6 +113,12 @@ class TestUserRoutes:
assert user_in_db.hashed_password != new_user["password"] assert user_in_db.hashed_password != new_user["password"]
assert auth_service.verify_password(new_user["password"], user_in_db.hashed_password) assert auth_service.verify_password(new_user["password"], user_in_db.hashed_password)
async def test_get_users(self, app: FastAPI, client: AsyncClient) -> None:
response = await client.get(app.url_path_for("get_users"))
assert response.status_code == status.HTTP_200_OK
assert len(response.json()) == 3 # user1, user2 and user3 should exist
class TestAuthTokens: class TestAuthTokens:
@ -124,16 +129,18 @@ class TestAuthTokens:
test_user: User test_user: User
) -> None: ) -> None:
secret_key = auth_service._server_config.get("jwt_secret_key")
token = auth_service.create_access_token(test_user.username) token = auth_service.create_access_token(test_user.username)
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) payload = jwt.decode(token, secret_key, algorithms=["HS256"])
username = payload.get("sub") username = payload.get("sub")
assert username == test_user.username assert username == test_user.username
async def test_token_missing_user_is_invalid(self, app: FastAPI, client: AsyncClient) -> None: async def test_token_missing_user_is_invalid(self, app: FastAPI, client: AsyncClient, config: Config) -> None:
secret_key = auth_service._server_config.get("jwt_secret_key")
token = auth_service.create_access_token(None) token = auth_service.create_access_token(None)
with pytest.raises(jwt.JWTError): with pytest.raises(jwt.JWTError):
jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) jwt.decode(token, secret_key, algorithms=["HS256"])
async def test_can_retrieve_username_from_token( async def test_can_retrieve_username_from_token(
self, self,
@ -148,10 +155,10 @@ class TestAuthTokens:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"secret, wrong_token", "wrong_secret, wrong_token",
( (
(SECRET_KEY, "asdf"), # use wrong token ("use correct secret", "asdf"), # use wrong token
(SECRET_KEY, ""), # use wrong token ("use correct secret", ""), # use wrong token
("ABC123", "use correct token"), # use wrong secret ("ABC123", "use correct token"), # use wrong secret
), ),
) )
@ -160,15 +167,17 @@ class TestAuthTokens:
app: FastAPI, app: FastAPI,
client: AsyncClient, client: AsyncClient,
test_user: User, test_user: User,
secret: Union[Secret, str], wrong_secret: str,
wrong_token: Optional[str], wrong_token: Optional[str],
) -> None: ) -> None:
token = auth_service.create_access_token(test_user.username) token = auth_service.create_access_token(test_user.username)
if wrong_secret == "use correct secret":
wrong_secret = auth_service._server_config.get("jwt_secret_key")
if wrong_token == "use correct token": if wrong_token == "use correct token":
wrong_token = token wrong_token = token
with pytest.raises(HTTPException): with pytest.raises(HTTPException):
auth_service.get_username_from_token(wrong_token, secret_key=str(secret)) auth_service.get_username_from_token(wrong_token, secret_key=wrong_secret)
class TestUserLogin: class TestUserLogin:
@ -189,8 +198,9 @@ class TestUserLogin:
assert res.status_code == status.HTTP_200_OK assert res.status_code == status.HTTP_200_OK
# check that token exists in response and has user encoded within it # check that token exists in response and has user encoded within it
secret_key = auth_service._server_config.get("jwt_secret_key")
token = res.json().get("access_token") token = res.json().get("access_token")
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) payload = jwt.decode(token, secret_key, algorithms=["HS256"])
assert "sub" in payload assert "sub" in payload
username = payload.get("sub") username = payload.get("sub")
assert username == test_user.username assert username == test_user.username