1
0
mirror of https://github.com/etesync/server synced 2024-12-28 18:28:07 +00:00

Pass generic context to callbacks instead of the whole view

This commit is contained in:
Tom Hacohen 2020-12-27 15:03:07 +02:00
parent 5a6c8a1d05
commit c2eb4fd30c
3 changed files with 26 additions and 12 deletions

View File

@ -20,7 +20,7 @@ from django.contrib.auth import get_user_model
from django.db import IntegrityError, transaction
from rest_framework import serializers, status
from . import models
from .utils import get_user_queryset, create_user
from .utils import get_user_queryset, create_user, CallbackContext
from .exceptions import EtebaseValidationError
@ -102,7 +102,7 @@ class CollectionTypeField(BinaryBase64Field):
class UserSlugRelatedField(serializers.SlugRelatedField):
def get_queryset(self):
view = self.context.get("view", None)
return get_user_queryset(super().get_queryset(), view)
return get_user_queryset(super().get_queryset(), context=CallbackContext(view.kwargs))
def __init__(self, **kwargs):
super().__init__(slug_field=User.USERNAME_FIELD, **kwargs)
@ -515,12 +515,17 @@ class AuthenticationSignupSerializer(BetterErrorsMixin, serializers.Serializer):
with transaction.atomic():
try:
view = self.context.get("view", None)
user_queryset = get_user_queryset(User.objects.all(), view)
user_queryset = get_user_queryset(User.objects.all(), context=CallbackContext(view.kwargs))
instance = user_queryset.get(**{User.USERNAME_FIELD: user_data["username"].lower()})
except User.DoesNotExist:
# Create the user and save the casing the user chose as the first name
try:
instance = create_user(**user_data, password=None, first_name=user_data["username"], view=view)
instance = create_user(
**user_data,
password=None,
first_name=user_data["username"],
context=CallbackContext(view.kwargs)
)
instance.full_clean()
except EtebaseValidationError as e:
raise e

View File

@ -1,3 +1,6 @@
import typing as t
from dataclasses import dataclass
from django.contrib.auth import get_user_model
from django.core.exceptions import PermissionDenied
@ -7,18 +10,24 @@ from . import app_settings
User = get_user_model()
def get_user_queryset(queryset, view):
@dataclass
class CallbackContext:
"""Class for passing extra context to callbacks"""
url_kwargs: t.Dict[str, t.Any]
def get_user_queryset(queryset, context: CallbackContext):
custom_func = app_settings.GET_USER_QUERYSET_FUNC
if custom_func is not None:
return custom_func(queryset, view)
return custom_func(queryset, context)
return queryset
def create_user(*args, **kwargs):
def create_user(context: CallbackContext, *args, **kwargs):
custom_func = app_settings.CREATE_USER_FUNC
if custom_func is not None:
return custom_func(*args, **kwargs)
_ = kwargs.pop("view")
return User.objects.create_user(*args, **kwargs)

View File

@ -73,7 +73,7 @@ from .serializers import (
UserInfoPubkeySerializer,
UserSerializer,
)
from .utils import get_user_queryset
from .utils import get_user_queryset, CallbackContext
from .exceptions import EtebaseValidationError
from .parsers import ChunkUploadParser
from .signals import user_signed_up
@ -598,7 +598,7 @@ class InvitationOutgoingViewSet(InvitationBaseViewSet):
def fetch_user_profile(self, request, *args, **kwargs):
username = request.GET.get("username")
kwargs = {User.USERNAME_FIELD: username.lower()}
user = get_object_or_404(get_user_queryset(User.objects.all(), self), **kwargs)
user = get_object_or_404(get_user_queryset(User.objects.all(), CallbackContext(self.kwargs)), **kwargs)
user_info = get_object_or_404(UserInfo.objects.all(), owner=user)
serializer = UserInfoPubkeySerializer(user_info)
return Response(serializer.data)
@ -642,7 +642,7 @@ class AuthenticationViewSet(viewsets.ViewSet):
)
def get_queryset(self):
return get_user_queryset(User.objects.all(), self)
return get_user_queryset(User.objects.all(), CallbackContext(self.kwargs))
def get_serializer_context(self):
return {"request": self.request, "format": self.format_kwarg, "view": self}
@ -837,7 +837,7 @@ class TestAuthenticationViewSet(viewsets.ViewSet):
return HttpResponseBadRequest("Only allowed in debug mode.")
with transaction.atomic():
user_queryset = get_user_queryset(User.objects.all(), self)
user_queryset = get_user_queryset(User.objects.all(), CallbackContext(self.kwargs))
user = get_object_or_404(user_queryset, username=request.data.get("user").get("username"))
# Only allow test users for extra safety