From df19887af7df5a78b394ec8d8182bc9cf82bfc4c Mon Sep 17 00:00:00 2001 From: Tom Hacohen Date: Sun, 27 Dec 2020 22:27:33 +0200 Subject: [PATCH] Use dependency injection for getting collection/item queryset. --- etebase_fastapi/collection.py | 44 +++++++++++++++++------------------ etebase_fastapi/invitation.py | 28 ++++++++++------------ etebase_fastapi/member.py | 34 +++++++++------------------ 3 files changed, 45 insertions(+), 61 deletions(-) diff --git a/etebase_fastapi/collection.py b/etebase_fastapi/collection.py index 1fc6f0a..196bb1d 100644 --- a/etebase_fastapi/collection.py +++ b/etebase_fastapi/collection.py @@ -194,18 +194,19 @@ def collection_list_common( return MsgpackResponse(content=ret) -def get_collection_queryset(user: User) -> QuerySet: +def get_collection_queryset(user: User = Depends(get_authenticated_user)) -> QuerySet: return default_queryset.filter(members__user=user) -def get_item_queryset( - user: User, collection_uid: str, queryset: QuerySet = default_item_queryset -) -> t.Tuple[models.Collection, QuerySet]: - collection = get_object_or_404(get_collection_queryset(user), uid=collection_uid) +def get_collection(collection_uid: str, queryset: QuerySet = Depends(get_collection_queryset)) -> models.Collection: + return get_object_or_404(queryset, uid=collection_uid) + + +def get_item_queryset(collection: models.Collection = Depends(get_collection)) -> QuerySet: # XXX Potentially add this for performance: .prefetch_related('revisions__chunks') - queryset = queryset.filter(collection__pk=collection.pk, revisions__current=True) + queryset = default_item_queryset.filter(collection__pk=collection.pk, revisions__current=True) - return collection, queryset + return queryset @collection_router.post("/list_multi/") @@ -213,11 +214,10 @@ async def list_multi( data: ListMulti, stoken: t.Optional[str] = None, limit: int = 50, + queryset: QuerySet = Depends(get_collection_queryset), user: User = Depends(get_authenticated_user), prefetch: Prefetch = PrefetchQuery, ): - queryset = get_collection_queryset(user) - # FIXME: Remove the isnull part once we attach collection types to all objects ("collection-type-migration") queryset = queryset.filter( Q(members__collectionType__uid__in=data.collectionTypes) | Q(members__collectionType__isnull=True) @@ -228,13 +228,12 @@ async def list_multi( @collection_router.post("/list/") async def collection_list( - req: Request, stoken: t.Optional[str] = None, limit: int = 50, prefetch: Prefetch = PrefetchQuery, user: User = Depends(get_authenticated_user), + queryset: QuerySet = Depends(get_collection_queryset), ): - queryset = get_collection_queryset(user) return await collection_list_common(queryset, user, stoken, limit, prefetch) @@ -309,9 +308,12 @@ async def create(data: CollectionIn, user: User = Depends(get_authenticated_user return MsgpackResponse({}, status_code=status.HTTP_201_CREATED) -@collection_router.get("/{uid}/") -def collection_get(uid: str, user: User = Depends(get_authenticated_user), prefetch: Prefetch = PrefetchQuery): - obj = get_collection_queryset(user).get(uid=uid) +@collection_router.get("/{collection_uid}/") +def collection_get( + obj: models.Collection = Depends(get_collection), + user: User = Depends(get_authenticated_user), + prefetch: Prefetch = PrefetchQuery + ): ret = CollectionOut.from_orm_context(obj, Context(user, prefetch)) return MsgpackResponse(ret) @@ -358,9 +360,10 @@ def item_create(item_model: CollectionItemIn, collection: models.Collection, val @collection_router.get("/{collection_uid}/item/{uid}/") def item_get( - collection_uid: str, uid: str, user: User = Depends(get_authenticated_user), prefetch: Prefetch = PrefetchQuery + uid: str, + queryset: QuerySet = Depends(get_item_queryset), + user: User = Depends(get_authenticated_user), prefetch: Prefetch = PrefetchQuery, ): - _, queryset = get_item_queryset(user, collection_uid) obj = queryset.get(uid=uid) ret = CollectionItemOut.from_orm_context(obj, Context(user, prefetch)) return MsgpackResponse(ret) @@ -386,14 +389,13 @@ def item_list_common( @collection_router.get("/{collection_uid}/item/") async def item_list( - collection_uid: str, + queryset: QuerySet = Depends(get_item_queryset), stoken: t.Optional[str] = None, limit: int = 50, prefetch: Prefetch = PrefetchQuery, withCollection: bool = False, user: User = Depends(get_authenticated_user), ): - _, queryset = await sync_to_async(get_item_queryset)(user, collection_uid) if not withCollection: queryset = queryset.filter(parent__isnull=True) @@ -419,14 +421,13 @@ def item_bulk_common(data: ItemBatchIn, user: User, stoken: t.Optional[str], uid @collection_router.get("/{collection_uid}/item/{uid}/revision/") def item_revisions( - collection_uid: str, uid: str, limit: int = 50, iterator: t.Optional[str] = None, prefetch: Prefetch = PrefetchQuery, user: User = Depends(get_authenticated_user), + items: QuerySet = Depends(get_item_queryset), ): - _, items = get_item_queryset(user, collection_uid) item = get_object_or_404(items, uid=uid) queryset = item.revisions.order_by("-id") @@ -456,13 +457,12 @@ def item_revisions( @collection_router.post("/{collection_uid}/item/fetch_updates/") def fetch_updates( - collection_uid: str, data: t.List[CollectionItemBulkGetIn], stoken: t.Optional[str] = None, prefetch: Prefetch = PrefetchQuery, user: User = Depends(get_authenticated_user), + queryset: QuerySet = Depends(get_item_queryset), ): - _, queryset = get_item_queryset(user, collection_uid) # FIXME: make configurable? item_limit = 200 diff --git a/etebase_fastapi/invitation.py b/etebase_fastapi/invitation.py index 077dcfd..cbf0554 100644 --- a/etebase_fastapi/invitation.py +++ b/etebase_fastapi/invitation.py @@ -73,12 +73,12 @@ class InvitationListResponse(BaseModel): done: bool -def get_incoming_queryset(user: User, queryset=default_queryset): - return queryset.filter(user=user) +def get_incoming_queryset(user: User = Depends(get_authenticated_user)): + return default_queryset.filter(user=user) -def get_outgoing_queryset(user: User, queryset=default_queryset): - return queryset.filter(fromMember__user=user) +def get_outgoing_queryset(user: User = Depends(get_authenticated_user)): + return default_queryset.filter(fromMember__user=user) def list_common( @@ -114,17 +114,16 @@ def list_common( def incoming_list( iterator: t.Optional[str] = None, limit: int = 50, - user: User = Depends(get_authenticated_user), + queryset: QuerySet = Depends(get_incoming_queryset), ): - return list_common(get_incoming_queryset(user), iterator, limit) + return list_common(queryset, iterator, limit) @invitation_incoming_router.get("/{invitation_uid}/", response_model=CollectionInvitationOut) def incoming_get( invitation_uid: str, - user: User = Depends(get_authenticated_user), + queryset: QuerySet = Depends(get_incoming_queryset), ): - queryset = get_incoming_queryset(user) obj = get_object_or_404(queryset, uid=invitation_uid) ret = CollectionInvitationOut.from_orm(obj) return MsgpackResponse(ret) @@ -133,9 +132,8 @@ def incoming_get( @invitation_incoming_router.delete("/{invitation_uid}/", status_code=status.HTTP_204_NO_CONTENT) def incoming_delete( invitation_uid: str, - user: User = Depends(get_authenticated_user), + queryset: QuerySet = Depends(get_incoming_queryset), ): - queryset = get_incoming_queryset(user) obj = get_object_or_404(queryset, uid=invitation_uid) obj.delete() @@ -144,9 +142,8 @@ def incoming_delete( def incoming_accept( invitation_uid: str, data: CollectionInvitationAcceptIn, - user: User = Depends(get_authenticated_user), + queryset: QuerySet = Depends(get_incoming_queryset), ): - queryset = get_incoming_queryset(user) invitation = get_object_or_404(queryset, uid=invitation_uid) with transaction.atomic(): @@ -201,17 +198,16 @@ def outgoing_create( def outgoing_list( iterator: t.Optional[str] = None, limit: int = 50, - user: User = Depends(get_authenticated_user), + queryset: QuerySet = Depends(get_outgoing_queryset), ): - return list_common(get_outgoing_queryset(user), iterator, limit) + return list_common(queryset, iterator, limit) @invitation_outgoing_router.delete("/{invitation_uid}/", status_code=status.HTTP_204_NO_CONTENT) def outgoing_delete( invitation_uid: str, - user: User = Depends(get_authenticated_user), + queryset: QuerySet = Depends(get_outgoing_queryset), ): - queryset = get_outgoing_queryset(user) obj = get_object_or_404(queryset, uid=invitation_uid) obj.delete() diff --git a/etebase_fastapi/member.py b/etebase_fastapi/member.py index 534cad1..a491490 100644 --- a/etebase_fastapi/member.py +++ b/etebase_fastapi/member.py @@ -12,15 +12,18 @@ from .msgpack import MsgpackResponse from .utils import get_object_or_404 from .stoken_handler import filter_by_stoken_and_limit -from .collection import collection_router, get_collection_queryset +from .collection import collection_router, get_collection User = get_user_model() default_queryset: QuerySet = models.CollectionMember.objects.all() -def get_queryset(user: User, collection_uid: str, queryset=default_queryset) -> t.Tuple[models.Collection, QuerySet]: - collection = get_object_or_404(get_collection_queryset(user), uid=collection_uid) - return collection, queryset.filter(collection=collection) +def get_queryset(collection: models.Collection = Depends(get_collection)) -> QuerySet: + return default_queryset.filter(collection=collection) + + +def get_member(username: str, queryset: QuerySet = Depends(get_queryset)) -> QuerySet: + return get_object_or_404(queryset, user__username__iexact=username) class CollectionMemberModifyAccessLevelIn(BaseModel): @@ -47,12 +50,10 @@ class MemberListResponse(BaseModel): @collection_router.get("/{collection_uid}/member/", response_model=MemberListResponse) def member_list( - collection_uid: str, iterator: t.Optional[str] = None, limit: int = 50, - user: User = Depends(get_authenticated_user), + queryset: QuerySet = Depends(get_queryset), ): - _, queryset = get_queryset(user, collection_uid) queryset = queryset.order_by("id") result, new_stoken_obj, done = filter_by_stoken_and_limit( iterator, limit, queryset, models.CollectionMember.stoken_annotation @@ -69,25 +70,16 @@ def member_list( @collection_router.delete("/{collection_uid}/member/{username}/", status_code=status.HTTP_204_NO_CONTENT) def member_delete( - collection_uid: str, - username: str, - user: User = Depends(get_authenticated_user), + obj: models.CollectionMember = Depends(get_member), ): - _, queryset = get_queryset(user, collection_uid) - obj = get_object_or_404(queryset, user__username__iexact=username) obj.revoke() @collection_router.patch("/{collection_uid}/member/{username}/", status_code=status.HTTP_204_NO_CONTENT) def member_patch( - collection_uid: str, - username: str, data: CollectionMemberModifyAccessLevelIn, - user: User = Depends(get_authenticated_user), + instance: models.CollectionMember = Depends(get_member), ): - _, queryset = get_queryset(user, collection_uid) - instance = get_object_or_404(queryset, user__username__iexact=username) - with transaction.atomic(): # We only allow updating accessLevel if instance.accessLevel != data.accessLevel: @@ -97,10 +89,6 @@ def member_patch( @collection_router.post("/{collection_uid}/member/leave/", status_code=status.HTTP_204_NO_CONTENT) -def member_leave( - collection_uid: str, - user: User = Depends(get_authenticated_user), -): - collection, _ = get_queryset(user, collection_uid) +def member_leave(user: User = Depends(get_authenticated_user), collection: models.Collection = Depends(get_collection)): obj = get_object_or_404(collection.members, user=user) obj.revoke()