1
0
mirror of https://github.com/GNS3/gns3-server synced 2024-11-13 20:08:55 +00:00

Add HTTP client to reuse the aiohttp session where needed.

Remove unnecessary aiohttp exceptions.
This commit is contained in:
grossmj 2020-10-22 16:19:44 +10:30
parent 36c8920cd1
commit a92c47b310
17 changed files with 221 additions and 151 deletions

View File

@ -41,6 +41,7 @@ from gns3server.controller.controller_error import (
from gns3server.endpoints import controller
from gns3server.endpoints import index
from gns3server.endpoints.compute import compute_api
from gns3server.utils.http_client import HTTPClient
from gns3server.version import __version__
import logging
@ -76,6 +77,7 @@ app.mount("/v2/compute", compute_api)
@app.exception_handler(ControllerError)
async def controller_error_handler(request: Request, exc: ControllerError):
log.error(f"Controller error: {exc}")
return JSONResponse(
status_code=409,
content={"message": str(exc)},
@ -84,6 +86,7 @@ async def controller_error_handler(request: Request, exc: ControllerError):
@app.exception_handler(ControllerTimeoutError)
async def controller_timeout_error_handler(request: Request, exc: ControllerTimeoutError):
log.error(f"Controller timeout error: {exc}")
return JSONResponse(
status_code=408,
content={"message": str(exc)},
@ -92,6 +95,7 @@ async def controller_timeout_error_handler(request: Request, exc: ControllerTime
@app.exception_handler(ControllerUnauthorizedError)
async def controller_unauthorized_error_handler(request: Request, exc: ControllerUnauthorizedError):
log.error(f"Controller unauthorized error: {exc}")
return JSONResponse(
status_code=401,
content={"message": str(exc)},
@ -100,6 +104,7 @@ async def controller_unauthorized_error_handler(request: Request, exc: Controlle
@app.exception_handler(ControllerForbiddenError)
async def controller_forbidden_error_handler(request: Request, exc: ControllerForbiddenError):
log.error(f"Controller forbidden error: {exc}")
return JSONResponse(
status_code=403,
content={"message": str(exc)},
@ -108,6 +113,7 @@ async def controller_forbidden_error_handler(request: Request, exc: ControllerFo
@app.exception_handler(ControllerNotFoundError)
async def controller_not_found_error_handler(request: Request, exc: ControllerNotFoundError):
log.error(f"Controller not found error: {exc}")
return JSONResponse(
status_code=404,
content={"message": str(exc)},
@ -164,13 +170,7 @@ async def startup_event():
@app.on_event("shutdown")
async def shutdown_event():
# close websocket connections
# websocket_connections = set(self._app['websockets'])
# if websocket_connections:
# log.info("Closing {} websocket connections...".format(len(websocket_connections)))
# for ws in websocket_connections:
# await ws.close(code=aiohttp.WSCloseCode.GOING_AWAY, message='Server shutdown')
await HTTPClient.close_session()
await Controller.instance().stop()
for module in MODULES:

View File

@ -60,9 +60,8 @@ class Docker(BaseManager):
if not self._connected:
try:
self._connected = True
connector = self.connector()
version = await self.query("GET", "version")
except (aiohttp.ClientOSError, FileNotFoundError):
except (aiohttp.ClientError, FileNotFoundError):
self._connected = False
raise DockerError("Can't connect to docker daemon")
@ -70,8 +69,8 @@ class Docker(BaseManager):
if docker_version < parse_version(DOCKER_MINIMUM_API_VERSION):
raise DockerError(
"Docker version is {}. GNS3 requires a minimum version of {}".format(
version["Version"], DOCKER_MINIMUM_VERSION))
"Docker version is {}. GNS3 requires a minimum version of {}".format(version["Version"],
DOCKER_MINIMUM_VERSION))
preferred_api_version = parse_version(DOCKER_PREFERRED_API_VERSION)
if docker_version >= preferred_api_version:
@ -84,7 +83,7 @@ class Docker(BaseManager):
raise DockerError("Docker is supported only on Linux")
try:
self._connector = aiohttp.connector.UnixConnector(self._server_url, limit=None)
except (aiohttp.ClientOSError, FileNotFoundError):
except (aiohttp.ClientError, FileNotFoundError):
raise DockerError("Can't connect to docker daemon")
return self._connector
@ -150,7 +149,7 @@ class Docker(BaseManager):
data=data,
headers={"content-type": "application/json", },
timeout=timeout)
except (aiohttp.ClientResponseError, aiohttp.ClientOSError) as e:
except aiohttp.ClientError as e:
raise DockerError("Docker has returned an error: {}".format(str(e)))
except (asyncio.TimeoutError):
raise DockerError("Docker timeout " + method + " " + path)

View File

@ -19,7 +19,6 @@
Dynamips server module.
"""
import aiohttp
import sys
import os
import shutil

View File

@ -16,7 +16,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import socket
from aiohttp.web import HTTPConflict
from fastapi import HTTPException, status
from gns3server.config import Config
import logging
@ -48,12 +48,12 @@ class PortManager:
console_start_port_range = server_config.getint("console_start_port_range", 5000)
console_end_port_range = server_config.getint("console_end_port_range", 10000)
self._console_port_range = (console_start_port_range, console_end_port_range)
log.debug("Console port range is {}-{}".format(console_start_port_range, console_end_port_range))
log.debug(f"Console port range is {console_start_port_range}-{console_end_port_range}")
udp_start_port_range = server_config.getint("udp_start_port_range", 20000)
udp_end_port_range = server_config.getint("udp_end_port_range", 30000)
self._udp_port_range = (udp_start_port_range, udp_end_port_range)
log.debug("UDP port range is {}-{}".format(udp_start_port_range, udp_end_port_range))
log.debug(f"UDP port range is {udp_start_port_range}-{udp_end_port_range}")
@classmethod
def instance(cls):
@ -149,7 +149,8 @@ class PortManager:
"""
if end_port < start_port:
raise HTTPConflict(text="Invalid port range {}-{}".format(start_port, end_port))
raise HTTPException(status_code=status.HTTP_409_CONFLICT,
detail=f"Invalid port range {start_port}-{end_port}")
last_exception = None
for port in range(start_port, end_port + 1):
@ -168,10 +169,9 @@ class PortManager:
else:
continue
raise HTTPConflict(text="Could not find a free port between {} and {} on host {}, last exception: {}".format(start_port,
end_port,
host,
last_exception))
raise HTTPException(status_code=status.HTTP_409_CONFLICT,
detail=f"Could not find a free port between {start_port} and {end_port} on host {host},"
f" last exception: {last_exception}")
@staticmethod
def _check_port(host, port, socket_type):
@ -212,7 +212,7 @@ class PortManager:
self._used_tcp_ports.add(port)
project.record_tcp_port(port)
log.debug("TCP port {} has been allocated".format(port))
log.debug(f"TCP port {port} has been allocated")
return port
def reserve_tcp_port(self, port, project, port_range_start=None, port_range_end=None):
@ -235,13 +235,14 @@ class PortManager:
if port in self._used_tcp_ports:
old_port = port
port = self.get_free_tcp_port(project, port_range_start=port_range_start, port_range_end=port_range_end)
msg = "TCP port {} already in use on host {}. Port has been replaced by {}".format(old_port, self._console_host, port)
msg = f"TCP port {old_port} already in use on host {self._console_host}. Port has been replaced by {port}"
log.debug(msg)
return port
if port < port_range_start or port > port_range_end:
old_port = port
port = self.get_free_tcp_port(project, port_range_start=port_range_start, port_range_end=port_range_end)
msg = "TCP port {} is outside the range {}-{} on host {}. Port has been replaced by {}".format(old_port, port_range_start, port_range_end, self._console_host, port)
msg = f"TCP port {old_port} is outside the range {port_range_start}-{port_range_end} on host " \
f"{self._console_host}. Port has been replaced by {port}"
log.debug(msg)
return port
try:
@ -249,13 +250,13 @@ class PortManager:
except OSError:
old_port = port
port = self.get_free_tcp_port(project, port_range_start=port_range_start, port_range_end=port_range_end)
msg = "TCP port {} already in use on host {}. Port has been replaced by {}".format(old_port, self._console_host, port)
msg = f"TCP port {old_port} already in use on host {self._console_host}. Port has been replaced by {port}"
log.debug(msg)
return port
self._used_tcp_ports.add(port)
project.record_tcp_port(port)
log.debug("TCP port {} has been reserved".format(port))
log.debug(f"TCP port {port} has been reserved")
return port
def release_tcp_port(self, port, project):
@ -269,7 +270,7 @@ class PortManager:
if port in self._used_tcp_ports:
self._used_tcp_ports.remove(port)
project.remove_tcp_port(port)
log.debug("TCP port {} has been released".format(port))
log.debug(f"TCP port {port} has been released")
def get_free_udp_port(self, project):
"""
@ -285,7 +286,7 @@ class PortManager:
self._used_udp_ports.add(port)
project.record_udp_port(port)
log.debug("UDP port {} has been allocated".format(port))
log.debug(f"UDP port {port} has been allocated")
return port
def reserve_udp_port(self, port, project):
@ -297,9 +298,12 @@ class PortManager:
"""
if port in self._used_udp_ports:
raise HTTPConflict(text="UDP port {} already in use on host {}".format(port, self._console_host))
raise HTTPException(status_code=status.HTTP_409_CONFLICT,
detail=f"UDP port {port} already in use on host {self._console_host}")
if port < self._udp_port_range[0] or port > self._udp_port_range[1]:
raise HTTPConflict(text="UDP port {} is outside the range {}-{}".format(port, self._udp_port_range[0], self._udp_port_range[1]))
raise HTTPException(status_code=status.HTTP_409_CONFLICT,
detail=f"UDP port {port} is outside the range "
f"{self._udp_port_range[0]}-{self._udp_port_range[1]}")
self._used_udp_ports.add(port)
project.record_udp_port(port)
log.debug("UDP port {} has been reserved".format(port))
@ -315,4 +319,4 @@ class PortManager:
if port in self._used_udp_ports:
self._used_udp_ports.remove(port)
project.remove_udp_port(port)
log.debug("UDP port {} has been released".format(port))
log.debug(f"UDP port {port} has been released")

View File

@ -19,12 +19,12 @@ import os
import json
import uuid
import asyncio
import aiohttp
from .appliance import Appliance
from ..config import Config
from ..utils.asyncio import locking
from ..utils.get_resource import get_resource
from ..utils.http_client import HTTPClient
from .controller_error import ControllerError
import logging
@ -142,20 +142,19 @@ class ApplianceManager:
"""
symbol_url = "https://raw.githubusercontent.com/GNS3/gns3-registry/master/symbols/{}".format(symbol)
async with aiohttp.ClientSession() as session:
async with session.get(symbol_url) as response:
if response.status != 200:
log.warning("Could not retrieve appliance symbol {} from GitHub due to HTTP error code {}".format(symbol, response.status))
else:
try:
symbol_data = await response.read()
log.info("Saving {} symbol to {}".format(symbol, destination_path))
with open(destination_path, 'wb') as f:
f.write(symbol_data)
except asyncio.TimeoutError:
log.warning("Timeout while downloading '{}'".format(symbol_url))
except OSError as e:
log.warning("Could not write appliance symbol '{}': {}".format(destination_path, e))
async with HTTPClient.get(symbol_url) as response:
if response.status != 200:
log.warning("Could not retrieve appliance symbol {} from GitHub due to HTTP error code {}".format(symbol, response.status))
else:
try:
symbol_data = await response.read()
log.info("Saving {} symbol to {}".format(symbol, destination_path))
with open(destination_path, 'wb') as f:
f.write(symbol_data)
except asyncio.TimeoutError:
log.warning("Timeout while downloading '{}'".format(symbol_url))
except OSError as e:
log.warning("Could not write appliance symbol '{}': {}".format(destination_path, e))
@locking
async def download_appliances(self):
@ -168,40 +167,40 @@ class ApplianceManager:
if self._appliances_etag:
log.info("Checking if appliances are up-to-date (ETag {})".format(self._appliances_etag))
headers["If-None-Match"] = self._appliances_etag
async with aiohttp.ClientSession() as session:
async with session.get('https://api.github.com/repos/GNS3/gns3-registry/contents/appliances', headers=headers) as response:
if response.status == 304:
log.info("Appliances are already up-to-date (ETag {})".format(self._appliances_etag))
return
elif response.status != 200:
raise ControllerError("Could not retrieve appliances from GitHub due to HTTP error code {}".format(response.status))
etag = response.headers.get("ETag")
if etag:
self._appliances_etag = etag
from . import Controller
Controller.instance().save()
json_data = await response.json()
appliances_dir = get_resource('appliances')
for appliance in json_data:
if appliance["type"] == "file":
appliance_name = appliance["name"]
log.info("Download appliance file from '{}'".format(appliance["download_url"]))
async with session.get(appliance["download_url"]) as response:
if response.status != 200:
log.warning("Could not download '{}' due to HTTP error code {}".format(appliance["download_url"], response.status))
continue
try:
appliance_data = await response.read()
except asyncio.TimeoutError:
log.warning("Timeout while downloading '{}'".format(appliance["download_url"]))
continue
path = os.path.join(appliances_dir, appliance_name)
try:
log.info("Saving {} file to {}".format(appliance_name, path))
with open(path, 'wb') as f:
f.write(appliance_data)
except OSError as e:
raise ControllerError("Could not write appliance file '{}': {}".format(path, e))
async with HTTPClient.get('https://api.github.com/repos/GNS3/gns3-registry/contents/appliances', headers=headers) as response:
if response.status == 304:
log.info("Appliances are already up-to-date (ETag {})".format(self._appliances_etag))
return
elif response.status != 200:
raise ControllerError("Could not retrieve appliances from GitHub due to HTTP error code {}".format(response.status))
etag = response.headers.get("ETag")
if etag:
self._appliances_etag = etag
from . import Controller
Controller.instance().save()
json_data = await response.json()
appliances_dir = get_resource('appliances')
for appliance in json_data:
if appliance["type"] == "file":
appliance_name = appliance["name"]
log.info("Download appliance file from '{}'".format(appliance["download_url"]))
async with HTTPClient.get(appliance["download_url"]) as response:
if response.status != 200:
log.warning("Could not download '{}' due to HTTP error code {}".format(appliance["download_url"], response.status))
continue
try:
appliance_data = await response.read()
except asyncio.TimeoutError:
log.warning("Timeout while downloading '{}'".format(appliance["download_url"]))
continue
path = os.path.join(appliances_dir, appliance_name)
try:
log.info("Saving {} file to {}".format(appliance_name, path))
with open(path, 'wb') as f:
f.write(appliance_data)
except OSError as e:
raise ControllerError("Could not write appliance file '{}': {}".format(path, e))
except ValueError as e:
raise ControllerError("Could not read appliances information from GitHub: {}".format(e))

View File

@ -63,7 +63,9 @@ class Compute:
A GNS3 compute.
"""
def __init__(self, compute_id, controller=None, protocol="http", host="localhost", port=3080, user=None, password=None, name=None, console_host=None):
def __init__(self, compute_id, controller=None, protocol="http", host="localhost", port=3080, user=None,
password=None, name=None, console_host=None):
self._http_session = None
assert controller is not None
log.info("Create compute %s", compute_id)
@ -103,14 +105,10 @@ class Compute:
def _session(self):
if self._http_session is None or self._http_session.closed is True:
self._http_session = aiohttp.ClientSession(connector=aiohttp.TCPConnector(limit=None, force_close=True))
connector = aiohttp.TCPConnector(force_close=True)
self._http_session = aiohttp.ClientSession(connector=connector)
return self._http_session
#def __del__(self):
#
# if self._http_session:
# self._http_session.close()
def _set_auth(self, user, password):
"""
Set authentication parameters
@ -466,7 +464,7 @@ class Compute:
elif response.type == aiohttp.WSMsgType.CLOSED:
pass
break
except aiohttp.client_exceptions.ClientResponseError as e:
except aiohttp.ClientError as e:
log.error("Client response error received on compute '{}' WebSocket '{}': {}".format(self._id, ws_url,e))
finally:
self._connected = False
@ -503,8 +501,7 @@ class Compute:
async def _run_http_query(self, method, path, data=None, timeout=20, raw=False):
with async_timeout.timeout(timeout):
url = self._getUrl(path)
headers = {}
headers['content-type'] = 'application/json'
headers = {'content-type': 'application/json'}
chunked = None
if data == {}:
data = None
@ -579,7 +576,7 @@ class Compute:
return response
async def get(self, path, **kwargs):
return (await self.http_query("GET", path, **kwargs))
return await self.http_query("GET", path, **kwargs)
async def post(self, path, data={}, **kwargs):
response = await self.http_query("POST", path, data, **kwargs)
@ -600,15 +597,13 @@ class Compute:
action = "/{}/{}".format(type, path)
res = await self.http_query(method, action, data=data, timeout=None)
except aiohttp.ServerDisconnectedError:
log.error("Connection lost to %s during %s %s", self._id, method, action)
raise aiohttp.web.HTTPGatewayTimeout()
raise ControllerError(f"Connection lost to {self._id} during {method} {action}")
return res.json
async def images(self, type):
"""
Return the list of images available for this type on the compute node.
"""
images = []
res = await self.http_query("GET", "/{}/images".format(type), timeout=None)
images = res.json
@ -641,11 +636,11 @@ class Compute:
:returns: Tuple (ip_for_this_compute, ip_for_other_compute)
"""
if other_compute == self:
return (self.host_ip, self.host_ip)
return self.host_ip, self.host_ip
# Perhaps the user has correct network gateway, we trust him
if (self.host_ip not in ('0.0.0.0', '127.0.0.1') and other_compute.host_ip not in ('0.0.0.0', '127.0.0.1')):
return (self.host_ip, other_compute.host_ip)
if self.host_ip not in ('0.0.0.0', '127.0.0.1') and other_compute.host_ip not in ('0.0.0.0', '127.0.0.1'):
return self.host_ip, other_compute.host_ip
this_compute_interfaces = await self.interfaces()
other_compute_interfaces = await other_compute.interfaces()
@ -675,6 +670,6 @@ class Compute:
other_network = ipaddress.ip_network("{}/{}".format(other_interface["ip_address"], other_interface["netmask"]), strict=False)
if this_network.overlaps(other_network):
return (this_interface["ip_address"], other_interface["ip_address"])
return this_interface["ip_address"], other_interface["ip_address"]
raise ValueError("No common subnet for compute {} and {}".format(self.name, other_compute.name))

View File

@ -24,6 +24,7 @@ import socket
from .base_gns3_vm import BaseGNS3VM
from .gns3_vm_error import GNS3VMError
from gns3server.utils import parse_version
from gns3server.utils.http_client import HTTPClient
from gns3server.utils.asyncio import wait_run_in_executor
from ...compute.virtualbox import (
@ -305,24 +306,24 @@ class VirtualBoxGNS3VM(BaseGNS3VM):
second to a GNS3 endpoint in order to get the list of the interfaces and
their IP and after that match it with VirtualBox host only.
"""
remaining_try = 300
while remaining_try > 0:
async with aiohttp.ClientSession() as session:
try:
async with session.get('http://127.0.0.1:{}/v2/compute/network/interfaces'.format(api_port)) as resp:
if resp.status < 300:
try:
json_data = await resp.json()
if json_data:
for interface in json_data:
if "name" in interface and interface["name"] == "eth{}".format(
hostonly_interface_number - 1):
if "ip_address" in interface and len(interface["ip_address"]) > 0:
return interface["ip_address"]
except ValueError:
pass
except (OSError, aiohttp.ClientError, TimeoutError, asyncio.TimeoutError):
pass
try:
async with HTTPClient.get(f"http://127.0.0.1:{api_port}/v2/compute/network/interfaces") as resp:
if resp.status < 300:
try:
json_data = await resp.json()
if json_data:
for interface in json_data:
if "name" in interface and interface["name"] == "eth{}".format(
hostonly_interface_number - 1):
if "ip_address" in interface and len(interface["ip_address"]) > 0:
return interface["ip_address"]
except ValueError:
pass
except (OSError, aiohttp.ClientError, TimeoutError, asyncio.TimeoutError):
pass
remaining_try -= 1
await asyncio.sleep(1)
raise GNS3VMError("Could not find guest IP address for {}".format(self.vmname))

View File

@ -19,8 +19,8 @@
API endpoints for links.
"""
import aiohttp
import multidict
import aiohttp
from fastapi import APIRouter, Depends, Request, status
from fastapi.responses import StreamingResponse
@ -31,9 +31,13 @@ from uuid import UUID
from gns3server.controller import Controller
from gns3server.controller.controller_error import ControllerError
from gns3server.controller.link import Link
from gns3server.utils.http_client import HTTPClient
from gns3server.endpoints.schemas.common import ErrorMessage
from gns3server.endpoints import schemas
import logging
log = logging.getLogger(__name__)
router = APIRouter()
responses = {
@ -201,12 +205,13 @@ async def pcap(request: Request, link: Link = Depends(dep_link)):
async def compute_pcap_stream():
connector = aiohttp.TCPConnector(limit=None, force_close=True)
async with aiohttp.ClientSession(connector=connector, headers=headers) as session:
async with session.request(request.method, pcap_streaming_url, timeout=None, data=body) as compute_response:
async for data in compute_response.content.iter_any():
try:
async with HTTPClient.request(request.method, pcap_streaming_url, timeout=None, data=body) as response:
async for data in response.content.iter_any():
if not data:
break
yield data
except aiohttp.ClientError as e:
raise ControllerError(f"Client error received when receiving pcap stream from compute: {e}")
return StreamingResponse(compute_pcap_stream(), media_type="application/vnd.tcpdump.pcap")

View File

@ -32,6 +32,7 @@ from gns3server.controller import Controller
from gns3server.controller.node import Node
from gns3server.controller.project import Project
from gns3server.utils import force_unix_path
from gns3server.utils.http_client import HTTPClient
from gns3server.controller.controller_error import ControllerForbiddenError
from gns3server.endpoints.schemas.common import ErrorMessage
from gns3server.endpoints import schemas
@ -400,18 +401,17 @@ async def ws_console(websocket: WebSocket, node: Node = Depends(dep_node)):
try:
# receive WebSocket data from compute console WebSocket and forward to client.
async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(limit=None, force_close=True)) as session:
async with session.ws_connect(ws_console_compute_url) as ws_console_compute:
asyncio.ensure_future(ws_receive(ws_console_compute))
async for msg in ws_console_compute:
if msg.type == aiohttp.WSMsgType.TEXT:
await websocket.send_text(msg.data)
elif msg.type == aiohttp.WSMsgType.BINARY:
await websocket.send_bytes(msg.data)
elif msg.type == aiohttp.WSMsgType.ERROR:
break
except aiohttp.client_exceptions.ClientResponseError as e:
log.error(f"Client response error received when forwarding to compute console WebSocket: {e}")
async with HTTPClient.get_client().ws_connect(ws_console_compute_url) as ws_console_compute:
asyncio.ensure_future(ws_receive(ws_console_compute))
async for msg in ws_console_compute:
if msg.type == aiohttp.WSMsgType.TEXT:
await websocket.send_text(msg.data)
elif msg.type == aiohttp.WSMsgType.BINARY:
await websocket.send_bytes(msg.data)
elif msg.type == aiohttp.WSMsgType.ERROR:
break
except aiohttp.ClientError as e:
log.error(f"Client error received when forwarding to compute console WebSocket: {e}")
@router.post("/console/reset",

View File

@ -29,7 +29,7 @@ import time
import logging
log = logging.getLogger()
from fastapi import APIRouter, Depends, Request, Body, Query, HTTPException, status, WebSocket, WebSocketDisconnect
from fastapi import APIRouter, Depends, Request, Body, HTTPException, status, WebSocket, WebSocketDisconnect
from fastapi.encoders import jsonable_encoder
from fastapi.responses import StreamingResponse, FileResponse
from websockets.exceptions import ConnectionClosed, WebSocketException

View File

@ -0,0 +1,69 @@
#!/usr/bin/env python
#
# Copyright (C) 2020 GNS3 Technologies Inc.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import aiohttp
import socket
import logging
log = logging.getLogger(__name__)
class HTTPClient:
"""
HTTP client for request to computes and external services.
"""
_aiohttp_client: aiohttp.ClientSession = None
@classmethod
def get_client(cls) -> aiohttp.ClientSession:
if cls._aiohttp_client is None:
cls._aiohttp_client = aiohttp.ClientSession(connector=aiohttp.TCPConnector(family=socket.AF_INET))
return cls._aiohttp_client
@classmethod
async def close_session(cls):
if cls._aiohttp_client:
await cls._aiohttp_client.close()
cls._aiohttp_client = None
@classmethod
def request(cls, method: str, url: str, user: str = None, password: str = None, **kwargs):
client = cls.get_client()
basic_auth = None
if user:
if not password:
password = ""
try:
basic_auth = aiohttp.BasicAuth(user, password, "utf-8")
except ValueError as e:
log.error(f"Basic authentication set-up error: {e}")
return client.request(method, url, auth=basic_auth, **kwargs)
@classmethod
def get(cls, path, **kwargs):
return cls.request("GET", path, **kwargs)
@classmethod
def post(cls, path, **kwargs):
return cls.request("POST", path, **kwargs)
@classmethod
def put(cls, path, **kwargs):
return cls.request("PUT", path, **kwargs)

View File

@ -18,12 +18,12 @@
import os
import sys
import aiohttp
import socket
import struct
import psutil
from .windows_service import check_windows_service_is_running
from gns3server.compute.compute_error import ComputeError
from gns3server.config import Config
if psutil.version_info < (3, 0, 0):
@ -162,7 +162,7 @@ def is_interface_up(interface):
return True
return False
except OSError as e:
raise aiohttp.web.HTTPInternalServerError(text="Exception when checking if {} is up: {}".format(interface, e))
raise ComputeError(f"Exception when checking if {interface} is up: {e}")
else:
# TODO: Windows & OSX support
return True
@ -221,13 +221,13 @@ def interfaces():
results = get_windows_interfaces()
except ImportError:
message = "pywin32 module is not installed, please install it on the server to get the available interface names"
raise aiohttp.web.HTTPInternalServerError(text=message)
raise ComputeError(message)
except Exception as e:
log.error("uncaught exception {type}".format(type=type(e)), exc_info=1)
raise aiohttp.web.HTTPInternalServerError(text="uncaught exception: {}".format(e))
raise ComputeError(f"uncaught exception: {e}")
if service_installed is False:
raise aiohttp.web.HTTPInternalServerError(text="The Winpcap or Npcap is not installed or running")
raise ComputeError("The Winpcap or Npcap is not installed or running")
# This interface have special behavior
for result in results:

View File

@ -16,8 +16,8 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import os
import aiohttp
from fastapi import HTTPException, status
from ..config import Config
@ -33,7 +33,7 @@ def get_default_project_directory():
try:
os.makedirs(path, exist_ok=True)
except OSError as e:
raise aiohttp.web.HTTPInternalServerError(text="Could not create project directory: {}".format(e))
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=f"Could not create project directory: {e}")
return path
@ -52,4 +52,4 @@ def check_path_allowed(path):
return
if "local" in config and config.getboolean("local") is False:
raise aiohttp.web.HTTPForbidden(text="The path is not allowed")
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="The path is not allowed")

View File

@ -19,7 +19,7 @@
Check for Windows service.
"""
import aiohttp
from gns3server.compute.compute_error import ComputeError
def check_windows_service_is_running(service_name):
@ -35,5 +35,5 @@ def check_windows_service_is_running(service_name):
if e.winerror == 1060:
return False
else:
raise aiohttp.web.HTTPInternalServerError(text="Could not check if the {} service is running: {}".format(service_name, e.strerror))
raise ComputeError(f"Could not check if the {service_name} service is running: {e.strerror}")
return True

View File

@ -15,9 +15,10 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import aiohttp
import pytest
import uuid
from fastapi import HTTPException
from unittest.mock import patch
from gns3server.compute.port_manager import PortManager
@ -94,7 +95,7 @@ def test_reserve_udp_port():
pm = PortManager()
project = Project(project_id=str(uuid.uuid4()))
pm.reserve_udp_port(20000, project)
with pytest.raises(aiohttp.web.HTTPConflict):
with pytest.raises(HTTPException):
pm.reserve_udp_port(20000, project)
@ -102,7 +103,7 @@ def test_reserve_udp_port_outside_range():
pm = PortManager()
project = Project(project_id=str(uuid.uuid4()))
with pytest.raises(aiohttp.web.HTTPConflict):
with pytest.raises(HTTPException):
pm.reserve_udp_port(80, project)
@ -123,7 +124,7 @@ def test_find_unused_port():
def test_find_unused_port_invalid_range():
with pytest.raises(aiohttp.web.HTTPConflict):
with pytest.raises(HTTPException):
p = PortManager().find_unused_port(10000, 1000)

View File

@ -74,8 +74,6 @@ async def test_compute_get(controller_api):
response = await controller_api.get("/computes/my_compute_id")
assert response.status_code == 200
print(response.json)
#assert response.json["protocol"] == "http"
@pytest.mark.asyncio

View File

@ -17,8 +17,8 @@
import os
import pytest
import aiohttp
from fastapi import HTTPException
from gns3server.utils.path import check_path_allowed, get_default_project_directory
@ -27,7 +27,7 @@ def test_check_path_allowed(config, tmpdir):
config.set("Server", "local", False)
config.set("Server", "projects_path", str(tmpdir))
with pytest.raises(aiohttp.web.HTTPForbidden):
with pytest.raises(HTTPException):
check_path_allowed("/private")
config.set("Server", "local", True)