From f9add36f18e04288a4b5ea7768666d70ebc1abc0 Mon Sep 17 00:00:00 2001 From: Tom Hacohen Date: Mon, 13 Jul 2020 14:30:18 +0300 Subject: [PATCH] Add support for custom user filtering. --- django_etebase/app_settings.py | 7 +++++++ django_etebase/serializers.py | 16 ++++++++++++---- django_etebase/utils.py | 12 ++++++++++++ django_etebase/views.py | 11 +++++++---- 4 files changed, 38 insertions(+), 8 deletions(-) create mode 100644 django_etebase/utils.py diff --git a/django_etebase/app_settings.py b/django_etebase/app_settings.py index b1fb4c3..7fe30b7 100644 --- a/django_etebase/app_settings.py +++ b/django_etebase/app_settings.py @@ -46,6 +46,13 @@ class AppSettings: ret.append(self.import_from_str(perm)) return ret + @property + def GET_USER_QUERYSET(self): # pylint: disable=invalid-name + get_user_queryset = self._setting("GET_USER_QUERYSET", None) + if get_user_queryset is not None: + return self.import_from_str(get_user_queryset) + return None + @property def CHALLENGE_VALID_SECONDS(self): # pylint: disable=invalid-name return self._setting("CHALLENGE_VALID_SECONDS", 60) diff --git a/django_etebase/serializers.py b/django_etebase/serializers.py index 13199b3..0655775 100644 --- a/django_etebase/serializers.py +++ b/django_etebase/serializers.py @@ -20,6 +20,7 @@ from django.contrib.auth import get_user_model from django.db import transaction from rest_framework import serializers from . import models +from .utils import get_user_queryset User = get_user_model() @@ -91,6 +92,15 @@ class CollectionContentField(BinaryBase64Field): return None +class UserSlugRelatedField(serializers.SlugRelatedField): + def get_queryset(self): + view = self.context.get('view', None) + return get_user_queryset(super().get_queryset(), view) + + def __init__(self, **kwargs): + super().__init__(slug_field=User.USERNAME_FIELD, **kwargs) + + class ChunksField(serializers.RelatedField): def to_representation(self, obj): obj = obj.chunk @@ -252,9 +262,8 @@ class CollectionSerializer(serializers.ModelSerializer): class CollectionMemberSerializer(serializers.ModelSerializer): - username = serializers.SlugRelatedField( + username = UserSlugRelatedField( source='user', - slug_field=User.USERNAME_FIELD, read_only=True, ) @@ -278,9 +287,8 @@ class CollectionMemberSerializer(serializers.ModelSerializer): class CollectionInvitationSerializer(serializers.ModelSerializer): - username = serializers.SlugRelatedField( + username = UserSlugRelatedField( source='user', - slug_field=User.USERNAME_FIELD, queryset=User.objects ) collection = serializers.CharField(source='collection.uid') diff --git a/django_etebase/utils.py b/django_etebase/utils.py new file mode 100644 index 0000000..315b82f --- /dev/null +++ b/django_etebase/utils.py @@ -0,0 +1,12 @@ +from django.contrib.auth import get_user_model +from . import app_settings + + +User = get_user_model() + + +def get_user_queryset(queryset, view): + custom_func = app_settings.GET_USER_QUERYSET + if custom_func is not None: + return custom_func(queryset, view) + return queryset diff --git a/django_etebase/views.py b/django_etebase/views.py index 7e8bf98..480843e 100644 --- a/django_etebase/views.py +++ b/django_etebase/views.py @@ -71,6 +71,7 @@ from .serializers import ( UserInfoPubkeySerializer, UserSerializer, ) +from .utils import get_user_queryset User = get_user_model() @@ -558,8 +559,9 @@ class InvitationOutgoingViewSet(InvitationBaseViewSet): @action_decorator(detail=False, allowed_methods=['GET'], methods=['GET']) def fetch_user_profile(self, request, *args, **kwargs): username = request.GET.get('username') - kwargs = {'owner__' + User.USERNAME_FIELD: username} - user_info = get_object_or_404(UserInfo.objects.all(), **kwargs) + kwargs = {User.USERNAME_FIELD: username} + user = get_object_or_404(get_user_queryset(User.objects.all(), self), **kwargs) + user_info = get_object_or_404(UserInfo.objects.all(), owner=user) serializer = UserInfoPubkeySerializer(user_info) return Response(serializer.data) @@ -597,7 +599,7 @@ class AuthenticationViewSet(viewsets.ViewSet): encoder=nacl.encoding.RawEncoder) def get_queryset(self): - return User.objects.all() + return get_user_queryset(User.objects.all(), self) def login_response_data(self, user): return { @@ -756,7 +758,8 @@ class TestAuthenticationViewSet(viewsets.ViewSet): return HttpResponseBadRequest("Only allowed in debug mode.") with transaction.atomic(): - user = get_object_or_404(User.objects.all(), username=request.data.get('user').get('username')) + user_queryset = get_user_queryset(User.objects.all(), self) + user = get_object_or_404(user_queryset, username=request.data.get('user').get('username')) # Only allow test users for extra safety if not getattr(user, User.USERNAME_FIELD).startswith('test_user'):