mirror of
https://github.com/etesync/server
synced 2024-12-28 18:28:07 +00:00
Merge: change the server to use FastAPI
FastAPI is much faster (twice as fast in our testing environment), though more importantly it's much faster to develop with, much less error-prone thanks to strong typing, and makes it easier to further extend the server. We currently still use the Django ORM behind the scenes, which means we still get all of the benefits of the django admin UI, and being able to use django for the non-API parts. Merge of #72
This commit is contained in:
commit
259e395c92
@ -62,7 +62,7 @@ Now you can initialise our django app.
|
||||
And you are done! You can now run the debug server just to see everything works as expected by running:
|
||||
|
||||
```
|
||||
./manage.py runserver 0.0.0.0:8000
|
||||
uvicorn etebase_server.asgi:application --port 8000
|
||||
```
|
||||
|
||||
Using the debug server in production is not recommended, so please read the following section for a proper deployment.
|
||||
|
@ -1,3 +0,0 @@
|
||||
from django.contrib import admin
|
||||
|
||||
# Register your models here.
|
@ -32,22 +32,16 @@ class AppSettings:
|
||||
return getattr(settings, self.prefix + name, dflt)
|
||||
|
||||
@cached_property
|
||||
def API_PERMISSIONS(self): # pylint: disable=invalid-name
|
||||
perms = self._setting("API_PERMISSIONS", ("rest_framework.permissions.IsAuthenticated",))
|
||||
def API_PERMISSIONS_READ(self): # pylint: disable=invalid-name
|
||||
perms = self._setting("API_PERMISSIONS_READ", tuple())
|
||||
ret = []
|
||||
for perm in perms:
|
||||
ret.append(self.import_from_str(perm))
|
||||
return ret
|
||||
|
||||
@cached_property
|
||||
def API_AUTHENTICATORS(self): # pylint: disable=invalid-name
|
||||
perms = self._setting(
|
||||
"API_AUTHENTICATORS",
|
||||
(
|
||||
"rest_framework.authentication.TokenAuthentication",
|
||||
"rest_framework.authentication.SessionAuthentication",
|
||||
),
|
||||
)
|
||||
def API_PERMISSIONS_WRITE(self): # pylint: disable=invalid-name
|
||||
perms = self._setting("API_PERMISSIONS_WRITE", tuple())
|
||||
ret = []
|
||||
for perm in perms:
|
||||
ret.append(self.import_from_str(perm))
|
||||
|
@ -1,5 +0,0 @@
|
||||
from django.apps import AppConfig
|
||||
|
||||
|
||||
class DrfMsgpackConfig(AppConfig):
|
||||
name = "drf_msgpack"
|
@ -1,14 +0,0 @@
|
||||
import msgpack
|
||||
|
||||
from rest_framework.parsers import BaseParser
|
||||
from rest_framework.exceptions import ParseError
|
||||
|
||||
|
||||
class MessagePackParser(BaseParser):
|
||||
media_type = "application/msgpack"
|
||||
|
||||
def parse(self, stream, media_type=None, parser_context=None):
|
||||
try:
|
||||
return msgpack.unpackb(stream.read(), raw=False)
|
||||
except Exception as exc:
|
||||
raise ParseError("MessagePack parse error - %s" % str(exc))
|
@ -1,15 +0,0 @@
|
||||
import msgpack
|
||||
|
||||
from rest_framework.renderers import BaseRenderer
|
||||
|
||||
|
||||
class MessagePackRenderer(BaseRenderer):
|
||||
media_type = "application/msgpack"
|
||||
format = "msgpack"
|
||||
render_style = "binary"
|
||||
charset = None
|
||||
|
||||
def render(self, data, media_type=None, renderer_context=None):
|
||||
if data is None:
|
||||
return b""
|
||||
return msgpack.packb(data, use_bin_type=True)
|
@ -1,3 +0,0 @@
|
||||
from django.shortcuts import render
|
||||
|
||||
# Create your views here.
|
@ -1,12 +0,0 @@
|
||||
from rest_framework import serializers, status
|
||||
|
||||
|
||||
class EtebaseValidationError(serializers.ValidationError):
|
||||
def __init__(self, code, detail, status_code=status.HTTP_400_BAD_REQUEST):
|
||||
super().__init__(
|
||||
{
|
||||
"code": code,
|
||||
"detail": detail,
|
||||
}
|
||||
)
|
||||
self.status_code = status_code
|
@ -1,14 +0,0 @@
|
||||
from rest_framework.parsers import FileUploadParser
|
||||
|
||||
|
||||
class ChunkUploadParser(FileUploadParser):
|
||||
"""
|
||||
Parser for chunk upload data.
|
||||
"""
|
||||
|
||||
def get_filename(self, stream, media_type, parser_context):
|
||||
"""
|
||||
Detects the uploaded file name.
|
||||
"""
|
||||
view = parser_context["view"]
|
||||
return parser_context["kwargs"][view.lookup_field]
|
@ -1,93 +0,0 @@
|
||||
# Copyright © 2017 Tom Hacohen
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as
|
||||
# published by the Free Software Foundation, version 3.
|
||||
#
|
||||
# This library is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
from rest_framework import permissions
|
||||
from django_etebase.models import Collection, AccessLevels
|
||||
|
||||
|
||||
def is_collection_admin(collection, user):
|
||||
member = collection.members.filter(user=user).first()
|
||||
return (member is not None) and (member.accessLevel == AccessLevels.ADMIN)
|
||||
|
||||
|
||||
class IsCollectionAdmin(permissions.BasePermission):
|
||||
"""
|
||||
Custom permission to only allow owners of a collection to view it
|
||||
"""
|
||||
|
||||
message = {
|
||||
"detail": "Only collection admins can perform this operation.",
|
||||
"code": "admin_access_required",
|
||||
}
|
||||
|
||||
def has_permission(self, request, view):
|
||||
collection_uid = view.kwargs["collection_uid"]
|
||||
try:
|
||||
collection = view.get_collection_queryset().get(main_item__uid=collection_uid)
|
||||
return is_collection_admin(collection, request.user)
|
||||
except Collection.DoesNotExist:
|
||||
# If the collection does not exist, we want to 404 later, not permission denied.
|
||||
return True
|
||||
|
||||
|
||||
class IsCollectionAdminOrReadOnly(permissions.BasePermission):
|
||||
"""
|
||||
Custom permission to only allow owners of a collection to edit it
|
||||
"""
|
||||
|
||||
message = {
|
||||
"detail": "Only collection admins can edit collections.",
|
||||
"code": "admin_access_required",
|
||||
}
|
||||
|
||||
def has_permission(self, request, view):
|
||||
collection_uid = view.kwargs.get("collection_uid", None)
|
||||
|
||||
# Allow creating new collections
|
||||
if collection_uid is None:
|
||||
return True
|
||||
|
||||
try:
|
||||
collection = view.get_collection_queryset().get(main_item__uid=collection_uid)
|
||||
if request.method in permissions.SAFE_METHODS:
|
||||
return True
|
||||
|
||||
return is_collection_admin(collection, request.user)
|
||||
except Collection.DoesNotExist:
|
||||
# If the collection does not exist, we want to 404 later, not permission denied.
|
||||
return True
|
||||
|
||||
|
||||
class HasWriteAccessOrReadOnly(permissions.BasePermission):
|
||||
"""
|
||||
Custom permission to restrict write
|
||||
"""
|
||||
|
||||
message = {
|
||||
"detail": "You need write access to write to this collection",
|
||||
"code": "no_write_access",
|
||||
}
|
||||
|
||||
def has_permission(self, request, view):
|
||||
collection_uid = view.kwargs["collection_uid"]
|
||||
try:
|
||||
collection = view.get_collection_queryset().get(main_item__uid=collection_uid)
|
||||
if request.method in permissions.SAFE_METHODS:
|
||||
return True
|
||||
else:
|
||||
member = collection.members.get(user=request.user)
|
||||
return member.accessLevel != AccessLevels.READ_ONLY
|
||||
except Collection.DoesNotExist:
|
||||
# If the collection does not exist, we want to 404 later, not permission denied.
|
||||
return True
|
@ -1,19 +0,0 @@
|
||||
from rest_framework.utils.encoders import JSONEncoder as DRFJSONEncoder
|
||||
from rest_framework.renderers import JSONRenderer as DRFJSONRenderer
|
||||
|
||||
from .serializers import b64encode
|
||||
|
||||
|
||||
class JSONEncoder(DRFJSONEncoder):
|
||||
def default(self, obj):
|
||||
if isinstance(obj, bytes) or isinstance(obj, memoryview):
|
||||
return b64encode(obj)
|
||||
return super().default(obj)
|
||||
|
||||
|
||||
class JSONRenderer(DRFJSONRenderer):
|
||||
"""
|
||||
Renderer which serializes to JSON with support for our base64
|
||||
"""
|
||||
|
||||
encoder_class = JSONEncoder
|
@ -1,17 +0,0 @@
|
||||
import os.path
|
||||
|
||||
from django.views.static import serve
|
||||
|
||||
|
||||
def sendfile(request, filename, **kwargs):
|
||||
"""
|
||||
Send file using Django dev static file server.
|
||||
|
||||
.. warning::
|
||||
|
||||
Do not use in production. This is only to be used when developing and
|
||||
is provided for convenience only
|
||||
"""
|
||||
dirname = os.path.dirname(filename)
|
||||
basename = os.path.basename(filename)
|
||||
return serve(request, basename, dirname)
|
@ -1,17 +0,0 @@
|
||||
from __future__ import absolute_import
|
||||
|
||||
from django.http import HttpResponse
|
||||
|
||||
from ..utils import _convert_file_to_url
|
||||
|
||||
|
||||
def sendfile(request, filename, **kwargs):
|
||||
response = HttpResponse()
|
||||
response['Location'] = _convert_file_to_url(filename)
|
||||
# need to destroy get_host() to stop django
|
||||
# rewriting our location to include http, so that
|
||||
# mod_wsgi is able to do the internal redirect
|
||||
request.get_host = lambda: ''
|
||||
request.build_absolute_uri = lambda location: location
|
||||
|
||||
return response
|
@ -1,12 +0,0 @@
|
||||
from __future__ import absolute_import
|
||||
|
||||
from django.http import HttpResponse
|
||||
|
||||
from ..utils import _convert_file_to_url
|
||||
|
||||
|
||||
def sendfile(request, filename, **kwargs):
|
||||
response = HttpResponse()
|
||||
response['X-Accel-Redirect'] = _convert_file_to_url(filename)
|
||||
|
||||
return response
|
@ -1,60 +0,0 @@
|
||||
from email.utils import mktime_tz, parsedate_tz
|
||||
import re
|
||||
|
||||
from django.core.files.base import File
|
||||
from django.http import HttpResponse, HttpResponseNotModified
|
||||
from django.utils.http import http_date
|
||||
|
||||
|
||||
def sendfile(request, filepath, **kwargs):
|
||||
'''Use the SENDFILE_ROOT value composed with the path arrived as argument
|
||||
to build an absolute path with which resolve and return the file contents.
|
||||
|
||||
If the path points to a file out of the root directory (should cover both
|
||||
situations with '..' and symlinks) then a 404 is raised.
|
||||
'''
|
||||
statobj = filepath.stat()
|
||||
|
||||
# Respect the If-Modified-Since header.
|
||||
if not was_modified_since(request.META.get('HTTP_IF_MODIFIED_SINCE'),
|
||||
statobj.st_mtime, statobj.st_size):
|
||||
return HttpResponseNotModified()
|
||||
|
||||
with File(filepath.open('rb')) as f:
|
||||
response = HttpResponse(f.chunks())
|
||||
|
||||
response["Last-Modified"] = http_date(statobj.st_mtime)
|
||||
return response
|
||||
|
||||
|
||||
def was_modified_since(header=None, mtime=0, size=0):
|
||||
"""
|
||||
Was something modified since the user last downloaded it?
|
||||
|
||||
header
|
||||
This is the value of the If-Modified-Since header. If this is None,
|
||||
I'll just return True.
|
||||
|
||||
mtime
|
||||
This is the modification time of the item we're talking about.
|
||||
|
||||
size
|
||||
This is the size of the item we're talking about.
|
||||
"""
|
||||
try:
|
||||
if header is None:
|
||||
raise ValueError
|
||||
matches = re.match(r"^([^;]+)(; length=([0-9]+))?$", header,
|
||||
re.IGNORECASE)
|
||||
header_date = parsedate_tz(matches.group(1))
|
||||
if header_date is None:
|
||||
raise ValueError
|
||||
header_mtime = mktime_tz(header_date)
|
||||
header_len = matches.group(3)
|
||||
if header_len and int(header_len) != size:
|
||||
raise ValueError
|
||||
if mtime > header_mtime:
|
||||
raise ValueError
|
||||
except (AttributeError, ValueError, OverflowError):
|
||||
return True
|
||||
return False
|
@ -1,9 +0,0 @@
|
||||
from django.http import HttpResponse
|
||||
|
||||
|
||||
def sendfile(request, filename, **kwargs):
|
||||
filename = str(filename)
|
||||
response = HttpResponse()
|
||||
response['X-Sendfile'] = filename
|
||||
|
||||
return response
|
@ -1,598 +0,0 @@
|
||||
# Copyright © 2017 Tom Hacohen
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as
|
||||
# published by the Free Software Foundation, version 3.
|
||||
#
|
||||
# This library is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import base64
|
||||
|
||||
from django.core.files.base import ContentFile
|
||||
from django.core import exceptions as django_exceptions
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.db import IntegrityError, transaction
|
||||
from rest_framework import serializers, status
|
||||
from . import models
|
||||
from .utils import get_user_queryset, create_user, CallbackContext
|
||||
|
||||
from .exceptions import EtebaseValidationError
|
||||
|
||||
User = get_user_model()
|
||||
|
||||
|
||||
def process_revisions_for_item(item, revision_data):
|
||||
chunks_objs = []
|
||||
chunks = revision_data.pop("chunks_relation")
|
||||
|
||||
revision = models.CollectionItemRevision(**revision_data, item=item)
|
||||
revision.validate_unique() # Verify there aren't any validation issues
|
||||
|
||||
for chunk in chunks:
|
||||
uid = chunk[0]
|
||||
chunk_obj = models.CollectionItemChunk.objects.filter(uid=uid).first()
|
||||
content = chunk[1] if len(chunk) > 1 else None
|
||||
# If the chunk already exists we assume it's fine. Otherwise, we upload it.
|
||||
if chunk_obj is None:
|
||||
if content is not None:
|
||||
chunk_obj = models.CollectionItemChunk(uid=uid, collection=item.collection)
|
||||
chunk_obj.chunkFile.save("IGNORED", ContentFile(content))
|
||||
chunk_obj.save()
|
||||
else:
|
||||
raise EtebaseValidationError("chunk_no_content", "Tried to create a new chunk without content")
|
||||
|
||||
chunks_objs.append(chunk_obj)
|
||||
|
||||
stoken = models.Stoken.objects.create()
|
||||
revision.stoken = stoken
|
||||
revision.save()
|
||||
|
||||
for chunk in chunks_objs:
|
||||
models.RevisionChunkRelation.objects.create(chunk=chunk, revision=revision)
|
||||
return revision
|
||||
|
||||
|
||||
def b64encode(value):
|
||||
return base64.urlsafe_b64encode(value).decode("ascii").strip("=")
|
||||
|
||||
|
||||
def b64decode(data):
|
||||
data += "=" * ((4 - len(data) % 4) % 4)
|
||||
return base64.urlsafe_b64decode(data)
|
||||
|
||||
|
||||
def b64decode_or_bytes(data):
|
||||
if isinstance(data, bytes):
|
||||
return data
|
||||
else:
|
||||
return b64decode(data)
|
||||
|
||||
|
||||
class BinaryBase64Field(serializers.Field):
|
||||
def to_representation(self, value):
|
||||
return value
|
||||
|
||||
def to_internal_value(self, data):
|
||||
return b64decode_or_bytes(data)
|
||||
|
||||
|
||||
class CollectionEncryptionKeyField(BinaryBase64Field):
|
||||
def get_attribute(self, instance):
|
||||
request = self.context.get("request", None)
|
||||
if request is not None:
|
||||
return instance.members.get(user=request.user).encryptionKey
|
||||
return None
|
||||
|
||||
|
||||
class CollectionTypeField(BinaryBase64Field):
|
||||
def get_attribute(self, instance):
|
||||
request = self.context.get("request", None)
|
||||
if request is not None:
|
||||
collection_type = instance.members.get(user=request.user).collectionType
|
||||
return collection_type and collection_type.uid
|
||||
return None
|
||||
|
||||
|
||||
class UserSlugRelatedField(serializers.SlugRelatedField):
|
||||
def get_queryset(self):
|
||||
view = self.context.get("view", None)
|
||||
return get_user_queryset(super().get_queryset(), context=CallbackContext(view.kwargs))
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(slug_field=User.USERNAME_FIELD, **kwargs)
|
||||
|
||||
def to_internal_value(self, data):
|
||||
return super().to_internal_value(data.lower())
|
||||
|
||||
|
||||
class ChunksField(serializers.RelatedField):
|
||||
def to_representation(self, obj):
|
||||
obj = obj.chunk
|
||||
if self.context.get("prefetch") == "auto":
|
||||
with open(obj.chunkFile.path, "rb") as f:
|
||||
return (obj.uid, f.read())
|
||||
else:
|
||||
return (obj.uid,)
|
||||
|
||||
def to_internal_value(self, data):
|
||||
content = data[1] if len(data) > 1 else None
|
||||
if data[0] is None:
|
||||
raise EtebaseValidationError("no_null", "null is not allowed")
|
||||
return (data[0], b64decode_or_bytes(content) if content is not None else None)
|
||||
|
||||
|
||||
class BetterErrorsMixin:
|
||||
@property
|
||||
def errors(self):
|
||||
nice = []
|
||||
errors = super().errors
|
||||
for error_type in errors:
|
||||
if error_type == "non_field_errors":
|
||||
nice.extend(self.flatten_errors(None, errors[error_type]))
|
||||
else:
|
||||
nice.extend(self.flatten_errors(error_type, errors[error_type]))
|
||||
if nice:
|
||||
return {"code": "field_errors", "detail": "Field validations failed.", "errors": nice}
|
||||
return {}
|
||||
|
||||
def flatten_errors(self, field_name, errors):
|
||||
ret = []
|
||||
if isinstance(errors, dict):
|
||||
for error_key in errors:
|
||||
error = errors[error_key]
|
||||
ret.extend(self.flatten_errors("{}.{}".format(field_name, error_key), error))
|
||||
else:
|
||||
for error in errors:
|
||||
if getattr(error, "messages", None):
|
||||
message = error.messages[0]
|
||||
else:
|
||||
message = str(error)
|
||||
ret.append(
|
||||
{
|
||||
"field": field_name,
|
||||
"code": error.code,
|
||||
"detail": message,
|
||||
}
|
||||
)
|
||||
return ret
|
||||
|
||||
def transform_validation_error(self, prefix, err):
|
||||
if hasattr(err, "error_dict"):
|
||||
errors = self.flatten_errors(prefix, err.error_dict)
|
||||
elif not hasattr(err, "message"):
|
||||
errors = self.flatten_errors(prefix, err.error_list)
|
||||
else:
|
||||
raise EtebaseValidationError(err.code, err.message)
|
||||
|
||||
raise serializers.ValidationError(
|
||||
{
|
||||
"code": "field_errors",
|
||||
"detail": "Field validations failed.",
|
||||
"errors": errors,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class CollectionItemChunkSerializer(BetterErrorsMixin, serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = models.CollectionItemChunk
|
||||
fields = ("uid", "chunkFile")
|
||||
|
||||
|
||||
class CollectionItemRevisionSerializer(BetterErrorsMixin, serializers.ModelSerializer):
|
||||
chunks = ChunksField(
|
||||
source="chunks_relation",
|
||||
queryset=models.RevisionChunkRelation.objects.all(),
|
||||
style={"base_template": "input.html"},
|
||||
many=True,
|
||||
)
|
||||
meta = BinaryBase64Field()
|
||||
|
||||
class Meta:
|
||||
model = models.CollectionItemRevision
|
||||
fields = ("chunks", "meta", "uid", "deleted")
|
||||
extra_kwargs = {
|
||||
"uid": {"validators": []}, # We deal with it in the serializers
|
||||
}
|
||||
|
||||
|
||||
class CollectionItemSerializer(BetterErrorsMixin, serializers.ModelSerializer):
|
||||
encryptionKey = BinaryBase64Field(required=False, default=None, allow_null=True)
|
||||
etag = serializers.CharField(allow_null=True, write_only=True)
|
||||
content = CollectionItemRevisionSerializer(many=False)
|
||||
|
||||
class Meta:
|
||||
model = models.CollectionItem
|
||||
fields = ("uid", "version", "encryptionKey", "content", "etag")
|
||||
|
||||
def create(self, validated_data):
|
||||
"""Function that's called when this serializer creates an item"""
|
||||
validate_etag = self.context.get("validate_etag", False)
|
||||
etag = validated_data.pop("etag")
|
||||
revision_data = validated_data.pop("content")
|
||||
uid = validated_data.pop("uid")
|
||||
|
||||
Model = self.__class__.Meta.model
|
||||
|
||||
with transaction.atomic():
|
||||
instance, created = Model.objects.get_or_create(uid=uid, defaults=validated_data)
|
||||
cur_etag = instance.etag if not created else None
|
||||
|
||||
# If we are trying to update an up to date item, abort early and consider it a success
|
||||
if cur_etag == revision_data.get("uid"):
|
||||
return instance
|
||||
|
||||
if validate_etag and cur_etag != etag:
|
||||
raise EtebaseValidationError(
|
||||
"wrong_etag",
|
||||
"Wrong etag. Expected {} got {}".format(cur_etag, etag),
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
)
|
||||
|
||||
if not created:
|
||||
# We don't have to use select_for_update here because the unique constraint on current guards against
|
||||
# the race condition. But it's a good idea because it'll lock and wait rather than fail.
|
||||
current_revision = instance.revisions.filter(current=True).select_for_update().first()
|
||||
|
||||
# If we are just re-uploading the same revision, consider it a succes and return.
|
||||
if current_revision.uid == revision_data.get("uid"):
|
||||
return instance
|
||||
|
||||
current_revision.current = None
|
||||
current_revision.save()
|
||||
|
||||
try:
|
||||
process_revisions_for_item(instance, revision_data)
|
||||
except django_exceptions.ValidationError as e:
|
||||
self.transform_validation_error("content", e)
|
||||
|
||||
return instance
|
||||
|
||||
def update(self, instance, validated_data):
|
||||
# We never update, we always update in the create method
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class CollectionItemDepSerializer(BetterErrorsMixin, serializers.ModelSerializer):
|
||||
etag = serializers.CharField()
|
||||
|
||||
class Meta:
|
||||
model = models.CollectionItem
|
||||
fields = ("uid", "etag")
|
||||
|
||||
def validate(self, data):
|
||||
item = self.__class__.Meta.model.objects.get(uid=data["uid"])
|
||||
etag = data["etag"]
|
||||
if item.etag != etag:
|
||||
raise EtebaseValidationError(
|
||||
"wrong_etag",
|
||||
"Wrong etag. Expected {} got {}".format(item.etag, etag),
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class CollectionItemBulkGetSerializer(BetterErrorsMixin, serializers.ModelSerializer):
|
||||
etag = serializers.CharField(required=False)
|
||||
|
||||
class Meta:
|
||||
model = models.CollectionItem
|
||||
fields = ("uid", "etag")
|
||||
|
||||
|
||||
class CollectionListMultiSerializer(BetterErrorsMixin, serializers.Serializer):
|
||||
collectionTypes = serializers.ListField(child=BinaryBase64Field())
|
||||
|
||||
|
||||
class CollectionSerializer(BetterErrorsMixin, serializers.ModelSerializer):
|
||||
collectionKey = CollectionEncryptionKeyField()
|
||||
collectionType = CollectionTypeField()
|
||||
accessLevel = serializers.SerializerMethodField("get_access_level_from_context")
|
||||
stoken = serializers.CharField(read_only=True)
|
||||
|
||||
item = CollectionItemSerializer(many=False, source="main_item")
|
||||
|
||||
class Meta:
|
||||
model = models.Collection
|
||||
fields = ("item", "accessLevel", "collectionKey", "collectionType", "stoken")
|
||||
|
||||
def get_access_level_from_context(self, obj):
|
||||
request = self.context.get("request", None)
|
||||
if request is not None:
|
||||
return obj.members.get(user=request.user).accessLevel
|
||||
return None
|
||||
|
||||
def create(self, validated_data):
|
||||
"""Function that's called when this serializer creates an item"""
|
||||
collection_key = validated_data.pop("collectionKey")
|
||||
collection_type = validated_data.pop("collectionType")
|
||||
|
||||
user = validated_data.get("owner")
|
||||
main_item_data = validated_data.pop("main_item")
|
||||
uid = main_item_data.get("uid")
|
||||
etag = main_item_data.pop("etag")
|
||||
revision_data = main_item_data.pop("content")
|
||||
|
||||
instance = self.__class__.Meta.model(uid=uid, **validated_data)
|
||||
|
||||
with transaction.atomic():
|
||||
if etag is not None:
|
||||
raise EtebaseValidationError("bad_etag", "etag is not null")
|
||||
|
||||
try:
|
||||
instance.validate_unique()
|
||||
except django_exceptions.ValidationError:
|
||||
raise EtebaseValidationError(
|
||||
"unique_uid", "Collection with this uid already exists", status_code=status.HTTP_409_CONFLICT
|
||||
)
|
||||
instance.save()
|
||||
|
||||
main_item = models.CollectionItem.objects.create(**main_item_data, collection=instance)
|
||||
|
||||
instance.main_item = main_item
|
||||
instance.save()
|
||||
|
||||
process_revisions_for_item(main_item, revision_data)
|
||||
|
||||
collection_type_obj, _ = models.CollectionType.objects.get_or_create(uid=collection_type, owner=user)
|
||||
|
||||
models.CollectionMember(
|
||||
collection=instance,
|
||||
stoken=models.Stoken.objects.create(),
|
||||
user=user,
|
||||
accessLevel=models.AccessLevels.ADMIN,
|
||||
encryptionKey=collection_key,
|
||||
collectionType=collection_type_obj,
|
||||
).save()
|
||||
|
||||
return instance
|
||||
|
||||
def update(self, instance, validated_data):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class CollectionMemberSerializer(BetterErrorsMixin, serializers.ModelSerializer):
|
||||
username = UserSlugRelatedField(
|
||||
source="user",
|
||||
read_only=True,
|
||||
style={"base_template": "input.html"},
|
||||
)
|
||||
|
||||
class Meta:
|
||||
model = models.CollectionMember
|
||||
fields = ("username", "accessLevel")
|
||||
|
||||
def create(self, validated_data):
|
||||
raise NotImplementedError()
|
||||
|
||||
def update(self, instance, validated_data):
|
||||
with transaction.atomic():
|
||||
# We only allow updating accessLevel
|
||||
access_level = validated_data.pop("accessLevel")
|
||||
if instance.accessLevel != access_level:
|
||||
instance.stoken = models.Stoken.objects.create()
|
||||
instance.accessLevel = access_level
|
||||
instance.save()
|
||||
|
||||
return instance
|
||||
|
||||
|
||||
class CollectionInvitationSerializer(BetterErrorsMixin, serializers.ModelSerializer):
|
||||
username = UserSlugRelatedField(
|
||||
source="user",
|
||||
queryset=User.objects,
|
||||
style={"base_template": "input.html"},
|
||||
)
|
||||
collection = serializers.CharField(source="collection.uid")
|
||||
fromUsername = serializers.CharField(source="fromMember.user.username", read_only=True)
|
||||
fromPubkey = BinaryBase64Field(source="fromMember.user.userinfo.pubkey", read_only=True)
|
||||
signedEncryptionKey = BinaryBase64Field()
|
||||
|
||||
class Meta:
|
||||
model = models.CollectionInvitation
|
||||
fields = (
|
||||
"username",
|
||||
"uid",
|
||||
"collection",
|
||||
"signedEncryptionKey",
|
||||
"accessLevel",
|
||||
"fromUsername",
|
||||
"fromPubkey",
|
||||
"version",
|
||||
)
|
||||
|
||||
def validate_user(self, value):
|
||||
request = self.context["request"]
|
||||
|
||||
if request.user.username == value.lower():
|
||||
raise EtebaseValidationError("no_self_invite", "Inviting yourself is not allowed")
|
||||
return value
|
||||
|
||||
def create(self, validated_data):
|
||||
request = self.context["request"]
|
||||
collection = validated_data.pop("collection")
|
||||
|
||||
member = collection.members.get(user=request.user)
|
||||
|
||||
with transaction.atomic():
|
||||
try:
|
||||
return type(self).Meta.model.objects.create(**validated_data, fromMember=member)
|
||||
except IntegrityError:
|
||||
raise EtebaseValidationError("invitation_exists", "Invitation already exists")
|
||||
|
||||
def update(self, instance, validated_data):
|
||||
with transaction.atomic():
|
||||
instance.accessLevel = validated_data.pop("accessLevel")
|
||||
instance.signedEncryptionKey = validated_data.pop("signedEncryptionKey")
|
||||
instance.save()
|
||||
|
||||
return instance
|
||||
|
||||
|
||||
class InvitationAcceptSerializer(BetterErrorsMixin, serializers.Serializer):
|
||||
collectionType = BinaryBase64Field()
|
||||
encryptionKey = BinaryBase64Field()
|
||||
|
||||
def create(self, validated_data):
|
||||
|
||||
with transaction.atomic():
|
||||
invitation = self.context["invitation"]
|
||||
encryption_key = validated_data.get("encryptionKey")
|
||||
collection_type = validated_data.pop("collectionType")
|
||||
|
||||
user = invitation.user
|
||||
collection_type_obj, _ = models.CollectionType.objects.get_or_create(uid=collection_type, owner=user)
|
||||
|
||||
member = models.CollectionMember.objects.create(
|
||||
collection=invitation.collection,
|
||||
stoken=models.Stoken.objects.create(),
|
||||
user=user,
|
||||
accessLevel=invitation.accessLevel,
|
||||
encryptionKey=encryption_key,
|
||||
collectionType=collection_type_obj,
|
||||
)
|
||||
|
||||
models.CollectionMemberRemoved.objects.filter(
|
||||
user=invitation.user, collection=invitation.collection
|
||||
).delete()
|
||||
|
||||
invitation.delete()
|
||||
|
||||
return member
|
||||
|
||||
def update(self, instance, validated_data):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class UserSerializer(BetterErrorsMixin, serializers.ModelSerializer):
|
||||
pubkey = BinaryBase64Field(source="userinfo.pubkey")
|
||||
encryptedContent = BinaryBase64Field(source="userinfo.encryptedContent")
|
||||
|
||||
class Meta:
|
||||
model = User
|
||||
fields = (User.USERNAME_FIELD, User.EMAIL_FIELD, "pubkey", "encryptedContent")
|
||||
|
||||
|
||||
class UserInfoPubkeySerializer(BetterErrorsMixin, serializers.ModelSerializer):
|
||||
pubkey = BinaryBase64Field()
|
||||
|
||||
class Meta:
|
||||
model = models.UserInfo
|
||||
fields = ("pubkey",)
|
||||
|
||||
|
||||
class UserSignupSerializer(BetterErrorsMixin, serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = User
|
||||
fields = (User.USERNAME_FIELD, User.EMAIL_FIELD)
|
||||
extra_kwargs = {
|
||||
"username": {"validators": []}, # We specifically validate in SignupSerializer
|
||||
}
|
||||
|
||||
|
||||
class AuthenticationSignupSerializer(BetterErrorsMixin, serializers.Serializer):
|
||||
"""Used both for creating new accounts and setting up existing ones for the first time.
|
||||
When setting up existing ones the email is ignored."
|
||||
"""
|
||||
|
||||
user = UserSignupSerializer(many=False)
|
||||
salt = BinaryBase64Field()
|
||||
loginPubkey = BinaryBase64Field()
|
||||
pubkey = BinaryBase64Field()
|
||||
encryptedContent = BinaryBase64Field()
|
||||
|
||||
def create(self, validated_data):
|
||||
"""Function that's called when this serializer creates an item"""
|
||||
user_data = validated_data.pop("user")
|
||||
|
||||
with transaction.atomic():
|
||||
view = self.context.get("view", None)
|
||||
try:
|
||||
user_queryset = get_user_queryset(User.objects.all(), context=CallbackContext(view.kwargs))
|
||||
instance = user_queryset.get(**{User.USERNAME_FIELD: user_data["username"].lower()})
|
||||
except User.DoesNotExist:
|
||||
# Create the user and save the casing the user chose as the first name
|
||||
try:
|
||||
instance = create_user(
|
||||
**user_data,
|
||||
password=None,
|
||||
first_name=user_data["username"],
|
||||
context=CallbackContext(view.kwargs)
|
||||
)
|
||||
instance.full_clean()
|
||||
except EtebaseValidationError as e:
|
||||
raise e
|
||||
except django_exceptions.ValidationError as e:
|
||||
self.transform_validation_error("user", e)
|
||||
except Exception as e:
|
||||
raise EtebaseValidationError("generic", str(e))
|
||||
|
||||
if hasattr(instance, "userinfo"):
|
||||
raise EtebaseValidationError("user_exists", "User already exists", status_code=status.HTTP_409_CONFLICT)
|
||||
|
||||
models.UserInfo.objects.create(**validated_data, owner=instance)
|
||||
|
||||
return instance
|
||||
|
||||
def update(self, instance, validated_data):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class AuthenticationLoginChallengeSerializer(BetterErrorsMixin, serializers.Serializer):
|
||||
username = serializers.CharField(required=True)
|
||||
|
||||
def create(self, validated_data):
|
||||
raise NotImplementedError()
|
||||
|
||||
def update(self, instance, validated_data):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class AuthenticationLoginSerializer(BetterErrorsMixin, serializers.Serializer):
|
||||
response = BinaryBase64Field()
|
||||
signature = BinaryBase64Field()
|
||||
|
||||
def create(self, validated_data):
|
||||
raise NotImplementedError()
|
||||
|
||||
def update(self, instance, validated_data):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class AuthenticationLoginInnerSerializer(AuthenticationLoginChallengeSerializer):
|
||||
challenge = BinaryBase64Field()
|
||||
host = serializers.CharField()
|
||||
action = serializers.CharField()
|
||||
|
||||
def create(self, validated_data):
|
||||
raise NotImplementedError()
|
||||
|
||||
def update(self, instance, validated_data):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class AuthenticationChangePasswordInnerSerializer(AuthenticationLoginInnerSerializer):
|
||||
loginPubkey = BinaryBase64Field()
|
||||
encryptedContent = BinaryBase64Field()
|
||||
|
||||
class Meta:
|
||||
model = models.UserInfo
|
||||
fields = ("loginPubkey", "encryptedContent")
|
||||
|
||||
def create(self, validated_data):
|
||||
raise NotImplementedError()
|
||||
|
||||
def update(self, instance, validated_data):
|
||||
with transaction.atomic():
|
||||
instance.loginPubkey = validated_data.pop("loginPubkey")
|
||||
instance.encryptedContent = validated_data.pop("encryptedContent")
|
||||
instance.save()
|
||||
|
||||
return instance
|
@ -1,3 +0,0 @@
|
||||
from django.test import TestCase
|
||||
|
||||
# Create your tests here.
|
@ -1,46 +0,0 @@
|
||||
from django.utils import timezone
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from rest_framework import exceptions
|
||||
from rest_framework.authentication import TokenAuthentication as DRFTokenAuthentication
|
||||
|
||||
from .models import AuthToken, get_default_expiry
|
||||
|
||||
|
||||
AUTO_REFRESH = True
|
||||
MIN_REFRESH_INTERVAL = 60
|
||||
|
||||
|
||||
class TokenAuthentication(DRFTokenAuthentication):
|
||||
keyword = "Token"
|
||||
model = AuthToken
|
||||
|
||||
def authenticate_credentials(self, key):
|
||||
msg = _("Invalid token.")
|
||||
model = self.get_model()
|
||||
try:
|
||||
token = model.objects.select_related("user").get(key=key)
|
||||
except model.DoesNotExist:
|
||||
raise exceptions.AuthenticationFailed(msg)
|
||||
|
||||
if not token.user.is_active:
|
||||
raise exceptions.AuthenticationFailed(_("User inactive or deleted."))
|
||||
|
||||
if token.expiry is not None:
|
||||
if token.expiry < timezone.now():
|
||||
token.delete()
|
||||
raise exceptions.AuthenticationFailed(msg)
|
||||
|
||||
if AUTO_REFRESH:
|
||||
self.renew_token(token)
|
||||
|
||||
return (token.user, token)
|
||||
|
||||
def renew_token(self, auth_token):
|
||||
current_expiry = auth_token.expiry
|
||||
new_expiry = get_default_expiry()
|
||||
# Throttle refreshing of token to avoid db writes
|
||||
delta = (new_expiry - current_expiry).total_seconds()
|
||||
if delta > MIN_REFRESH_INTERVAL:
|
||||
auth_token.expiry = new_expiry
|
||||
auth_token.save(update_fields=("expiry",))
|
@ -1,9 +1,9 @@
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.db import models
|
||||
from django.utils import timezone
|
||||
from django.utils.crypto import get_random_string
|
||||
from myauth.models import get_typed_user_model
|
||||
|
||||
User = get_user_model()
|
||||
User = get_typed_user_model()
|
||||
|
||||
|
||||
def generate_key():
|
||||
|
@ -1,30 +0,0 @@
|
||||
from django.conf import settings
|
||||
from django.conf.urls import include
|
||||
from django.urls import path
|
||||
|
||||
from rest_framework_nested import routers
|
||||
|
||||
from django_etebase import views
|
||||
|
||||
router = routers.DefaultRouter()
|
||||
router.register(r"collection", views.CollectionViewSet)
|
||||
router.register(r"authentication", views.AuthenticationViewSet, basename="authentication")
|
||||
router.register(r"invitation/incoming", views.InvitationIncomingViewSet, basename="invitation_incoming")
|
||||
router.register(r"invitation/outgoing", views.InvitationOutgoingViewSet, basename="invitation_outgoing")
|
||||
|
||||
collections_router = routers.NestedSimpleRouter(router, r"collection", lookup="collection")
|
||||
collections_router.register(r"item", views.CollectionItemViewSet, basename="collection_item")
|
||||
collections_router.register(r"member", views.CollectionMemberViewSet, basename="collection_member")
|
||||
|
||||
item_router = routers.NestedSimpleRouter(collections_router, r"item", lookup="collection_item")
|
||||
item_router.register(r"chunk", views.CollectionItemChunkViewSet, basename="collection_items_chunk")
|
||||
|
||||
if settings.DEBUG:
|
||||
router.register(r"test/authentication", views.TestAuthenticationViewSet, basename="test_authentication")
|
||||
|
||||
app_name = "django_etebase"
|
||||
urlpatterns = [
|
||||
path("v1/", include(router.urls)),
|
||||
path("v1/", include(collections_router.urls)),
|
||||
path("v1/", include(item_router.urls)),
|
||||
]
|
@ -1,13 +1,13 @@
|
||||
import typing as t
|
||||
from dataclasses import dataclass
|
||||
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.core.exceptions import PermissionDenied
|
||||
from myauth.models import UserType, get_typed_user_model
|
||||
|
||||
from . import app_settings
|
||||
|
||||
|
||||
User = get_user_model()
|
||||
User = get_typed_user_model()
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -15,6 +15,7 @@ class CallbackContext:
|
||||
"""Class for passing extra context to callbacks"""
|
||||
|
||||
url_kwargs: t.Dict[str, t.Any]
|
||||
user: t.Optional[UserType] = None
|
||||
|
||||
|
||||
def get_user_queryset(queryset, context: CallbackContext):
|
||||
@ -27,7 +28,7 @@ def get_user_queryset(queryset, context: CallbackContext):
|
||||
def create_user(context: CallbackContext, *args, **kwargs):
|
||||
custom_func = app_settings.CREATE_USER_FUNC
|
||||
if custom_func is not None:
|
||||
return custom_func(*args, **kwargs)
|
||||
return custom_func(context, *args, **kwargs)
|
||||
return User.objects.create_user(*args, **kwargs)
|
||||
|
||||
|
||||
|
@ -1,861 +0,0 @@
|
||||
# Copyright © 2017 Tom Hacohen
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as
|
||||
# published by the Free Software Foundation, version 3.
|
||||
#
|
||||
# This library is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import msgpack
|
||||
|
||||
from django.conf import settings
|
||||
from django.contrib.auth import get_user_model, user_logged_in, user_logged_out
|
||||
from django.core.exceptions import PermissionDenied
|
||||
from django.db import transaction, IntegrityError
|
||||
from django.db.models import Q
|
||||
from django.http import HttpResponseBadRequest, HttpResponse, Http404
|
||||
from django.shortcuts import get_object_or_404
|
||||
|
||||
from rest_framework import status
|
||||
from rest_framework import viewsets
|
||||
from rest_framework.decorators import action as action_decorator
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.parsers import JSONParser, FormParser, MultiPartParser
|
||||
from rest_framework.renderers import BrowsableAPIRenderer
|
||||
from rest_framework.exceptions import AuthenticationFailed
|
||||
from rest_framework.permissions import IsAuthenticated
|
||||
|
||||
import nacl.encoding
|
||||
import nacl.signing
|
||||
import nacl.secret
|
||||
import nacl.hash
|
||||
|
||||
from .sendfile import sendfile
|
||||
from .token_auth.models import AuthToken
|
||||
|
||||
from .drf_msgpack.parsers import MessagePackParser
|
||||
from .drf_msgpack.renderers import MessagePackRenderer
|
||||
|
||||
from . import app_settings, permissions
|
||||
from .renderers import JSONRenderer
|
||||
from .models import (
|
||||
Collection,
|
||||
CollectionItem,
|
||||
CollectionItemRevision,
|
||||
CollectionMember,
|
||||
CollectionMemberRemoved,
|
||||
CollectionInvitation,
|
||||
Stoken,
|
||||
UserInfo,
|
||||
)
|
||||
from .serializers import (
|
||||
AuthenticationChangePasswordInnerSerializer,
|
||||
AuthenticationSignupSerializer,
|
||||
AuthenticationLoginChallengeSerializer,
|
||||
AuthenticationLoginSerializer,
|
||||
AuthenticationLoginInnerSerializer,
|
||||
CollectionSerializer,
|
||||
CollectionItemSerializer,
|
||||
CollectionItemBulkGetSerializer,
|
||||
CollectionItemDepSerializer,
|
||||
CollectionItemRevisionSerializer,
|
||||
CollectionItemChunkSerializer,
|
||||
CollectionListMultiSerializer,
|
||||
CollectionMemberSerializer,
|
||||
CollectionInvitationSerializer,
|
||||
InvitationAcceptSerializer,
|
||||
UserInfoPubkeySerializer,
|
||||
UserSerializer,
|
||||
)
|
||||
from .utils import get_user_queryset, CallbackContext
|
||||
from .exceptions import EtebaseValidationError
|
||||
from .parsers import ChunkUploadParser
|
||||
from .signals import user_signed_up
|
||||
|
||||
User = get_user_model()
|
||||
|
||||
|
||||
def msgpack_encode(content):
|
||||
return msgpack.packb(content, use_bin_type=True)
|
||||
|
||||
|
||||
def msgpack_decode(content):
|
||||
return msgpack.unpackb(content, raw=False)
|
||||
|
||||
|
||||
class BaseViewSet(viewsets.ModelViewSet):
|
||||
authentication_classes = tuple(app_settings.API_AUTHENTICATORS)
|
||||
permission_classes = tuple(app_settings.API_PERMISSIONS)
|
||||
renderer_classes = [JSONRenderer, MessagePackRenderer] + ([BrowsableAPIRenderer] if settings.DEBUG else [])
|
||||
parser_classes = [JSONParser, MessagePackParser, FormParser, MultiPartParser]
|
||||
stoken_annotation = None
|
||||
|
||||
def get_serializer_class(self):
|
||||
serializer_class = self.serializer_class
|
||||
|
||||
if self.request.method == "PUT":
|
||||
serializer_class = getattr(self, "serializer_update_class", serializer_class)
|
||||
|
||||
return serializer_class
|
||||
|
||||
def get_collection_queryset(self, queryset=Collection.objects):
|
||||
user = self.request.user
|
||||
return queryset.filter(members__user=user)
|
||||
|
||||
def get_stoken_obj_id(self, request):
|
||||
return request.GET.get("stoken", None)
|
||||
|
||||
def get_stoken_obj(self, request):
|
||||
stoken = self.get_stoken_obj_id(request)
|
||||
|
||||
if stoken is not None:
|
||||
try:
|
||||
return Stoken.objects.get(uid=stoken)
|
||||
except Stoken.DoesNotExist:
|
||||
raise EtebaseValidationError("bad_stoken", "Invalid stoken.", status_code=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
return None
|
||||
|
||||
def filter_by_stoken(self, request, queryset):
|
||||
stoken_rev = self.get_stoken_obj(request)
|
||||
|
||||
queryset = queryset.annotate(max_stoken=self.stoken_annotation).order_by("max_stoken")
|
||||
|
||||
if stoken_rev is not None:
|
||||
queryset = queryset.filter(max_stoken__gt=stoken_rev.id)
|
||||
|
||||
return queryset, stoken_rev
|
||||
|
||||
def get_queryset_stoken(self, queryset):
|
||||
maxid = -1
|
||||
for row in queryset:
|
||||
rowmaxid = getattr(row, "max_stoken") or -1
|
||||
maxid = max(maxid, rowmaxid)
|
||||
new_stoken = (maxid >= 0) and Stoken.objects.get(id=maxid)
|
||||
|
||||
return new_stoken or None
|
||||
|
||||
def filter_by_stoken_and_limit(self, request, queryset):
|
||||
limit = int(request.GET.get("limit", 50))
|
||||
|
||||
queryset, stoken_rev = self.filter_by_stoken(request, queryset)
|
||||
|
||||
result = list(queryset[: limit + 1])
|
||||
if len(result) < limit + 1:
|
||||
done = True
|
||||
else:
|
||||
done = False
|
||||
result = result[:-1]
|
||||
|
||||
new_stoken_obj = self.get_queryset_stoken(result) or stoken_rev
|
||||
|
||||
return result, new_stoken_obj, done
|
||||
|
||||
# Change how our list works by default
|
||||
def list(self, request, collection_uid=None, *args, **kwargs):
|
||||
queryset = self.get_queryset()
|
||||
serializer = self.get_serializer(queryset, many=True)
|
||||
|
||||
ret = {
|
||||
"data": serializer.data,
|
||||
"done": True, # we always return all the items, so it's always done
|
||||
}
|
||||
|
||||
return Response(ret)
|
||||
|
||||
|
||||
class CollectionViewSet(BaseViewSet):
|
||||
allowed_methods = ["GET", "POST"]
|
||||
permission_classes = BaseViewSet.permission_classes + (permissions.IsCollectionAdminOrReadOnly,)
|
||||
queryset = Collection.objects.all()
|
||||
serializer_class = CollectionSerializer
|
||||
lookup_field = "uid"
|
||||
lookup_url_kwarg = "uid"
|
||||
stoken_annotation = Collection.stoken_annotation
|
||||
|
||||
def get_queryset(self, queryset=None):
|
||||
if queryset is None:
|
||||
queryset = type(self).queryset
|
||||
return self.get_collection_queryset(queryset)
|
||||
|
||||
def get_serializer_context(self):
|
||||
context = super().get_serializer_context()
|
||||
prefetch = self.request.query_params.get("prefetch", "auto")
|
||||
context.update({"request": self.request, "prefetch": prefetch})
|
||||
return context
|
||||
|
||||
def destroy(self, request, uid=None, *args, **kwargs):
|
||||
# FIXME: implement
|
||||
return Response(status=status.HTTP_405_METHOD_NOT_ALLOWED)
|
||||
|
||||
def partial_update(self, request, uid=None, *args, **kwargs):
|
||||
return Response(status=status.HTTP_405_METHOD_NOT_ALLOWED)
|
||||
|
||||
def update(self, request, *args, **kwargs):
|
||||
return Response(status=status.HTTP_405_METHOD_NOT_ALLOWED)
|
||||
|
||||
def create(self, request, *args, **kwargs):
|
||||
serializer = self.get_serializer(data=request.data)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
serializer.save(owner=self.request.user)
|
||||
|
||||
return Response({}, status=status.HTTP_201_CREATED)
|
||||
|
||||
def list(self, request, *args, **kwargs):
|
||||
queryset = self.get_queryset()
|
||||
return self.list_common(request, queryset, *args, **kwargs)
|
||||
|
||||
@action_decorator(detail=False, methods=["POST"])
|
||||
def list_multi(self, request, *args, **kwargs):
|
||||
serializer = CollectionListMultiSerializer(data=request.data)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
|
||||
collection_types = serializer.validated_data["collectionTypes"]
|
||||
|
||||
queryset = self.get_queryset()
|
||||
# FIXME: Remove the isnull part once we attach collection types to all objects ("collection-type-migration")
|
||||
queryset = queryset.filter(
|
||||
Q(members__collectionType__uid__in=collection_types) | Q(members__collectionType__isnull=True)
|
||||
)
|
||||
|
||||
return self.list_common(request, queryset, *args, **kwargs)
|
||||
|
||||
def list_common(self, request, queryset, *args, **kwargs):
|
||||
result, new_stoken_obj, done = self.filter_by_stoken_and_limit(request, queryset)
|
||||
new_stoken = new_stoken_obj and new_stoken_obj.uid
|
||||
|
||||
serializer = self.get_serializer(result, many=True)
|
||||
|
||||
ret = {
|
||||
"data": serializer.data,
|
||||
"stoken": new_stoken,
|
||||
"done": done,
|
||||
}
|
||||
|
||||
stoken_obj = self.get_stoken_obj(request)
|
||||
if stoken_obj is not None:
|
||||
# FIXME: honour limit? (the limit should be combined for data and this because of stoken)
|
||||
remed_qs = CollectionMemberRemoved.objects.filter(user=request.user, stoken__id__gt=stoken_obj.id)
|
||||
if not ret["done"]:
|
||||
# We only filter by the new_stoken if we are not done. This is because if we are done, the new stoken
|
||||
# can point to the most recent collection change rather than most recent removed membership.
|
||||
remed_qs = remed_qs.filter(stoken__id__lte=new_stoken_obj.id)
|
||||
|
||||
remed = remed_qs.values_list("collection__uid", flat=True)
|
||||
if len(remed) > 0:
|
||||
ret["removedMemberships"] = [{"uid": x} for x in remed]
|
||||
|
||||
return Response(ret)
|
||||
|
||||
|
||||
class CollectionItemViewSet(BaseViewSet):
|
||||
allowed_methods = ["GET", "POST", "PUT"]
|
||||
permission_classes = BaseViewSet.permission_classes + (permissions.HasWriteAccessOrReadOnly,)
|
||||
queryset = CollectionItem.objects.all()
|
||||
serializer_class = CollectionItemSerializer
|
||||
lookup_field = "uid"
|
||||
stoken_annotation = CollectionItem.stoken_annotation
|
||||
|
||||
def get_queryset(self):
|
||||
collection_uid = self.kwargs["collection_uid"]
|
||||
try:
|
||||
collection = self.get_collection_queryset(Collection.objects).get(uid=collection_uid)
|
||||
except Collection.DoesNotExist:
|
||||
raise Http404("Collection does not exist")
|
||||
# XXX Potentially add this for performance: .prefetch_related('revisions__chunks')
|
||||
queryset = type(self).queryset.filter(collection__pk=collection.pk, revisions__current=True)
|
||||
|
||||
return queryset
|
||||
|
||||
def get_serializer_context(self):
|
||||
context = super().get_serializer_context()
|
||||
prefetch = self.request.query_params.get("prefetch", "auto")
|
||||
context.update({"request": self.request, "prefetch": prefetch})
|
||||
return context
|
||||
|
||||
def create(self, request, collection_uid=None, *args, **kwargs):
|
||||
# We create using batch and transaction
|
||||
return Response(status=status.HTTP_405_METHOD_NOT_ALLOWED)
|
||||
|
||||
def destroy(self, request, collection_uid=None, uid=None, *args, **kwargs):
|
||||
# We can't have destroy because we need to get data from the user (in the body) such as hmac.
|
||||
return Response(status=status.HTTP_405_METHOD_NOT_ALLOWED)
|
||||
|
||||
def update(self, request, collection_uid=None, uid=None, *args, **kwargs):
|
||||
return Response(status=status.HTTP_405_METHOD_NOT_ALLOWED)
|
||||
|
||||
def partial_update(self, request, collection_uid=None, uid=None, *args, **kwargs):
|
||||
return Response(status=status.HTTP_405_METHOD_NOT_ALLOWED)
|
||||
|
||||
def list(self, request, collection_uid=None, *args, **kwargs):
|
||||
queryset = self.get_queryset()
|
||||
|
||||
if not self.request.query_params.get("withCollection", False):
|
||||
queryset = queryset.filter(parent__isnull=True)
|
||||
|
||||
result, new_stoken_obj, done = self.filter_by_stoken_and_limit(request, queryset)
|
||||
new_stoken = new_stoken_obj and new_stoken_obj.uid
|
||||
|
||||
serializer = self.get_serializer(result, many=True)
|
||||
|
||||
ret = {
|
||||
"data": serializer.data,
|
||||
"stoken": new_stoken,
|
||||
"done": done,
|
||||
}
|
||||
return Response(ret)
|
||||
|
||||
@action_decorator(detail=True, methods=["GET"])
|
||||
def revision(self, request, collection_uid=None, uid=None, *args, **kwargs):
|
||||
col = get_object_or_404(self.get_collection_queryset(Collection.objects), uid=collection_uid)
|
||||
item = get_object_or_404(col.items, uid=uid)
|
||||
|
||||
limit = int(request.GET.get("limit", 50))
|
||||
iterator = request.GET.get("iterator", None)
|
||||
|
||||
queryset = item.revisions.order_by("-id")
|
||||
|
||||
if iterator is not None:
|
||||
iterator = get_object_or_404(queryset, uid=iterator)
|
||||
queryset = queryset.filter(id__lt=iterator.id)
|
||||
|
||||
result = list(queryset[: limit + 1])
|
||||
if len(result) < limit + 1:
|
||||
done = True
|
||||
else:
|
||||
done = False
|
||||
result = result[:-1]
|
||||
|
||||
serializer = CollectionItemRevisionSerializer(result, context=self.get_serializer_context(), many=True)
|
||||
|
||||
iterator = serializer.data[-1]["uid"] if len(result) > 0 else None
|
||||
|
||||
ret = {
|
||||
"data": serializer.data,
|
||||
"iterator": iterator,
|
||||
"done": done,
|
||||
}
|
||||
return Response(ret)
|
||||
|
||||
# FIXME: rename to something consistent with what the clients have - maybe list_updates?
|
||||
@action_decorator(detail=False, methods=["POST"])
|
||||
def fetch_updates(self, request, collection_uid=None, *args, **kwargs):
|
||||
queryset = self.get_queryset()
|
||||
|
||||
serializer = CollectionItemBulkGetSerializer(data=request.data, many=True)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
# FIXME: make configurable?
|
||||
item_limit = 200
|
||||
|
||||
if len(serializer.validated_data) > item_limit:
|
||||
content = {"code": "too_many_items", "detail": "Request has too many items. Limit: {}".format(item_limit)}
|
||||
return Response(content, status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
queryset, stoken_rev = self.filter_by_stoken(request, queryset)
|
||||
|
||||
uids, etags = zip(*[(item["uid"], item.get("etag")) for item in serializer.validated_data])
|
||||
revs = CollectionItemRevision.objects.filter(uid__in=etags, current=True)
|
||||
queryset = queryset.filter(uid__in=uids).exclude(revisions__in=revs)
|
||||
|
||||
new_stoken_obj = self.get_queryset_stoken(queryset)
|
||||
new_stoken = new_stoken_obj and new_stoken_obj.uid
|
||||
stoken = stoken_rev and getattr(stoken_rev, "uid", None)
|
||||
new_stoken = new_stoken or stoken
|
||||
|
||||
serializer = self.get_serializer(queryset, many=True)
|
||||
|
||||
ret = {
|
||||
"data": serializer.data,
|
||||
"stoken": new_stoken,
|
||||
"done": True, # we always return all the items, so it's always done
|
||||
}
|
||||
return Response(ret)
|
||||
|
||||
@action_decorator(detail=False, methods=["POST"])
|
||||
def batch(self, request, collection_uid=None, *args, **kwargs):
|
||||
return self.transaction(request, collection_uid, validate_etag=False)
|
||||
|
||||
@action_decorator(detail=False, methods=["POST"])
|
||||
def transaction(self, request, collection_uid=None, validate_etag=True, *args, **kwargs):
|
||||
stoken = request.GET.get("stoken", None)
|
||||
with transaction.atomic(): # We need this for locking on the collection object
|
||||
collection_object = get_object_or_404(
|
||||
self.get_collection_queryset(Collection.objects).select_for_update(), # Lock writes on the collection
|
||||
uid=collection_uid,
|
||||
)
|
||||
|
||||
if stoken is not None and stoken != collection_object.stoken:
|
||||
content = {"code": "stale_stoken", "detail": "Stoken is too old"}
|
||||
return Response(content, status=status.HTTP_409_CONFLICT)
|
||||
|
||||
items = request.data.get("items")
|
||||
deps = request.data.get("deps", None)
|
||||
# FIXME: It should just be one serializer
|
||||
context = self.get_serializer_context()
|
||||
context.update({"validate_etag": validate_etag})
|
||||
serializer = self.get_serializer_class()(data=items, context=context, many=True)
|
||||
deps_serializer = CollectionItemDepSerializer(data=deps, context=context, many=True)
|
||||
|
||||
ser_valid = serializer.is_valid()
|
||||
deps_ser_valid = deps is None or deps_serializer.is_valid()
|
||||
if ser_valid and deps_ser_valid:
|
||||
items = serializer.save(collection=collection_object)
|
||||
|
||||
ret = {}
|
||||
return Response(ret, status=status.HTTP_200_OK)
|
||||
|
||||
return Response(
|
||||
{
|
||||
"items": serializer.errors,
|
||||
"deps": deps_serializer.errors if deps is not None else [],
|
||||
},
|
||||
status=status.HTTP_409_CONFLICT,
|
||||
)
|
||||
|
||||
|
||||
class CollectionItemChunkViewSet(viewsets.ViewSet):
|
||||
allowed_methods = ["GET", "PUT"]
|
||||
authentication_classes = BaseViewSet.authentication_classes
|
||||
permission_classes = BaseViewSet.permission_classes
|
||||
renderer_classes = BaseViewSet.renderer_classes
|
||||
parser_classes = (ChunkUploadParser,)
|
||||
serializer_class = CollectionItemChunkSerializer
|
||||
lookup_field = "uid"
|
||||
|
||||
def get_serializer_class(self):
|
||||
return self.serializer_class
|
||||
|
||||
def get_collection_queryset(self, queryset=Collection.objects):
|
||||
user = self.request.user
|
||||
return queryset.filter(members__user=user)
|
||||
|
||||
def update(self, request, *args, collection_uid=None, collection_item_uid=None, uid=None, **kwargs):
|
||||
col = get_object_or_404(self.get_collection_queryset(), uid=collection_uid)
|
||||
# IGNORED FOR NOW: col_it = get_object_or_404(col.items, uid=collection_item_uid)
|
||||
|
||||
data = {
|
||||
"uid": uid,
|
||||
"chunkFile": request.data["file"],
|
||||
}
|
||||
|
||||
serializer = self.get_serializer_class()(data=data)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
try:
|
||||
serializer.save(collection=col)
|
||||
except IntegrityError:
|
||||
return Response(
|
||||
{"code": "chunk_exists", "detail": "Chunk already exists."}, status=status.HTTP_409_CONFLICT
|
||||
)
|
||||
|
||||
return Response({}, status=status.HTTP_201_CREATED)
|
||||
|
||||
@action_decorator(detail=True, methods=["GET"])
|
||||
def download(self, request, collection_uid=None, collection_item_uid=None, uid=None, *args, **kwargs):
|
||||
col = get_object_or_404(self.get_collection_queryset(), uid=collection_uid)
|
||||
chunk = get_object_or_404(col.chunks, uid=uid)
|
||||
|
||||
filename = chunk.chunkFile.path
|
||||
return sendfile(request, filename)
|
||||
|
||||
|
||||
class CollectionMemberViewSet(BaseViewSet):
|
||||
allowed_methods = ["GET", "PUT", "DELETE"]
|
||||
our_base_permission_classes = BaseViewSet.permission_classes
|
||||
permission_classes = our_base_permission_classes + (permissions.IsCollectionAdmin,)
|
||||
queryset = CollectionMember.objects.all()
|
||||
serializer_class = CollectionMemberSerializer
|
||||
lookup_field = f"user__{User.USERNAME_FIELD}__iexact"
|
||||
lookup_url_kwarg = "username"
|
||||
stoken_annotation = CollectionMember.stoken_annotation
|
||||
|
||||
# FIXME: need to make sure that there's always an admin, and maybe also don't let an owner remove adm access
|
||||
# (if we want to transfer, we need to do that specifically)
|
||||
|
||||
def get_queryset(self, queryset=None):
|
||||
collection_uid = self.kwargs["collection_uid"]
|
||||
try:
|
||||
collection = self.get_collection_queryset(Collection.objects).get(uid=collection_uid)
|
||||
except Collection.DoesNotExist:
|
||||
raise Http404("Collection does not exist")
|
||||
|
||||
if queryset is None:
|
||||
queryset = type(self).queryset
|
||||
|
||||
return queryset.filter(collection=collection)
|
||||
|
||||
# We override this method because we expect the stoken to be called iterator
|
||||
def get_stoken_obj_id(self, request):
|
||||
return request.GET.get("iterator", None)
|
||||
|
||||
def list(self, request, collection_uid=None, *args, **kwargs):
|
||||
queryset = self.get_queryset().order_by("id")
|
||||
result, new_stoken_obj, done = self.filter_by_stoken_and_limit(request, queryset)
|
||||
new_stoken = new_stoken_obj and new_stoken_obj.uid
|
||||
serializer = self.get_serializer(result, many=True)
|
||||
|
||||
ret = {
|
||||
"data": serializer.data,
|
||||
"iterator": new_stoken, # Here we call it an iterator, it's only stoken for collection/items
|
||||
"done": done,
|
||||
}
|
||||
|
||||
return Response(ret)
|
||||
|
||||
def create(self, request, *args, **kwargs):
|
||||
return Response(status=status.HTTP_405_METHOD_NOT_ALLOWED)
|
||||
|
||||
# FIXME: block leaving if we are the last admins - should be deleted / assigned in this case depending if there
|
||||
# are other memebers.
|
||||
def perform_destroy(self, instance):
|
||||
instance.revoke()
|
||||
|
||||
@action_decorator(detail=False, methods=["POST"], permission_classes=our_base_permission_classes)
|
||||
def leave(self, request, collection_uid=None, *args, **kwargs):
|
||||
collection_uid = self.kwargs["collection_uid"]
|
||||
col = get_object_or_404(self.get_collection_queryset(Collection.objects), uid=collection_uid)
|
||||
|
||||
member = col.members.get(user=request.user)
|
||||
self.perform_destroy(member)
|
||||
|
||||
return Response({})
|
||||
|
||||
|
||||
class InvitationBaseViewSet(BaseViewSet):
|
||||
queryset = CollectionInvitation.objects.all()
|
||||
serializer_class = CollectionInvitationSerializer
|
||||
lookup_field = "uid"
|
||||
lookup_url_kwarg = "invitation_uid"
|
||||
|
||||
def list(self, request, collection_uid=None, *args, **kwargs):
|
||||
limit = int(request.GET.get("limit", 50))
|
||||
iterator = request.GET.get("iterator", None)
|
||||
|
||||
queryset = self.get_queryset().order_by("id")
|
||||
|
||||
if iterator is not None:
|
||||
iterator = get_object_or_404(queryset, uid=iterator)
|
||||
queryset = queryset.filter(id__gt=iterator.id)
|
||||
|
||||
result = list(queryset[: limit + 1])
|
||||
if len(result) < limit + 1:
|
||||
done = True
|
||||
else:
|
||||
done = False
|
||||
result = result[:-1]
|
||||
|
||||
serializer = self.get_serializer(result, many=True)
|
||||
|
||||
iterator = serializer.data[-1]["uid"] if len(result) > 0 else None
|
||||
|
||||
ret = {
|
||||
"data": serializer.data,
|
||||
"iterator": iterator,
|
||||
"done": done,
|
||||
}
|
||||
|
||||
return Response(ret)
|
||||
|
||||
|
||||
class InvitationOutgoingViewSet(InvitationBaseViewSet):
|
||||
allowed_methods = ["GET", "POST", "PUT", "DELETE"]
|
||||
|
||||
def get_queryset(self, queryset=None):
|
||||
if queryset is None:
|
||||
queryset = type(self).queryset
|
||||
|
||||
return queryset.filter(fromMember__user=self.request.user)
|
||||
|
||||
def create(self, request, *args, **kwargs):
|
||||
serializer = self.get_serializer(data=request.data)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
collection_uid = serializer.validated_data.get("collection", {}).get("uid")
|
||||
|
||||
try:
|
||||
collection = self.get_collection_queryset(Collection.objects).get(uid=collection_uid)
|
||||
except Collection.DoesNotExist:
|
||||
raise Http404("Collection does not exist")
|
||||
|
||||
if request.user == serializer.validated_data.get("user"):
|
||||
content = {"code": "self_invite", "detail": "Inviting yourself is invalid"}
|
||||
return Response(content, status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
if not permissions.is_collection_admin(collection, request.user):
|
||||
raise PermissionDenied(
|
||||
{"code": "admin_access_required", "detail": "User is not an admin of this collection"}
|
||||
)
|
||||
|
||||
serializer.save(collection=collection)
|
||||
|
||||
return Response({}, status=status.HTTP_201_CREATED)
|
||||
|
||||
@action_decorator(detail=False, allowed_methods=["GET"], methods=["GET"])
|
||||
def fetch_user_profile(self, request, *args, **kwargs):
|
||||
username = request.GET.get("username")
|
||||
kwargs = {User.USERNAME_FIELD: username.lower()}
|
||||
user = get_object_or_404(get_user_queryset(User.objects.all(), CallbackContext(self.kwargs)), **kwargs)
|
||||
user_info = get_object_or_404(UserInfo.objects.all(), owner=user)
|
||||
serializer = UserInfoPubkeySerializer(user_info)
|
||||
return Response(serializer.data)
|
||||
|
||||
|
||||
class InvitationIncomingViewSet(InvitationBaseViewSet):
|
||||
allowed_methods = ["GET", "DELETE"]
|
||||
|
||||
def get_queryset(self, queryset=None):
|
||||
if queryset is None:
|
||||
queryset = type(self).queryset
|
||||
|
||||
return queryset.filter(user=self.request.user)
|
||||
|
||||
@action_decorator(detail=True, allowed_methods=["POST"], methods=["POST"])
|
||||
def accept(self, request, invitation_uid=None, *args, **kwargs):
|
||||
invitation = get_object_or_404(self.get_queryset(), uid=invitation_uid)
|
||||
context = self.get_serializer_context()
|
||||
context.update({"invitation": invitation})
|
||||
|
||||
serializer = InvitationAcceptSerializer(data=request.data, context=context)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
serializer.save()
|
||||
return Response(status=status.HTTP_201_CREATED)
|
||||
|
||||
|
||||
class AuthenticationViewSet(viewsets.ViewSet):
|
||||
allowed_methods = ["POST"]
|
||||
authentication_classes = BaseViewSet.authentication_classes
|
||||
renderer_classes = BaseViewSet.renderer_classes
|
||||
parser_classes = BaseViewSet.parser_classes
|
||||
|
||||
def get_encryption_key(self, salt):
|
||||
key = nacl.hash.blake2b(settings.SECRET_KEY.encode(), encoder=nacl.encoding.RawEncoder)
|
||||
return nacl.hash.blake2b(
|
||||
b"",
|
||||
key=key,
|
||||
salt=salt[: nacl.hash.BLAKE2B_SALTBYTES],
|
||||
person=b"etebase-auth",
|
||||
encoder=nacl.encoding.RawEncoder,
|
||||
)
|
||||
|
||||
def get_queryset(self):
|
||||
return get_user_queryset(User.objects.all(), CallbackContext(self.kwargs))
|
||||
|
||||
def get_serializer_context(self):
|
||||
return {"request": self.request, "format": self.format_kwarg, "view": self}
|
||||
|
||||
def login_response_data(self, user):
|
||||
return {
|
||||
"token": AuthToken.objects.create(user=user).key,
|
||||
"user": UserSerializer(user).data,
|
||||
}
|
||||
|
||||
def list(self, request, *args, **kwargs):
|
||||
return Response(status=status.HTTP_405_METHOD_NOT_ALLOWED)
|
||||
|
||||
@action_decorator(detail=False, methods=["POST"])
|
||||
def signup(self, request, *args, **kwargs):
|
||||
serializer = AuthenticationSignupSerializer(data=request.data, context=self.get_serializer_context())
|
||||
serializer.is_valid(raise_exception=True)
|
||||
user = serializer.save()
|
||||
|
||||
user_signed_up.send(sender=user.__class__, request=request, user=user)
|
||||
|
||||
data = self.login_response_data(user)
|
||||
return Response(data, status=status.HTTP_201_CREATED)
|
||||
|
||||
def get_login_user(self, username):
|
||||
kwargs = {User.USERNAME_FIELD + "__iexact": username.lower()}
|
||||
try:
|
||||
user = self.get_queryset().get(**kwargs)
|
||||
if not hasattr(user, "userinfo"):
|
||||
raise AuthenticationFailed({"code": "user_not_init", "detail": "User not properly init"})
|
||||
return user
|
||||
except User.DoesNotExist:
|
||||
raise AuthenticationFailed({"code": "user_not_found", "detail": "User not found"})
|
||||
|
||||
def validate_login_request(self, request, validated_data, response_raw, signature, expected_action):
|
||||
from datetime import datetime
|
||||
|
||||
username = validated_data.get("username")
|
||||
user = self.get_login_user(username)
|
||||
host = validated_data["host"]
|
||||
challenge = validated_data["challenge"]
|
||||
action = validated_data["action"]
|
||||
|
||||
salt = bytes(user.userinfo.salt)
|
||||
enc_key = self.get_encryption_key(salt)
|
||||
box = nacl.secret.SecretBox(enc_key)
|
||||
|
||||
challenge_data = msgpack_decode(box.decrypt(challenge))
|
||||
now = int(datetime.now().timestamp())
|
||||
if action != expected_action:
|
||||
content = {"code": "wrong_action", "detail": 'Expected "{}" but got something else'.format(expected_action)}
|
||||
return Response(content, status=status.HTTP_400_BAD_REQUEST)
|
||||
elif now - challenge_data["timestamp"] > app_settings.CHALLENGE_VALID_SECONDS:
|
||||
content = {"code": "challenge_expired", "detail": "Login challange has expired"}
|
||||
return Response(content, status=status.HTTP_400_BAD_REQUEST)
|
||||
elif challenge_data["userId"] != user.id:
|
||||
content = {"code": "wrong_user", "detail": "This challenge is for the wrong user"}
|
||||
return Response(content, status=status.HTTP_400_BAD_REQUEST)
|
||||
elif not settings.DEBUG and host.split(":", 1)[0] != request.get_host().split(":", 1)[0]:
|
||||
detail = 'Found wrong host name. Got: "{}" expected: "{}"'.format(host, request.get_host())
|
||||
content = {"code": "wrong_host", "detail": detail}
|
||||
return Response(content, status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
verify_key = nacl.signing.VerifyKey(bytes(user.userinfo.loginPubkey), encoder=nacl.encoding.RawEncoder)
|
||||
|
||||
try:
|
||||
verify_key.verify(response_raw, signature)
|
||||
except nacl.exceptions.BadSignatureError:
|
||||
return Response(
|
||||
{"code": "login_bad_signature", "detail": "Wrong password for user."},
|
||||
status=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
@action_decorator(detail=False, methods=["GET"])
|
||||
def is_etebase(self, request, *args, **kwargs):
|
||||
return Response({}, status=status.HTTP_200_OK)
|
||||
|
||||
@action_decorator(detail=False, methods=["POST"])
|
||||
def login_challenge(self, request, *args, **kwargs):
|
||||
from datetime import datetime
|
||||
|
||||
serializer = AuthenticationLoginChallengeSerializer(data=request.data)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
username = serializer.validated_data.get("username")
|
||||
user = self.get_login_user(username)
|
||||
|
||||
salt = bytes(user.userinfo.salt)
|
||||
enc_key = self.get_encryption_key(salt)
|
||||
box = nacl.secret.SecretBox(enc_key)
|
||||
|
||||
challenge_data = {
|
||||
"timestamp": int(datetime.now().timestamp()),
|
||||
"userId": user.id,
|
||||
}
|
||||
challenge = box.encrypt(msgpack_encode(challenge_data), encoder=nacl.encoding.RawEncoder)
|
||||
|
||||
ret = {
|
||||
"salt": salt,
|
||||
"challenge": challenge,
|
||||
"version": user.userinfo.version,
|
||||
}
|
||||
return Response(ret, status=status.HTTP_200_OK)
|
||||
|
||||
@action_decorator(detail=False, methods=["POST"])
|
||||
def login(self, request, *args, **kwargs):
|
||||
outer_serializer = AuthenticationLoginSerializer(data=request.data)
|
||||
outer_serializer.is_valid(raise_exception=True)
|
||||
|
||||
response_raw = outer_serializer.validated_data["response"]
|
||||
response = msgpack_decode(response_raw)
|
||||
signature = outer_serializer.validated_data["signature"]
|
||||
|
||||
context = {"host": request.get_host()}
|
||||
serializer = AuthenticationLoginInnerSerializer(data=response, context=context)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
|
||||
bad_login_response = self.validate_login_request(
|
||||
request, serializer.validated_data, response_raw, signature, "login"
|
||||
)
|
||||
if bad_login_response is not None:
|
||||
return bad_login_response
|
||||
|
||||
username = serializer.validated_data.get("username")
|
||||
user = self.get_login_user(username)
|
||||
|
||||
data = self.login_response_data(user)
|
||||
|
||||
user_logged_in.send(sender=user.__class__, request=request, user=user)
|
||||
|
||||
return Response(data, status=status.HTTP_200_OK)
|
||||
|
||||
@action_decorator(detail=False, methods=["POST"], permission_classes=[IsAuthenticated])
|
||||
def logout(self, request, *args, **kwargs):
|
||||
request.auth.delete()
|
||||
user_logged_out.send(sender=request.user.__class__, request=request, user=request.user)
|
||||
return Response(status=status.HTTP_204_NO_CONTENT)
|
||||
|
||||
@action_decorator(detail=False, methods=["POST"], permission_classes=BaseViewSet.permission_classes)
|
||||
def change_password(self, request, *args, **kwargs):
|
||||
outer_serializer = AuthenticationLoginSerializer(data=request.data)
|
||||
outer_serializer.is_valid(raise_exception=True)
|
||||
|
||||
response_raw = outer_serializer.validated_data["response"]
|
||||
response = msgpack_decode(response_raw)
|
||||
signature = outer_serializer.validated_data["signature"]
|
||||
|
||||
context = {"host": request.get_host()}
|
||||
serializer = AuthenticationChangePasswordInnerSerializer(request.user.userinfo, data=response, context=context)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
|
||||
bad_login_response = self.validate_login_request(
|
||||
request, serializer.validated_data, response_raw, signature, "changePassword"
|
||||
)
|
||||
if bad_login_response is not None:
|
||||
return bad_login_response
|
||||
|
||||
serializer.save()
|
||||
|
||||
return Response({}, status=status.HTTP_200_OK)
|
||||
|
||||
@action_decorator(detail=False, methods=["POST"], permission_classes=[IsAuthenticated])
|
||||
def dashboard_url(self, request, *args, **kwargs):
|
||||
get_dashboard_url = app_settings.DASHBOARD_URL_FUNC
|
||||
if get_dashboard_url is None:
|
||||
raise EtebaseValidationError(
|
||||
"not_supported", "This server doesn't have a user dashboard.", status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
ret = {
|
||||
"url": get_dashboard_url(request, *args, **kwargs),
|
||||
}
|
||||
return Response(ret)
|
||||
|
||||
|
||||
class TestAuthenticationViewSet(viewsets.ViewSet):
|
||||
allowed_methods = ["POST"]
|
||||
renderer_classes = BaseViewSet.renderer_classes
|
||||
parser_classes = BaseViewSet.parser_classes
|
||||
|
||||
def get_serializer_context(self):
|
||||
return {"request": self.request, "format": self.format_kwarg, "view": self}
|
||||
|
||||
def list(self, request, *args, **kwargs):
|
||||
return Response(status=status.HTTP_405_METHOD_NOT_ALLOWED)
|
||||
|
||||
@action_decorator(detail=False, methods=["POST"])
|
||||
def reset(self, request, *args, **kwargs):
|
||||
# Only run when in DEBUG mode! It's only used for tests
|
||||
if not settings.DEBUG:
|
||||
return HttpResponseBadRequest("Only allowed in debug mode.")
|
||||
|
||||
with transaction.atomic():
|
||||
user_queryset = get_user_queryset(User.objects.all(), CallbackContext(self.kwargs))
|
||||
user = get_object_or_404(user_queryset, username=request.data.get("user").get("username"))
|
||||
|
||||
# Only allow test users for extra safety
|
||||
if not getattr(user, User.USERNAME_FIELD).startswith("test_user"):
|
||||
return HttpResponseBadRequest("Endpoint not allowed for user.")
|
||||
|
||||
if hasattr(user, "userinfo"):
|
||||
user.userinfo.delete()
|
||||
|
||||
serializer = AuthenticationSignupSerializer(data=request.data, context=self.get_serializer_context())
|
||||
serializer.is_valid(raise_exception=True)
|
||||
serializer.save()
|
||||
|
||||
# Delete all of the journal data for this user for a clear test env
|
||||
user.collection_set.all().delete()
|
||||
user.collectionmember_set.all().delete()
|
||||
user.incoming_invitations.all().delete()
|
||||
|
||||
# FIXME: also delete chunk files!!!
|
||||
|
||||
return HttpResponse()
|
267
etebase_fastapi/authentication.py
Normal file
267
etebase_fastapi/authentication.py
Normal file
@ -0,0 +1,267 @@
|
||||
import typing as t
|
||||
from datetime import datetime
|
||||
from functools import cached_property
|
||||
|
||||
import nacl
|
||||
import nacl.encoding
|
||||
import nacl.hash
|
||||
import nacl.secret
|
||||
import nacl.signing
|
||||
from asgiref.sync import sync_to_async
|
||||
from django.conf import settings
|
||||
from django.contrib.auth import user_logged_out, user_logged_in
|
||||
from django.core import exceptions as django_exceptions
|
||||
from django.db import transaction
|
||||
from fastapi import APIRouter, Depends, status, Request
|
||||
|
||||
from django_etebase import app_settings, models
|
||||
from django_etebase.token_auth.models import AuthToken
|
||||
from django_etebase.models import UserInfo
|
||||
from django_etebase.signals import user_signed_up
|
||||
from django_etebase.utils import create_user, get_user_queryset, CallbackContext
|
||||
from myauth.models import UserType, get_typed_user_model
|
||||
from .exceptions import AuthenticationFailed, transform_validation_error, HttpError
|
||||
from .msgpack import MsgpackRoute
|
||||
from .utils import BaseModel, permission_responses, msgpack_encode, msgpack_decode
|
||||
from .dependencies import AuthData, get_auth_data, get_authenticated_user
|
||||
|
||||
User = get_typed_user_model()
|
||||
authentication_router = APIRouter(route_class=MsgpackRoute)
|
||||
|
||||
|
||||
class LoginChallengeIn(BaseModel):
|
||||
username: str
|
||||
|
||||
|
||||
class LoginChallengeOut(BaseModel):
|
||||
salt: bytes
|
||||
challenge: bytes
|
||||
version: int
|
||||
|
||||
|
||||
class LoginResponse(BaseModel):
|
||||
username: str
|
||||
challenge: bytes
|
||||
host: str
|
||||
action: t.Literal["login", "changePassword"]
|
||||
|
||||
|
||||
class UserOut(BaseModel):
|
||||
username: str
|
||||
email: str
|
||||
pubkey: bytes
|
||||
encryptedContent: bytes
|
||||
|
||||
@classmethod
|
||||
def from_orm(cls: t.Type["UserOut"], obj: UserType) -> "UserOut":
|
||||
return cls(
|
||||
username=obj.username,
|
||||
email=obj.email,
|
||||
pubkey=bytes(obj.userinfo.pubkey),
|
||||
encryptedContent=bytes(obj.userinfo.encryptedContent),
|
||||
)
|
||||
|
||||
|
||||
class LoginOut(BaseModel):
|
||||
token: str
|
||||
user: UserOut
|
||||
|
||||
@classmethod
|
||||
def from_orm(cls: t.Type["LoginOut"], obj: UserType) -> "LoginOut":
|
||||
token = AuthToken.objects.create(user=obj).key
|
||||
user = UserOut.from_orm(obj)
|
||||
return cls(token=token, user=user)
|
||||
|
||||
|
||||
class Authentication(BaseModel):
|
||||
class Config:
|
||||
keep_untouched = (cached_property,)
|
||||
|
||||
response: bytes
|
||||
signature: bytes
|
||||
|
||||
|
||||
class Login(Authentication):
|
||||
@cached_property
|
||||
def response_data(self) -> LoginResponse:
|
||||
return LoginResponse(**msgpack_decode(self.response))
|
||||
|
||||
|
||||
class ChangePasswordResponse(LoginResponse):
|
||||
loginPubkey: bytes
|
||||
encryptedContent: bytes
|
||||
|
||||
|
||||
class ChangePassword(Authentication):
|
||||
@cached_property
|
||||
def response_data(self) -> ChangePasswordResponse:
|
||||
return ChangePasswordResponse(**msgpack_decode(self.response))
|
||||
|
||||
|
||||
class UserSignup(BaseModel):
|
||||
username: str
|
||||
email: str
|
||||
|
||||
|
||||
class SignupIn(BaseModel):
|
||||
user: UserSignup
|
||||
salt: bytes
|
||||
loginPubkey: bytes
|
||||
pubkey: bytes
|
||||
encryptedContent: bytes
|
||||
|
||||
|
||||
@sync_to_async
|
||||
def __get_login_user(username: str) -> UserType:
|
||||
kwargs = {User.USERNAME_FIELD + "__iexact": username.lower()}
|
||||
try:
|
||||
user = User.objects.get(**kwargs)
|
||||
if not hasattr(user, "userinfo"):
|
||||
raise AuthenticationFailed(code="user_not_init", detail="User not properly init")
|
||||
return user
|
||||
except User.DoesNotExist:
|
||||
raise AuthenticationFailed(code="user_not_found", detail="User not found")
|
||||
|
||||
|
||||
async def get_login_user(challenge: LoginChallengeIn) -> UserType:
|
||||
user = await __get_login_user(challenge.username)
|
||||
return user
|
||||
|
||||
|
||||
def get_encryption_key(salt):
|
||||
key = nacl.hash.blake2b(settings.SECRET_KEY.encode(), encoder=nacl.encoding.RawEncoder)
|
||||
return nacl.hash.blake2b(
|
||||
b"",
|
||||
key=key,
|
||||
salt=salt[: nacl.hash.BLAKE2B_SALTBYTES],
|
||||
person=b"etebase-auth",
|
||||
encoder=nacl.encoding.RawEncoder,
|
||||
)
|
||||
|
||||
|
||||
def save_changed_password(data: ChangePassword, user: UserType):
|
||||
response_data = data.response_data
|
||||
user_info: UserInfo = user.userinfo
|
||||
user_info.loginPubkey = response_data.loginPubkey
|
||||
user_info.encryptedContent = response_data.encryptedContent
|
||||
user_info.save()
|
||||
|
||||
|
||||
@sync_to_async
|
||||
def validate_login_request(
|
||||
validated_data: LoginResponse,
|
||||
challenge_sent_to_user: Authentication,
|
||||
user: UserType,
|
||||
expected_action: str,
|
||||
host_from_request: str,
|
||||
):
|
||||
enc_key = get_encryption_key(bytes(user.userinfo.salt))
|
||||
box = nacl.secret.SecretBox(enc_key)
|
||||
challenge_data = msgpack_decode(box.decrypt(validated_data.challenge))
|
||||
now = int(datetime.now().timestamp())
|
||||
if validated_data.action != expected_action:
|
||||
raise HttpError("wrong_action", f'Expected "{expected_action}" but got something else')
|
||||
elif now - challenge_data["timestamp"] > app_settings.CHALLENGE_VALID_SECONDS:
|
||||
raise HttpError("challenge_expired", "Login challenge has expired")
|
||||
elif challenge_data["userId"] != user.id:
|
||||
raise HttpError("wrong_user", "This challenge is for the wrong user")
|
||||
elif not settings.DEBUG and validated_data.host.split(":", 1)[0] != host_from_request:
|
||||
raise HttpError(
|
||||
"wrong_host", f'Found wrong host name. Got: "{validated_data.host}" expected: "{host_from_request}"'
|
||||
)
|
||||
verify_key = nacl.signing.VerifyKey(bytes(user.userinfo.loginPubkey), encoder=nacl.encoding.RawEncoder)
|
||||
try:
|
||||
verify_key.verify(challenge_sent_to_user.response, challenge_sent_to_user.signature)
|
||||
except nacl.exceptions.BadSignatureError:
|
||||
raise HttpError("login_bad_signature", "Wrong password for user.", status.HTTP_401_UNAUTHORIZED)
|
||||
|
||||
|
||||
@authentication_router.get("/is_etebase/")
|
||||
async def is_etebase():
|
||||
pass
|
||||
|
||||
|
||||
@authentication_router.post("/login_challenge/", response_model=LoginChallengeOut)
|
||||
def login_challenge(user: UserType = Depends(get_login_user)):
|
||||
salt = bytes(user.userinfo.salt)
|
||||
enc_key = get_encryption_key(salt)
|
||||
box = nacl.secret.SecretBox(enc_key)
|
||||
challenge_data = {
|
||||
"timestamp": int(datetime.now().timestamp()),
|
||||
"userId": user.id,
|
||||
}
|
||||
challenge = bytes(box.encrypt(msgpack_encode(challenge_data), encoder=nacl.encoding.RawEncoder))
|
||||
return LoginChallengeOut(salt=salt, challenge=challenge, version=user.userinfo.version)
|
||||
|
||||
|
||||
@authentication_router.post("/login/", response_model=LoginOut)
|
||||
async def login(data: Login, request: Request):
|
||||
user = await get_login_user(LoginChallengeIn(username=data.response_data.username))
|
||||
host = request.headers.get("Host")
|
||||
await validate_login_request(data.response_data, data, user, "login", host)
|
||||
data = await sync_to_async(LoginOut.from_orm)(user)
|
||||
await sync_to_async(user_logged_in.send)(sender=user.__class__, request=None, user=user)
|
||||
return data
|
||||
|
||||
|
||||
@authentication_router.post("/logout/", status_code=status.HTTP_204_NO_CONTENT, responses=permission_responses)
|
||||
def logout(auth_data: AuthData = Depends(get_auth_data)):
|
||||
auth_data.token.delete()
|
||||
user_logged_out.send(sender=auth_data.user.__class__, request=None, user=auth_data.user)
|
||||
|
||||
|
||||
@authentication_router.post("/change_password/", status_code=status.HTTP_204_NO_CONTENT, responses=permission_responses)
|
||||
async def change_password(data: ChangePassword, request: Request, user: UserType = Depends(get_authenticated_user)):
|
||||
host = request.headers.get("Host")
|
||||
await validate_login_request(data.response_data, data, user, "changePassword", host)
|
||||
await sync_to_async(save_changed_password)(data, user)
|
||||
|
||||
|
||||
@authentication_router.post("/dashboard_url/", responses=permission_responses)
|
||||
def dashboard_url(request: Request, user: UserType = Depends(get_authenticated_user)):
|
||||
get_dashboard_url = app_settings.DASHBOARD_URL_FUNC
|
||||
if get_dashboard_url is None:
|
||||
raise HttpError("not_supported", "This server doesn't have a user dashboard.")
|
||||
|
||||
ret = {
|
||||
"url": get_dashboard_url(CallbackContext(request.path_params, user=user)),
|
||||
}
|
||||
return ret
|
||||
|
||||
|
||||
def signup_save(data: SignupIn, request: Request) -> UserType:
|
||||
user_data = data.user
|
||||
with transaction.atomic():
|
||||
try:
|
||||
user_queryset = get_user_queryset(User.objects.all(), CallbackContext(request.path_params))
|
||||
instance = user_queryset.get(**{User.USERNAME_FIELD: user_data.username.lower()})
|
||||
except User.DoesNotExist:
|
||||
# Create the user and save the casing the user chose as the first name
|
||||
try:
|
||||
instance = create_user(
|
||||
CallbackContext(request.path_params),
|
||||
**user_data.dict(),
|
||||
password=None,
|
||||
first_name=user_data.username,
|
||||
)
|
||||
instance.full_clean()
|
||||
except HttpError as e:
|
||||
raise e
|
||||
except django_exceptions.ValidationError as e:
|
||||
transform_validation_error("user", e)
|
||||
except Exception as e:
|
||||
raise HttpError("generic", str(e))
|
||||
|
||||
if hasattr(instance, "userinfo"):
|
||||
raise HttpError("user_exists", "User already exists", status_code=status.HTTP_409_CONFLICT)
|
||||
|
||||
models.UserInfo.objects.create(**data.dict(exclude={"user"}), owner=instance)
|
||||
return instance
|
||||
|
||||
|
||||
@authentication_router.post("/signup/", response_model=LoginOut, status_code=status.HTTP_201_CREATED)
|
||||
def signup(data: SignupIn, request: Request):
|
||||
user = signup_save(data, request)
|
||||
ret = LoginOut.from_orm(user)
|
||||
user_signed_up.send(sender=user.__class__, request=None, user=user)
|
||||
return ret
|
589
etebase_fastapi/collection.py
Normal file
589
etebase_fastapi/collection.py
Normal file
@ -0,0 +1,589 @@
|
||||
import typing as t
|
||||
|
||||
from asgiref.sync import sync_to_async
|
||||
from django.core import exceptions as django_exceptions
|
||||
from django.core.files.base import ContentFile
|
||||
from django.db import transaction, IntegrityError
|
||||
from django.db.models import Q, QuerySet
|
||||
from fastapi import APIRouter, Depends, status, Request
|
||||
|
||||
from django_etebase import models
|
||||
from myauth.models import UserType, get_typed_user_model
|
||||
from .authentication import get_authenticated_user
|
||||
from .exceptions import HttpError, transform_validation_error, PermissionDenied, ValidationError
|
||||
from .msgpack import MsgpackRoute
|
||||
from .stoken_handler import filter_by_stoken_and_limit, filter_by_stoken, get_stoken_obj, get_queryset_stoken
|
||||
from .utils import (
|
||||
get_object_or_404,
|
||||
Context,
|
||||
Prefetch,
|
||||
PrefetchQuery,
|
||||
is_collection_admin,
|
||||
BaseModel,
|
||||
permission_responses,
|
||||
PERMISSIONS_READ,
|
||||
PERMISSIONS_READWRITE,
|
||||
)
|
||||
from .dependencies import get_collection_queryset, get_item_queryset, get_collection
|
||||
from .sendfile import sendfile
|
||||
|
||||
User = get_typed_user_model
|
||||
collection_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses)
|
||||
item_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses)
|
||||
|
||||
|
||||
class ListMulti(BaseModel):
|
||||
collectionTypes: t.List[bytes]
|
||||
|
||||
|
||||
ChunkType = t.Tuple[str, t.Optional[bytes]]
|
||||
|
||||
|
||||
class CollectionItemRevisionInOut(BaseModel):
|
||||
uid: str
|
||||
meta: bytes
|
||||
deleted: bool
|
||||
chunks: t.List[ChunkType]
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
|
||||
@classmethod
|
||||
def from_orm_context(
|
||||
cls: t.Type["CollectionItemRevisionInOut"], obj: models.CollectionItemRevision, context: Context
|
||||
) -> "CollectionItemRevisionInOut":
|
||||
chunks: t.List[ChunkType] = []
|
||||
for chunk_relation in obj.chunks_relation.all():
|
||||
chunk_obj = chunk_relation.chunk
|
||||
if context.prefetch == "auto":
|
||||
with open(chunk_obj.chunkFile.path, "rb") as f:
|
||||
chunks.append((chunk_obj.uid, f.read()))
|
||||
else:
|
||||
chunks.append((chunk_obj.uid, None))
|
||||
return cls(uid=obj.uid, meta=bytes(obj.meta), deleted=obj.deleted, chunks=chunks)
|
||||
|
||||
|
||||
class CollectionItemCommon(BaseModel):
|
||||
uid: str
|
||||
version: int
|
||||
encryptionKey: t.Optional[bytes]
|
||||
content: CollectionItemRevisionInOut
|
||||
|
||||
|
||||
class CollectionItemOut(CollectionItemCommon):
|
||||
class Config:
|
||||
orm_mode = True
|
||||
|
||||
@classmethod
|
||||
def from_orm_context(
|
||||
cls: t.Type["CollectionItemOut"], obj: models.CollectionItem, context: Context
|
||||
) -> "CollectionItemOut":
|
||||
return cls(
|
||||
uid=obj.uid,
|
||||
version=obj.version,
|
||||
encryptionKey=obj.encryptionKey,
|
||||
content=CollectionItemRevisionInOut.from_orm_context(obj.content, context),
|
||||
)
|
||||
|
||||
|
||||
class CollectionItemIn(CollectionItemCommon):
|
||||
etag: t.Optional[str]
|
||||
|
||||
|
||||
class CollectionCommon(BaseModel):
|
||||
# FIXME: remove optional once we finish collection-type-migration
|
||||
collectionType: t.Optional[bytes]
|
||||
collectionKey: bytes
|
||||
|
||||
|
||||
class CollectionOut(CollectionCommon):
|
||||
accessLevel: models.AccessLevels
|
||||
stoken: str
|
||||
item: CollectionItemOut
|
||||
|
||||
@classmethod
|
||||
def from_orm_context(cls: t.Type["CollectionOut"], obj: models.Collection, context: Context) -> "CollectionOut":
|
||||
member: models.CollectionMember = obj.members.get(user=context.user)
|
||||
collection_type = member.collectionType
|
||||
ret = cls(
|
||||
collectionType=collection_type and bytes(collection_type.uid),
|
||||
collectionKey=bytes(member.encryptionKey),
|
||||
accessLevel=member.accessLevel,
|
||||
stoken=obj.stoken,
|
||||
item=CollectionItemOut.from_orm_context(obj.main_item, context),
|
||||
)
|
||||
return ret
|
||||
|
||||
|
||||
class CollectionIn(CollectionCommon):
|
||||
item: CollectionItemIn
|
||||
|
||||
|
||||
class RemovedMembershipOut(BaseModel):
|
||||
uid: str
|
||||
|
||||
|
||||
class CollectionListResponse(BaseModel):
|
||||
data: t.List[CollectionOut]
|
||||
stoken: t.Optional[str]
|
||||
done: bool
|
||||
|
||||
removedMemberships: t.Optional[t.List[RemovedMembershipOut]]
|
||||
|
||||
|
||||
class CollectionItemListResponse(BaseModel):
|
||||
data: t.List[CollectionItemOut]
|
||||
stoken: t.Optional[str]
|
||||
done: bool
|
||||
|
||||
|
||||
class CollectionItemRevisionListResponse(BaseModel):
|
||||
data: t.List[CollectionItemRevisionInOut]
|
||||
iterator: t.Optional[str]
|
||||
done: bool
|
||||
|
||||
|
||||
class CollectionItemBulkGetIn(BaseModel):
|
||||
uid: str
|
||||
etag: t.Optional[str]
|
||||
|
||||
|
||||
class ItemDepIn(BaseModel):
|
||||
uid: str
|
||||
etag: str
|
||||
|
||||
def validate_db(self):
|
||||
item = models.CollectionItem.objects.get(uid=self.uid)
|
||||
etag = self.etag
|
||||
if item.etag != etag:
|
||||
raise ValidationError(
|
||||
"wrong_etag",
|
||||
"Wrong etag. Expected {} got {}".format(item.etag, etag),
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
field=self.uid,
|
||||
)
|
||||
|
||||
|
||||
class ItemBatchIn(BaseModel):
|
||||
items: t.List[CollectionItemIn]
|
||||
deps: t.Optional[t.List[ItemDepIn]]
|
||||
|
||||
def validate_db(self):
|
||||
if self.deps is not None:
|
||||
errors: t.List[HttpError] = []
|
||||
for dep in self.deps:
|
||||
try:
|
||||
dep.validate_db()
|
||||
except ValidationError as e:
|
||||
errors.append(e)
|
||||
if len(errors) > 0:
|
||||
raise ValidationError(
|
||||
code="dep_failed",
|
||||
detail="Dependencies failed to validate",
|
||||
errors=errors,
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
)
|
||||
|
||||
|
||||
@sync_to_async
|
||||
def collection_list_common(
|
||||
queryset: QuerySet,
|
||||
user: UserType,
|
||||
stoken: t.Optional[str],
|
||||
limit: int,
|
||||
prefetch: Prefetch,
|
||||
) -> CollectionListResponse:
|
||||
result, new_stoken_obj, done = filter_by_stoken_and_limit(
|
||||
stoken, limit, queryset, models.Collection.stoken_annotation
|
||||
)
|
||||
new_stoken = new_stoken_obj and new_stoken_obj.uid
|
||||
context = Context(user, prefetch)
|
||||
data: t.List[CollectionOut] = [CollectionOut.from_orm_context(item, context) for item in result]
|
||||
|
||||
ret = CollectionListResponse(data=data, stoken=new_stoken, done=done)
|
||||
|
||||
stoken_obj = get_stoken_obj(stoken)
|
||||
if stoken_obj is not None:
|
||||
# FIXME: honour limit? (the limit should be combined for data and this because of stoken)
|
||||
remed_qs = models.CollectionMemberRemoved.objects.filter(user=user, stoken__id__gt=stoken_obj.id)
|
||||
if not done and new_stoken_obj is not None:
|
||||
# We only filter by the new_stoken if we are not done. This is because if we are done, the new stoken
|
||||
# can point to the most recent collection change rather than most recent removed membership.
|
||||
remed_qs = remed_qs.filter(stoken__id__lte=new_stoken_obj.id)
|
||||
|
||||
remed = remed_qs.values_list("collection__uid", flat=True)
|
||||
if len(remed) > 0:
|
||||
ret.removedMemberships = [RemovedMembershipOut(uid=x) for x in remed]
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
# permissions
|
||||
|
||||
|
||||
def verify_collection_admin(
|
||||
collection: models.Collection = Depends(get_collection), user: UserType = Depends(get_authenticated_user)
|
||||
):
|
||||
if not is_collection_admin(collection, user):
|
||||
raise PermissionDenied("admin_access_required", "Only collection admins can perform this operation.")
|
||||
|
||||
|
||||
def has_write_access(
|
||||
collection: models.Collection = Depends(get_collection), user: UserType = Depends(get_authenticated_user)
|
||||
):
|
||||
member = collection.members.get(user=user)
|
||||
if member.accessLevel == models.AccessLevels.READ_ONLY:
|
||||
raise PermissionDenied("no_write_access", "You need write access to write to this collection")
|
||||
|
||||
|
||||
# paths
|
||||
|
||||
|
||||
@collection_router.post(
|
||||
"/list_multi/",
|
||||
response_model=CollectionListResponse,
|
||||
response_model_exclude_unset=True,
|
||||
dependencies=PERMISSIONS_READ,
|
||||
)
|
||||
async def list_multi(
|
||||
data: ListMulti,
|
||||
stoken: t.Optional[str] = None,
|
||||
limit: int = 50,
|
||||
queryset: QuerySet = Depends(get_collection_queryset),
|
||||
user: UserType = Depends(get_authenticated_user),
|
||||
prefetch: Prefetch = PrefetchQuery,
|
||||
):
|
||||
# FIXME: Remove the isnull part once we attach collection types to all objects ("collection-type-migration")
|
||||
queryset = queryset.filter(
|
||||
Q(members__collectionType__uid__in=data.collectionTypes) | Q(members__collectionType__isnull=True)
|
||||
)
|
||||
|
||||
return await collection_list_common(queryset, user, stoken, limit, prefetch)
|
||||
|
||||
|
||||
@collection_router.get("/", response_model=CollectionListResponse, dependencies=PERMISSIONS_READ)
|
||||
async def collection_list(
|
||||
stoken: t.Optional[str] = None,
|
||||
limit: int = 50,
|
||||
prefetch: Prefetch = PrefetchQuery,
|
||||
user: UserType = Depends(get_authenticated_user),
|
||||
queryset: QuerySet = Depends(get_collection_queryset),
|
||||
):
|
||||
return await collection_list_common(queryset, user, stoken, limit, prefetch)
|
||||
|
||||
|
||||
def process_revisions_for_item(item: models.CollectionItem, revision_data: CollectionItemRevisionInOut):
|
||||
chunks_objs = []
|
||||
|
||||
revision = models.CollectionItemRevision(**revision_data.dict(exclude={"chunks"}), item=item)
|
||||
revision.validate_unique() # Verify there aren't any validation issues
|
||||
|
||||
for chunk in revision_data.chunks:
|
||||
uid = chunk[0]
|
||||
chunk_obj = models.CollectionItemChunk.objects.filter(uid=uid).first()
|
||||
content = chunk[1] if len(chunk) > 1 else None
|
||||
# If the chunk already exists we assume it's fine. Otherwise, we upload it.
|
||||
if chunk_obj is None:
|
||||
if content is not None:
|
||||
chunk_obj = models.CollectionItemChunk(uid=uid, collection=item.collection)
|
||||
chunk_obj.chunkFile.save("IGNORED", ContentFile(content))
|
||||
chunk_obj.save()
|
||||
else:
|
||||
raise ValidationError("chunk_no_content", "Tried to create a new chunk without content")
|
||||
|
||||
chunks_objs.append(chunk_obj)
|
||||
|
||||
stoken = models.Stoken.objects.create()
|
||||
revision.stoken = stoken
|
||||
revision.save()
|
||||
|
||||
for chunk in chunks_objs:
|
||||
models.RevisionChunkRelation.objects.create(chunk=chunk, revision=revision)
|
||||
return revision
|
||||
|
||||
|
||||
def _create(data: CollectionIn, user: UserType):
|
||||
with transaction.atomic():
|
||||
if data.item.etag is not None:
|
||||
raise ValidationError("bad_etag", "etag is not null")
|
||||
instance = models.Collection(uid=data.item.uid, owner=user)
|
||||
try:
|
||||
instance.validate_unique()
|
||||
except django_exceptions.ValidationError:
|
||||
raise ValidationError(
|
||||
"unique_uid", "Collection with this uid already exists", status_code=status.HTTP_409_CONFLICT
|
||||
)
|
||||
instance.save()
|
||||
|
||||
main_item = models.CollectionItem.objects.create(
|
||||
uid=data.item.uid, version=data.item.version, collection=instance
|
||||
)
|
||||
|
||||
instance.main_item = main_item
|
||||
instance.save()
|
||||
|
||||
# TODO
|
||||
process_revisions_for_item(main_item, data.item.content)
|
||||
|
||||
collection_type_obj, _ = models.CollectionType.objects.get_or_create(uid=data.collectionType, owner=user)
|
||||
|
||||
models.CollectionMember(
|
||||
collection=instance,
|
||||
stoken=models.Stoken.objects.create(),
|
||||
user=user,
|
||||
accessLevel=models.AccessLevels.ADMIN,
|
||||
encryptionKey=data.collectionKey,
|
||||
collectionType=collection_type_obj,
|
||||
).save()
|
||||
|
||||
|
||||
@collection_router.post("/", status_code=status.HTTP_201_CREATED, dependencies=PERMISSIONS_READWRITE)
|
||||
async def create(data: CollectionIn, user: UserType = Depends(get_authenticated_user)):
|
||||
await sync_to_async(_create)(data, user)
|
||||
|
||||
|
||||
@collection_router.get("/{collection_uid}/", response_model=CollectionOut, dependencies=PERMISSIONS_READ)
|
||||
def collection_get(
|
||||
obj: models.Collection = Depends(get_collection),
|
||||
user: UserType = Depends(get_authenticated_user),
|
||||
prefetch: Prefetch = PrefetchQuery,
|
||||
):
|
||||
return CollectionOut.from_orm_context(obj, Context(user, prefetch))
|
||||
|
||||
|
||||
def item_create(item_model: CollectionItemIn, collection: models.Collection, validate_etag: bool):
|
||||
"""Function that's called when this serializer creates an item"""
|
||||
etag = item_model.etag
|
||||
revision_data = item_model.content
|
||||
uid = item_model.uid
|
||||
|
||||
Model = models.CollectionItem
|
||||
|
||||
with transaction.atomic():
|
||||
instance, created = Model.objects.get_or_create(
|
||||
uid=uid, collection=collection, defaults=item_model.dict(exclude={"uid", "etag", "content"})
|
||||
)
|
||||
cur_etag = instance.etag if not created else None
|
||||
|
||||
# If we are trying to update an up to date item, abort early and consider it a success
|
||||
if cur_etag == revision_data.uid:
|
||||
return instance
|
||||
|
||||
if validate_etag and cur_etag != etag:
|
||||
raise ValidationError(
|
||||
"wrong_etag",
|
||||
"Wrong etag. Expected {} got {}".format(cur_etag, etag),
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
field=uid,
|
||||
)
|
||||
|
||||
if not created:
|
||||
# We don't have to use select_for_update here because the unique constraint on current guards against
|
||||
# the race condition. But it's a good idea because it'll lock and wait rather than fail.
|
||||
current_revision = instance.revisions.filter(current=True).select_for_update().first()
|
||||
current_revision.current = None
|
||||
current_revision.save()
|
||||
|
||||
try:
|
||||
process_revisions_for_item(instance, revision_data)
|
||||
except django_exceptions.ValidationError as e:
|
||||
transform_validation_error("content", e)
|
||||
|
||||
return instance
|
||||
|
||||
|
||||
@item_router.get("/item/{item_uid}/", response_model=CollectionItemOut, dependencies=PERMISSIONS_READ)
|
||||
def item_get(
|
||||
item_uid: str,
|
||||
queryset: QuerySet = Depends(get_item_queryset),
|
||||
user: UserType = Depends(get_authenticated_user),
|
||||
prefetch: Prefetch = PrefetchQuery,
|
||||
):
|
||||
obj = queryset.get(uid=item_uid)
|
||||
return CollectionItemOut.from_orm_context(obj, Context(user, prefetch))
|
||||
|
||||
|
||||
@sync_to_async
|
||||
def item_list_common(
|
||||
queryset: QuerySet,
|
||||
user: UserType,
|
||||
stoken: t.Optional[str],
|
||||
limit: int,
|
||||
prefetch: Prefetch,
|
||||
) -> CollectionItemListResponse:
|
||||
result, new_stoken_obj, done = filter_by_stoken_and_limit(
|
||||
stoken, limit, queryset, models.CollectionItem.stoken_annotation
|
||||
)
|
||||
new_stoken = new_stoken_obj and new_stoken_obj.uid
|
||||
context = Context(user, prefetch)
|
||||
data: t.List[CollectionItemOut] = [CollectionItemOut.from_orm_context(item, context) for item in result]
|
||||
return CollectionItemListResponse(data=data, stoken=new_stoken, done=done)
|
||||
|
||||
|
||||
@item_router.get("/item/", response_model=CollectionItemListResponse, dependencies=PERMISSIONS_READ)
|
||||
async def item_list(
|
||||
queryset: QuerySet = Depends(get_item_queryset),
|
||||
stoken: t.Optional[str] = None,
|
||||
limit: int = 50,
|
||||
prefetch: Prefetch = PrefetchQuery,
|
||||
withCollection: bool = False,
|
||||
user: UserType = Depends(get_authenticated_user),
|
||||
):
|
||||
if not withCollection:
|
||||
queryset = queryset.filter(parent__isnull=True)
|
||||
|
||||
response = await item_list_common(queryset, user, stoken, limit, prefetch)
|
||||
return response
|
||||
|
||||
|
||||
def item_bulk_common(data: ItemBatchIn, user: UserType, stoken: t.Optional[str], uid: str, validate_etag: bool):
|
||||
queryset = get_collection_queryset(user)
|
||||
with transaction.atomic(): # We need this for locking the collection object
|
||||
collection_object = queryset.select_for_update().get(uid=uid)
|
||||
|
||||
if stoken is not None and stoken != collection_object.stoken:
|
||||
raise HttpError("stale_stoken", "Stoken is too old", status_code=status.HTTP_409_CONFLICT)
|
||||
|
||||
data.validate_db()
|
||||
|
||||
errors: t.List[HttpError] = []
|
||||
for item in data.items:
|
||||
try:
|
||||
item_create(item, collection_object, validate_etag)
|
||||
except ValidationError as e:
|
||||
errors.append(e)
|
||||
|
||||
if len(errors) > 0:
|
||||
raise ValidationError(
|
||||
code="item_failed",
|
||||
detail="Items failed to validate",
|
||||
errors=errors,
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
)
|
||||
|
||||
|
||||
@item_router.get(
|
||||
"/item/{item_uid}/revision/", response_model=CollectionItemRevisionListResponse, dependencies=PERMISSIONS_READ
|
||||
)
|
||||
def item_revisions(
|
||||
item_uid: str,
|
||||
limit: int = 50,
|
||||
iterator: t.Optional[str] = None,
|
||||
prefetch: Prefetch = PrefetchQuery,
|
||||
user: UserType = Depends(get_authenticated_user),
|
||||
items: QuerySet = Depends(get_item_queryset),
|
||||
):
|
||||
item = get_object_or_404(items, uid=item_uid)
|
||||
|
||||
queryset = item.revisions.order_by("-id")
|
||||
|
||||
if iterator is not None:
|
||||
iterator_obj = get_object_or_404(queryset, uid=iterator)
|
||||
queryset = queryset.filter(id__lt=iterator_obj.id)
|
||||
|
||||
result = list(queryset[: limit + 1])
|
||||
if len(result) < limit + 1:
|
||||
done = True
|
||||
else:
|
||||
done = False
|
||||
result = result[:-1]
|
||||
|
||||
context = Context(user, prefetch)
|
||||
ret_data = [CollectionItemRevisionInOut.from_orm_context(revision, context) for revision in result]
|
||||
iterator = ret_data[-1].uid if len(result) > 0 else None
|
||||
|
||||
return CollectionItemRevisionListResponse(
|
||||
data=ret_data,
|
||||
iterator=iterator,
|
||||
done=done,
|
||||
)
|
||||
|
||||
|
||||
@item_router.post("/item/fetch_updates/", response_model=CollectionItemListResponse, dependencies=PERMISSIONS_READ)
|
||||
def fetch_updates(
|
||||
data: t.List[CollectionItemBulkGetIn],
|
||||
stoken: t.Optional[str] = None,
|
||||
prefetch: Prefetch = PrefetchQuery,
|
||||
user: UserType = Depends(get_authenticated_user),
|
||||
queryset: QuerySet = Depends(get_item_queryset),
|
||||
):
|
||||
# FIXME: make configurable?
|
||||
item_limit = 200
|
||||
|
||||
if len(data) > item_limit:
|
||||
raise HttpError("too_many_items", "Request has too many items.", status_code=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
queryset, stoken_rev = filter_by_stoken(stoken, queryset, models.CollectionItem.stoken_annotation)
|
||||
|
||||
uids, etags = zip(*[(item.uid, item.etag) for item in data])
|
||||
revs = models.CollectionItemRevision.objects.filter(uid__in=etags, current=True)
|
||||
queryset = queryset.filter(uid__in=uids).exclude(revisions__in=revs)
|
||||
|
||||
new_stoken_obj = get_queryset_stoken(queryset)
|
||||
new_stoken = new_stoken_obj and new_stoken_obj.uid
|
||||
stoken = stoken_rev and getattr(stoken_rev, "uid", None)
|
||||
new_stoken = new_stoken or stoken
|
||||
|
||||
context = Context(user, prefetch)
|
||||
return CollectionItemListResponse(
|
||||
data=[CollectionItemOut.from_orm_context(item, context) for item in queryset],
|
||||
stoken=new_stoken,
|
||||
done=True, # we always return all the items, so it's always done
|
||||
)
|
||||
|
||||
|
||||
@item_router.post("/item/transaction/", dependencies=[Depends(has_write_access), *PERMISSIONS_READWRITE])
|
||||
def item_transaction(
|
||||
collection_uid: str, data: ItemBatchIn, stoken: t.Optional[str] = None, user: UserType = Depends(get_authenticated_user)
|
||||
):
|
||||
return item_bulk_common(data, user, stoken, collection_uid, validate_etag=True)
|
||||
|
||||
|
||||
@item_router.post("/item/batch/", dependencies=[Depends(has_write_access), *PERMISSIONS_READWRITE])
|
||||
def item_batch(
|
||||
collection_uid: str, data: ItemBatchIn, stoken: t.Optional[str] = None, user: UserType = Depends(get_authenticated_user)
|
||||
):
|
||||
return item_bulk_common(data, user, stoken, collection_uid, validate_etag=False)
|
||||
|
||||
|
||||
# Chunks
|
||||
|
||||
|
||||
@sync_to_async
|
||||
def chunk_save(chunk_uid: str, collection: models.Collection, content_file: ContentFile):
|
||||
chunk_obj = models.CollectionItemChunk(uid=chunk_uid, collection=collection)
|
||||
chunk_obj.chunkFile.save("IGNORED", content_file)
|
||||
chunk_obj.save()
|
||||
return chunk_obj
|
||||
|
||||
|
||||
@item_router.put(
|
||||
"/item/{item_uid}/chunk/{chunk_uid}/",
|
||||
dependencies=[Depends(has_write_access), *PERMISSIONS_READWRITE],
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
async def chunk_update(
|
||||
request: Request,
|
||||
chunk_uid: str,
|
||||
collection: models.Collection = Depends(get_collection),
|
||||
):
|
||||
# IGNORED FOR NOW: col_it = get_object_or_404(col.items, uid=collection_item_uid)
|
||||
content_file = ContentFile(await request.body())
|
||||
try:
|
||||
await chunk_save(chunk_uid, collection, content_file)
|
||||
except IntegrityError:
|
||||
raise HttpError("chunk_exists", "Chunk already exists.", status_code=status.HTTP_409_CONFLICT)
|
||||
|
||||
|
||||
@item_router.get(
|
||||
"/item/{item_uid}/chunk/{chunk_uid}/download/",
|
||||
dependencies=PERMISSIONS_READ,
|
||||
)
|
||||
def chunk_download(
|
||||
chunk_uid: str,
|
||||
collection: models.Collection = Depends(get_collection),
|
||||
):
|
||||
chunk = get_object_or_404(collection.chunks, uid=chunk_uid)
|
||||
|
||||
filename = chunk.chunkFile.path
|
||||
return sendfile(filename)
|
82
etebase_fastapi/dependencies.py
Normal file
82
etebase_fastapi/dependencies.py
Normal file
@ -0,0 +1,82 @@
|
||||
import dataclasses
|
||||
|
||||
from fastapi import Depends
|
||||
from fastapi.security import APIKeyHeader
|
||||
|
||||
from django.utils import timezone
|
||||
from django.db.models import QuerySet
|
||||
|
||||
from django_etebase import models
|
||||
from django_etebase.token_auth.models import AuthToken, get_default_expiry
|
||||
from myauth.models import UserType, get_typed_user_model
|
||||
from .exceptions import AuthenticationFailed
|
||||
from .utils import get_object_or_404
|
||||
|
||||
|
||||
User = get_typed_user_model()
|
||||
token_scheme = APIKeyHeader(name="Authorization")
|
||||
AUTO_REFRESH = True
|
||||
MIN_REFRESH_INTERVAL = 60
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class AuthData:
|
||||
user: UserType
|
||||
token: AuthToken
|
||||
|
||||
|
||||
def __renew_token(auth_token: AuthToken):
|
||||
current_expiry = auth_token.expiry
|
||||
new_expiry = get_default_expiry()
|
||||
# Throttle refreshing of token to avoid db writes
|
||||
delta = (new_expiry - current_expiry).total_seconds()
|
||||
if delta > MIN_REFRESH_INTERVAL:
|
||||
auth_token.expiry = new_expiry
|
||||
auth_token.save(update_fields=("expiry",))
|
||||
|
||||
|
||||
def __get_authenticated_user(api_token: str):
|
||||
api_token = api_token.split()[1]
|
||||
try:
|
||||
token: AuthToken = AuthToken.objects.select_related("user").get(key=api_token)
|
||||
except AuthToken.DoesNotExist:
|
||||
raise AuthenticationFailed(detail="Invalid token.")
|
||||
if not token.user.is_active:
|
||||
raise AuthenticationFailed(detail="User inactive or deleted.")
|
||||
|
||||
if token.expiry is not None:
|
||||
if token.expiry < timezone.now():
|
||||
token.delete()
|
||||
raise AuthenticationFailed(detail="Invalid token.")
|
||||
|
||||
if AUTO_REFRESH:
|
||||
__renew_token(token)
|
||||
|
||||
return token.user, token
|
||||
|
||||
|
||||
def get_auth_data(api_token: str = Depends(token_scheme)) -> AuthData:
|
||||
user, token = __get_authenticated_user(api_token)
|
||||
return AuthData(user, token)
|
||||
|
||||
|
||||
def get_authenticated_user(api_token: str = Depends(token_scheme)) -> UserType:
|
||||
user, _ = __get_authenticated_user(api_token)
|
||||
return user
|
||||
|
||||
|
||||
def get_collection_queryset(user: UserType = Depends(get_authenticated_user)) -> QuerySet:
|
||||
default_queryset: QuerySet = models.Collection.objects.all()
|
||||
return default_queryset.filter(members__user=user)
|
||||
|
||||
|
||||
def get_collection(collection_uid: str, queryset: QuerySet = Depends(get_collection_queryset)) -> models.Collection:
|
||||
return get_object_or_404(queryset, uid=collection_uid)
|
||||
|
||||
|
||||
def get_item_queryset(collection: models.Collection = Depends(get_collection)) -> QuerySet:
|
||||
default_item_queryset: QuerySet = models.CollectionItem.objects.all()
|
||||
# XXX Potentially add this for performance: .prefetch_related('revisions__chunks')
|
||||
queryset = default_item_queryset.filter(collection__pk=collection.pk, revisions__current=True)
|
||||
|
||||
return queryset
|
118
etebase_fastapi/exceptions.py
Normal file
118
etebase_fastapi/exceptions.py
Normal file
@ -0,0 +1,118 @@
|
||||
from fastapi import status
|
||||
import typing as t
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class HttpErrorField(BaseModel):
|
||||
field: str
|
||||
code: str
|
||||
detail: str
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
|
||||
|
||||
class HttpErrorOut(BaseModel):
|
||||
code: str
|
||||
detail: str
|
||||
errors: t.Optional[t.List[HttpErrorField]]
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
|
||||
|
||||
class CustomHttpException(Exception):
|
||||
def __init__(self, code: str, detail: str, status_code: int = status.HTTP_400_BAD_REQUEST):
|
||||
self.status_code = status_code
|
||||
self.code = code
|
||||
self.detail = detail
|
||||
|
||||
@property
|
||||
def as_dict(self) -> dict:
|
||||
return {"code": self.code, "detail": self.detail}
|
||||
|
||||
|
||||
class AuthenticationFailed(CustomHttpException):
|
||||
def __init__(
|
||||
self,
|
||||
code="authentication_failed",
|
||||
detail: str = "Incorrect authentication credentials.",
|
||||
status_code: int = status.HTTP_401_UNAUTHORIZED,
|
||||
):
|
||||
super().__init__(code=code, detail=detail, status_code=status_code)
|
||||
|
||||
|
||||
class NotAuthenticated(CustomHttpException):
|
||||
def __init__(
|
||||
self,
|
||||
code="not_authenticated",
|
||||
detail: str = "Authentication credentials were not provided.",
|
||||
status_code: int = status.HTTP_401_UNAUTHORIZED,
|
||||
):
|
||||
super().__init__(code=code, detail=detail, status_code=status_code)
|
||||
|
||||
|
||||
class PermissionDenied(CustomHttpException):
|
||||
def __init__(
|
||||
self,
|
||||
code="permission_denied",
|
||||
detail: str = "You do not have permission to perform this action.",
|
||||
status_code: int = status.HTTP_403_FORBIDDEN,
|
||||
):
|
||||
super().__init__(code=code, detail=detail, status_code=status_code)
|
||||
|
||||
|
||||
class HttpError(CustomHttpException):
|
||||
def __init__(
|
||||
self,
|
||||
code: str,
|
||||
detail: str,
|
||||
status_code: int = status.HTTP_400_BAD_REQUEST,
|
||||
errors: t.Optional[t.List["HttpError"]] = None,
|
||||
):
|
||||
self.errors = errors
|
||||
super().__init__(code=code, detail=detail, status_code=status_code)
|
||||
|
||||
@property
|
||||
def as_dict(self) -> dict:
|
||||
return HttpErrorOut(code=self.code, errors=self.errors, detail=self.detail).dict()
|
||||
|
||||
|
||||
class ValidationError(HttpError):
|
||||
def __init__(
|
||||
self,
|
||||
code: str,
|
||||
detail: str,
|
||||
status_code: int = status.HTTP_400_BAD_REQUEST,
|
||||
errors: t.Optional[t.List["HttpError"]] = None,
|
||||
field: t.Optional[str] = None,
|
||||
):
|
||||
self.field = field
|
||||
super().__init__(code=code, detail=detail, errors=errors, status_code=status_code)
|
||||
|
||||
|
||||
def flatten_errors(field_name, errors) -> t.List[HttpError]:
|
||||
ret = []
|
||||
if isinstance(errors, dict):
|
||||
for error_key in errors:
|
||||
error = errors[error_key]
|
||||
ret.extend(flatten_errors("{}.{}".format(field_name, error_key), error))
|
||||
else:
|
||||
for error in errors:
|
||||
if error.messages:
|
||||
message = error.messages[0]
|
||||
else:
|
||||
message = str(error)
|
||||
ret.append(dict(code=error.code, detail=message, field=field_name))
|
||||
return ret
|
||||
|
||||
|
||||
def transform_validation_error(prefix, err):
|
||||
if hasattr(err, "error_dict"):
|
||||
errors = flatten_errors(prefix, err.error_dict)
|
||||
elif not hasattr(err, "message"):
|
||||
errors = flatten_errors(prefix, err.error_list)
|
||||
else:
|
||||
raise HttpError(err.code, err.message)
|
||||
raise HttpError(code="field_errors", detail="Field validations failed.", errors=errors)
|
240
etebase_fastapi/invitation.py
Normal file
240
etebase_fastapi/invitation.py
Normal file
@ -0,0 +1,240 @@
|
||||
import typing as t
|
||||
|
||||
from django.db import transaction, IntegrityError
|
||||
from django.db.models import QuerySet
|
||||
from fastapi import APIRouter, Depends, status, Request
|
||||
|
||||
from django_etebase import models
|
||||
from django_etebase.utils import get_user_queryset, CallbackContext
|
||||
from myauth.models import UserType, get_typed_user_model
|
||||
from .authentication import get_authenticated_user
|
||||
from .exceptions import HttpError, PermissionDenied
|
||||
from .msgpack import MsgpackRoute
|
||||
from .utils import (
|
||||
get_object_or_404,
|
||||
Context,
|
||||
is_collection_admin,
|
||||
BaseModel,
|
||||
permission_responses,
|
||||
PERMISSIONS_READ,
|
||||
PERMISSIONS_READWRITE,
|
||||
)
|
||||
|
||||
User = get_typed_user_model()
|
||||
invitation_incoming_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses)
|
||||
invitation_outgoing_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses)
|
||||
default_queryset: QuerySet = models.CollectionInvitation.objects.all()
|
||||
|
||||
|
||||
class UserInfoOut(BaseModel):
|
||||
pubkey: bytes
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
|
||||
@classmethod
|
||||
def from_orm(cls: t.Type["UserInfoOut"], obj: models.UserInfo) -> "UserInfoOut":
|
||||
return cls(pubkey=bytes(obj.pubkey))
|
||||
|
||||
|
||||
class CollectionInvitationAcceptIn(BaseModel):
|
||||
collectionType: bytes
|
||||
encryptionKey: bytes
|
||||
|
||||
|
||||
class CollectionInvitationCommon(BaseModel):
|
||||
uid: str
|
||||
version: int
|
||||
accessLevel: models.AccessLevels
|
||||
username: str
|
||||
collection: str
|
||||
signedEncryptionKey: bytes
|
||||
|
||||
|
||||
class CollectionInvitationIn(CollectionInvitationCommon):
|
||||
def validate_db(self, context: Context):
|
||||
user = context.user
|
||||
if user is not None and (user.username == self.username.lower()):
|
||||
raise HttpError("no_self_invite", "Inviting yourself is not allowed")
|
||||
|
||||
|
||||
class CollectionInvitationOut(CollectionInvitationCommon):
|
||||
fromUsername: str
|
||||
fromPubkey: bytes
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
|
||||
@classmethod
|
||||
def from_orm(cls: t.Type["CollectionInvitationOut"], obj: models.CollectionInvitation) -> "CollectionInvitationOut":
|
||||
return cls(
|
||||
uid=obj.uid,
|
||||
version=obj.version,
|
||||
accessLevel=obj.accessLevel,
|
||||
username=obj.user.username,
|
||||
collection=obj.collection.uid,
|
||||
fromUsername=obj.fromMember.user.username,
|
||||
fromPubkey=bytes(obj.fromMember.user.userinfo.pubkey),
|
||||
signedEncryptionKey=bytes(obj.signedEncryptionKey),
|
||||
)
|
||||
|
||||
|
||||
class InvitationListResponse(BaseModel):
|
||||
data: t.List[CollectionInvitationOut]
|
||||
iterator: t.Optional[str]
|
||||
done: bool
|
||||
|
||||
|
||||
def get_incoming_queryset(user: UserType = Depends(get_authenticated_user)):
|
||||
return default_queryset.filter(user=user)
|
||||
|
||||
|
||||
def get_outgoing_queryset(user: UserType = Depends(get_authenticated_user)):
|
||||
return default_queryset.filter(fromMember__user=user)
|
||||
|
||||
|
||||
def list_common(
|
||||
queryset: QuerySet,
|
||||
iterator: t.Optional[str],
|
||||
limit: int,
|
||||
) -> InvitationListResponse:
|
||||
queryset = queryset.order_by("id")
|
||||
|
||||
if iterator is not None:
|
||||
iterator_obj = get_object_or_404(queryset, uid=iterator)
|
||||
queryset = queryset.filter(id__gt=iterator_obj.id)
|
||||
|
||||
result = list(queryset[: limit + 1])
|
||||
if len(result) < limit + 1:
|
||||
done = True
|
||||
else:
|
||||
done = False
|
||||
result = result[:-1]
|
||||
|
||||
ret_data = result
|
||||
iterator = ret_data[-1].uid if len(result) > 0 else None
|
||||
|
||||
return InvitationListResponse(
|
||||
data=ret_data,
|
||||
iterator=iterator,
|
||||
done=done,
|
||||
)
|
||||
|
||||
|
||||
@invitation_incoming_router.get("/", response_model=InvitationListResponse, dependencies=PERMISSIONS_READ)
|
||||
def incoming_list(
|
||||
iterator: t.Optional[str] = None,
|
||||
limit: int = 50,
|
||||
queryset: QuerySet = Depends(get_incoming_queryset),
|
||||
):
|
||||
return list_common(queryset, iterator, limit)
|
||||
|
||||
|
||||
@invitation_incoming_router.get(
|
||||
"/{invitation_uid}/", response_model=CollectionInvitationOut, dependencies=PERMISSIONS_READ
|
||||
)
|
||||
def incoming_get(
|
||||
invitation_uid: str,
|
||||
queryset: QuerySet = Depends(get_incoming_queryset),
|
||||
):
|
||||
obj = get_object_or_404(queryset, uid=invitation_uid)
|
||||
return CollectionInvitationOut.from_orm(obj)
|
||||
|
||||
|
||||
@invitation_incoming_router.delete(
|
||||
"/{invitation_uid}/", status_code=status.HTTP_204_NO_CONTENT, dependencies=PERMISSIONS_READWRITE
|
||||
)
|
||||
def incoming_delete(
|
||||
invitation_uid: str,
|
||||
queryset: QuerySet = Depends(get_incoming_queryset),
|
||||
):
|
||||
obj = get_object_or_404(queryset, uid=invitation_uid)
|
||||
obj.delete()
|
||||
|
||||
|
||||
@invitation_incoming_router.post(
|
||||
"/{invitation_uid}/accept/", status_code=status.HTTP_201_CREATED, dependencies=PERMISSIONS_READWRITE
|
||||
)
|
||||
def incoming_accept(
|
||||
invitation_uid: str,
|
||||
data: CollectionInvitationAcceptIn,
|
||||
queryset: QuerySet = Depends(get_incoming_queryset),
|
||||
):
|
||||
invitation = get_object_or_404(queryset, uid=invitation_uid)
|
||||
|
||||
with transaction.atomic():
|
||||
user = invitation.user
|
||||
collection_type_obj, _ = models.CollectionType.objects.get_or_create(uid=data.collectionType, owner=user)
|
||||
|
||||
models.CollectionMember.objects.create(
|
||||
collection=invitation.collection,
|
||||
stoken=models.Stoken.objects.create(),
|
||||
user=user,
|
||||
accessLevel=invitation.accessLevel,
|
||||
encryptionKey=data.encryptionKey,
|
||||
collectionType=collection_type_obj,
|
||||
)
|
||||
|
||||
models.CollectionMemberRemoved.objects.filter(user=invitation.user, collection=invitation.collection).delete()
|
||||
|
||||
invitation.delete()
|
||||
|
||||
|
||||
@invitation_outgoing_router.post("/", status_code=status.HTTP_201_CREATED, dependencies=PERMISSIONS_READWRITE)
|
||||
def outgoing_create(
|
||||
data: CollectionInvitationIn,
|
||||
request: Request,
|
||||
user: UserType = Depends(get_authenticated_user),
|
||||
):
|
||||
collection = get_object_or_404(models.Collection.objects, uid=data.collection)
|
||||
to_user = get_object_or_404(
|
||||
get_user_queryset(User.objects.all(), CallbackContext(request.path_params)), username=data.username
|
||||
)
|
||||
|
||||
context = Context(user, None)
|
||||
data.validate_db(context)
|
||||
|
||||
if not is_collection_admin(collection, user):
|
||||
raise PermissionDenied("admin_access_required", "User is not an admin of this collection")
|
||||
|
||||
member = collection.members.get(user=user)
|
||||
|
||||
with transaction.atomic():
|
||||
try:
|
||||
ret = models.CollectionInvitation.objects.create(
|
||||
**data.dict(exclude={"collection", "username"}), user=to_user, fromMember=member
|
||||
)
|
||||
except IntegrityError:
|
||||
raise HttpError("invitation_exists", "Invitation already exists")
|
||||
|
||||
|
||||
@invitation_outgoing_router.get("/", response_model=InvitationListResponse, dependencies=PERMISSIONS_READ)
|
||||
def outgoing_list(
|
||||
iterator: t.Optional[str] = None,
|
||||
limit: int = 50,
|
||||
queryset: QuerySet = Depends(get_outgoing_queryset),
|
||||
):
|
||||
return list_common(queryset, iterator, limit)
|
||||
|
||||
|
||||
@invitation_outgoing_router.delete(
|
||||
"/{invitation_uid}/", status_code=status.HTTP_204_NO_CONTENT, dependencies=PERMISSIONS_READWRITE
|
||||
)
|
||||
def outgoing_delete(
|
||||
invitation_uid: str,
|
||||
queryset: QuerySet = Depends(get_outgoing_queryset),
|
||||
):
|
||||
obj = get_object_or_404(queryset, uid=invitation_uid)
|
||||
obj.delete()
|
||||
|
||||
|
||||
@invitation_outgoing_router.get("/fetch_user_profile/", response_model=UserInfoOut, dependencies=PERMISSIONS_READ)
|
||||
def outgoing_fetch_user_profile(
|
||||
username: str,
|
||||
request: Request,
|
||||
user: UserType = Depends(get_authenticated_user),
|
||||
):
|
||||
kwargs = {User.USERNAME_FIELD: username.lower()}
|
||||
user = get_object_or_404(get_user_queryset(User.objects.all(), CallbackContext(request.path_params)), **kwargs)
|
||||
user_info = get_object_or_404(models.UserInfo.objects.all(), owner=user)
|
||||
return UserInfoOut.from_orm(user_info)
|
61
etebase_fastapi/main.py
Normal file
61
etebase_fastapi/main.py
Normal file
@ -0,0 +1,61 @@
|
||||
from django.conf import settings
|
||||
|
||||
# Not at the top of the file because we first need to setup django
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.middleware.trustedhost import TrustedHostMiddleware
|
||||
|
||||
from .exceptions import CustomHttpException
|
||||
from .authentication import authentication_router
|
||||
from .collection import collection_router, item_router
|
||||
from .member import member_router
|
||||
from .invitation import invitation_incoming_router, invitation_outgoing_router
|
||||
from .msgpack import MsgpackResponse
|
||||
|
||||
|
||||
def create_application(prefix="", middlewares=[]):
|
||||
app = FastAPI(
|
||||
title="Etebase",
|
||||
description="The Etebase server API documentation",
|
||||
externalDocs={
|
||||
"url": "https://docs.etebase.com",
|
||||
"description": "Docs about the API specifications and clients.",
|
||||
}
|
||||
# FIXME: version="2.5.0",
|
||||
)
|
||||
VERSION = "v1"
|
||||
BASE_PATH = f"{prefix}/api/{VERSION}"
|
||||
COLLECTION_UID_MARKER = "{collection_uid}"
|
||||
app.include_router(authentication_router, prefix=f"{BASE_PATH}/authentication", tags=["authentication"])
|
||||
app.include_router(collection_router, prefix=f"{BASE_PATH}/collection", tags=["collection"])
|
||||
app.include_router(item_router, prefix=f"{BASE_PATH}/collection/{COLLECTION_UID_MARKER}", tags=["item"])
|
||||
app.include_router(member_router, prefix=f"{BASE_PATH}/collection/{COLLECTION_UID_MARKER}", tags=["member"])
|
||||
app.include_router(
|
||||
invitation_incoming_router, prefix=f"{BASE_PATH}/invitation/incoming", tags=["incoming invitation"]
|
||||
)
|
||||
app.include_router(
|
||||
invitation_outgoing_router, prefix=f"{BASE_PATH}/invitation/outgoing", tags=["outgoing invitation"]
|
||||
)
|
||||
|
||||
if settings.DEBUG:
|
||||
from etebase_fastapi.test_reset_view import test_reset_view_router
|
||||
|
||||
app.include_router(test_reset_view_router, prefix=f"{BASE_PATH}/test/authentication")
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origin_regex="https?://.*",
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
app.add_middleware(TrustedHostMiddleware, allowed_hosts=settings.ALLOWED_HOSTS)
|
||||
|
||||
for middleware in middlewares:
|
||||
app.add_middleware(middleware)
|
||||
|
||||
@app.exception_handler(CustomHttpException)
|
||||
async def custom_exception_handler(request: Request, exc: CustomHttpException):
|
||||
return MsgpackResponse(status_code=exc.status_code, content=exc.as_dict)
|
||||
|
||||
return app
|
105
etebase_fastapi/member.py
Normal file
105
etebase_fastapi/member.py
Normal file
@ -0,0 +1,105 @@
|
||||
import typing as t
|
||||
|
||||
from django.db import transaction
|
||||
from django.db.models import QuerySet
|
||||
from fastapi import APIRouter, Depends, status
|
||||
|
||||
from django_etebase import models
|
||||
from myauth.models import UserType, get_typed_user_model
|
||||
from .authentication import get_authenticated_user
|
||||
from .msgpack import MsgpackRoute
|
||||
from .utils import get_object_or_404, BaseModel, permission_responses, PERMISSIONS_READ, PERMISSIONS_READWRITE
|
||||
from .stoken_handler import filter_by_stoken_and_limit
|
||||
|
||||
from .collection import get_collection, verify_collection_admin
|
||||
|
||||
User = get_typed_user_model()
|
||||
member_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses)
|
||||
default_queryset: QuerySet = models.CollectionMember.objects.all()
|
||||
|
||||
|
||||
def get_queryset(collection: models.Collection = Depends(get_collection)) -> QuerySet:
|
||||
return default_queryset.filter(collection=collection)
|
||||
|
||||
|
||||
def get_member(username: str, queryset: QuerySet = Depends(get_queryset)) -> QuerySet:
|
||||
return get_object_or_404(queryset, user__username__iexact=username)
|
||||
|
||||
|
||||
class CollectionMemberModifyAccessLevelIn(BaseModel):
|
||||
accessLevel: models.AccessLevels
|
||||
|
||||
|
||||
class CollectionMemberOut(BaseModel):
|
||||
username: str
|
||||
accessLevel: models.AccessLevels
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
|
||||
@classmethod
|
||||
def from_orm(cls: t.Type["CollectionMemberOut"], obj: models.CollectionMember) -> "CollectionMemberOut":
|
||||
return cls(username=obj.user.username, accessLevel=obj.accessLevel)
|
||||
|
||||
|
||||
class MemberListResponse(BaseModel):
|
||||
data: t.List[CollectionMemberOut]
|
||||
iterator: t.Optional[str]
|
||||
done: bool
|
||||
|
||||
|
||||
@member_router.get(
|
||||
"/member/", response_model=MemberListResponse, dependencies=[Depends(verify_collection_admin), *PERMISSIONS_READ]
|
||||
)
|
||||
def member_list(
|
||||
iterator: t.Optional[str] = None,
|
||||
limit: int = 50,
|
||||
queryset: QuerySet = Depends(get_queryset),
|
||||
):
|
||||
queryset = queryset.order_by("id")
|
||||
result, new_stoken_obj, done = filter_by_stoken_and_limit(
|
||||
iterator, limit, queryset, models.CollectionMember.stoken_annotation
|
||||
)
|
||||
new_stoken = new_stoken_obj and new_stoken_obj.uid
|
||||
|
||||
return MemberListResponse(
|
||||
data=[CollectionMemberOut.from_orm(item) for item in result],
|
||||
iterator=new_stoken,
|
||||
done=done,
|
||||
)
|
||||
|
||||
|
||||
@member_router.delete(
|
||||
"/member/{username}/",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
dependencies=[Depends(verify_collection_admin), *PERMISSIONS_READWRITE],
|
||||
)
|
||||
def member_delete(
|
||||
obj: models.CollectionMember = Depends(get_member),
|
||||
):
|
||||
obj.revoke()
|
||||
|
||||
|
||||
@member_router.patch(
|
||||
"/member/{username}/",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
dependencies=[Depends(verify_collection_admin), *PERMISSIONS_READWRITE],
|
||||
)
|
||||
def member_patch(
|
||||
data: CollectionMemberModifyAccessLevelIn,
|
||||
instance: models.CollectionMember = Depends(get_member),
|
||||
):
|
||||
with transaction.atomic():
|
||||
# We only allow updating accessLevel
|
||||
if instance.accessLevel != data.accessLevel:
|
||||
instance.stoken = models.Stoken.objects.create()
|
||||
instance.accessLevel = data.accessLevel
|
||||
instance.save()
|
||||
|
||||
|
||||
@member_router.post("/member/leave/", status_code=status.HTTP_204_NO_CONTENT, dependencies=PERMISSIONS_READ)
|
||||
def member_leave(
|
||||
user: UserType = Depends(get_authenticated_user), collection: models.Collection = Depends(get_collection)
|
||||
):
|
||||
obj = get_object_or_404(collection.members, user=user)
|
||||
obj.revoke()
|
71
etebase_fastapi/msgpack.py
Normal file
71
etebase_fastapi/msgpack.py
Normal file
@ -0,0 +1,71 @@
|
||||
import typing as t
|
||||
import msgpack
|
||||
from fastapi.routing import APIRoute, get_request_handler
|
||||
from pydantic import BaseModel
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
|
||||
class MsgpackRequest(Request):
|
||||
media_type = "application/msgpack"
|
||||
|
||||
async def json(self) -> bytes:
|
||||
if not hasattr(self, "_json"):
|
||||
body = await super().body()
|
||||
self._json = msgpack.unpackb(body, raw=False)
|
||||
return self._json
|
||||
|
||||
|
||||
class MsgpackResponse(Response):
|
||||
media_type = "application/msgpack"
|
||||
|
||||
def render(self, content: t.Optional[t.Any]) -> bytes:
|
||||
if content is None:
|
||||
return b""
|
||||
|
||||
if isinstance(content, BaseModel):
|
||||
content = content.dict()
|
||||
ret = msgpack.packb(content, use_bin_type=True)
|
||||
assert ret is not None
|
||||
return ret
|
||||
|
||||
|
||||
class MsgpackRoute(APIRoute):
|
||||
# keep track of content-type -> request classes
|
||||
REQUESTS_CLASSES = {MsgpackRequest.media_type: MsgpackRequest}
|
||||
# keep track of content-type -> response classes
|
||||
ROUTES_HANDLERS_CLASSES = {MsgpackResponse.media_type: MsgpackResponse}
|
||||
|
||||
def _get_media_type_route_handler(self, media_type):
|
||||
return get_request_handler(
|
||||
dependant=self.dependant,
|
||||
body_field=self.body_field,
|
||||
status_code=self.status_code,
|
||||
# use custom response class or fallback on default self.response_class
|
||||
response_class=self.ROUTES_HANDLERS_CLASSES.get(media_type, self.response_class),
|
||||
response_field=self.secure_cloned_response_field,
|
||||
response_model_include=self.response_model_include,
|
||||
response_model_exclude=self.response_model_exclude,
|
||||
response_model_by_alias=self.response_model_by_alias,
|
||||
response_model_exclude_unset=self.response_model_exclude_unset,
|
||||
response_model_exclude_defaults=self.response_model_exclude_defaults,
|
||||
response_model_exclude_none=self.response_model_exclude_none,
|
||||
dependency_overrides_provider=self.dependency_overrides_provider,
|
||||
)
|
||||
|
||||
def get_route_handler(self) -> t.Callable:
|
||||
async def custom_route_handler(request: Request) -> Response:
|
||||
|
||||
content_type = request.headers.get("Content-Type")
|
||||
try:
|
||||
request_cls = self.REQUESTS_CLASSES[content_type]
|
||||
request = request_cls(request.scope, request.receive)
|
||||
except KeyError:
|
||||
# nothing registered to handle content_type, process given requests as-is
|
||||
pass
|
||||
|
||||
accept = request.headers.get("Accept")
|
||||
route_handler = self._get_media_type_route_handler(accept)
|
||||
return await route_handler(request)
|
||||
|
||||
return custom_route_handler
|
9
etebase_fastapi/sendfile/backends/mod_wsgi.py
Normal file
9
etebase_fastapi/sendfile/backends/mod_wsgi.py
Normal file
@ -0,0 +1,9 @@
|
||||
from __future__ import absolute_import
|
||||
|
||||
from fastapi import Response
|
||||
|
||||
from ..utils import _convert_file_to_url
|
||||
|
||||
|
||||
def sendfile(filename, **kwargs):
|
||||
return Response(headers={"Location": _convert_file_to_url(filename)})
|
9
etebase_fastapi/sendfile/backends/nginx.py
Normal file
9
etebase_fastapi/sendfile/backends/nginx.py
Normal file
@ -0,0 +1,9 @@
|
||||
from __future__ import absolute_import
|
||||
|
||||
from fastapi import Response
|
||||
|
||||
from ..utils import _convert_file_to_url
|
||||
|
||||
|
||||
def sendfile(filename, **kwargs):
|
||||
return Response(headers={"X-Accel-Redirect": _convert_file_to_url(filename)})
|
12
etebase_fastapi/sendfile/backends/simple.py
Normal file
12
etebase_fastapi/sendfile/backends/simple.py
Normal file
@ -0,0 +1,12 @@
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
|
||||
def sendfile(filename, mimetype, **kwargs):
|
||||
"""Use the SENDFILE_ROOT value composed with the path arrived as argument
|
||||
to build an absolute path with which resolve and return the file contents.
|
||||
|
||||
If the path points to a file out of the root directory (should cover both
|
||||
situations with '..' and symlinks) then a 404 is raised.
|
||||
"""
|
||||
|
||||
return FileResponse(filename, media_type=mimetype)
|
6
etebase_fastapi/sendfile/backends/xsendfile.py
Normal file
6
etebase_fastapi/sendfile/backends/xsendfile.py
Normal file
@ -0,0 +1,6 @@
|
||||
from fastapi import Response
|
||||
|
||||
|
||||
def sendfile(filename, **kwargs):
|
||||
filename = str(filename)
|
||||
return Response(headers={"X-Sendfile": filename})
|
@ -4,9 +4,11 @@ from pathlib import Path, PurePath
|
||||
from urllib.parse import quote
|
||||
import logging
|
||||
|
||||
from fastapi import status
|
||||
from ..exceptions import HttpError
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.http import Http404
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -54,12 +56,12 @@ def _sanitize_path(filepath):
|
||||
try:
|
||||
filepath_abs.relative_to(path_root)
|
||||
except ValueError:
|
||||
raise Http404("{} wrt {} is impossible".format(filepath_abs, path_root))
|
||||
raise HttpError("generic", "{} wrt {} is impossible".format(filepath_abs, path_root), status_code=status.HTTP_404_NOT_FOUND)
|
||||
|
||||
return filepath_abs
|
||||
|
||||
|
||||
def sendfile(request, filename, mimetype="application/octet-stream", encoding=None):
|
||||
def sendfile(filename, mimetype="application/octet-stream", encoding=None):
|
||||
"""
|
||||
Create a response to send file using backend configured in ``SENDFILE_BACKEND``
|
||||
|
||||
@ -75,11 +77,10 @@ def sendfile(request, filename, mimetype="application/octet-stream", encoding=No
|
||||
_sendfile = _get_sendfile()
|
||||
|
||||
if not filepath_obj.exists():
|
||||
raise Http404('"%s" does not exist' % filepath_obj)
|
||||
raise HttpError("does_not_exist", '"%s" does not exist' % filepath_obj, status_code=status.HTTP_404_NOT_FOUND)
|
||||
|
||||
response = _sendfile(request, filepath_obj, mimetype=mimetype)
|
||||
response = _sendfile(filepath_obj, mimetype=mimetype)
|
||||
|
||||
response["Content-length"] = filepath_obj.stat().st_size
|
||||
response["Content-Type"] = mimetype
|
||||
response.headers["Content-Type"] = mimetype
|
||||
|
||||
return response
|
62
etebase_fastapi/stoken_handler.py
Normal file
62
etebase_fastapi/stoken_handler.py
Normal file
@ -0,0 +1,62 @@
|
||||
import typing as t
|
||||
|
||||
from django.db.models import QuerySet
|
||||
from fastapi import status
|
||||
|
||||
from django_etebase.models import Stoken
|
||||
|
||||
from .exceptions import HttpError
|
||||
|
||||
# TODO missing stoken_annotation type
|
||||
StokenAnnotation = t.Any
|
||||
|
||||
|
||||
def get_stoken_obj(stoken: t.Optional[str]) -> t.Optional[Stoken]:
|
||||
if stoken is not None:
|
||||
try:
|
||||
return Stoken.objects.get(uid=stoken)
|
||||
except Stoken.DoesNotExist:
|
||||
raise HttpError("bad_stoken", "Invalid stoken.", status_code=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def filter_by_stoken(
|
||||
stoken: t.Optional[str], queryset: QuerySet, stoken_annotation: StokenAnnotation
|
||||
) -> t.Tuple[QuerySet, t.Optional[Stoken]]:
|
||||
stoken_rev = get_stoken_obj(stoken)
|
||||
|
||||
queryset = queryset.annotate(max_stoken=stoken_annotation).order_by("max_stoken")
|
||||
|
||||
if stoken_rev is not None:
|
||||
queryset = queryset.filter(max_stoken__gt=stoken_rev.id)
|
||||
|
||||
return queryset, stoken_rev
|
||||
|
||||
|
||||
def get_queryset_stoken(queryset: t.Iterable[t.Any]) -> t.Optional[Stoken]:
|
||||
maxid = -1
|
||||
for row in queryset:
|
||||
rowmaxid = getattr(row, "max_stoken") or -1
|
||||
maxid = max(maxid, rowmaxid)
|
||||
new_stoken = Stoken.objects.get(id=maxid) if (maxid >= 0) else None
|
||||
|
||||
return new_stoken or None
|
||||
|
||||
|
||||
def filter_by_stoken_and_limit(
|
||||
stoken: t.Optional[str], limit: int, queryset: QuerySet, stoken_annotation: StokenAnnotation
|
||||
) -> t.Tuple[list, t.Optional[Stoken], bool]:
|
||||
|
||||
queryset, stoken_rev = filter_by_stoken(stoken=stoken, queryset=queryset, stoken_annotation=stoken_annotation)
|
||||
|
||||
result = list(queryset[: limit + 1])
|
||||
if len(result) < limit + 1:
|
||||
done = True
|
||||
else:
|
||||
done = False
|
||||
result = result[:-1]
|
||||
|
||||
new_stoken_obj = get_queryset_stoken(result) or stoken_rev
|
||||
|
||||
return result, new_stoken_obj, done
|
38
etebase_fastapi/test_reset_view.py
Normal file
38
etebase_fastapi/test_reset_view.py
Normal file
@ -0,0 +1,38 @@
|
||||
from django.conf import settings
|
||||
from django.db import transaction
|
||||
from django.shortcuts import get_object_or_404
|
||||
from fastapi import APIRouter, Request, status
|
||||
|
||||
from django_etebase.utils import get_user_queryset, CallbackContext
|
||||
from etebase_fastapi.authentication import SignupIn, signup_save
|
||||
from etebase_fastapi.msgpack import MsgpackRoute
|
||||
from etebase_fastapi.exceptions import HttpError
|
||||
from myauth.models import get_typed_user_model
|
||||
|
||||
test_reset_view_router = APIRouter(route_class=MsgpackRoute, tags=["test helpers"])
|
||||
User = get_typed_user_model()
|
||||
|
||||
|
||||
@test_reset_view_router.post("/reset/", status_code=status.HTTP_204_NO_CONTENT)
|
||||
def reset(data: SignupIn, request: Request):
|
||||
# Only run when in DEBUG mode! It's only used for tests
|
||||
if not settings.DEBUG:
|
||||
raise HttpError(code="generic", detail="Only allowed in debug mode.")
|
||||
|
||||
with transaction.atomic():
|
||||
user_queryset = get_user_queryset(User.objects.all(), CallbackContext(request.path_params))
|
||||
user = get_object_or_404(user_queryset, username=data.user.username)
|
||||
# Only allow test users for extra safety
|
||||
if not getattr(user, User.USERNAME_FIELD).startswith("test_user"):
|
||||
raise HttpError(code="generic", detail="Endpoint not allowed for user.")
|
||||
|
||||
if hasattr(user, "userinfo"):
|
||||
user.userinfo.delete()
|
||||
|
||||
signup_save(data, request)
|
||||
# Delete all of the journal data for this user for a clear test env
|
||||
user.collection_set.all().delete()
|
||||
user.collectionmember_set.all().delete()
|
||||
user.incoming_invitations.all().delete()
|
||||
|
||||
# FIXME: also delete chunk files!!!
|
74
etebase_fastapi/utils.py
Normal file
74
etebase_fastapi/utils.py
Normal file
@ -0,0 +1,74 @@
|
||||
import dataclasses
|
||||
import typing as t
|
||||
import msgpack
|
||||
import base64
|
||||
|
||||
from fastapi import status, Query, Depends
|
||||
from pydantic import BaseModel as PyBaseModel
|
||||
|
||||
from django.db.models import QuerySet
|
||||
from django.core.exceptions import ObjectDoesNotExist
|
||||
|
||||
from django_etebase import app_settings
|
||||
from django_etebase.models import AccessLevels
|
||||
from myauth.models import UserType, get_typed_user_model
|
||||
|
||||
from .exceptions import HttpError, HttpErrorOut
|
||||
|
||||
User = get_typed_user_model()
|
||||
|
||||
Prefetch = t.Literal["auto", "medium"]
|
||||
PrefetchQuery = Query(default="auto")
|
||||
|
||||
|
||||
class BaseModel(PyBaseModel):
|
||||
class Config:
|
||||
json_encoders = {
|
||||
bytes: lambda x: x,
|
||||
}
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Context:
|
||||
user: t.Optional[UserType]
|
||||
prefetch: t.Optional[Prefetch]
|
||||
|
||||
|
||||
def get_object_or_404(queryset: QuerySet, **kwargs):
|
||||
try:
|
||||
return queryset.get(**kwargs)
|
||||
except ObjectDoesNotExist as e:
|
||||
raise HttpError("does_not_exist", str(e), status_code=status.HTTP_404_NOT_FOUND)
|
||||
|
||||
|
||||
def is_collection_admin(collection, user):
|
||||
member = collection.members.filter(user=user).first()
|
||||
return (member is not None) and (member.accessLevel == AccessLevels.ADMIN)
|
||||
|
||||
|
||||
def msgpack_encode(content):
|
||||
return msgpack.packb(content, use_bin_type=True)
|
||||
|
||||
|
||||
def msgpack_decode(content):
|
||||
return msgpack.unpackb(content, raw=False)
|
||||
|
||||
|
||||
def b64encode(value):
|
||||
return base64.urlsafe_b64encode(value).decode("ascii").strip("=")
|
||||
|
||||
|
||||
def b64decode(data):
|
||||
data += "=" * ((4 - len(data) % 4) % 4)
|
||||
return base64.urlsafe_b64decode(data)
|
||||
|
||||
|
||||
PERMISSIONS_READ = [Depends(x) for x in app_settings.API_PERMISSIONS_READ]
|
||||
PERMISSIONS_READWRITE = PERMISSIONS_READ + [Depends(x) for x in app_settings.API_PERMISSIONS_WRITE]
|
||||
|
||||
|
||||
response_model_dict = {"model": HttpErrorOut}
|
||||
permission_responses: t.Dict[t.Union[int, str], t.Dict[str, t.Any]] = {
|
||||
401: response_model_dict,
|
||||
403: response_model_dict,
|
||||
}
|
@ -1,16 +1,19 @@
|
||||
"""
|
||||
ASGI config for etebase_server project.
|
||||
|
||||
It exposes the ASGI callable as a module-level variable named ``application``.
|
||||
|
||||
For more information on this file, see
|
||||
https://docs.djangoproject.com/en/3.0/howto/deployment/asgi/
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from django.core.asgi import get_asgi_application
|
||||
|
||||
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "etebase_server.settings")
|
||||
django_application = get_asgi_application()
|
||||
|
||||
application = get_asgi_application()
|
||||
|
||||
def create_application():
|
||||
from etebase_fastapi.main import create_application
|
||||
|
||||
app = create_application()
|
||||
|
||||
app.mount("/", django_application)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
application = create_application()
|
||||
|
@ -53,8 +53,6 @@ INSTALLED_APPS = [
|
||||
"django.contrib.sessions",
|
||||
"django.contrib.messages",
|
||||
"django.contrib.staticfiles",
|
||||
"corsheaders",
|
||||
"rest_framework",
|
||||
"myauth.apps.MyauthConfig",
|
||||
"django_etebase.apps.DjangoEtebaseConfig",
|
||||
"django_etebase.token_auth.apps.TokenAuthConfig",
|
||||
@ -63,7 +61,6 @@ INSTALLED_APPS = [
|
||||
MIDDLEWARE = [
|
||||
"django.middleware.security.SecurityMiddleware",
|
||||
"django.contrib.sessions.middleware.SessionMiddleware",
|
||||
"corsheaders.middleware.CorsMiddleware",
|
||||
"django.middleware.common.CommonMiddleware",
|
||||
"django.middleware.csrf.CsrfViewMiddleware",
|
||||
"django.contrib.auth.middleware.AuthenticationMiddleware",
|
||||
@ -124,9 +121,6 @@ USE_L10N = True
|
||||
|
||||
USE_TZ = True
|
||||
|
||||
# Cors
|
||||
CORS_ORIGIN_ALLOW_ALL = True
|
||||
|
||||
# Static files (CSS, JavaScript, Images)
|
||||
# https://docs.djangoproject.com/en/3.0/howto/static-files/
|
||||
|
||||
@ -166,11 +160,6 @@ if any(os.path.isfile(x) for x in config_locations):
|
||||
if "database" in config:
|
||||
DATABASES = {"default": {x.upper(): y for x, y in config.items("database")}}
|
||||
|
||||
ETEBASE_API_PERMISSIONS = ("rest_framework.permissions.IsAuthenticated",)
|
||||
ETEBASE_API_AUTHENTICATORS = (
|
||||
"django_etebase.token_auth.authentication.TokenAuthentication",
|
||||
"rest_framework.authentication.SessionAuthentication",
|
||||
)
|
||||
ETEBASE_CREATE_USER_FUNC = "django_etebase.utils.create_user_blocked"
|
||||
|
||||
# Efficient file streaming (for large files)
|
||||
|
@ -1,16 +1,25 @@
|
||||
import os
|
||||
|
||||
from django.conf import settings
|
||||
from django.conf.urls import include, url
|
||||
from django.conf.urls import url
|
||||
from django.contrib import admin
|
||||
from django.urls import path
|
||||
from django.urls import path, re_path
|
||||
from django.views.generic import TemplateView
|
||||
from django.views.static import serve
|
||||
from django.contrib.staticfiles import finders
|
||||
|
||||
urlpatterns = [
|
||||
url(r"^api/", include("django_etebase.urls")),
|
||||
url(r"^admin/", admin.site.urls),
|
||||
path("", TemplateView.as_view(template_name="success.html")),
|
||||
]
|
||||
|
||||
if settings.DEBUG:
|
||||
urlpatterns += [
|
||||
url(r"^api-auth/", include("rest_framework.urls", namespace="rest_framework")),
|
||||
]
|
||||
|
||||
def serve_static(request, path):
|
||||
filename = finders.find(path)
|
||||
dirname = os.path.dirname(filename)
|
||||
basename = os.path.basename(filename)
|
||||
|
||||
return serve(request, basename, dirname)
|
||||
|
||||
urlpatterns += [re_path(r"^static/(?P<path>.*)$", serve_static)]
|
||||
|
@ -1,16 +0,0 @@
|
||||
"""
|
||||
WSGI config for etebase_server project.
|
||||
|
||||
It exposes the WSGI callable as a module-level variable named ``application``.
|
||||
|
||||
For more information on this file, see
|
||||
https://docs.djangoproject.com/en/3.0/howto/deployment/wsgi/
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from django.core.wsgi import get_wsgi_application
|
||||
|
||||
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "etebase_server.settings")
|
||||
|
||||
application = get_wsgi_application()
|
@ -1,8 +1,8 @@
|
||||
from django import forms
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.contrib.auth.forms import UsernameField
|
||||
from myauth.models import get_typed_user_model
|
||||
|
||||
User = get_user_model()
|
||||
User = get_typed_user_model()
|
||||
|
||||
|
||||
class AdminUserCreationForm(forms.ModelForm):
|
||||
|
@ -1,3 +1,5 @@
|
||||
import typing as t
|
||||
|
||||
from django.contrib.auth.models import AbstractUser, UserManager as DjangoUserManager
|
||||
from django.core import validators
|
||||
from django.db import models
|
||||
@ -28,9 +30,21 @@ class User(AbstractUser):
|
||||
unique=True,
|
||||
help_text=_("Required. 150 characters or fewer. Letters, digits and ./-/_ only."),
|
||||
validators=[username_validator],
|
||||
error_messages={"unique": _("A user with that username already exists."),},
|
||||
error_messages={
|
||||
"unique": _("A user with that username already exists."),
|
||||
},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def normalize_username(cls, username):
|
||||
return super().normalize_username(username).lower()
|
||||
|
||||
|
||||
UserType = t.Type[User]
|
||||
|
||||
|
||||
def get_typed_user_model() -> UserType:
|
||||
from django.contrib.auth import get_user_model
|
||||
|
||||
ret: t.Any = get_user_model()
|
||||
return ret
|
||||
|
28
requirements-dev.txt
Normal file
28
requirements-dev.txt
Normal file
@ -0,0 +1,28 @@
|
||||
#
|
||||
# This file is autogenerated by pip-compile
|
||||
# To update, run:
|
||||
#
|
||||
# pip-compile --output-file=requirements-dev.txt requirements.in/development.txt
|
||||
#
|
||||
appdirs==1.4.4 # via black
|
||||
asgiref==3.3.1 # via django
|
||||
black==20.8b1 # via -r requirements.in/development.txt
|
||||
click==7.1.2 # via black, pip-tools
|
||||
coverage==5.3.1 # via -r requirements.in/development.txt
|
||||
django-stubs==1.7.0 # via -r requirements.in/development.txt
|
||||
django==3.1.4 # via django-stubs
|
||||
mypy-extensions==0.4.3 # via black, mypy
|
||||
mypy==0.790 # via django-stubs
|
||||
pathspec==0.8.1 # via black
|
||||
pip-tools==5.4.0 # via -r requirements.in/development.txt
|
||||
pytz==2020.5 # via django
|
||||
pywatchman==1.4.1 # via -r requirements.in/development.txt
|
||||
regex==2020.11.13 # via black
|
||||
six==1.15.0 # via pip-tools
|
||||
sqlparse==0.4.1 # via django
|
||||
toml==0.10.2 # via black
|
||||
typed-ast==1.4.1 # via black, mypy
|
||||
typing-extensions==3.7.4.3 # via black, django-stubs, mypy
|
||||
|
||||
# The following packages are considered to be unsafe in a requirements file:
|
||||
# pip
|
@ -1,7 +1,5 @@
|
||||
django
|
||||
django-cors-headers
|
||||
djangorestframework
|
||||
drf-nested-routers
|
||||
msgpack
|
||||
psycopg2-binary
|
||||
pynacl
|
||||
fastapi
|
||||
uvicorn
|
||||
|
@ -1,4 +1,5 @@
|
||||
coverage
|
||||
pip-tools
|
||||
pywatchman
|
||||
black
|
||||
black
|
||||
django-stubs
|
||||
|
@ -4,16 +4,18 @@
|
||||
#
|
||||
# pip-compile --output-file=requirements.txt requirements.in/base.txt
|
||||
#
|
||||
asgiref==3.2.10 # via django
|
||||
cffi==1.14.0 # via pynacl
|
||||
django-cors-headers==3.2.1 # via -r requirements.in/base.txt
|
||||
django==3.1.1 # via -r requirements.in/base.txt, django-cors-headers, djangorestframework, drf-nested-routers
|
||||
djangorestframework==3.11.0 # via -r requirements.in/base.txt, drf-nested-routers
|
||||
drf-nested-routers==0.91 # via -r requirements.in/base.txt
|
||||
msgpack==1.0.0 # via -r requirements.in/base.txt
|
||||
psycopg2-binary==2.8.4 # via -r requirements.in/base.txt
|
||||
asgiref==3.3.1 # via django
|
||||
cffi==1.14.4 # via pynacl
|
||||
click==7.1.2 # via uvicorn
|
||||
django==3.1.4 # via -r requirements.in/base.txt
|
||||
fastapi==0.63.0 # via -r requirements.in/base.txt
|
||||
h11==0.11.0 # via uvicorn
|
||||
msgpack==1.0.2 # via -r requirements.in/base.txt
|
||||
pycparser==2.20 # via cffi
|
||||
pynacl==1.3.0 # via -r requirements.in/base.txt
|
||||
pytz==2019.3 # via django
|
||||
six==1.14.0 # via pynacl
|
||||
sqlparse==0.3.0 # via django
|
||||
pydantic==1.7.3 # via fastapi
|
||||
pynacl==1.4.0 # via -r requirements.in/base.txt
|
||||
pytz==2020.4 # via django
|
||||
six==1.15.0 # via pynacl
|
||||
sqlparse==0.4.1 # via django
|
||||
starlette==0.13.6 # via fastapi
|
||||
uvicorn==0.13.2 # via -r requirements.in/base.txt
|
||||
|
Loading…
Reference in New Issue
Block a user