mirror of
https://github.com/GNS3/gns3-server
synced 2025-01-23 22:41:02 +00:00
Add zstandard compression support for project export
This commit is contained in:
parent
37c7bc4956
commit
8a964390f8
@ -21,10 +21,10 @@ API routes for projects.
|
||||
import os
|
||||
import asyncio
|
||||
import tempfile
|
||||
import zipfile
|
||||
import aiofiles
|
||||
import time
|
||||
import urllib.parse
|
||||
import gns3server.utils.zipfile_zstd as zipfile
|
||||
|
||||
import logging
|
||||
|
||||
@ -285,7 +285,7 @@ async def export_project(
|
||||
include_snapshots: bool = False,
|
||||
include_images: bool = False,
|
||||
reset_mac_addresses: bool = False,
|
||||
compression: str = "zip",
|
||||
compression: schemas.ProjectCompression = "zstd",
|
||||
) -> StreamingResponse:
|
||||
"""
|
||||
Export a project as a portable archive.
|
||||
@ -300,6 +300,8 @@ async def export_project(
|
||||
compression = zipfile.ZIP_BZIP2
|
||||
elif compression_query == "lzma":
|
||||
compression = zipfile.ZIP_LZMA
|
||||
elif compression_query == "zstd":
|
||||
compression = zipfile.ZIP_ZSTANDARD
|
||||
|
||||
try:
|
||||
begin = time.time()
|
||||
|
@ -28,7 +28,7 @@ from .controller.appliances import ApplianceVersion, Appliance
|
||||
from .controller.drawings import Drawing
|
||||
from .controller.gns3vm import GNS3VM
|
||||
from .controller.nodes import NodeCreate, NodeUpdate, NodeDuplicate, NodeCapture, Node
|
||||
from .controller.projects import ProjectCreate, ProjectUpdate, ProjectDuplicate, Project, ProjectFile
|
||||
from .controller.projects import ProjectCreate, ProjectUpdate, ProjectDuplicate, Project, ProjectFile, ProjectCompression
|
||||
from .controller.users import UserCreate, UserUpdate, LoggedInUserUpdate, User, Credentials, UserGroupCreate, UserGroupUpdate, UserGroup
|
||||
from .controller.rbac import RoleCreate, RoleUpdate, Role, PermissionCreate, PermissionUpdate, Permission
|
||||
from .controller.tokens import Token
|
||||
|
@ -102,3 +102,15 @@ class ProjectFile(BaseModel):
|
||||
|
||||
path: str = Field(..., description="File path")
|
||||
md5sum: str = Field(..., description="File checksum")
|
||||
|
||||
|
||||
class ProjectCompression(str, Enum):
|
||||
"""
|
||||
Supported project compression.
|
||||
"""
|
||||
|
||||
none = "none"
|
||||
zip = "zip"
|
||||
bzip2 = "bzip2"
|
||||
lzma = "lzma"
|
||||
zstd = "zstd"
|
||||
|
@ -43,26 +43,37 @@ from zipfile import (
|
||||
stringEndArchive64Locator,
|
||||
)
|
||||
|
||||
|
||||
ZIP_ZSTANDARD = 93 # zstandard is supported by WinZIP v24 and later, PowerArchiver 2021 and 7-Zip-zstd
|
||||
ZSTANDARD_VERSION = 20
|
||||
stringDataDescriptor = b"PK\x07\x08" # magic number for data descriptor
|
||||
|
||||
|
||||
def _get_compressor(compress_type):
|
||||
def _get_compressor(compress_type, compresslevel=None):
|
||||
"""
|
||||
Return the compressor.
|
||||
"""
|
||||
|
||||
if compress_type == zipfile.ZIP_DEFLATED:
|
||||
from zipfile import zlib
|
||||
|
||||
if compresslevel is not None:
|
||||
return zlib.compressobj(compresslevel, zlib.DEFLATED, -15)
|
||||
return zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -15)
|
||||
elif compress_type == zipfile.ZIP_BZIP2:
|
||||
from zipfile import bz2
|
||||
|
||||
if compresslevel is not None:
|
||||
return bz2.BZ2Compressor(compresslevel)
|
||||
return bz2.BZ2Compressor()
|
||||
# compresslevel is ignored for ZIP_LZMA
|
||||
elif compress_type == zipfile.ZIP_LZMA:
|
||||
from zipfile import LZMACompressor
|
||||
|
||||
return LZMACompressor()
|
||||
elif compress_type == ZIP_ZSTANDARD:
|
||||
import zstandard as zstd
|
||||
if compresslevel is not None:
|
||||
params = zstd.ZstdCompressionParameters.from_level(compresslevel, threads=-1, enable_ldm=True, window_log=31)
|
||||
return zstd.ZstdCompressor(compression_params=params).compressobj()
|
||||
return zstd.ZstdCompressor().compressobj()
|
||||
else:
|
||||
return None
|
||||
|
||||
@ -129,7 +140,15 @@ class ZipInfo(zipfile.ZipInfo):
|
||||
|
||||
|
||||
class ZipFile(zipfile.ZipFile):
|
||||
def __init__(self, fileobj=None, mode="w", compression=zipfile.ZIP_STORED, allowZip64=True, chunksize=32768):
|
||||
def __init__(
|
||||
self,
|
||||
fileobj=None,
|
||||
mode="w",
|
||||
compression=zipfile.ZIP_STORED,
|
||||
allowZip64=True,
|
||||
compresslevel=None,
|
||||
chunksize=32768
|
||||
):
|
||||
"""Open the ZIP file with mode write "w"."""
|
||||
|
||||
if mode not in ("w",):
|
||||
@ -138,7 +157,13 @@ class ZipFile(zipfile.ZipFile):
|
||||
fileobj = PointerIO()
|
||||
|
||||
self._comment = b""
|
||||
zipfile.ZipFile.__init__(self, fileobj, mode=mode, compression=compression, allowZip64=allowZip64)
|
||||
zipfile.ZipFile.__init__(
|
||||
self, fileobj,
|
||||
mode=mode,
|
||||
compression=compression,
|
||||
compresslevel=compresslevel,
|
||||
allowZip64=allowZip64
|
||||
)
|
||||
self._chunksize = chunksize
|
||||
self.paths_to_write = []
|
||||
|
||||
@ -195,23 +220,33 @@ class ZipFile(zipfile.ZipFile):
|
||||
for chunk in self._close():
|
||||
yield chunk
|
||||
|
||||
def write(self, filename, arcname=None, compress_type=None):
|
||||
def write(self, filename, arcname=None, compress_type=None, compresslevel=None):
|
||||
"""
|
||||
Write a file to the archive under the name `arcname`.
|
||||
"""
|
||||
|
||||
kwargs = {"filename": filename, "arcname": arcname, "compress_type": compress_type}
|
||||
kwargs = {
|
||||
"filename": filename,
|
||||
"arcname": arcname,
|
||||
"compress_type": compress_type,
|
||||
"compresslevel": compresslevel
|
||||
}
|
||||
self.paths_to_write.append(kwargs)
|
||||
|
||||
def write_iter(self, arcname, iterable, compress_type=None):
|
||||
def write_iter(self, arcname, iterable, compress_type=None, compresslevel=None):
|
||||
"""
|
||||
Write the bytes iterable `iterable` to the archive under the name `arcname`.
|
||||
"""
|
||||
|
||||
kwargs = {"arcname": arcname, "iterable": iterable, "compress_type": compress_type}
|
||||
kwargs = {
|
||||
"arcname": arcname,
|
||||
"iterable": iterable,
|
||||
"compress_type": compress_type,
|
||||
"compresslevel": compresslevel
|
||||
}
|
||||
self.paths_to_write.append(kwargs)
|
||||
|
||||
def writestr(self, arcname, data, compress_type=None):
|
||||
def writestr(self, arcname, data, compress_type=None, compresslevel=None):
|
||||
"""
|
||||
Writes a str into ZipFile by wrapping data as a generator
|
||||
"""
|
||||
@ -219,9 +254,9 @@ class ZipFile(zipfile.ZipFile):
|
||||
def _iterable():
|
||||
yield data
|
||||
|
||||
return self.write_iter(arcname, _iterable(), compress_type=compress_type)
|
||||
return self.write_iter(arcname, _iterable(), compress_type=compress_type, compresslevel=compresslevel)
|
||||
|
||||
async def _write(self, filename=None, iterable=None, arcname=None, compress_type=None):
|
||||
async def _write(self, filename=None, iterable=None, arcname=None, compress_type=None, compresslevel=None):
|
||||
"""
|
||||
Put the bytes from filename into the archive under the name `arcname`.
|
||||
"""
|
||||
@ -256,6 +291,11 @@ class ZipFile(zipfile.ZipFile):
|
||||
else:
|
||||
zinfo.compress_type = compress_type
|
||||
|
||||
if compresslevel is None:
|
||||
zinfo._compresslevel = self.compresslevel
|
||||
else:
|
||||
zinfo._compresslevel = compresslevel
|
||||
|
||||
if st:
|
||||
zinfo.file_size = st[6]
|
||||
else:
|
||||
@ -279,7 +319,7 @@ class ZipFile(zipfile.ZipFile):
|
||||
yield self.fp.write(zinfo.FileHeader(False))
|
||||
return
|
||||
|
||||
cmpr = _get_compressor(zinfo.compress_type)
|
||||
cmpr = _get_compressor(zinfo.compress_type, zinfo._compresslevel)
|
||||
|
||||
# Must overwrite CRC and sizes with correct data later
|
||||
zinfo.CRC = CRC = 0
|
||||
@ -369,6 +409,8 @@ class ZipFile(zipfile.ZipFile):
|
||||
min_version = max(zipfile.BZIP2_VERSION, min_version)
|
||||
elif zinfo.compress_type == zipfile.ZIP_LZMA:
|
||||
min_version = max(zipfile.LZMA_VERSION, min_version)
|
||||
elif zinfo.compress_type == ZIP_ZSTANDARD:
|
||||
min_version = max(ZSTANDARD_VERSION, min_version)
|
||||
|
||||
extract_version = max(min_version, zinfo.extract_version)
|
||||
create_version = max(min_version, zinfo.create_version)
|
||||
|
10
gns3server/utils/zipfile_zstd/__init__.py
Normal file
10
gns3server/utils/zipfile_zstd/__init__.py
Normal file
@ -0,0 +1,10 @@
|
||||
|
||||
# NOTE: this patches the standard zipfile module
|
||||
from . import _zipfile
|
||||
|
||||
from zipfile import *
|
||||
from zipfile import (
|
||||
ZIP_ZSTANDARD,
|
||||
ZSTANDARD_VERSION,
|
||||
)
|
||||
|
20
gns3server/utils/zipfile_zstd/_patcher.py
Normal file
20
gns3server/utils/zipfile_zstd/_patcher.py
Normal file
@ -0,0 +1,20 @@
|
||||
import functools
|
||||
|
||||
|
||||
class patch:
|
||||
|
||||
originals = {}
|
||||
|
||||
def __init__(self, host, name):
|
||||
self.host = host
|
||||
self.name = name
|
||||
|
||||
def __call__(self, func):
|
||||
original = getattr(self.host, self.name)
|
||||
self.originals[self.name] = original
|
||||
|
||||
functools.update_wrapper(func, original)
|
||||
setattr(self.host, self.name, func)
|
||||
|
||||
return func
|
||||
|
64
gns3server/utils/zipfile_zstd/_zipfile.py
Normal file
64
gns3server/utils/zipfile_zstd/_zipfile.py
Normal file
@ -0,0 +1,64 @@
|
||||
import zipfile
|
||||
import zstandard as zstd
|
||||
import inspect
|
||||
|
||||
from ._patcher import patch
|
||||
|
||||
|
||||
zipfile.ZIP_ZSTANDARD = 93
|
||||
zipfile.compressor_names[zipfile.ZIP_ZSTANDARD] = 'zstandard'
|
||||
zipfile.ZSTANDARD_VERSION = 20
|
||||
|
||||
|
||||
@patch(zipfile, '_check_compression')
|
||||
def zstd_check_compression(compression):
|
||||
if compression == zipfile.ZIP_ZSTANDARD:
|
||||
pass
|
||||
else:
|
||||
patch.originals['_check_compression'](compression)
|
||||
|
||||
|
||||
class ZstdDecompressObjWrapper:
|
||||
def __init__(self, o):
|
||||
self.o = o
|
||||
|
||||
def __getattr__(self, attr):
|
||||
if attr == 'eof':
|
||||
return False
|
||||
return getattr(self.o, attr)
|
||||
|
||||
|
||||
@patch(zipfile, '_get_decompressor')
|
||||
def zstd_get_decompressor(compress_type):
|
||||
if compress_type == zipfile.ZIP_ZSTANDARD:
|
||||
return ZstdDecompressObjWrapper(zstd.ZstdDecompressor(max_window_size=2147483648).decompressobj())
|
||||
else:
|
||||
return patch.originals['_get_decompressor'](compress_type)
|
||||
|
||||
|
||||
if 'compresslevel' in inspect.signature(zipfile._get_compressor).parameters:
|
||||
@patch(zipfile, '_get_compressor')
|
||||
def zstd_get_compressor(compress_type, compresslevel=None):
|
||||
if compress_type == zipfile.ZIP_ZSTANDARD:
|
||||
if compresslevel is None:
|
||||
compresslevel = 3
|
||||
return zstd.ZstdCompressor(level=compresslevel, threads=12).compressobj()
|
||||
else:
|
||||
return patch.originals['_get_compressor'](compress_type, compresslevel=compresslevel)
|
||||
else:
|
||||
@patch(zipfile, '_get_compressor')
|
||||
def zstd_get_compressor(compress_type, compresslevel=None):
|
||||
if compress_type == zipfile.ZIP_ZSTANDARD:
|
||||
if compresslevel is None:
|
||||
compresslevel = 3
|
||||
return zstd.ZstdCompressor(level=compresslevel, threads=12).compressobj()
|
||||
else:
|
||||
return patch.originals['_get_compressor'](compress_type)
|
||||
|
||||
|
||||
@patch(zipfile.ZipInfo, 'FileHeader')
|
||||
def zstd_FileHeader(self, zip64=None):
|
||||
if self.compress_type == zipfile.ZIP_ZSTANDARD:
|
||||
self.create_version = max(self.create_version, zipfile.ZSTANDARD_VERSION)
|
||||
self.extract_version = max(self.extract_version, zipfile.ZSTANDARD_VERSION)
|
||||
return patch.originals['FileHeader'](self, zip64=zip64)
|
Loading…
Reference in New Issue
Block a user