mirror of
https://github.com/GNS3/gns3-server
synced 2024-12-24 15:58:08 +00:00
Use aiofiles where relevant.
This commit is contained in:
parent
b0df7ecabf
commit
af80b0bb6e
@ -20,6 +20,7 @@ import os
|
||||
import struct
|
||||
import stat
|
||||
import asyncio
|
||||
import aiofiles
|
||||
|
||||
import aiohttp
|
||||
import socket
|
||||
@ -46,6 +47,8 @@ from .nios.nio_ethernet import NIOEthernet
|
||||
from ..utils.images import md5sum, remove_checksum, images_directories, default_images_directory, list_images
|
||||
from .error import NodeError, ImageMissingError
|
||||
|
||||
CHUNK_SIZE = 1024 * 8 # 8KB
|
||||
|
||||
|
||||
class BaseManager:
|
||||
|
||||
@ -456,7 +459,7 @@ class BaseManager:
|
||||
with open(path, "rb") as f:
|
||||
await response.prepare(request)
|
||||
while nio.capturing:
|
||||
data = f.read(4096)
|
||||
data = f.read(CHUNK_SIZE)
|
||||
if not data:
|
||||
await asyncio.sleep(0.1)
|
||||
continue
|
||||
@ -594,18 +597,18 @@ class BaseManager:
|
||||
path = os.path.abspath(os.path.join(directory, *os.path.split(filename)))
|
||||
if os.path.commonprefix([directory, path]) != directory:
|
||||
raise aiohttp.web.HTTPForbidden(text="Could not write image: {}, {} is forbidden".format(filename, path))
|
||||
log.info("Writing image file %s", path)
|
||||
log.info("Writing image file to '{}'".format(path))
|
||||
try:
|
||||
remove_checksum(path)
|
||||
# We store the file under his final name only when the upload is finished
|
||||
tmp_path = path + ".tmp"
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
with open(tmp_path, 'wb') as f:
|
||||
async with aiofiles.open(tmp_path, 'wb') as f:
|
||||
while True:
|
||||
packet = await stream.read(4096)
|
||||
if not packet:
|
||||
chunk = await stream.read(CHUNK_SIZE)
|
||||
if not chunk:
|
||||
break
|
||||
f.write(packet)
|
||||
await f.write(chunk)
|
||||
os.chmod(tmp_path, stat.S_IWRITE | stat.S_IREAD | stat.S_IEXEC)
|
||||
shutil.move(tmp_path, path)
|
||||
await cancellable_wait_run_in_executor(md5sum, path)
|
||||
|
@ -37,6 +37,7 @@ log = logging.getLogger(__name__)
|
||||
DOCKER_MINIMUM_API_VERSION = "1.25"
|
||||
DOCKER_MINIMUM_VERSION = "1.13"
|
||||
DOCKER_PREFERRED_API_VERSION = "1.30"
|
||||
CHUNK_SIZE = 1024 * 8 # 8KB
|
||||
|
||||
|
||||
class Docker(BaseManager):
|
||||
@ -206,7 +207,7 @@ class Docker(BaseManager):
|
||||
content = ""
|
||||
while True:
|
||||
try:
|
||||
chunk = await response.content.read(1024)
|
||||
chunk = await response.content.read(CHUNK_SIZE)
|
||||
except aiohttp.ServerDisconnectedError:
|
||||
log.error("Disconnected from server while pulling Docker image '{}' from docker hub".format(image))
|
||||
break
|
||||
|
@ -320,28 +320,6 @@ class Compute:
|
||||
raise aiohttp.web.HTTPNotFound(text="{} not found on compute".format(image))
|
||||
return response
|
||||
|
||||
async def stream_file(self, project, path, timeout=None):
|
||||
"""
|
||||
Read file of a project and stream it
|
||||
|
||||
:param project: A project object
|
||||
:param path: The path of the file in the project
|
||||
:param timeout: timeout
|
||||
:returns: A file stream
|
||||
"""
|
||||
|
||||
url = self._getUrl("/projects/{}/stream/{}".format(project.id, path))
|
||||
response = await self._session().request("GET", url, auth=self._auth, timeout=timeout)
|
||||
if response.status == 404:
|
||||
raise aiohttp.web.HTTPNotFound(text="file '{}' not found on compute".format(path))
|
||||
elif response.status == 403:
|
||||
raise aiohttp.web.HTTPForbidden(text="forbidden to open '{}' on compute".format(path))
|
||||
elif response.status != 200:
|
||||
raise aiohttp.web.HTTPInternalServerError(text="Unexpected error {}: {}: while opening {} on compute".format(response.status,
|
||||
response.reason,
|
||||
path))
|
||||
return response
|
||||
|
||||
async def http_query(self, method, path, data=None, dont_connect=False, **kwargs):
|
||||
"""
|
||||
:param dont_connect: If true do not reconnect if not connected
|
||||
|
@ -19,6 +19,7 @@ import os
|
||||
import sys
|
||||
import json
|
||||
import asyncio
|
||||
import aiofiles
|
||||
import aiohttp
|
||||
import zipfile
|
||||
import tempfile
|
||||
@ -28,6 +29,8 @@ from datetime import datetime
|
||||
import logging
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
CHUNK_SIZE = 1024 * 8 # 8KB
|
||||
|
||||
|
||||
async def export_project(zstream, project, temporary_dir, include_images=False, keep_compute_id=False, allow_all_nodes=False, reset_mac_addresses=False):
|
||||
"""
|
||||
@ -36,13 +39,13 @@ async def export_project(zstream, project, temporary_dir, include_images=False,
|
||||
The file will be read chunk by chunk when you iterate over the zip stream.
|
||||
Some files like snapshots and packet captures are ignored.
|
||||
|
||||
:param zstream: ZipStream object
|
||||
:param project: Project instance
|
||||
:param temporary_dir: A temporary dir where to store intermediate data
|
||||
:param include images: save OS images to the zip file
|
||||
:param keep_compute_id: If false replace all compute id by local (standard behavior for .gns3project to make it portable)
|
||||
:param allow_all_nodes: Allow all nodes type to be include in the zip even if not portable
|
||||
:param reset_mac_addresses: Reset MAC addresses for every nodes.
|
||||
|
||||
:returns: ZipStream object
|
||||
"""
|
||||
|
||||
# To avoid issue with data not saved we disallow the export of a running project
|
||||
@ -80,28 +83,28 @@ async def export_project(zstream, project, temporary_dir, include_images=False,
|
||||
zstream.write(path, os.path.relpath(path, project._path))
|
||||
|
||||
# Export files from remote computes
|
||||
downloaded_files = set()
|
||||
for compute in project.computes:
|
||||
if compute.id != "local":
|
||||
compute_files = await compute.list_files(project)
|
||||
for compute_file in compute_files:
|
||||
if _is_exportable(compute_file["path"]):
|
||||
(fd, temp_path) = tempfile.mkstemp(dir=temporary_dir)
|
||||
f = open(fd, "wb", closefd=True)
|
||||
log.debug("Downloading file '{}' from compute '{}'".format(compute_file["path"], compute.id))
|
||||
response = await compute.download_file(project, compute_file["path"])
|
||||
#if response.status != 200:
|
||||
# raise aiohttp.web.HTTPConflict(text="Cannot export file from compute '{}'. Compute returned status code {}.".format(compute.id, response.status))
|
||||
(fd, temp_path) = tempfile.mkstemp(dir=temporary_dir)
|
||||
async with aiofiles.open(fd, 'wb') as f:
|
||||
while True:
|
||||
try:
|
||||
data = await response.content.read(1024)
|
||||
data = await response.content.read(CHUNK_SIZE)
|
||||
except asyncio.TimeoutError:
|
||||
raise aiohttp.web.HTTPRequestTimeout(text="Timeout when downloading file '{}' from remote compute {}:{}".format(compute_file["path"], compute.host, compute.port))
|
||||
if not data:
|
||||
break
|
||||
f.write(data)
|
||||
await f.write(data)
|
||||
response.close()
|
||||
f.close()
|
||||
_patch_mtime(temp_path)
|
||||
zstream.write(temp_path, arcname=compute_file["path"])
|
||||
downloaded_files.add(compute_file['path'])
|
||||
|
||||
|
||||
def _patch_mtime(path):
|
||||
@ -262,30 +265,26 @@ async def _export_remote_images(project, compute_id, image_type, image, project_
|
||||
Export specific image from remote compute.
|
||||
"""
|
||||
|
||||
log.info("Downloading image '{}' from compute '{}'".format(image, compute_id))
|
||||
|
||||
log.debug("Downloading image '{}' from compute '{}'".format(image, compute_id))
|
||||
try:
|
||||
compute = [compute for compute in project.computes if compute.id == compute_id][0]
|
||||
except IndexError:
|
||||
raise aiohttp.web.HTTPConflict(text="Cannot export image from '{}' compute. Compute doesn't exist.".format(compute_id))
|
||||
|
||||
(fd, temp_path) = tempfile.mkstemp(dir=temporary_dir)
|
||||
f = open(fd, "wb", closefd=True)
|
||||
response = await compute.download_image(image_type, image)
|
||||
|
||||
if response.status != 200:
|
||||
raise aiohttp.web.HTTPConflict(text="Cannot export image from '{}' compute. Compute returned status code {}.".format(compute_id, response.status))
|
||||
raise aiohttp.web.HTTPConflict(text="Cannot export image from compute '{}'. Compute returned status code {}.".format(compute_id, response.status))
|
||||
|
||||
(fd, temp_path) = tempfile.mkstemp(dir=temporary_dir)
|
||||
async with aiofiles.open(fd, 'wb') as f:
|
||||
while True:
|
||||
try:
|
||||
data = await response.content.read(1024)
|
||||
data = await response.content.read(CHUNK_SIZE)
|
||||
except asyncio.TimeoutError:
|
||||
raise aiohttp.web.HTTPRequestTimeout(text="Timeout when downloading image '{}' from remote compute {}:{}".format(image, compute.host, compute.port))
|
||||
if not data:
|
||||
break
|
||||
f.write(data)
|
||||
await f.write(data)
|
||||
response.close()
|
||||
f.close()
|
||||
arcname = os.path.join("images", image_type, image)
|
||||
log.info("Saved {}".format(arcname))
|
||||
project_zipfile.write(temp_path, arcname=arcname, compress_type=zipfile.ZIP_DEFLATED)
|
||||
|
@ -20,7 +20,6 @@ import sys
|
||||
import json
|
||||
import uuid
|
||||
import shutil
|
||||
import asyncio
|
||||
import zipfile
|
||||
import aiohttp
|
||||
import itertools
|
||||
|
@ -971,7 +971,8 @@ class Project:
|
||||
try:
|
||||
begin = time.time()
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
with aiozipstream.ZipFile(compression=zipfile.ZIP_DEFLATED) as zstream:
|
||||
# Do not compress the exported project when duplicating
|
||||
with aiozipstream.ZipFile(compression=zipfile.ZIP_STORED) as zstream:
|
||||
await export_project(zstream, self, tmpdir, keep_compute_id=True, allow_all_nodes=True, reset_mac_addresses=True)
|
||||
|
||||
# export the project to a temporary location
|
||||
@ -985,7 +986,7 @@ class Project:
|
||||
with open(project_path, "rb") as f:
|
||||
project = await import_project(self._controller, str(uuid.uuid4()), f, location=location, name=name, keep_compute_id=True)
|
||||
|
||||
log.info("Project '{}' duplicated in {:.4f} seconds".format(project.id, time.time() - begin))
|
||||
log.info("Project '{}' duplicated in {:.4f} seconds".format(project.name, time.time() - begin))
|
||||
except (ValueError, OSError, UnicodeEncodeError) as e:
|
||||
raise aiohttp.web.HTTPConflict(text="Cannot duplicate project: {}".format(str(e)))
|
||||
|
||||
|
@ -101,7 +101,7 @@ class Snapshot:
|
||||
async with aiofiles.open(self.path, 'wb') as f:
|
||||
async for chunk in zstream:
|
||||
await f.write(chunk)
|
||||
log.info("Snapshot '{}' created in {:.4f} seconds".format(self.path, time.time() - begin))
|
||||
log.info("Snapshot '{}' created in {:.4f} seconds".format(self.name, time.time() - begin))
|
||||
except (ValueError, OSError, RuntimeError) as e:
|
||||
raise aiohttp.web.HTTPConflict(text="Could not create snapshot file '{}': {}".format(self.path, e))
|
||||
|
||||
|
@ -493,7 +493,7 @@ class DynamipsVMHandler:
|
||||
if filename[0] == ".":
|
||||
raise aiohttp.web.HTTPForbidden()
|
||||
|
||||
await response.file(image_path)
|
||||
await response.stream_file(image_path)
|
||||
|
||||
@Route.post(
|
||||
r"/projects/{project_id}/dynamips/nodes/{node_id}/duplicate",
|
||||
|
@ -451,4 +451,4 @@ class IOUHandler:
|
||||
if filename[0] == ".":
|
||||
raise aiohttp.web.HTTPForbidden()
|
||||
|
||||
await response.file(image_path)
|
||||
await response.stream_file(image_path)
|
||||
|
@ -37,6 +37,8 @@ from gns3server.schemas.project import (
|
||||
import logging
|
||||
log = logging.getLogger()
|
||||
|
||||
CHUNK_SIZE = 1024 * 8 # 8KB
|
||||
|
||||
|
||||
class ProjectHandler:
|
||||
|
||||
@ -248,64 +250,7 @@ class ProjectHandler:
|
||||
raise aiohttp.web.HTTPForbidden()
|
||||
path = os.path.join(project.path, path)
|
||||
|
||||
response.content_type = "application/octet-stream"
|
||||
response.set_status(200)
|
||||
response.enable_chunked_encoding()
|
||||
|
||||
try:
|
||||
with open(path, "rb") as f:
|
||||
await response.prepare(request)
|
||||
while True:
|
||||
data = f.read(4096)
|
||||
if not data:
|
||||
break
|
||||
await response.write(data)
|
||||
|
||||
except FileNotFoundError:
|
||||
raise aiohttp.web.HTTPNotFound()
|
||||
except PermissionError:
|
||||
raise aiohttp.web.HTTPForbidden()
|
||||
|
||||
@Route.get(
|
||||
r"/projects/{project_id}/stream/{path:.+}",
|
||||
description="Stream a file from a project",
|
||||
parameters={
|
||||
"project_id": "Project UUID",
|
||||
},
|
||||
status_codes={
|
||||
200: "File returned",
|
||||
403: "Permission denied",
|
||||
404: "The file doesn't exist"
|
||||
})
|
||||
async def stream_file(request, response):
|
||||
|
||||
pm = ProjectManager.instance()
|
||||
project = pm.get_project(request.match_info["project_id"])
|
||||
path = request.match_info["path"]
|
||||
path = os.path.normpath(path)
|
||||
|
||||
# Raise an error if user try to escape
|
||||
if path[0] == ".":
|
||||
raise aiohttp.web.HTTPForbidden()
|
||||
path = os.path.join(project.path, path)
|
||||
|
||||
response.content_type = "application/octet-stream"
|
||||
response.set_status(200)
|
||||
response.enable_chunked_encoding()
|
||||
|
||||
# FIXME: file streaming is never stopped
|
||||
try:
|
||||
with open(path, "rb") as f:
|
||||
await response.prepare(request)
|
||||
while True:
|
||||
data = f.read(4096)
|
||||
if not data:
|
||||
await asyncio.sleep(0.1)
|
||||
await response.write(data)
|
||||
except FileNotFoundError:
|
||||
raise aiohttp.web.HTTPNotFound()
|
||||
except PermissionError:
|
||||
raise aiohttp.web.HTTPForbidden()
|
||||
await response.stream_file(path)
|
||||
|
||||
@Route.post(
|
||||
r"/projects/{project_id}/files/{path:.+}",
|
||||
@ -338,7 +283,7 @@ class ProjectHandler:
|
||||
with open(path, 'wb+') as f:
|
||||
while True:
|
||||
try:
|
||||
chunk = await request.content.read(1024)
|
||||
chunk = await request.content.read(CHUNK_SIZE)
|
||||
except asyncio.TimeoutError:
|
||||
raise aiohttp.web.HTTPRequestTimeout(text="Timeout when writing to file '{}'".format(path))
|
||||
if not chunk:
|
||||
@ -349,64 +294,3 @@ class ProjectHandler:
|
||||
raise aiohttp.web.HTTPNotFound()
|
||||
except PermissionError:
|
||||
raise aiohttp.web.HTTPForbidden()
|
||||
|
||||
@Route.get(
|
||||
r"/projects/{project_id}/export",
|
||||
description="Export a project as a portable archive",
|
||||
parameters={
|
||||
"project_id": "Project UUID",
|
||||
},
|
||||
raw=True,
|
||||
status_codes={
|
||||
200: "File returned",
|
||||
404: "The project doesn't exist"
|
||||
})
|
||||
async def export_project(request, response):
|
||||
|
||||
pm = ProjectManager.instance()
|
||||
project = pm.get_project(request.match_info["project_id"])
|
||||
response.content_type = 'application/gns3project'
|
||||
response.headers['CONTENT-DISPOSITION'] = 'attachment; filename="{}.gns3project"'.format(project.name)
|
||||
response.enable_chunked_encoding()
|
||||
await response.prepare(request)
|
||||
|
||||
include_images = bool(int(request.json.get("include_images", "0")))
|
||||
for data in project.export(include_images=include_images):
|
||||
await response.write(data)
|
||||
|
||||
#await response.write_eof() #FIXME: shound't be needed anymore
|
||||
|
||||
@Route.post(
|
||||
r"/projects/{project_id}/import",
|
||||
description="Import a project from a portable archive",
|
||||
parameters={
|
||||
"project_id": "Project UUID",
|
||||
},
|
||||
raw=True,
|
||||
output=PROJECT_OBJECT_SCHEMA,
|
||||
status_codes={
|
||||
200: "Project imported",
|
||||
403: "Forbidden to import project"
|
||||
})
|
||||
async def import_project(request, response):
|
||||
|
||||
pm = ProjectManager.instance()
|
||||
project_id = request.match_info["project_id"]
|
||||
project = pm.create_project(project_id=project_id)
|
||||
|
||||
# We write the content to a temporary location and after we extract it all.
|
||||
# It could be more optimal to stream this but it is not implemented in Python.
|
||||
# Spooled means the file is temporary kept in memory until max_size is reached
|
||||
try:
|
||||
with tempfile.SpooledTemporaryFile(max_size=10000) as temp:
|
||||
while True:
|
||||
chunk = await request.content.read(1024)
|
||||
if not chunk:
|
||||
break
|
||||
temp.write(chunk)
|
||||
project.import_zip(temp, gns3vm=bool(int(request.GET.get("gns3vm", "1"))))
|
||||
except OSError as e:
|
||||
raise aiohttp.web.HTTPInternalServerError(text="Could not import the project: {}".format(e))
|
||||
|
||||
response.json(project)
|
||||
response.set_status(201)
|
||||
|
@ -576,4 +576,4 @@ class QEMUHandler:
|
||||
if filename[0] == ".":
|
||||
raise aiohttp.web.HTTPForbidden()
|
||||
|
||||
await response.file(image_path)
|
||||
await response.stream_file(image_path)
|
||||
|
@ -16,11 +16,11 @@
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import tempfile
|
||||
import zipfile
|
||||
import aiofiles
|
||||
import time
|
||||
|
||||
from gns3server.web.route import Route
|
||||
@ -51,6 +51,8 @@ async def process_websocket(ws):
|
||||
except aiohttp.WSServerHandshakeError:
|
||||
pass
|
||||
|
||||
CHUNK_SIZE = 1024 * 8 # 8KB
|
||||
|
||||
|
||||
class ProjectHandler:
|
||||
|
||||
@ -304,7 +306,6 @@ class ProjectHandler:
|
||||
controller = Controller.instance()
|
||||
project = await controller.get_loaded_project(request.match_info["project_id"])
|
||||
|
||||
|
||||
try:
|
||||
begin = time.time()
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
@ -321,8 +322,8 @@ class ProjectHandler:
|
||||
async for chunk in zstream:
|
||||
await response.write(chunk)
|
||||
|
||||
log.info("Project '{}' exported in {:.4f} seconds".format(project.id, time.time() - begin))
|
||||
#await response.write_eof() #FIXME: shound't be needed anymore
|
||||
log.info("Project '{}' exported in {:.4f} seconds".format(project.name, time.time() - begin))
|
||||
|
||||
# Will be raise if you have no space left or permission issue on your temporary directory
|
||||
# RuntimeError: something was wrong during the zip process
|
||||
except (ValueError, OSError, RuntimeError) as e:
|
||||
@ -354,29 +355,23 @@ class ProjectHandler:
|
||||
|
||||
# We write the content to a temporary location and after we extract it all.
|
||||
# It could be more optimal to stream this but it is not implemented in Python.
|
||||
# Spooled means the file is temporary kept in memory until max_size is reached
|
||||
# Cannot use tempfile.SpooledTemporaryFile(max_size=10000) in Python 3.7 due
|
||||
# to a bug https://bugs.python.org/issue26175
|
||||
try:
|
||||
if sys.version_info >= (3, 7) and sys.version_info < (3, 8):
|
||||
with tempfile.TemporaryFile() as temp:
|
||||
begin = time.time()
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
temp_project_path = os.path.join(tmpdir, "project.zip")
|
||||
async with aiofiles.open(temp_project_path, 'wb') as f:
|
||||
while True:
|
||||
chunk = await request.content.read(1024)
|
||||
chunk = await request.content.read(CHUNK_SIZE)
|
||||
if not chunk:
|
||||
break
|
||||
temp.write(chunk)
|
||||
project = await import_project(controller, request.match_info["project_id"], temp, location=path, name=name)
|
||||
else:
|
||||
with tempfile.SpooledTemporaryFile(max_size=10000) as temp:
|
||||
while True:
|
||||
chunk = await request.content.read(1024)
|
||||
if not chunk:
|
||||
break
|
||||
temp.write(chunk)
|
||||
project = await import_project(controller, request.match_info["project_id"], temp, location=path, name=name)
|
||||
await f.write(chunk)
|
||||
|
||||
with open(temp_project_path, "rb") as f:
|
||||
project = await import_project(controller, request.match_info["project_id"], f, location=path, name=name)
|
||||
|
||||
log.info("Project '{}' imported in {:.4f} seconds".format(project.name, time.time() - begin))
|
||||
except OSError as e:
|
||||
raise aiohttp.web.HTTPInternalServerError(text="Could not import the project: {}".format(e))
|
||||
|
||||
response.json(project)
|
||||
response.set_status(201)
|
||||
|
||||
@ -443,7 +438,7 @@ class ProjectHandler:
|
||||
with open(path, "rb") as f:
|
||||
await response.prepare(request)
|
||||
while True:
|
||||
data = f.read(4096)
|
||||
data = f.read(CHUNK_SIZE)
|
||||
if not data:
|
||||
break
|
||||
await response.write(data)
|
||||
@ -483,7 +478,7 @@ class ProjectHandler:
|
||||
with open(path, 'wb+') as f:
|
||||
while True:
|
||||
try:
|
||||
chunk = await request.content.read(1024)
|
||||
chunk = await request.content.read(CHUNK_SIZE)
|
||||
except asyncio.TimeoutError:
|
||||
raise aiohttp.web.HTTPRequestTimeout(text="Timeout when writing to file '{}'".format(path))
|
||||
if not chunk:
|
||||
|
@ -53,7 +53,7 @@ class SymbolHandler:
|
||||
|
||||
controller = Controller.instance()
|
||||
try:
|
||||
await response.file(controller.symbols.get_path(request.match_info["symbol_id"]))
|
||||
await response.stream_file(controller.symbols.get_path(request.match_info["symbol_id"]))
|
||||
except (KeyError, OSError) as e:
|
||||
log.warning("Could not get symbol file: {}".format(e))
|
||||
response.set_status(404)
|
||||
|
@ -92,7 +92,7 @@ class IndexHandler:
|
||||
if not os.path.exists(static):
|
||||
static = get_static_path(os.path.join('web-ui', 'index.html'))
|
||||
|
||||
await response.file(static)
|
||||
await response.stream_file(static)
|
||||
|
||||
@Route.get(
|
||||
r"/v1/version",
|
||||
|
@ -20,7 +20,7 @@ import jsonschema
|
||||
import aiohttp
|
||||
import aiohttp.web
|
||||
import mimetypes
|
||||
import asyncio
|
||||
import aiofiles
|
||||
import logging
|
||||
import jinja2
|
||||
import sys
|
||||
@ -32,6 +32,8 @@ from ..version import __version__
|
||||
log = logging.getLogger(__name__)
|
||||
renderer = jinja2.Environment(loader=jinja2.FileSystemLoader(get_resource('templates')))
|
||||
|
||||
CHUNK_SIZE = 1024 * 8 # 8KB
|
||||
|
||||
|
||||
class Response(aiohttp.web.Response):
|
||||
|
||||
@ -112,16 +114,21 @@ class Response(aiohttp.web.Response):
|
||||
raise aiohttp.web.HTTPBadRequest(text="{}".format(e))
|
||||
self.body = json.dumps(answer, indent=4, sort_keys=True).encode('utf-8')
|
||||
|
||||
async def file(self, path, status=200, set_content_length=True):
|
||||
async def stream_file(self, path, status=200, set_content_type=None, set_content_length=True):
|
||||
"""
|
||||
Return a file as a response
|
||||
Stream a file as a response
|
||||
"""
|
||||
|
||||
if not os.path.exists(path):
|
||||
raise aiohttp.web.HTTPNotFound()
|
||||
|
||||
if not set_content_type:
|
||||
ct, encoding = mimetypes.guess_type(path)
|
||||
if not ct:
|
||||
ct = 'application/octet-stream'
|
||||
else:
|
||||
ct = set_content_type
|
||||
|
||||
if encoding:
|
||||
self.headers[aiohttp.hdrs.CONTENT_ENCODING] = encoding
|
||||
self.content_type = ct
|
||||
@ -136,16 +143,13 @@ class Response(aiohttp.web.Response):
|
||||
self.set_status(status)
|
||||
|
||||
try:
|
||||
with open(path, 'rb') as fobj:
|
||||
async with aiofiles.open(path, 'rb') as f:
|
||||
await self.prepare(self._request)
|
||||
|
||||
while True:
|
||||
data = fobj.read(4096)
|
||||
data = await f.read(CHUNK_SIZE)
|
||||
if not data:
|
||||
break
|
||||
await self.write(data)
|
||||
# await self.drain()
|
||||
|
||||
except FileNotFoundError:
|
||||
raise aiohttp.web.HTTPNotFound()
|
||||
except PermissionError:
|
||||
|
@ -293,15 +293,6 @@ def test_json(compute):
|
||||
}
|
||||
|
||||
|
||||
def test_streamFile(project, async_run, compute):
|
||||
response = MagicMock()
|
||||
response.status = 200
|
||||
with asyncio_patch("aiohttp.ClientSession.request", return_value=response) as mock:
|
||||
async_run(compute.stream_file(project, "test/titi", timeout=120))
|
||||
mock.assert_called_with("GET", "https://example.com:84/v2/compute/projects/{}/stream/test/titi".format(project.id), auth=None, timeout=120)
|
||||
async_run(compute.close())
|
||||
|
||||
|
||||
def test_downloadFile(project, async_run, compute):
|
||||
response = MagicMock()
|
||||
response.status = 200
|
||||
@ -310,6 +301,7 @@ def test_downloadFile(project, async_run, compute):
|
||||
mock.assert_called_with("GET", "https://example.com:84/v2/compute/projects/{}/files/test/titi".format(project.id), auth=None)
|
||||
async_run(compute.close())
|
||||
|
||||
|
||||
def test_close(compute, async_run):
|
||||
assert compute.connected is True
|
||||
async_run(compute.close())
|
||||
|
@ -34,11 +34,11 @@ def test_response_file(async_run, tmpdir, response):
|
||||
with open(filename, 'w+') as f:
|
||||
f.write('world')
|
||||
|
||||
async_run(response.file(filename))
|
||||
async_run(response.stream_file(filename))
|
||||
assert response.status == 200
|
||||
|
||||
|
||||
def test_response_file_not_found(async_run, tmpdir, response):
|
||||
filename = str(tmpdir / 'hello-not-found')
|
||||
|
||||
pytest.raises(HTTPNotFound, lambda: async_run(response.file(filename)))
|
||||
pytest.raises(HTTPNotFound, lambda: async_run(response.stream_file(filename)))
|
||||
|
Loading…
Reference in New Issue
Block a user