From a92c47b310768fb4c417fb321dfc679b93735fc6 Mon Sep 17 00:00:00 2001 From: grossmj Date: Thu, 22 Oct 2020 16:19:44 +1030 Subject: [PATCH] Add HTTP client to reuse the aiohttp session where needed. Remove unnecessary aiohttp exceptions. --- gns3server/app.py | 14 +-- gns3server/compute/docker/__init__.py | 11 +-- gns3server/compute/dynamips/__init__.py | 1 - gns3server/compute/port_manager.py | 40 ++++---- gns3server/controller/appliance_manager.py | 97 +++++++++---------- gns3server/controller/compute.py | 31 +++--- .../controller/gns3vm/virtualbox_gns3_vm.py | 33 ++++--- gns3server/endpoints/controller/links.py | 15 ++- gns3server/endpoints/controller/nodes.py | 24 ++--- gns3server/endpoints/controller/projects.py | 2 +- gns3server/utils/http_client.py | 69 +++++++++++++ gns3server/utils/interfaces.py | 10 +- gns3server/utils/path.py | 6 +- gns3server/utils/windows_service.py | 4 +- tests/compute/test_port_manager.py | 9 +- tests/endpoints/controller/test_computes.py | 2 - tests/utils/test_path.py | 4 +- 17 files changed, 221 insertions(+), 151 deletions(-) create mode 100644 gns3server/utils/http_client.py diff --git a/gns3server/app.py b/gns3server/app.py index 10e50d79..8b4da4a6 100644 --- a/gns3server/app.py +++ b/gns3server/app.py @@ -41,6 +41,7 @@ from gns3server.controller.controller_error import ( from gns3server.endpoints import controller from gns3server.endpoints import index from gns3server.endpoints.compute import compute_api +from gns3server.utils.http_client import HTTPClient from gns3server.version import __version__ import logging @@ -76,6 +77,7 @@ app.mount("/v2/compute", compute_api) @app.exception_handler(ControllerError) async def controller_error_handler(request: Request, exc: ControllerError): + log.error(f"Controller error: {exc}") return JSONResponse( status_code=409, content={"message": str(exc)}, @@ -84,6 +86,7 @@ async def controller_error_handler(request: Request, exc: ControllerError): @app.exception_handler(ControllerTimeoutError) async def controller_timeout_error_handler(request: Request, exc: ControllerTimeoutError): + log.error(f"Controller timeout error: {exc}") return JSONResponse( status_code=408, content={"message": str(exc)}, @@ -92,6 +95,7 @@ async def controller_timeout_error_handler(request: Request, exc: ControllerTime @app.exception_handler(ControllerUnauthorizedError) async def controller_unauthorized_error_handler(request: Request, exc: ControllerUnauthorizedError): + log.error(f"Controller unauthorized error: {exc}") return JSONResponse( status_code=401, content={"message": str(exc)}, @@ -100,6 +104,7 @@ async def controller_unauthorized_error_handler(request: Request, exc: Controlle @app.exception_handler(ControllerForbiddenError) async def controller_forbidden_error_handler(request: Request, exc: ControllerForbiddenError): + log.error(f"Controller forbidden error: {exc}") return JSONResponse( status_code=403, content={"message": str(exc)}, @@ -108,6 +113,7 @@ async def controller_forbidden_error_handler(request: Request, exc: ControllerFo @app.exception_handler(ControllerNotFoundError) async def controller_not_found_error_handler(request: Request, exc: ControllerNotFoundError): + log.error(f"Controller not found error: {exc}") return JSONResponse( status_code=404, content={"message": str(exc)}, @@ -164,13 +170,7 @@ async def startup_event(): @app.on_event("shutdown") async def shutdown_event(): - # close websocket connections - # websocket_connections = set(self._app['websockets']) - # if websocket_connections: - # log.info("Closing {} websocket connections...".format(len(websocket_connections))) - # for ws in websocket_connections: - # await ws.close(code=aiohttp.WSCloseCode.GOING_AWAY, message='Server shutdown') - + await HTTPClient.close_session() await Controller.instance().stop() for module in MODULES: diff --git a/gns3server/compute/docker/__init__.py b/gns3server/compute/docker/__init__.py index 1a13d1af..126311db 100644 --- a/gns3server/compute/docker/__init__.py +++ b/gns3server/compute/docker/__init__.py @@ -60,9 +60,8 @@ class Docker(BaseManager): if not self._connected: try: self._connected = True - connector = self.connector() version = await self.query("GET", "version") - except (aiohttp.ClientOSError, FileNotFoundError): + except (aiohttp.ClientError, FileNotFoundError): self._connected = False raise DockerError("Can't connect to docker daemon") @@ -70,8 +69,8 @@ class Docker(BaseManager): if docker_version < parse_version(DOCKER_MINIMUM_API_VERSION): raise DockerError( - "Docker version is {}. GNS3 requires a minimum version of {}".format( - version["Version"], DOCKER_MINIMUM_VERSION)) + "Docker version is {}. GNS3 requires a minimum version of {}".format(version["Version"], + DOCKER_MINIMUM_VERSION)) preferred_api_version = parse_version(DOCKER_PREFERRED_API_VERSION) if docker_version >= preferred_api_version: @@ -84,7 +83,7 @@ class Docker(BaseManager): raise DockerError("Docker is supported only on Linux") try: self._connector = aiohttp.connector.UnixConnector(self._server_url, limit=None) - except (aiohttp.ClientOSError, FileNotFoundError): + except (aiohttp.ClientError, FileNotFoundError): raise DockerError("Can't connect to docker daemon") return self._connector @@ -150,7 +149,7 @@ class Docker(BaseManager): data=data, headers={"content-type": "application/json", }, timeout=timeout) - except (aiohttp.ClientResponseError, aiohttp.ClientOSError) as e: + except aiohttp.ClientError as e: raise DockerError("Docker has returned an error: {}".format(str(e))) except (asyncio.TimeoutError): raise DockerError("Docker timeout " + method + " " + path) diff --git a/gns3server/compute/dynamips/__init__.py b/gns3server/compute/dynamips/__init__.py index 499e3653..cf77daf8 100644 --- a/gns3server/compute/dynamips/__init__.py +++ b/gns3server/compute/dynamips/__init__.py @@ -19,7 +19,6 @@ Dynamips server module. """ -import aiohttp import sys import os import shutil diff --git a/gns3server/compute/port_manager.py b/gns3server/compute/port_manager.py index bc91a435..ac0f4e73 100644 --- a/gns3server/compute/port_manager.py +++ b/gns3server/compute/port_manager.py @@ -16,7 +16,7 @@ # along with this program. If not, see . import socket -from aiohttp.web import HTTPConflict +from fastapi import HTTPException, status from gns3server.config import Config import logging @@ -48,12 +48,12 @@ class PortManager: console_start_port_range = server_config.getint("console_start_port_range", 5000) console_end_port_range = server_config.getint("console_end_port_range", 10000) self._console_port_range = (console_start_port_range, console_end_port_range) - log.debug("Console port range is {}-{}".format(console_start_port_range, console_end_port_range)) + log.debug(f"Console port range is {console_start_port_range}-{console_end_port_range}") udp_start_port_range = server_config.getint("udp_start_port_range", 20000) udp_end_port_range = server_config.getint("udp_end_port_range", 30000) self._udp_port_range = (udp_start_port_range, udp_end_port_range) - log.debug("UDP port range is {}-{}".format(udp_start_port_range, udp_end_port_range)) + log.debug(f"UDP port range is {udp_start_port_range}-{udp_end_port_range}") @classmethod def instance(cls): @@ -149,7 +149,8 @@ class PortManager: """ if end_port < start_port: - raise HTTPConflict(text="Invalid port range {}-{}".format(start_port, end_port)) + raise HTTPException(status_code=status.HTTP_409_CONFLICT, + detail=f"Invalid port range {start_port}-{end_port}") last_exception = None for port in range(start_port, end_port + 1): @@ -168,10 +169,9 @@ class PortManager: else: continue - raise HTTPConflict(text="Could not find a free port between {} and {} on host {}, last exception: {}".format(start_port, - end_port, - host, - last_exception)) + raise HTTPException(status_code=status.HTTP_409_CONFLICT, + detail=f"Could not find a free port between {start_port} and {end_port} on host {host}," + f" last exception: {last_exception}") @staticmethod def _check_port(host, port, socket_type): @@ -212,7 +212,7 @@ class PortManager: self._used_tcp_ports.add(port) project.record_tcp_port(port) - log.debug("TCP port {} has been allocated".format(port)) + log.debug(f"TCP port {port} has been allocated") return port def reserve_tcp_port(self, port, project, port_range_start=None, port_range_end=None): @@ -235,13 +235,14 @@ class PortManager: if port in self._used_tcp_ports: old_port = port port = self.get_free_tcp_port(project, port_range_start=port_range_start, port_range_end=port_range_end) - msg = "TCP port {} already in use on host {}. Port has been replaced by {}".format(old_port, self._console_host, port) + msg = f"TCP port {old_port} already in use on host {self._console_host}. Port has been replaced by {port}" log.debug(msg) return port if port < port_range_start or port > port_range_end: old_port = port port = self.get_free_tcp_port(project, port_range_start=port_range_start, port_range_end=port_range_end) - msg = "TCP port {} is outside the range {}-{} on host {}. Port has been replaced by {}".format(old_port, port_range_start, port_range_end, self._console_host, port) + msg = f"TCP port {old_port} is outside the range {port_range_start}-{port_range_end} on host " \ + f"{self._console_host}. Port has been replaced by {port}" log.debug(msg) return port try: @@ -249,13 +250,13 @@ class PortManager: except OSError: old_port = port port = self.get_free_tcp_port(project, port_range_start=port_range_start, port_range_end=port_range_end) - msg = "TCP port {} already in use on host {}. Port has been replaced by {}".format(old_port, self._console_host, port) + msg = f"TCP port {old_port} already in use on host {self._console_host}. Port has been replaced by {port}" log.debug(msg) return port self._used_tcp_ports.add(port) project.record_tcp_port(port) - log.debug("TCP port {} has been reserved".format(port)) + log.debug(f"TCP port {port} has been reserved") return port def release_tcp_port(self, port, project): @@ -269,7 +270,7 @@ class PortManager: if port in self._used_tcp_ports: self._used_tcp_ports.remove(port) project.remove_tcp_port(port) - log.debug("TCP port {} has been released".format(port)) + log.debug(f"TCP port {port} has been released") def get_free_udp_port(self, project): """ @@ -285,7 +286,7 @@ class PortManager: self._used_udp_ports.add(port) project.record_udp_port(port) - log.debug("UDP port {} has been allocated".format(port)) + log.debug(f"UDP port {port} has been allocated") return port def reserve_udp_port(self, port, project): @@ -297,9 +298,12 @@ class PortManager: """ if port in self._used_udp_ports: - raise HTTPConflict(text="UDP port {} already in use on host {}".format(port, self._console_host)) + raise HTTPException(status_code=status.HTTP_409_CONFLICT, + detail=f"UDP port {port} already in use on host {self._console_host}") if port < self._udp_port_range[0] or port > self._udp_port_range[1]: - raise HTTPConflict(text="UDP port {} is outside the range {}-{}".format(port, self._udp_port_range[0], self._udp_port_range[1])) + raise HTTPException(status_code=status.HTTP_409_CONFLICT, + detail=f"UDP port {port} is outside the range " + f"{self._udp_port_range[0]}-{self._udp_port_range[1]}") self._used_udp_ports.add(port) project.record_udp_port(port) log.debug("UDP port {} has been reserved".format(port)) @@ -315,4 +319,4 @@ class PortManager: if port in self._used_udp_ports: self._used_udp_ports.remove(port) project.remove_udp_port(port) - log.debug("UDP port {} has been released".format(port)) + log.debug(f"UDP port {port} has been released") diff --git a/gns3server/controller/appliance_manager.py b/gns3server/controller/appliance_manager.py index f96373e8..da2dd86e 100644 --- a/gns3server/controller/appliance_manager.py +++ b/gns3server/controller/appliance_manager.py @@ -19,12 +19,12 @@ import os import json import uuid import asyncio -import aiohttp from .appliance import Appliance from ..config import Config from ..utils.asyncio import locking from ..utils.get_resource import get_resource +from ..utils.http_client import HTTPClient from .controller_error import ControllerError import logging @@ -142,20 +142,19 @@ class ApplianceManager: """ symbol_url = "https://raw.githubusercontent.com/GNS3/gns3-registry/master/symbols/{}".format(symbol) - async with aiohttp.ClientSession() as session: - async with session.get(symbol_url) as response: - if response.status != 200: - log.warning("Could not retrieve appliance symbol {} from GitHub due to HTTP error code {}".format(symbol, response.status)) - else: - try: - symbol_data = await response.read() - log.info("Saving {} symbol to {}".format(symbol, destination_path)) - with open(destination_path, 'wb') as f: - f.write(symbol_data) - except asyncio.TimeoutError: - log.warning("Timeout while downloading '{}'".format(symbol_url)) - except OSError as e: - log.warning("Could not write appliance symbol '{}': {}".format(destination_path, e)) + async with HTTPClient.get(symbol_url) as response: + if response.status != 200: + log.warning("Could not retrieve appliance symbol {} from GitHub due to HTTP error code {}".format(symbol, response.status)) + else: + try: + symbol_data = await response.read() + log.info("Saving {} symbol to {}".format(symbol, destination_path)) + with open(destination_path, 'wb') as f: + f.write(symbol_data) + except asyncio.TimeoutError: + log.warning("Timeout while downloading '{}'".format(symbol_url)) + except OSError as e: + log.warning("Could not write appliance symbol '{}': {}".format(destination_path, e)) @locking async def download_appliances(self): @@ -168,40 +167,40 @@ class ApplianceManager: if self._appliances_etag: log.info("Checking if appliances are up-to-date (ETag {})".format(self._appliances_etag)) headers["If-None-Match"] = self._appliances_etag - async with aiohttp.ClientSession() as session: - async with session.get('https://api.github.com/repos/GNS3/gns3-registry/contents/appliances', headers=headers) as response: - if response.status == 304: - log.info("Appliances are already up-to-date (ETag {})".format(self._appliances_etag)) - return - elif response.status != 200: - raise ControllerError("Could not retrieve appliances from GitHub due to HTTP error code {}".format(response.status)) - etag = response.headers.get("ETag") - if etag: - self._appliances_etag = etag - from . import Controller - Controller.instance().save() - json_data = await response.json() - appliances_dir = get_resource('appliances') - for appliance in json_data: - if appliance["type"] == "file": - appliance_name = appliance["name"] - log.info("Download appliance file from '{}'".format(appliance["download_url"])) - async with session.get(appliance["download_url"]) as response: - if response.status != 200: - log.warning("Could not download '{}' due to HTTP error code {}".format(appliance["download_url"], response.status)) - continue - try: - appliance_data = await response.read() - except asyncio.TimeoutError: - log.warning("Timeout while downloading '{}'".format(appliance["download_url"])) - continue - path = os.path.join(appliances_dir, appliance_name) - try: - log.info("Saving {} file to {}".format(appliance_name, path)) - with open(path, 'wb') as f: - f.write(appliance_data) - except OSError as e: - raise ControllerError("Could not write appliance file '{}': {}".format(path, e)) + + async with HTTPClient.get('https://api.github.com/repos/GNS3/gns3-registry/contents/appliances', headers=headers) as response: + if response.status == 304: + log.info("Appliances are already up-to-date (ETag {})".format(self._appliances_etag)) + return + elif response.status != 200: + raise ControllerError("Could not retrieve appliances from GitHub due to HTTP error code {}".format(response.status)) + etag = response.headers.get("ETag") + if etag: + self._appliances_etag = etag + from . import Controller + Controller.instance().save() + json_data = await response.json() + appliances_dir = get_resource('appliances') + for appliance in json_data: + if appliance["type"] == "file": + appliance_name = appliance["name"] + log.info("Download appliance file from '{}'".format(appliance["download_url"])) + async with HTTPClient.get(appliance["download_url"]) as response: + if response.status != 200: + log.warning("Could not download '{}' due to HTTP error code {}".format(appliance["download_url"], response.status)) + continue + try: + appliance_data = await response.read() + except asyncio.TimeoutError: + log.warning("Timeout while downloading '{}'".format(appliance["download_url"])) + continue + path = os.path.join(appliances_dir, appliance_name) + try: + log.info("Saving {} file to {}".format(appliance_name, path)) + with open(path, 'wb') as f: + f.write(appliance_data) + except OSError as e: + raise ControllerError("Could not write appliance file '{}': {}".format(path, e)) except ValueError as e: raise ControllerError("Could not read appliances information from GitHub: {}".format(e)) diff --git a/gns3server/controller/compute.py b/gns3server/controller/compute.py index 770acce0..892146d9 100644 --- a/gns3server/controller/compute.py +++ b/gns3server/controller/compute.py @@ -63,7 +63,9 @@ class Compute: A GNS3 compute. """ - def __init__(self, compute_id, controller=None, protocol="http", host="localhost", port=3080, user=None, password=None, name=None, console_host=None): + def __init__(self, compute_id, controller=None, protocol="http", host="localhost", port=3080, user=None, + password=None, name=None, console_host=None): + self._http_session = None assert controller is not None log.info("Create compute %s", compute_id) @@ -103,14 +105,10 @@ class Compute: def _session(self): if self._http_session is None or self._http_session.closed is True: - self._http_session = aiohttp.ClientSession(connector=aiohttp.TCPConnector(limit=None, force_close=True)) + connector = aiohttp.TCPConnector(force_close=True) + self._http_session = aiohttp.ClientSession(connector=connector) return self._http_session - #def __del__(self): - # - # if self._http_session: - # self._http_session.close() - def _set_auth(self, user, password): """ Set authentication parameters @@ -466,7 +464,7 @@ class Compute: elif response.type == aiohttp.WSMsgType.CLOSED: pass break - except aiohttp.client_exceptions.ClientResponseError as e: + except aiohttp.ClientError as e: log.error("Client response error received on compute '{}' WebSocket '{}': {}".format(self._id, ws_url,e)) finally: self._connected = False @@ -503,8 +501,7 @@ class Compute: async def _run_http_query(self, method, path, data=None, timeout=20, raw=False): with async_timeout.timeout(timeout): url = self._getUrl(path) - headers = {} - headers['content-type'] = 'application/json' + headers = {'content-type': 'application/json'} chunked = None if data == {}: data = None @@ -579,7 +576,7 @@ class Compute: return response async def get(self, path, **kwargs): - return (await self.http_query("GET", path, **kwargs)) + return await self.http_query("GET", path, **kwargs) async def post(self, path, data={}, **kwargs): response = await self.http_query("POST", path, data, **kwargs) @@ -600,15 +597,13 @@ class Compute: action = "/{}/{}".format(type, path) res = await self.http_query(method, action, data=data, timeout=None) except aiohttp.ServerDisconnectedError: - log.error("Connection lost to %s during %s %s", self._id, method, action) - raise aiohttp.web.HTTPGatewayTimeout() + raise ControllerError(f"Connection lost to {self._id} during {method} {action}") return res.json async def images(self, type): """ Return the list of images available for this type on the compute node. """ - images = [] res = await self.http_query("GET", "/{}/images".format(type), timeout=None) images = res.json @@ -641,11 +636,11 @@ class Compute: :returns: Tuple (ip_for_this_compute, ip_for_other_compute) """ if other_compute == self: - return (self.host_ip, self.host_ip) + return self.host_ip, self.host_ip # Perhaps the user has correct network gateway, we trust him - if (self.host_ip not in ('0.0.0.0', '127.0.0.1') and other_compute.host_ip not in ('0.0.0.0', '127.0.0.1')): - return (self.host_ip, other_compute.host_ip) + if self.host_ip not in ('0.0.0.0', '127.0.0.1') and other_compute.host_ip not in ('0.0.0.0', '127.0.0.1'): + return self.host_ip, other_compute.host_ip this_compute_interfaces = await self.interfaces() other_compute_interfaces = await other_compute.interfaces() @@ -675,6 +670,6 @@ class Compute: other_network = ipaddress.ip_network("{}/{}".format(other_interface["ip_address"], other_interface["netmask"]), strict=False) if this_network.overlaps(other_network): - return (this_interface["ip_address"], other_interface["ip_address"]) + return this_interface["ip_address"], other_interface["ip_address"] raise ValueError("No common subnet for compute {} and {}".format(self.name, other_compute.name)) diff --git a/gns3server/controller/gns3vm/virtualbox_gns3_vm.py b/gns3server/controller/gns3vm/virtualbox_gns3_vm.py index 874cca36..fb9fbc99 100644 --- a/gns3server/controller/gns3vm/virtualbox_gns3_vm.py +++ b/gns3server/controller/gns3vm/virtualbox_gns3_vm.py @@ -24,6 +24,7 @@ import socket from .base_gns3_vm import BaseGNS3VM from .gns3_vm_error import GNS3VMError from gns3server.utils import parse_version +from gns3server.utils.http_client import HTTPClient from gns3server.utils.asyncio import wait_run_in_executor from ...compute.virtualbox import ( @@ -305,24 +306,24 @@ class VirtualBoxGNS3VM(BaseGNS3VM): second to a GNS3 endpoint in order to get the list of the interfaces and their IP and after that match it with VirtualBox host only. """ + remaining_try = 300 while remaining_try > 0: - async with aiohttp.ClientSession() as session: - try: - async with session.get('http://127.0.0.1:{}/v2/compute/network/interfaces'.format(api_port)) as resp: - if resp.status < 300: - try: - json_data = await resp.json() - if json_data: - for interface in json_data: - if "name" in interface and interface["name"] == "eth{}".format( - hostonly_interface_number - 1): - if "ip_address" in interface and len(interface["ip_address"]) > 0: - return interface["ip_address"] - except ValueError: - pass - except (OSError, aiohttp.ClientError, TimeoutError, asyncio.TimeoutError): - pass + try: + async with HTTPClient.get(f"http://127.0.0.1:{api_port}/v2/compute/network/interfaces") as resp: + if resp.status < 300: + try: + json_data = await resp.json() + if json_data: + for interface in json_data: + if "name" in interface and interface["name"] == "eth{}".format( + hostonly_interface_number - 1): + if "ip_address" in interface and len(interface["ip_address"]) > 0: + return interface["ip_address"] + except ValueError: + pass + except (OSError, aiohttp.ClientError, TimeoutError, asyncio.TimeoutError): + pass remaining_try -= 1 await asyncio.sleep(1) raise GNS3VMError("Could not find guest IP address for {}".format(self.vmname)) diff --git a/gns3server/endpoints/controller/links.py b/gns3server/endpoints/controller/links.py index 46f313be..ff1c393d 100644 --- a/gns3server/endpoints/controller/links.py +++ b/gns3server/endpoints/controller/links.py @@ -19,8 +19,8 @@ API endpoints for links. """ -import aiohttp import multidict +import aiohttp from fastapi import APIRouter, Depends, Request, status from fastapi.responses import StreamingResponse @@ -31,9 +31,13 @@ from uuid import UUID from gns3server.controller import Controller from gns3server.controller.controller_error import ControllerError from gns3server.controller.link import Link +from gns3server.utils.http_client import HTTPClient from gns3server.endpoints.schemas.common import ErrorMessage from gns3server.endpoints import schemas +import logging +log = logging.getLogger(__name__) + router = APIRouter() responses = { @@ -201,12 +205,13 @@ async def pcap(request: Request, link: Link = Depends(dep_link)): async def compute_pcap_stream(): - connector = aiohttp.TCPConnector(limit=None, force_close=True) - async with aiohttp.ClientSession(connector=connector, headers=headers) as session: - async with session.request(request.method, pcap_streaming_url, timeout=None, data=body) as compute_response: - async for data in compute_response.content.iter_any(): + try: + async with HTTPClient.request(request.method, pcap_streaming_url, timeout=None, data=body) as response: + async for data in response.content.iter_any(): if not data: break yield data + except aiohttp.ClientError as e: + raise ControllerError(f"Client error received when receiving pcap stream from compute: {e}") return StreamingResponse(compute_pcap_stream(), media_type="application/vnd.tcpdump.pcap") diff --git a/gns3server/endpoints/controller/nodes.py b/gns3server/endpoints/controller/nodes.py index 0f65043e..c2f6cca0 100644 --- a/gns3server/endpoints/controller/nodes.py +++ b/gns3server/endpoints/controller/nodes.py @@ -32,6 +32,7 @@ from gns3server.controller import Controller from gns3server.controller.node import Node from gns3server.controller.project import Project from gns3server.utils import force_unix_path +from gns3server.utils.http_client import HTTPClient from gns3server.controller.controller_error import ControllerForbiddenError from gns3server.endpoints.schemas.common import ErrorMessage from gns3server.endpoints import schemas @@ -400,18 +401,17 @@ async def ws_console(websocket: WebSocket, node: Node = Depends(dep_node)): try: # receive WebSocket data from compute console WebSocket and forward to client. - async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(limit=None, force_close=True)) as session: - async with session.ws_connect(ws_console_compute_url) as ws_console_compute: - asyncio.ensure_future(ws_receive(ws_console_compute)) - async for msg in ws_console_compute: - if msg.type == aiohttp.WSMsgType.TEXT: - await websocket.send_text(msg.data) - elif msg.type == aiohttp.WSMsgType.BINARY: - await websocket.send_bytes(msg.data) - elif msg.type == aiohttp.WSMsgType.ERROR: - break - except aiohttp.client_exceptions.ClientResponseError as e: - log.error(f"Client response error received when forwarding to compute console WebSocket: {e}") + async with HTTPClient.get_client().ws_connect(ws_console_compute_url) as ws_console_compute: + asyncio.ensure_future(ws_receive(ws_console_compute)) + async for msg in ws_console_compute: + if msg.type == aiohttp.WSMsgType.TEXT: + await websocket.send_text(msg.data) + elif msg.type == aiohttp.WSMsgType.BINARY: + await websocket.send_bytes(msg.data) + elif msg.type == aiohttp.WSMsgType.ERROR: + break + except aiohttp.ClientError as e: + log.error(f"Client error received when forwarding to compute console WebSocket: {e}") @router.post("/console/reset", diff --git a/gns3server/endpoints/controller/projects.py b/gns3server/endpoints/controller/projects.py index 16d18e0f..a4bd174d 100644 --- a/gns3server/endpoints/controller/projects.py +++ b/gns3server/endpoints/controller/projects.py @@ -29,7 +29,7 @@ import time import logging log = logging.getLogger() -from fastapi import APIRouter, Depends, Request, Body, Query, HTTPException, status, WebSocket, WebSocketDisconnect +from fastapi import APIRouter, Depends, Request, Body, HTTPException, status, WebSocket, WebSocketDisconnect from fastapi.encoders import jsonable_encoder from fastapi.responses import StreamingResponse, FileResponse from websockets.exceptions import ConnectionClosed, WebSocketException diff --git a/gns3server/utils/http_client.py b/gns3server/utils/http_client.py new file mode 100644 index 00000000..79d5639e --- /dev/null +++ b/gns3server/utils/http_client.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python +# +# Copyright (C) 2020 GNS3 Technologies Inc. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import aiohttp +import socket + +import logging +log = logging.getLogger(__name__) + + +class HTTPClient: + """ + HTTP client for request to computes and external services. + """ + + _aiohttp_client: aiohttp.ClientSession = None + + @classmethod + def get_client(cls) -> aiohttp.ClientSession: + if cls._aiohttp_client is None: + cls._aiohttp_client = aiohttp.ClientSession(connector=aiohttp.TCPConnector(family=socket.AF_INET)) + return cls._aiohttp_client + + @classmethod + async def close_session(cls): + if cls._aiohttp_client: + await cls._aiohttp_client.close() + cls._aiohttp_client = None + + @classmethod + def request(cls, method: str, url: str, user: str = None, password: str = None, **kwargs): + + client = cls.get_client() + basic_auth = None + if user: + if not password: + password = "" + try: + basic_auth = aiohttp.BasicAuth(user, password, "utf-8") + except ValueError as e: + log.error(f"Basic authentication set-up error: {e}") + + return client.request(method, url, auth=basic_auth, **kwargs) + + @classmethod + def get(cls, path, **kwargs): + return cls.request("GET", path, **kwargs) + + @classmethod + def post(cls, path, **kwargs): + return cls.request("POST", path, **kwargs) + + @classmethod + def put(cls, path, **kwargs): + return cls.request("PUT", path, **kwargs) diff --git a/gns3server/utils/interfaces.py b/gns3server/utils/interfaces.py index 0431c5ad..9d865a7f 100644 --- a/gns3server/utils/interfaces.py +++ b/gns3server/utils/interfaces.py @@ -18,12 +18,12 @@ import os import sys -import aiohttp import socket import struct import psutil from .windows_service import check_windows_service_is_running +from gns3server.compute.compute_error import ComputeError from gns3server.config import Config if psutil.version_info < (3, 0, 0): @@ -162,7 +162,7 @@ def is_interface_up(interface): return True return False except OSError as e: - raise aiohttp.web.HTTPInternalServerError(text="Exception when checking if {} is up: {}".format(interface, e)) + raise ComputeError(f"Exception when checking if {interface} is up: {e}") else: # TODO: Windows & OSX support return True @@ -221,13 +221,13 @@ def interfaces(): results = get_windows_interfaces() except ImportError: message = "pywin32 module is not installed, please install it on the server to get the available interface names" - raise aiohttp.web.HTTPInternalServerError(text=message) + raise ComputeError(message) except Exception as e: log.error("uncaught exception {type}".format(type=type(e)), exc_info=1) - raise aiohttp.web.HTTPInternalServerError(text="uncaught exception: {}".format(e)) + raise ComputeError(f"uncaught exception: {e}") if service_installed is False: - raise aiohttp.web.HTTPInternalServerError(text="The Winpcap or Npcap is not installed or running") + raise ComputeError("The Winpcap or Npcap is not installed or running") # This interface have special behavior for result in results: diff --git a/gns3server/utils/path.py b/gns3server/utils/path.py index 94a1ee64..3430a2f3 100644 --- a/gns3server/utils/path.py +++ b/gns3server/utils/path.py @@ -16,8 +16,8 @@ # along with this program. If not, see . import os -import aiohttp +from fastapi import HTTPException, status from ..config import Config @@ -33,7 +33,7 @@ def get_default_project_directory(): try: os.makedirs(path, exist_ok=True) except OSError as e: - raise aiohttp.web.HTTPInternalServerError(text="Could not create project directory: {}".format(e)) + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=f"Could not create project directory: {e}") return path @@ -52,4 +52,4 @@ def check_path_allowed(path): return if "local" in config and config.getboolean("local") is False: - raise aiohttp.web.HTTPForbidden(text="The path is not allowed") + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="The path is not allowed") diff --git a/gns3server/utils/windows_service.py b/gns3server/utils/windows_service.py index 7b2e2310..07762a8c 100644 --- a/gns3server/utils/windows_service.py +++ b/gns3server/utils/windows_service.py @@ -19,7 +19,7 @@ Check for Windows service. """ -import aiohttp +from gns3server.compute.compute_error import ComputeError def check_windows_service_is_running(service_name): @@ -35,5 +35,5 @@ def check_windows_service_is_running(service_name): if e.winerror == 1060: return False else: - raise aiohttp.web.HTTPInternalServerError(text="Could not check if the {} service is running: {}".format(service_name, e.strerror)) + raise ComputeError(f"Could not check if the {service_name} service is running: {e.strerror}") return True diff --git a/tests/compute/test_port_manager.py b/tests/compute/test_port_manager.py index a2613521..0e3b878f 100644 --- a/tests/compute/test_port_manager.py +++ b/tests/compute/test_port_manager.py @@ -15,9 +15,10 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -import aiohttp import pytest import uuid + +from fastapi import HTTPException from unittest.mock import patch from gns3server.compute.port_manager import PortManager @@ -94,7 +95,7 @@ def test_reserve_udp_port(): pm = PortManager() project = Project(project_id=str(uuid.uuid4())) pm.reserve_udp_port(20000, project) - with pytest.raises(aiohttp.web.HTTPConflict): + with pytest.raises(HTTPException): pm.reserve_udp_port(20000, project) @@ -102,7 +103,7 @@ def test_reserve_udp_port_outside_range(): pm = PortManager() project = Project(project_id=str(uuid.uuid4())) - with pytest.raises(aiohttp.web.HTTPConflict): + with pytest.raises(HTTPException): pm.reserve_udp_port(80, project) @@ -123,7 +124,7 @@ def test_find_unused_port(): def test_find_unused_port_invalid_range(): - with pytest.raises(aiohttp.web.HTTPConflict): + with pytest.raises(HTTPException): p = PortManager().find_unused_port(10000, 1000) diff --git a/tests/endpoints/controller/test_computes.py b/tests/endpoints/controller/test_computes.py index 431a5218..acea9a2d 100644 --- a/tests/endpoints/controller/test_computes.py +++ b/tests/endpoints/controller/test_computes.py @@ -74,8 +74,6 @@ async def test_compute_get(controller_api): response = await controller_api.get("/computes/my_compute_id") assert response.status_code == 200 - print(response.json) - #assert response.json["protocol"] == "http" @pytest.mark.asyncio diff --git a/tests/utils/test_path.py b/tests/utils/test_path.py index e73de26f..2c709f94 100644 --- a/tests/utils/test_path.py +++ b/tests/utils/test_path.py @@ -17,8 +17,8 @@ import os import pytest -import aiohttp +from fastapi import HTTPException from gns3server.utils.path import check_path_allowed, get_default_project_directory @@ -27,7 +27,7 @@ def test_check_path_allowed(config, tmpdir): config.set("Server", "local", False) config.set("Server", "projects_path", str(tmpdir)) - with pytest.raises(aiohttp.web.HTTPForbidden): + with pytest.raises(HTTPException): check_path_allowed("/private") config.set("Server", "local", True)