diff --git a/django_etebase/token_auth/models.py b/django_etebase/token_auth/models.py index ac1efff..dd5ae87 100644 --- a/django_etebase/token_auth/models.py +++ b/django_etebase/token_auth/models.py @@ -1,9 +1,9 @@ -from django.contrib.auth import get_user_model from django.db import models from django.utils import timezone from django.utils.crypto import get_random_string +from myauth.models import get_typed_user_model -User = get_user_model() +User = get_typed_user_model() def generate_key(): diff --git a/django_etebase/utils.py b/django_etebase/utils.py index 4d36a94..d812ae3 100644 --- a/django_etebase/utils.py +++ b/django_etebase/utils.py @@ -1,13 +1,13 @@ import typing as t from dataclasses import dataclass -from django.contrib.auth import get_user_model from django.core.exceptions import PermissionDenied +from myauth.models import UserType, get_typed_user_model from . import app_settings -User = get_user_model() +User = get_typed_user_model() @dataclass @@ -15,7 +15,7 @@ class CallbackContext: """Class for passing extra context to callbacks""" url_kwargs: t.Dict[str, t.Any] - user: t.Optional[User] = None + user: t.Optional[UserType] = None def get_user_queryset(queryset, context: CallbackContext): diff --git a/etebase_fastapi/authentication.py b/etebase_fastapi/authentication.py index fe522f7..064d2da 100644 --- a/etebase_fastapi/authentication.py +++ b/etebase_fastapi/authentication.py @@ -9,7 +9,7 @@ import nacl.secret import nacl.signing from asgiref.sync import sync_to_async from django.conf import settings -from django.contrib.auth import get_user_model, user_logged_out, user_logged_in +from django.contrib.auth import user_logged_out, user_logged_in from django.core import exceptions as django_exceptions from django.db import transaction from fastapi import APIRouter, Depends, status, Request @@ -19,12 +19,13 @@ from django_etebase.token_auth.models import AuthToken from django_etebase.models import UserInfo from django_etebase.signals import user_signed_up from django_etebase.utils import create_user, get_user_queryset, CallbackContext +from myauth.models import UserType, get_typed_user_model from .exceptions import AuthenticationFailed, transform_validation_error, HttpError from .msgpack import MsgpackRoute from .utils import BaseModel, permission_responses, msgpack_encode, msgpack_decode from .dependencies import AuthData, get_auth_data, get_authenticated_user -User = get_user_model() +User = get_typed_user_model() authentication_router = APIRouter(route_class=MsgpackRoute) @@ -52,7 +53,7 @@ class UserOut(BaseModel): encryptedContent: bytes @classmethod - def from_orm(cls: t.Type["UserOut"], obj: User) -> "UserOut": + def from_orm(cls: t.Type["UserOut"], obj: UserType) -> "UserOut": return cls( username=obj.username, email=obj.email, @@ -66,7 +67,7 @@ class LoginOut(BaseModel): user: UserOut @classmethod - def from_orm(cls: t.Type["LoginOut"], obj: User) -> "LoginOut": + def from_orm(cls: t.Type["LoginOut"], obj: UserType) -> "LoginOut": token = AuthToken.objects.create(user=obj).key user = UserOut.from_orm(obj) return cls(token=token, user=user) @@ -111,7 +112,7 @@ class SignupIn(BaseModel): @sync_to_async -def __get_login_user(username: str) -> User: +def __get_login_user(username: str) -> UserType: kwargs = {User.USERNAME_FIELD + "__iexact": username.lower()} try: user = User.objects.get(**kwargs) @@ -122,7 +123,7 @@ def __get_login_user(username: str) -> User: raise AuthenticationFailed(code="user_not_found", detail="User not found") -async def get_login_user(challenge: LoginChallengeIn) -> User: +async def get_login_user(challenge: LoginChallengeIn) -> UserType: user = await __get_login_user(challenge.username) return user @@ -138,7 +139,7 @@ def get_encryption_key(salt): ) -def save_changed_password(data: ChangePassword, user: User): +def save_changed_password(data: ChangePassword, user: UserType): response_data = data.response_data user_info: UserInfo = user.userinfo user_info.loginPubkey = response_data.loginPubkey @@ -150,7 +151,7 @@ def save_changed_password(data: ChangePassword, user: User): def validate_login_request( validated_data: LoginResponse, challenge_sent_to_user: Authentication, - user: User, + user: UserType, expected_action: str, host_from_request: str, ): @@ -159,7 +160,7 @@ def validate_login_request( challenge_data = msgpack_decode(box.decrypt(validated_data.challenge)) now = int(datetime.now().timestamp()) if validated_data.action != expected_action: - raise HttpError("wrong_action", f'Expected "{challenge_sent_to_user.response}" but got something else') + raise HttpError("wrong_action", f'Expected "{expected_action}" but got something else') elif now - challenge_data["timestamp"] > app_settings.CHALLENGE_VALID_SECONDS: raise HttpError("challenge_expired", "Login challenge has expired") elif challenge_data["userId"] != user.id: @@ -181,7 +182,7 @@ async def is_etebase(): @authentication_router.post("/login_challenge/", response_model=LoginChallengeOut) -def login_challenge(user: User = Depends(get_login_user)): +def login_challenge(user: UserType = Depends(get_login_user)): salt = bytes(user.userinfo.salt) enc_key = get_encryption_key(salt) box = nacl.secret.SecretBox(enc_key) @@ -210,14 +211,14 @@ def logout(auth_data: AuthData = Depends(get_auth_data)): @authentication_router.post("/change_password/", status_code=status.HTTP_204_NO_CONTENT, responses=permission_responses) -async def change_password(data: ChangePassword, request: Request, user: User = Depends(get_authenticated_user)): +async def change_password(data: ChangePassword, request: Request, user: UserType = Depends(get_authenticated_user)): host = request.headers.get("Host") await validate_login_request(data.response_data, data, user, "changePassword", host) await sync_to_async(save_changed_password)(data, user) @authentication_router.post("/dashboard_url/", responses=permission_responses) -def dashboard_url(request: Request, user: User = Depends(get_authenticated_user)): +def dashboard_url(request: Request, user: UserType = Depends(get_authenticated_user)): get_dashboard_url = app_settings.DASHBOARD_URL_FUNC if get_dashboard_url is None: raise HttpError("not_supported", "This server doesn't have a user dashboard.") @@ -228,7 +229,7 @@ def dashboard_url(request: Request, user: User = Depends(get_authenticated_user) return ret -def signup_save(data: SignupIn, request: Request) -> User: +def signup_save(data: SignupIn, request: Request) -> UserType: user_data = data.user with transaction.atomic(): try: diff --git a/etebase_fastapi/collection.py b/etebase_fastapi/collection.py index 5c6e6b6..9e25b38 100644 --- a/etebase_fastapi/collection.py +++ b/etebase_fastapi/collection.py @@ -1,7 +1,6 @@ import typing as t from asgiref.sync import sync_to_async -from django.contrib.auth import get_user_model from django.core import exceptions as django_exceptions from django.core.files.base import ContentFile from django.db import transaction, IntegrityError @@ -9,6 +8,7 @@ from django.db.models import Q, QuerySet from fastapi import APIRouter, Depends, status, Request from django_etebase import models +from myauth.models import UserType, get_typed_user_model from .authentication import get_authenticated_user from .exceptions import HttpError, transform_validation_error, PermissionDenied, ValidationError from .msgpack import MsgpackRoute @@ -27,7 +27,7 @@ from .utils import ( from .dependencies import get_collection_queryset, get_item_queryset, get_collection from .sendfile import sendfile -User = get_user_model() +User = get_typed_user_model collection_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses) item_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses) @@ -36,11 +36,14 @@ class ListMulti(BaseModel): collectionTypes: t.List[bytes] +ChunkType = t.Tuple[str, t.Optional[bytes]] + + class CollectionItemRevisionInOut(BaseModel): uid: str meta: bytes deleted: bool - chunks: t.List[t.Tuple[str, t.Optional[bytes]]] + chunks: t.List[ChunkType] class Config: orm_mode = True @@ -49,7 +52,7 @@ class CollectionItemRevisionInOut(BaseModel): def from_orm_context( cls: t.Type["CollectionItemRevisionInOut"], obj: models.CollectionItemRevision, context: Context ) -> "CollectionItemRevisionInOut": - chunks = [] + chunks: t.List[ChunkType] = [] for chunk_relation in obj.chunks_relation.all(): chunk_obj = chunk_relation.chunk if context.prefetch == "auto": @@ -185,7 +188,7 @@ class ItemBatchIn(BaseModel): @sync_to_async def collection_list_common( queryset: QuerySet, - user: User, + user: UserType, stoken: t.Optional[str], limit: int, prefetch: Prefetch, @@ -210,7 +213,7 @@ def collection_list_common( remed = remed_qs.values_list("collection__uid", flat=True) if len(remed) > 0: - ret.removedMemberships = [{"uid": x} for x in remed] + ret.removedMemberships = [RemovedMembershipOut(uid=x) for x in remed] return ret @@ -219,14 +222,14 @@ def collection_list_common( def verify_collection_admin( - collection: models.Collection = Depends(get_collection), user: User = Depends(get_authenticated_user) + collection: models.Collection = Depends(get_collection), user: UserType = Depends(get_authenticated_user) ): if not is_collection_admin(collection, user): raise PermissionDenied("admin_access_required", "Only collection admins can perform this operation.") def has_write_access( - collection: models.Collection = Depends(get_collection), user: User = Depends(get_authenticated_user) + collection: models.Collection = Depends(get_collection), user: UserType = Depends(get_authenticated_user) ): member = collection.members.get(user=user) if member.accessLevel == models.AccessLevels.READ_ONLY: @@ -247,7 +250,7 @@ async def list_multi( stoken: t.Optional[str] = None, limit: int = 50, queryset: QuerySet = Depends(get_collection_queryset), - user: User = Depends(get_authenticated_user), + user: UserType = Depends(get_authenticated_user), prefetch: Prefetch = PrefetchQuery, ): # FIXME: Remove the isnull part once we attach collection types to all objects ("collection-type-migration") @@ -263,7 +266,7 @@ async def collection_list( stoken: t.Optional[str] = None, limit: int = 50, prefetch: Prefetch = PrefetchQuery, - user: User = Depends(get_authenticated_user), + user: UserType = Depends(get_authenticated_user), queryset: QuerySet = Depends(get_collection_queryset), ): return await collection_list_common(queryset, user, stoken, limit, prefetch) @@ -299,7 +302,7 @@ def process_revisions_for_item(item: models.CollectionItem, revision_data: Colle return revision -def _create(data: CollectionIn, user: User): +def _create(data: CollectionIn, user: UserType): with transaction.atomic(): if data.item.etag is not None: raise ValidationError("bad_etag", "etag is not null") @@ -335,14 +338,14 @@ def _create(data: CollectionIn, user: User): @collection_router.post("/", status_code=status.HTTP_201_CREATED, dependencies=PERMISSIONS_READWRITE) -async def create(data: CollectionIn, user: User = Depends(get_authenticated_user)): +async def create(data: CollectionIn, user: UserType = Depends(get_authenticated_user)): await sync_to_async(_create)(data, user) @collection_router.get("/{collection_uid}/", response_model=CollectionOut, dependencies=PERMISSIONS_READ) def collection_get( obj: models.Collection = Depends(get_collection), - user: User = Depends(get_authenticated_user), + user: UserType = Depends(get_authenticated_user), prefetch: Prefetch = PrefetchQuery, ): return CollectionOut.from_orm_context(obj, Context(user, prefetch)) @@ -393,7 +396,7 @@ def item_create(item_model: CollectionItemIn, collection: models.Collection, val def item_get( item_uid: str, queryset: QuerySet = Depends(get_item_queryset), - user: User = Depends(get_authenticated_user), + user: UserType = Depends(get_authenticated_user), prefetch: Prefetch = PrefetchQuery, ): obj = queryset.get(uid=item_uid) @@ -403,7 +406,7 @@ def item_get( @sync_to_async def item_list_common( queryset: QuerySet, - user: User, + user: UserType, stoken: t.Optional[str], limit: int, prefetch: Prefetch, @@ -424,7 +427,7 @@ async def item_list( limit: int = 50, prefetch: Prefetch = PrefetchQuery, withCollection: bool = False, - user: User = Depends(get_authenticated_user), + user: UserType = Depends(get_authenticated_user), ): if not withCollection: queryset = queryset.filter(parent__isnull=True) @@ -433,7 +436,7 @@ async def item_list( return response -def item_bulk_common(data: ItemBatchIn, user: User, stoken: t.Optional[str], uid: str, validate_etag: bool): +def item_bulk_common(data: ItemBatchIn, user: UserType, stoken: t.Optional[str], uid: str, validate_etag: bool): queryset = get_collection_queryset(user) with transaction.atomic(): # We need this for locking the collection object collection_object = queryset.select_for_update().get(uid=uid) @@ -467,7 +470,7 @@ def item_revisions( limit: int = 50, iterator: t.Optional[str] = None, prefetch: Prefetch = PrefetchQuery, - user: User = Depends(get_authenticated_user), + user: UserType = Depends(get_authenticated_user), items: QuerySet = Depends(get_item_queryset), ): item = get_object_or_404(items, uid=item_uid) @@ -501,7 +504,7 @@ def fetch_updates( data: t.List[CollectionItemBulkGetIn], stoken: t.Optional[str] = None, prefetch: Prefetch = PrefetchQuery, - user: User = Depends(get_authenticated_user), + user: UserType = Depends(get_authenticated_user), queryset: QuerySet = Depends(get_item_queryset), ): # FIXME: make configurable? @@ -531,14 +534,14 @@ def fetch_updates( @item_router.post("/item/transaction/", dependencies=[Depends(has_write_access), *PERMISSIONS_READWRITE]) def item_transaction( - collection_uid: str, data: ItemBatchIn, stoken: t.Optional[str] = None, user: User = Depends(get_authenticated_user) + collection_uid: str, data: ItemBatchIn, stoken: t.Optional[str] = None, user: UserType = Depends(get_authenticated_user) ): return item_bulk_common(data, user, stoken, collection_uid, validate_etag=True) @item_router.post("/item/batch/", dependencies=[Depends(has_write_access), *PERMISSIONS_READWRITE]) def item_batch( - collection_uid: str, data: ItemBatchIn, stoken: t.Optional[str] = None, user: User = Depends(get_authenticated_user) + collection_uid: str, data: ItemBatchIn, stoken: t.Optional[str] = None, user: UserType = Depends(get_authenticated_user) ): return item_bulk_common(data, user, stoken, collection_uid, validate_etag=False) diff --git a/etebase_fastapi/dependencies.py b/etebase_fastapi/dependencies.py index ddb9b3b..fb9cec5 100644 --- a/etebase_fastapi/dependencies.py +++ b/etebase_fastapi/dependencies.py @@ -3,17 +3,17 @@ import dataclasses from fastapi import Depends from fastapi.security import APIKeyHeader -from django.contrib.auth import get_user_model from django.utils import timezone from django.db.models import QuerySet from django_etebase import models from django_etebase.token_auth.models import AuthToken, get_default_expiry +from myauth.models import UserType, get_typed_user_model from .exceptions import AuthenticationFailed from .utils import get_object_or_404 -User = get_user_model() +User = get_typed_user_model() token_scheme = APIKeyHeader(name="Authorization") AUTO_REFRESH = True MIN_REFRESH_INTERVAL = 60 @@ -21,7 +21,7 @@ MIN_REFRESH_INTERVAL = 60 @dataclasses.dataclass(frozen=True) class AuthData: - user: User + user: UserType token: AuthToken @@ -60,12 +60,12 @@ def get_auth_data(api_token: str = Depends(token_scheme)) -> AuthData: return AuthData(user, token) -def get_authenticated_user(api_token: str = Depends(token_scheme)) -> User: +def get_authenticated_user(api_token: str = Depends(token_scheme)) -> UserType: user, _ = __get_authenticated_user(api_token) return user -def get_collection_queryset(user: User = Depends(get_authenticated_user)) -> QuerySet: +def get_collection_queryset(user: UserType = Depends(get_authenticated_user)) -> QuerySet: default_queryset: QuerySet = models.Collection.objects.all() return default_queryset.filter(members__user=user) diff --git a/etebase_fastapi/invitation.py b/etebase_fastapi/invitation.py index 9e731bc..eb9f549 100644 --- a/etebase_fastapi/invitation.py +++ b/etebase_fastapi/invitation.py @@ -1,12 +1,12 @@ import typing as t -from django.contrib.auth import get_user_model from django.db import transaction, IntegrityError from django.db.models import QuerySet from fastapi import APIRouter, Depends, status, Request from django_etebase import models from django_etebase.utils import get_user_queryset, CallbackContext +from myauth.models import UserType, get_typed_user_model from .authentication import get_authenticated_user from .exceptions import HttpError, PermissionDenied from .msgpack import MsgpackRoute @@ -20,7 +20,7 @@ from .utils import ( PERMISSIONS_READWRITE, ) -User = get_user_model() +User = get_typed_user_model() invitation_incoming_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses) invitation_outgoing_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses) default_queryset: QuerySet = models.CollectionInvitation.objects.all() @@ -53,7 +53,8 @@ class CollectionInvitationCommon(BaseModel): class CollectionInvitationIn(CollectionInvitationCommon): def validate_db(self, context: Context): - if context.user.username == self.username.lower(): + user = context.user + if user is not None and (user.username == self.username.lower()): raise HttpError("no_self_invite", "Inviting yourself is not allowed") @@ -84,11 +85,11 @@ class InvitationListResponse(BaseModel): done: bool -def get_incoming_queryset(user: User = Depends(get_authenticated_user)): +def get_incoming_queryset(user: UserType = Depends(get_authenticated_user)): return default_queryset.filter(user=user) -def get_outgoing_queryset(user: User = Depends(get_authenticated_user)): +def get_outgoing_queryset(user: UserType = Depends(get_authenticated_user)): return default_queryset.filter(fromMember__user=user) @@ -183,7 +184,7 @@ def incoming_accept( def outgoing_create( data: CollectionInvitationIn, request: Request, - user: User = Depends(get_authenticated_user), + user: UserType = Depends(get_authenticated_user), ): collection = get_object_or_404(models.Collection.objects, uid=data.collection) to_user = get_object_or_404( @@ -231,7 +232,7 @@ def outgoing_delete( def outgoing_fetch_user_profile( username: str, request: Request, - user: User = Depends(get_authenticated_user), + user: UserType = Depends(get_authenticated_user), ): kwargs = {User.USERNAME_FIELD: username.lower()} user = get_object_or_404(get_user_queryset(User.objects.all(), CallbackContext(request.path_params)), **kwargs) diff --git a/etebase_fastapi/member.py b/etebase_fastapi/member.py index 725d44b..22977ac 100644 --- a/etebase_fastapi/member.py +++ b/etebase_fastapi/member.py @@ -1,11 +1,11 @@ import typing as t -from django.contrib.auth import get_user_model from django.db import transaction from django.db.models import QuerySet from fastapi import APIRouter, Depends, status from django_etebase import models +from myauth.models import UserType, get_typed_user_model from .authentication import get_authenticated_user from .msgpack import MsgpackRoute from .utils import get_object_or_404, BaseModel, permission_responses, PERMISSIONS_READ, PERMISSIONS_READWRITE @@ -13,7 +13,7 @@ from .stoken_handler import filter_by_stoken_and_limit from .collection import get_collection, verify_collection_admin -User = get_user_model() +User = get_typed_user_model() member_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses) default_queryset: QuerySet = models.CollectionMember.objects.all() @@ -98,6 +98,8 @@ def member_patch( @member_router.post("/member/leave/", status_code=status.HTTP_204_NO_CONTENT, dependencies=PERMISSIONS_READ) -def member_leave(user: User = Depends(get_authenticated_user), collection: models.Collection = Depends(get_collection)): +def member_leave( + user: UserType = Depends(get_authenticated_user), collection: models.Collection = Depends(get_collection) +): obj = get_object_or_404(collection.members, user=user) obj.revoke() diff --git a/etebase_fastapi/msgpack.py b/etebase_fastapi/msgpack.py index edffd7e..915e783 100644 --- a/etebase_fastapi/msgpack.py +++ b/etebase_fastapi/msgpack.py @@ -19,13 +19,15 @@ class MsgpackRequest(Request): class MsgpackResponse(Response): media_type = "application/msgpack" - def render(self, content: t.Optional[t.Any]) -> t.Optional[bytes]: + def render(self, content: t.Optional[t.Any]) -> bytes: if content is None: return b"" if isinstance(content, BaseModel): content = content.dict() - return msgpack.packb(content, use_bin_type=True) + ret = msgpack.packb(content, use_bin_type=True) + assert ret is not None + return ret class MsgpackRoute(APIRoute): diff --git a/etebase_fastapi/test_reset_view.py b/etebase_fastapi/test_reset_view.py index 3075290..e328875 100644 --- a/etebase_fastapi/test_reset_view.py +++ b/etebase_fastapi/test_reset_view.py @@ -1,5 +1,4 @@ from django.conf import settings -from django.contrib.auth import get_user_model from django.db import transaction from django.shortcuts import get_object_or_404 from fastapi import APIRouter, Request, status @@ -8,9 +7,10 @@ from django_etebase.utils import get_user_queryset, CallbackContext from etebase_fastapi.authentication import SignupIn, signup_save from etebase_fastapi.msgpack import MsgpackRoute from etebase_fastapi.exceptions import HttpError +from myauth.models import get_typed_user_model test_reset_view_router = APIRouter(route_class=MsgpackRoute, tags=["test helpers"]) -User = get_user_model() +User = get_typed_user_model() @test_reset_view_router.post("/reset/", status_code=status.HTTP_204_NO_CONTENT) diff --git a/etebase_fastapi/utils.py b/etebase_fastapi/utils.py index 7280018..c91c3ec 100644 --- a/etebase_fastapi/utils.py +++ b/etebase_fastapi/utils.py @@ -8,14 +8,14 @@ from pydantic import BaseModel as PyBaseModel from django.db.models import QuerySet from django.core.exceptions import ObjectDoesNotExist -from django.contrib.auth import get_user_model from django_etebase import app_settings from django_etebase.models import AccessLevels +from myauth.models import UserType, get_typed_user_model from .exceptions import HttpError, HttpErrorOut -User = get_user_model() +User = get_typed_user_model() Prefetch = t.Literal["auto", "medium"] PrefetchQuery = Query(default="auto") @@ -30,7 +30,7 @@ class BaseModel(PyBaseModel): @dataclasses.dataclass class Context: - user: t.Optional[User] + user: t.Optional[UserType] prefetch: t.Optional[Prefetch] diff --git a/myauth/forms.py b/myauth/forms.py index 7aacb9b..fc2be74 100644 --- a/myauth/forms.py +++ b/myauth/forms.py @@ -1,8 +1,8 @@ from django import forms -from django.contrib.auth import get_user_model from django.contrib.auth.forms import UsernameField +from myauth.models import get_typed_user_model -User = get_user_model() +User = get_typed_user_model() class AdminUserCreationForm(forms.ModelForm): diff --git a/myauth/models.py b/myauth/models.py index d6585a8..5bc4af7 100644 --- a/myauth/models.py +++ b/myauth/models.py @@ -1,3 +1,5 @@ +import typing as t + from django.contrib.auth.models import AbstractUser, UserManager as DjangoUserManager from django.core import validators from django.db import models @@ -28,9 +30,21 @@ class User(AbstractUser): unique=True, help_text=_("Required. 150 characters or fewer. Letters, digits and ./-/_ only."), validators=[username_validator], - error_messages={"unique": _("A user with that username already exists."),}, + error_messages={ + "unique": _("A user with that username already exists."), + }, ) @classmethod def normalize_username(cls, username): return super().normalize_username(username).lower() + + +UserType = t.Type[User] + + +def get_typed_user_model() -> UserType: + from django.contrib.auth import get_user_model + + ret: t.Any = get_user_model() + return ret