From d55204b96d6c6ae154d03c8b3d505054381ce176 Mon Sep 17 00:00:00 2001 From: Tom Hacohen Date: Tue, 29 Dec 2020 17:18:09 +0200 Subject: [PATCH] Improve typing information. --- django_etebase/utils.py | 5 +++-- etebase_fastapi/routers/collection.py | 18 ++++++++++-------- etebase_fastapi/routers/invitation.py | 19 ++++++++++--------- etebase_fastapi/routers/member.py | 9 +++++---- etebase_fastapi/utils.py | 7 +++++-- 5 files changed, 33 insertions(+), 25 deletions(-) diff --git a/django_etebase/utils.py b/django_etebase/utils.py index d812ae3..3a05fd4 100644 --- a/django_etebase/utils.py +++ b/django_etebase/utils.py @@ -1,6 +1,7 @@ import typing as t from dataclasses import dataclass +from django.db.models import QuerySet from django.core.exceptions import PermissionDenied from myauth.models import UserType, get_typed_user_model @@ -18,14 +19,14 @@ class CallbackContext: user: t.Optional[UserType] = None -def get_user_queryset(queryset, context: CallbackContext): +def get_user_queryset(queryset: QuerySet[UserType], context: CallbackContext) -> QuerySet[UserType]: custom_func = app_settings.GET_USER_QUERYSET_FUNC if custom_func is not None: return custom_func(queryset, context) return queryset -def create_user(context: CallbackContext, *args, **kwargs): +def create_user(context: CallbackContext, *args, **kwargs) -> UserType: custom_func = app_settings.CREATE_USER_FUNC if custom_func is not None: return custom_func(context, *args, **kwargs) diff --git a/etebase_fastapi/routers/collection.py b/etebase_fastapi/routers/collection.py index 56afd7b..4825626 100644 --- a/etebase_fastapi/routers/collection.py +++ b/etebase_fastapi/routers/collection.py @@ -30,6 +30,8 @@ from ..sendfile import sendfile User = get_typed_user_model collection_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses) item_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses) +CollectionQuerySet = QuerySet[models.Collection] +CollectionItemQuerySet = QuerySet[models.CollectionItem] class ListMulti(BaseModel): @@ -187,7 +189,7 @@ class ItemBatchIn(BaseModel): @sync_to_async def collection_list_common( - queryset: QuerySet, + queryset: CollectionQuerySet, user: UserType, stoken: t.Optional[str], limit: int, @@ -249,7 +251,7 @@ async def list_multi( data: ListMulti, stoken: t.Optional[str] = None, limit: int = 50, - queryset: QuerySet = Depends(get_collection_queryset), + queryset: CollectionQuerySet = Depends(get_collection_queryset), user: UserType = Depends(get_authenticated_user), prefetch: Prefetch = PrefetchQuery, ): @@ -267,7 +269,7 @@ async def collection_list( limit: int = 50, prefetch: Prefetch = PrefetchQuery, user: UserType = Depends(get_authenticated_user), - queryset: QuerySet = Depends(get_collection_queryset), + queryset: CollectionQuerySet = Depends(get_collection_queryset), ): return await collection_list_common(queryset, user, stoken, limit, prefetch) @@ -395,7 +397,7 @@ def item_create(item_model: CollectionItemIn, collection: models.Collection, val @item_router.get("/item/{item_uid}/", response_model=CollectionItemOut, dependencies=PERMISSIONS_READ) def item_get( item_uid: str, - queryset: QuerySet = Depends(get_item_queryset), + queryset: CollectionItemQuerySet = Depends(get_item_queryset), user: UserType = Depends(get_authenticated_user), prefetch: Prefetch = PrefetchQuery, ): @@ -405,7 +407,7 @@ def item_get( @sync_to_async def item_list_common( - queryset: QuerySet, + queryset: CollectionItemQuerySet, user: UserType, stoken: t.Optional[str], limit: int, @@ -422,7 +424,7 @@ def item_list_common( @item_router.get("/item/", response_model=CollectionItemListResponse, dependencies=PERMISSIONS_READ) async def item_list( - queryset: QuerySet = Depends(get_item_queryset), + queryset: CollectionItemQuerySet = Depends(get_item_queryset), stoken: t.Optional[str] = None, limit: int = 50, prefetch: Prefetch = PrefetchQuery, @@ -471,7 +473,7 @@ def item_revisions( iterator: t.Optional[str] = None, prefetch: Prefetch = PrefetchQuery, user: UserType = Depends(get_authenticated_user), - items: QuerySet = Depends(get_item_queryset), + items: CollectionItemQuerySet = Depends(get_item_queryset), ): item = get_object_or_404(items, uid=item_uid) @@ -505,7 +507,7 @@ def fetch_updates( stoken: t.Optional[str] = None, prefetch: Prefetch = PrefetchQuery, user: UserType = Depends(get_authenticated_user), - queryset: QuerySet = Depends(get_item_queryset), + queryset: CollectionItemQuerySet = Depends(get_item_queryset), ): # FIXME: make configurable? item_limit = 200 diff --git a/etebase_fastapi/routers/invitation.py b/etebase_fastapi/routers/invitation.py index 6a06c60..aceb05d 100644 --- a/etebase_fastapi/routers/invitation.py +++ b/etebase_fastapi/routers/invitation.py @@ -23,7 +23,8 @@ from ..utils import ( User = get_typed_user_model() invitation_incoming_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses) invitation_outgoing_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses) -default_queryset: QuerySet = models.CollectionInvitation.objects.all() +InvitationQuerySet = QuerySet[models.CollectionInvitation] +default_queryset: InvitationQuerySet = models.CollectionInvitation.objects.all() class UserInfoOut(BaseModel): @@ -94,7 +95,7 @@ def get_outgoing_queryset(user: UserType = Depends(get_authenticated_user)): def list_common( - queryset: QuerySet, + queryset: InvitationQuerySet, iterator: t.Optional[str], limit: int, ) -> InvitationListResponse: @@ -125,7 +126,7 @@ def list_common( def incoming_list( iterator: t.Optional[str] = None, limit: int = 50, - queryset: QuerySet = Depends(get_incoming_queryset), + queryset: InvitationQuerySet = Depends(get_incoming_queryset), ): return list_common(queryset, iterator, limit) @@ -135,7 +136,7 @@ def incoming_list( ) def incoming_get( invitation_uid: str, - queryset: QuerySet = Depends(get_incoming_queryset), + queryset: InvitationQuerySet = Depends(get_incoming_queryset), ): obj = get_object_or_404(queryset, uid=invitation_uid) return CollectionInvitationOut.from_orm(obj) @@ -146,7 +147,7 @@ def incoming_get( ) def incoming_delete( invitation_uid: str, - queryset: QuerySet = Depends(get_incoming_queryset), + queryset: InvitationQuerySet = Depends(get_incoming_queryset), ): obj = get_object_or_404(queryset, uid=invitation_uid) obj.delete() @@ -158,7 +159,7 @@ def incoming_delete( def incoming_accept( invitation_uid: str, data: CollectionInvitationAcceptIn, - queryset: QuerySet = Depends(get_incoming_queryset), + queryset: InvitationQuerySet = Depends(get_incoming_queryset), ): invitation = get_object_or_404(queryset, uid=invitation_uid) @@ -201,7 +202,7 @@ def outgoing_create( with transaction.atomic(): try: - ret = models.CollectionInvitation.objects.create( + models.CollectionInvitation.objects.create( **data.dict(exclude={"collection", "username"}), user=to_user, fromMember=member ) except IntegrityError: @@ -212,7 +213,7 @@ def outgoing_create( def outgoing_list( iterator: t.Optional[str] = None, limit: int = 50, - queryset: QuerySet = Depends(get_outgoing_queryset), + queryset: InvitationQuerySet = Depends(get_outgoing_queryset), ): return list_common(queryset, iterator, limit) @@ -222,7 +223,7 @@ def outgoing_list( ) def outgoing_delete( invitation_uid: str, - queryset: QuerySet = Depends(get_outgoing_queryset), + queryset: InvitationQuerySet = Depends(get_outgoing_queryset), ): obj = get_object_or_404(queryset, uid=invitation_uid) obj.delete() diff --git a/etebase_fastapi/routers/member.py b/etebase_fastapi/routers/member.py index 210374c..41393bf 100644 --- a/etebase_fastapi/routers/member.py +++ b/etebase_fastapi/routers/member.py @@ -15,14 +15,15 @@ from .collection import get_collection, verify_collection_admin User = get_typed_user_model() member_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses) -default_queryset: QuerySet = models.CollectionMember.objects.all() +MemberQuerySet = QuerySet[models.CollectionMember] +default_queryset: MemberQuerySet = models.CollectionMember.objects.all() -def get_queryset(collection: models.Collection = Depends(get_collection)) -> QuerySet: +def get_queryset(collection: models.Collection = Depends(get_collection)) -> MemberQuerySet: return default_queryset.filter(collection=collection) -def get_member(username: str, queryset: QuerySet = Depends(get_queryset)) -> QuerySet: +def get_member(username: str, queryset: MemberQuerySet = Depends(get_queryset)) -> models.CollectionMember: return get_object_or_404(queryset, user__username__iexact=username) @@ -54,7 +55,7 @@ class MemberListResponse(BaseModel): def member_list( iterator: t.Optional[str] = None, limit: int = 50, - queryset: QuerySet = Depends(get_queryset), + queryset: MemberQuerySet = Depends(get_queryset), ): queryset = queryset.order_by("id") result, new_stoken_obj, done = filter_by_stoken_and_limit( diff --git a/etebase_fastapi/utils.py b/etebase_fastapi/utils.py index 3a091c5..b52b7c7 100644 --- a/etebase_fastapi/utils.py +++ b/etebase_fastapi/utils.py @@ -6,7 +6,7 @@ import base64 from fastapi import status, Query, Depends from pydantic import BaseModel as PyBaseModel -from django.db.models import QuerySet +from django.db.models import Model, QuerySet from django.core.exceptions import ObjectDoesNotExist from django_etebase import app_settings @@ -21,6 +21,9 @@ Prefetch = t.Literal["auto", "medium"] PrefetchQuery = Query(default="auto") +T = t.TypeVar("T", bound=Model, covariant=True) + + class BaseModel(PyBaseModel): class Config: json_encoders = { @@ -34,7 +37,7 @@ class Context: prefetch: t.Optional[Prefetch] -def get_object_or_404(queryset: QuerySet, **kwargs): +def get_object_or_404(queryset: QuerySet[T], **kwargs) -> T: try: return queryset.get(**kwargs) except ObjectDoesNotExist as e: