diff --git a/django_etebase/serializers.py b/django_etebase/serializers.py index 06bc8ad..dce2fe6 100644 --- a/django_etebase/serializers.py +++ b/django_etebase/serializers.py @@ -20,7 +20,7 @@ from django.contrib.auth import get_user_model from django.db import IntegrityError, transaction from rest_framework import serializers, status from . import models -from .utils import get_user_queryset, create_user +from .utils import get_user_queryset, create_user, CallbackContext from .exceptions import EtebaseValidationError @@ -102,7 +102,7 @@ class CollectionTypeField(BinaryBase64Field): class UserSlugRelatedField(serializers.SlugRelatedField): def get_queryset(self): view = self.context.get("view", None) - return get_user_queryset(super().get_queryset(), view) + return get_user_queryset(super().get_queryset(), context=CallbackContext(view.kwargs)) def __init__(self, **kwargs): super().__init__(slug_field=User.USERNAME_FIELD, **kwargs) @@ -515,12 +515,17 @@ class AuthenticationSignupSerializer(BetterErrorsMixin, serializers.Serializer): with transaction.atomic(): try: view = self.context.get("view", None) - user_queryset = get_user_queryset(User.objects.all(), view) + user_queryset = get_user_queryset(User.objects.all(), context=CallbackContext(view.kwargs)) instance = user_queryset.get(**{User.USERNAME_FIELD: user_data["username"].lower()}) except User.DoesNotExist: # Create the user and save the casing the user chose as the first name try: - instance = create_user(**user_data, password=None, first_name=user_data["username"], view=view) + instance = create_user( + **user_data, + password=None, + first_name=user_data["username"], + context=CallbackContext(view.kwargs) + ) instance.full_clean() except EtebaseValidationError as e: raise e diff --git a/django_etebase/utils.py b/django_etebase/utils.py index e496a77..1c8654b 100644 --- a/django_etebase/utils.py +++ b/django_etebase/utils.py @@ -1,3 +1,6 @@ +import typing as t +from dataclasses import dataclass + from django.contrib.auth import get_user_model from django.core.exceptions import PermissionDenied @@ -7,18 +10,24 @@ from . import app_settings User = get_user_model() -def get_user_queryset(queryset, view): +@dataclass +class CallbackContext: + """Class for passing extra context to callbacks""" + + url_kwargs: t.Dict[str, t.Any] + + +def get_user_queryset(queryset, context: CallbackContext): custom_func = app_settings.GET_USER_QUERYSET_FUNC if custom_func is not None: - return custom_func(queryset, view) + return custom_func(queryset, context) return queryset -def create_user(*args, **kwargs): +def create_user(context: CallbackContext, *args, **kwargs): custom_func = app_settings.CREATE_USER_FUNC if custom_func is not None: return custom_func(*args, **kwargs) - _ = kwargs.pop("view") return User.objects.create_user(*args, **kwargs) diff --git a/django_etebase/views.py b/django_etebase/views.py index 1de5ed7..5a03aa4 100644 --- a/django_etebase/views.py +++ b/django_etebase/views.py @@ -73,7 +73,7 @@ from .serializers import ( UserInfoPubkeySerializer, UserSerializer, ) -from .utils import get_user_queryset +from .utils import get_user_queryset, CallbackContext from .exceptions import EtebaseValidationError from .parsers import ChunkUploadParser from .signals import user_signed_up @@ -598,7 +598,7 @@ class InvitationOutgoingViewSet(InvitationBaseViewSet): def fetch_user_profile(self, request, *args, **kwargs): username = request.GET.get("username") kwargs = {User.USERNAME_FIELD: username.lower()} - user = get_object_or_404(get_user_queryset(User.objects.all(), self), **kwargs) + user = get_object_or_404(get_user_queryset(User.objects.all(), CallbackContext(self.kwargs)), **kwargs) user_info = get_object_or_404(UserInfo.objects.all(), owner=user) serializer = UserInfoPubkeySerializer(user_info) return Response(serializer.data) @@ -642,7 +642,7 @@ class AuthenticationViewSet(viewsets.ViewSet): ) def get_queryset(self): - return get_user_queryset(User.objects.all(), self) + return get_user_queryset(User.objects.all(), CallbackContext(self.kwargs)) def get_serializer_context(self): return {"request": self.request, "format": self.format_kwarg, "view": self} @@ -837,7 +837,7 @@ class TestAuthenticationViewSet(viewsets.ViewSet): return HttpResponseBadRequest("Only allowed in debug mode.") with transaction.atomic(): - user_queryset = get_user_queryset(User.objects.all(), self) + user_queryset = get_user_queryset(User.objects.all(), CallbackContext(self.kwargs)) user = get_object_or_404(user_queryset, username=request.data.get("user").get("username")) # Only allow test users for extra safety