mirror of
https://github.com/GNS3/gns3-server
synced 2025-01-28 00:41:01 +00:00
Protect the API and add alternative authentication endpoint.
This commit is contained in:
parent
e28452f09a
commit
0465cb87f6
@ -14,7 +14,7 @@
|
|||||||
# You should have received a copy of the GNU General Public License
|
# You should have received a copy of the GNU General Public License
|
||||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter, Depends
|
||||||
|
|
||||||
from . import controller
|
from . import controller
|
||||||
from . import appliances
|
from . import appliances
|
||||||
@ -30,17 +30,80 @@ from . import symbols
|
|||||||
from . import templates
|
from . import templates
|
||||||
from . import users
|
from . import users
|
||||||
|
|
||||||
|
from .dependencies.authentication import get_current_active_user
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
router.include_router(controller.router, tags=["Controller"])
|
router.include_router(controller.router, tags=["Controller"])
|
||||||
router.include_router(users.router, prefix="/users", tags=["Users"])
|
router.include_router(users.router, prefix="/users", tags=["Users"])
|
||||||
router.include_router(appliances.router, prefix="/appliances", tags=["Appliances"])
|
|
||||||
router.include_router(computes.router, prefix="/computes", tags=["Computes"])
|
router.include_router(
|
||||||
router.include_router(drawings.router, prefix="/projects/{project_id}/drawings", tags=["Drawings"])
|
appliances.router,
|
||||||
router.include_router(gns3vm.router, prefix="/gns3vm", tags=["GNS3 VM"])
|
dependencies=[Depends(get_current_active_user)],
|
||||||
router.include_router(links.router, prefix="/projects/{project_id}/links", tags=["Links"])
|
prefix="/appliances",
|
||||||
router.include_router(nodes.router, prefix="/projects/{project_id}/nodes", tags=["Nodes"])
|
tags=["Appliances"]
|
||||||
router.include_router(notifications.router, prefix="/notifications", tags=["Notifications"])
|
)
|
||||||
router.include_router(projects.router, prefix="/projects", tags=["Projects"])
|
|
||||||
router.include_router(snapshots.router, prefix="/projects/{project_id}/snapshots", tags=["Snapshots"])
|
router.include_router(
|
||||||
router.include_router(symbols.router, prefix="/symbols", tags=["Symbols"])
|
computes.router,
|
||||||
router.include_router(templates.router, tags=["Templates"])
|
dependencies=[Depends(get_current_active_user)],
|
||||||
|
prefix="/computes",
|
||||||
|
tags=["Computes"]
|
||||||
|
)
|
||||||
|
|
||||||
|
router.include_router(
|
||||||
|
drawings.router,
|
||||||
|
dependencies=[Depends(get_current_active_user)],
|
||||||
|
prefix="/projects/{project_id}/drawings",
|
||||||
|
tags=["Drawings"])
|
||||||
|
|
||||||
|
router.include_router(
|
||||||
|
gns3vm.router,
|
||||||
|
dependencies=[Depends(get_current_active_user)],
|
||||||
|
prefix="/gns3vm",
|
||||||
|
tags=["GNS3 VM"]
|
||||||
|
)
|
||||||
|
|
||||||
|
router.include_router(
|
||||||
|
links.router,
|
||||||
|
dependencies=[Depends(get_current_active_user)],
|
||||||
|
prefix="/projects/{project_id}/links",
|
||||||
|
tags=["Links"]
|
||||||
|
)
|
||||||
|
|
||||||
|
router.include_router(
|
||||||
|
nodes.router,
|
||||||
|
dependencies=[Depends(get_current_active_user)],
|
||||||
|
prefix="/projects/{project_id}/nodes",
|
||||||
|
tags=["Nodes"]
|
||||||
|
)
|
||||||
|
|
||||||
|
router.include_router(
|
||||||
|
notifications.router,
|
||||||
|
dependencies=[Depends(get_current_active_user)],
|
||||||
|
prefix="/notifications",
|
||||||
|
tags=["Notifications"])
|
||||||
|
|
||||||
|
router.include_router(
|
||||||
|
projects.router,
|
||||||
|
dependencies=[Depends(get_current_active_user)],
|
||||||
|
prefix="/projects",
|
||||||
|
tags=["Projects"])
|
||||||
|
|
||||||
|
router.include_router(
|
||||||
|
snapshots.router,
|
||||||
|
dependencies=[Depends(get_current_active_user)],
|
||||||
|
prefix="/projects/{project_id}/snapshots",
|
||||||
|
tags=["Snapshots"])
|
||||||
|
|
||||||
|
router.include_router(
|
||||||
|
symbols.router,
|
||||||
|
dependencies=[Depends(get_current_active_user)],
|
||||||
|
prefix="/symbols", tags=["Symbols"]
|
||||||
|
)
|
||||||
|
|
||||||
|
router.include_router(
|
||||||
|
templates.router,
|
||||||
|
dependencies=[Depends(get_current_active_user)],
|
||||||
|
tags=["Templates"]
|
||||||
|
)
|
||||||
|
@ -18,7 +18,7 @@ import asyncio
|
|||||||
import signal
|
import signal
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from fastapi import APIRouter, status
|
from fastapi import APIRouter, Depends, status
|
||||||
from fastapi.encoders import jsonable_encoder
|
from fastapi.encoders import jsonable_encoder
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
@ -28,6 +28,7 @@ from gns3server.version import __version__
|
|||||||
from gns3server.controller.controller_error import ControllerError, ControllerForbiddenError
|
from gns3server.controller.controller_error import ControllerError, ControllerForbiddenError
|
||||||
from gns3server import schemas
|
from gns3server import schemas
|
||||||
|
|
||||||
|
from .dependencies.authentication import get_current_active_user
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
@ -36,8 +37,39 @@ log = logging.getLogger(__name__)
|
|||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/version",
|
||||||
|
response_model=schemas.Version,
|
||||||
|
)
|
||||||
|
def get_version() -> dict:
|
||||||
|
"""
|
||||||
|
Return the server version number.
|
||||||
|
"""
|
||||||
|
|
||||||
|
local_server = Config.instance().settings.Server.local
|
||||||
|
return {"version": __version__, "local": local_server}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/version",
|
||||||
|
response_model=schemas.Version,
|
||||||
|
response_model_exclude_defaults=True,
|
||||||
|
responses={409: {"model": schemas.ErrorMessage, "description": "Invalid version"}},
|
||||||
|
)
|
||||||
|
def check_version(version: schemas.Version) -> dict:
|
||||||
|
"""
|
||||||
|
Check if version is the same as the server.
|
||||||
|
"""
|
||||||
|
|
||||||
|
print(version.version)
|
||||||
|
if version.version != __version__:
|
||||||
|
raise ControllerError(f"Client version {version.version} is not the same as server version {__version__}")
|
||||||
|
return {"version": __version__}
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/shutdown",
|
"/shutdown",
|
||||||
|
dependencies=[Depends(get_current_active_user)],
|
||||||
status_code=status.HTTP_204_NO_CONTENT,
|
status_code=status.HTTP_204_NO_CONTENT,
|
||||||
responses={403: {"model": schemas.ErrorMessage, "description": "Server shutdown not allowed"}},
|
responses={403: {"model": schemas.ErrorMessage, "description": "Server shutdown not allowed"}},
|
||||||
)
|
)
|
||||||
@ -71,38 +103,11 @@ async def shutdown() -> None:
|
|||||||
os.kill(os.getpid(), signal.SIGTERM)
|
os.kill(os.getpid(), signal.SIGTERM)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/version", response_model=schemas.Version)
|
@router.get(
|
||||||
def get_version() -> dict:
|
"/iou_license",
|
||||||
"""
|
dependencies=[Depends(get_current_active_user)],
|
||||||
Return the server version number.
|
response_model=schemas.IOULicense
|
||||||
"""
|
|
||||||
|
|
||||||
local_server = Config.instance().settings.Server.local
|
|
||||||
return {"version": __version__, "local": local_server}
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/version",
|
|
||||||
response_model=schemas.Version,
|
|
||||||
response_model_exclude_defaults=True,
|
|
||||||
responses={409: {"model": schemas.ErrorMessage, "description": "Invalid version"}},
|
|
||||||
)
|
)
|
||||||
def check_version(version: schemas.Version) -> dict:
|
|
||||||
"""
|
|
||||||
Check if version is the same as the server.
|
|
||||||
|
|
||||||
:param request:
|
|
||||||
:param response:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
|
|
||||||
print(version.version)
|
|
||||||
if version.version != __version__:
|
|
||||||
raise ControllerError(f"Client version {version.version} is not the same as server version {__version__}")
|
|
||||||
return {"version": __version__}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/iou_license", response_model=schemas.IOULicense)
|
|
||||||
def get_iou_license() -> schemas.IOULicense:
|
def get_iou_license() -> schemas.IOULicense:
|
||||||
"""
|
"""
|
||||||
Return the IOU license settings
|
Return the IOU license settings
|
||||||
@ -111,7 +116,12 @@ def get_iou_license() -> schemas.IOULicense:
|
|||||||
return Controller.instance().iou_license
|
return Controller.instance().iou_license
|
||||||
|
|
||||||
|
|
||||||
@router.put("/iou_license", status_code=status.HTTP_201_CREATED, response_model=schemas.IOULicense)
|
@router.put(
|
||||||
|
"/iou_license",
|
||||||
|
dependencies=[Depends(get_current_active_user)],
|
||||||
|
status_code=status.HTTP_201_CREATED,
|
||||||
|
response_model=schemas.IOULicense
|
||||||
|
)
|
||||||
async def update_iou_license(iou_license: schemas.IOULicense) -> schemas.IOULicense:
|
async def update_iou_license(iou_license: schemas.IOULicense) -> schemas.IOULicense:
|
||||||
"""
|
"""
|
||||||
Update the IOU license settings.
|
Update the IOU license settings.
|
||||||
@ -124,7 +134,7 @@ async def update_iou_license(iou_license: schemas.IOULicense) -> schemas.IOULice
|
|||||||
return current_iou_license
|
return current_iou_license
|
||||||
|
|
||||||
|
|
||||||
@router.get("/statistics")
|
@router.get("/statistics", dependencies=[Depends(get_current_active_user)])
|
||||||
async def statistics() -> List[dict]:
|
async def statistics() -> List[dict]:
|
||||||
"""
|
"""
|
||||||
Return server statistics.
|
Return server statistics.
|
||||||
|
@ -24,7 +24,7 @@ from gns3server.services import auth_service
|
|||||||
|
|
||||||
from .database import get_repository
|
from .database import get_repository
|
||||||
|
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/v3/users/login") # FIXME: URL prefix
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/v3/users/login")
|
||||||
|
|
||||||
|
|
||||||
async def get_user_from_token(
|
async def get_user_from_token(
|
||||||
|
@ -44,10 +44,53 @@ log = logging.getLogger(__name__)
|
|||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
@router.get("", response_model=List[schemas.User])
|
@router.post("/login", response_model=schemas.Token)
|
||||||
|
async def login(
|
||||||
|
users_repo: UsersRepository = Depends(get_repository(UsersRepository)),
|
||||||
|
form_data: OAuth2PasswordRequestForm = Depends(),
|
||||||
|
) -> schemas.Token:
|
||||||
|
"""
|
||||||
|
Default user login method using forms (x-www-form-urlencoded).
|
||||||
|
Example: curl http://host:port/v3/users/login -H "Content-Type: application/x-www-form-urlencoded" -d "username=admin&password=admin"
|
||||||
|
"""
|
||||||
|
|
||||||
|
user = await users_repo.authenticate_user(username=form_data.username, password=form_data.password)
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Authentication was unsuccessful.",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
|
||||||
|
token = schemas.Token(access_token=auth_service.create_access_token(user.username), token_type="bearer")
|
||||||
|
return token
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/authenticate", response_model=schemas.Token)
|
||||||
|
async def authenticate(
|
||||||
|
user_credentials: schemas.Credentials,
|
||||||
|
users_repo: UsersRepository = Depends(get_repository(UsersRepository)),
|
||||||
|
) -> schemas.Token:
|
||||||
|
"""
|
||||||
|
Alternative authentication method using json.
|
||||||
|
Example: curl http://host:port/v3/users/authenticate -d '{"username": "admin", "password": "admin"}'
|
||||||
|
"""
|
||||||
|
|
||||||
|
user = await users_repo.authenticate_user(username=user_credentials.username, password=user_credentials.password)
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Authentication was unsuccessful.",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
|
||||||
|
token = schemas.Token(access_token=auth_service.create_access_token(user.username), token_type="bearer")
|
||||||
|
return token
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("", response_model=List[schemas.User], dependencies=[Depends(get_current_active_user)])
|
||||||
async def get_users(
|
async def get_users(
|
||||||
users_repo: UsersRepository = Depends(get_repository(UsersRepository)),
|
users_repo: UsersRepository = Depends(get_repository(UsersRepository))
|
||||||
current_user: schemas.User = Depends(get_current_active_user)
|
|
||||||
) -> List[schemas.User]:
|
) -> List[schemas.User]:
|
||||||
"""
|
"""
|
||||||
Get all users.
|
Get all users.
|
||||||
@ -56,11 +99,15 @@ async def get_users(
|
|||||||
return await users_repo.get_users()
|
return await users_repo.get_users()
|
||||||
|
|
||||||
|
|
||||||
@router.post("", response_model=schemas.User, status_code=status.HTTP_201_CREATED)
|
@router.post(
|
||||||
|
"",
|
||||||
|
response_model=schemas.User,
|
||||||
|
dependencies=[Depends(get_current_active_user)],
|
||||||
|
status_code=status.HTTP_201_CREATED
|
||||||
|
)
|
||||||
async def create_user(
|
async def create_user(
|
||||||
user_create: schemas.UserCreate,
|
user_create: schemas.UserCreate,
|
||||||
users_repo: UsersRepository = Depends(get_repository(UsersRepository)),
|
users_repo: UsersRepository = Depends(get_repository(UsersRepository))
|
||||||
current_user: schemas.User = Depends(get_current_active_user)
|
|
||||||
) -> schemas.User:
|
) -> schemas.User:
|
||||||
"""
|
"""
|
||||||
Create a new user.
|
Create a new user.
|
||||||
@ -75,11 +122,10 @@ async def create_user(
|
|||||||
return await users_repo.create_user(user_create)
|
return await users_repo.create_user(user_create)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{user_id}",response_model=schemas.User)
|
@router.get("/{user_id}", dependencies=[Depends(get_current_active_user)], response_model=schemas.User)
|
||||||
async def get_user(
|
async def get_user(
|
||||||
user_id: UUID,
|
user_id: UUID,
|
||||||
users_repo: UsersRepository = Depends(get_repository(UsersRepository)),
|
users_repo: UsersRepository = Depends(get_repository(UsersRepository)),
|
||||||
current_user: schemas.User = Depends(get_current_active_user)
|
|
||||||
) -> schemas.User:
|
) -> schemas.User:
|
||||||
"""
|
"""
|
||||||
Get an user.
|
Get an user.
|
||||||
@ -91,12 +137,11 @@ async def get_user(
|
|||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
@router.put("/{user_id}", response_model=schemas.User)
|
@router.put("/{user_id}", dependencies=[Depends(get_current_active_user)], response_model=schemas.User)
|
||||||
async def update_user(
|
async def update_user(
|
||||||
user_id: UUID,
|
user_id: UUID,
|
||||||
user_update: schemas.UserUpdate,
|
user_update: schemas.UserUpdate,
|
||||||
users_repo: UsersRepository = Depends(get_repository(UsersRepository)),
|
users_repo: UsersRepository = Depends(get_repository(UsersRepository))
|
||||||
current_user: schemas.User = Depends(get_current_active_user)
|
|
||||||
) -> schemas.User:
|
) -> schemas.User:
|
||||||
"""
|
"""
|
||||||
Update an user.
|
Update an user.
|
||||||
@ -126,27 +171,6 @@ async def delete_user(
|
|||||||
raise ControllerNotFoundError(f"User '{user_id}' not found")
|
raise ControllerNotFoundError(f"User '{user_id}' not found")
|
||||||
|
|
||||||
|
|
||||||
@router.post("/login", response_model=schemas.Token)
|
|
||||||
async def login(
|
|
||||||
users_repo: UsersRepository = Depends(get_repository(UsersRepository)),
|
|
||||||
form_data: OAuth2PasswordRequestForm = Depends(),
|
|
||||||
) -> schemas.Token:
|
|
||||||
"""
|
|
||||||
User login.
|
|
||||||
"""
|
|
||||||
|
|
||||||
user = await users_repo.authenticate_user(username=form_data.username, password=form_data.password)
|
|
||||||
if not user:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail="Authentication was unsuccessful.",
|
|
||||||
headers={"WWW-Authenticate": "Bearer"},
|
|
||||||
)
|
|
||||||
|
|
||||||
token = schemas.Token(access_token=auth_service.create_access_token(user.username), token_type="bearer")
|
|
||||||
return token
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/users/me/", response_model=schemas.User)
|
@router.get("/users/me/", response_model=schemas.User)
|
||||||
async def get_current_active_user(current_user: schemas.User = Depends(get_current_active_user)) -> schemas.User:
|
async def get_current_active_user(current_user: schemas.User = Depends(get_current_active_user)) -> schemas.User:
|
||||||
"""
|
"""
|
||||||
|
@ -27,7 +27,7 @@ from .controller.drawings import Drawing
|
|||||||
from .controller.gns3vm import GNS3VM
|
from .controller.gns3vm import GNS3VM
|
||||||
from .controller.nodes import NodeCreate, NodeUpdate, NodeDuplicate, NodeCapture, Node
|
from .controller.nodes import NodeCreate, NodeUpdate, NodeDuplicate, NodeCapture, Node
|
||||||
from .controller.projects import ProjectCreate, ProjectUpdate, ProjectDuplicate, Project, ProjectFile
|
from .controller.projects import ProjectCreate, ProjectUpdate, ProjectDuplicate, Project, ProjectFile
|
||||||
from .controller.users import UserCreate, UserUpdate, User
|
from .controller.users import UserCreate, UserUpdate, User, Credentials
|
||||||
from .controller.tokens import Token
|
from .controller.tokens import Token
|
||||||
from .controller.snapshots import SnapshotCreate, Snapshot
|
from .controller.snapshots import SnapshotCreate, Snapshot
|
||||||
from .controller.iou_license import IOULicense
|
from .controller.iou_license import IOULicense
|
||||||
|
@ -56,3 +56,9 @@ class User(DateTimeModelMixin, UserBase):
|
|||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
orm_mode = True
|
orm_mode = True
|
||||||
|
|
||||||
|
|
||||||
|
class Credentials(BaseModel):
|
||||||
|
|
||||||
|
username: str
|
||||||
|
password: str
|
||||||
|
@ -214,6 +214,22 @@ class TestUserLogin:
|
|||||||
assert "token_type" in res.json()
|
assert "token_type" in res.json()
|
||||||
assert res.json().get("token_type") == "bearer"
|
assert res.json().get("token_type") == "bearer"
|
||||||
|
|
||||||
|
async def test_user_can_authenticate_using_json(
|
||||||
|
self,
|
||||||
|
app: FastAPI,
|
||||||
|
unauthorized_client: AsyncClient,
|
||||||
|
test_user: User,
|
||||||
|
config: Config
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
credentials = {
|
||||||
|
"username": test_user.username,
|
||||||
|
"password": "user1_password",
|
||||||
|
}
|
||||||
|
res = await unauthorized_client.post(app.url_path_for("authenticate"), json=credentials)
|
||||||
|
assert res.status_code == status.HTTP_200_OK
|
||||||
|
assert res.json().get("access_token")
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"username, password, status_code",
|
"username, password, status_code",
|
||||||
(
|
(
|
||||||
|
Loading…
Reference in New Issue
Block a user