Fix many type errors.

pull/72/head
Tom Hacohen 3 years ago
parent e13f26ec56
commit 794b5f3983

@ -1,9 +1,9 @@
from django.contrib.auth import get_user_model
from django.db import models from django.db import models
from django.utils import timezone from django.utils import timezone
from django.utils.crypto import get_random_string from django.utils.crypto import get_random_string
from myauth.models import get_typed_user_model
User = get_user_model() User = get_typed_user_model()
def generate_key(): def generate_key():

@ -1,13 +1,13 @@
import typing as t import typing as t
from dataclasses import dataclass from dataclasses import dataclass
from django.contrib.auth import get_user_model
from django.core.exceptions import PermissionDenied from django.core.exceptions import PermissionDenied
from myauth.models import UserType, get_typed_user_model
from . import app_settings from . import app_settings
User = get_user_model() User = get_typed_user_model()
@dataclass @dataclass
@ -15,7 +15,7 @@ class CallbackContext:
"""Class for passing extra context to callbacks""" """Class for passing extra context to callbacks"""
url_kwargs: t.Dict[str, t.Any] url_kwargs: t.Dict[str, t.Any]
user: t.Optional[User] = None user: t.Optional[UserType] = None
def get_user_queryset(queryset, context: CallbackContext): def get_user_queryset(queryset, context: CallbackContext):

@ -9,7 +9,7 @@ import nacl.secret
import nacl.signing import nacl.signing
from asgiref.sync import sync_to_async from asgiref.sync import sync_to_async
from django.conf import settings from django.conf import settings
from django.contrib.auth import get_user_model, user_logged_out, user_logged_in from django.contrib.auth import user_logged_out, user_logged_in
from django.core import exceptions as django_exceptions from django.core import exceptions as django_exceptions
from django.db import transaction from django.db import transaction
from fastapi import APIRouter, Depends, status, Request from fastapi import APIRouter, Depends, status, Request
@ -19,12 +19,13 @@ from django_etebase.token_auth.models import AuthToken
from django_etebase.models import UserInfo from django_etebase.models import UserInfo
from django_etebase.signals import user_signed_up from django_etebase.signals import user_signed_up
from django_etebase.utils import create_user, get_user_queryset, CallbackContext from django_etebase.utils import create_user, get_user_queryset, CallbackContext
from myauth.models import UserType, get_typed_user_model
from .exceptions import AuthenticationFailed, transform_validation_error, HttpError from .exceptions import AuthenticationFailed, transform_validation_error, HttpError
from .msgpack import MsgpackRoute from .msgpack import MsgpackRoute
from .utils import BaseModel, permission_responses, msgpack_encode, msgpack_decode from .utils import BaseModel, permission_responses, msgpack_encode, msgpack_decode
from .dependencies import AuthData, get_auth_data, get_authenticated_user from .dependencies import AuthData, get_auth_data, get_authenticated_user
User = get_user_model() User = get_typed_user_model()
authentication_router = APIRouter(route_class=MsgpackRoute) authentication_router = APIRouter(route_class=MsgpackRoute)
@ -52,7 +53,7 @@ class UserOut(BaseModel):
encryptedContent: bytes encryptedContent: bytes
@classmethod @classmethod
def from_orm(cls: t.Type["UserOut"], obj: User) -> "UserOut": def from_orm(cls: t.Type["UserOut"], obj: UserType) -> "UserOut":
return cls( return cls(
username=obj.username, username=obj.username,
email=obj.email, email=obj.email,
@ -66,7 +67,7 @@ class LoginOut(BaseModel):
user: UserOut user: UserOut
@classmethod @classmethod
def from_orm(cls: t.Type["LoginOut"], obj: User) -> "LoginOut": def from_orm(cls: t.Type["LoginOut"], obj: UserType) -> "LoginOut":
token = AuthToken.objects.create(user=obj).key token = AuthToken.objects.create(user=obj).key
user = UserOut.from_orm(obj) user = UserOut.from_orm(obj)
return cls(token=token, user=user) return cls(token=token, user=user)
@ -111,7 +112,7 @@ class SignupIn(BaseModel):
@sync_to_async @sync_to_async
def __get_login_user(username: str) -> User: def __get_login_user(username: str) -> UserType:
kwargs = {User.USERNAME_FIELD + "__iexact": username.lower()} kwargs = {User.USERNAME_FIELD + "__iexact": username.lower()}
try: try:
user = User.objects.get(**kwargs) user = User.objects.get(**kwargs)
@ -122,7 +123,7 @@ def __get_login_user(username: str) -> User:
raise AuthenticationFailed(code="user_not_found", detail="User not found") raise AuthenticationFailed(code="user_not_found", detail="User not found")
async def get_login_user(challenge: LoginChallengeIn) -> User: async def get_login_user(challenge: LoginChallengeIn) -> UserType:
user = await __get_login_user(challenge.username) user = await __get_login_user(challenge.username)
return user return user
@ -138,7 +139,7 @@ def get_encryption_key(salt):
) )
def save_changed_password(data: ChangePassword, user: User): def save_changed_password(data: ChangePassword, user: UserType):
response_data = data.response_data response_data = data.response_data
user_info: UserInfo = user.userinfo user_info: UserInfo = user.userinfo
user_info.loginPubkey = response_data.loginPubkey user_info.loginPubkey = response_data.loginPubkey
@ -150,7 +151,7 @@ def save_changed_password(data: ChangePassword, user: User):
def validate_login_request( def validate_login_request(
validated_data: LoginResponse, validated_data: LoginResponse,
challenge_sent_to_user: Authentication, challenge_sent_to_user: Authentication,
user: User, user: UserType,
expected_action: str, expected_action: str,
host_from_request: str, host_from_request: str,
): ):
@ -159,7 +160,7 @@ def validate_login_request(
challenge_data = msgpack_decode(box.decrypt(validated_data.challenge)) challenge_data = msgpack_decode(box.decrypt(validated_data.challenge))
now = int(datetime.now().timestamp()) now = int(datetime.now().timestamp())
if validated_data.action != expected_action: if validated_data.action != expected_action:
raise HttpError("wrong_action", f'Expected "{challenge_sent_to_user.response}" but got something else') raise HttpError("wrong_action", f'Expected "{expected_action}" but got something else')
elif now - challenge_data["timestamp"] > app_settings.CHALLENGE_VALID_SECONDS: elif now - challenge_data["timestamp"] > app_settings.CHALLENGE_VALID_SECONDS:
raise HttpError("challenge_expired", "Login challenge has expired") raise HttpError("challenge_expired", "Login challenge has expired")
elif challenge_data["userId"] != user.id: elif challenge_data["userId"] != user.id:
@ -181,7 +182,7 @@ async def is_etebase():
@authentication_router.post("/login_challenge/", response_model=LoginChallengeOut) @authentication_router.post("/login_challenge/", response_model=LoginChallengeOut)
def login_challenge(user: User = Depends(get_login_user)): def login_challenge(user: UserType = Depends(get_login_user)):
salt = bytes(user.userinfo.salt) salt = bytes(user.userinfo.salt)
enc_key = get_encryption_key(salt) enc_key = get_encryption_key(salt)
box = nacl.secret.SecretBox(enc_key) box = nacl.secret.SecretBox(enc_key)
@ -210,14 +211,14 @@ def logout(auth_data: AuthData = Depends(get_auth_data)):
@authentication_router.post("/change_password/", status_code=status.HTTP_204_NO_CONTENT, responses=permission_responses) @authentication_router.post("/change_password/", status_code=status.HTTP_204_NO_CONTENT, responses=permission_responses)
async def change_password(data: ChangePassword, request: Request, user: User = Depends(get_authenticated_user)): async def change_password(data: ChangePassword, request: Request, user: UserType = Depends(get_authenticated_user)):
host = request.headers.get("Host") host = request.headers.get("Host")
await validate_login_request(data.response_data, data, user, "changePassword", host) await validate_login_request(data.response_data, data, user, "changePassword", host)
await sync_to_async(save_changed_password)(data, user) await sync_to_async(save_changed_password)(data, user)
@authentication_router.post("/dashboard_url/", responses=permission_responses) @authentication_router.post("/dashboard_url/", responses=permission_responses)
def dashboard_url(request: Request, user: User = Depends(get_authenticated_user)): def dashboard_url(request: Request, user: UserType = Depends(get_authenticated_user)):
get_dashboard_url = app_settings.DASHBOARD_URL_FUNC get_dashboard_url = app_settings.DASHBOARD_URL_FUNC
if get_dashboard_url is None: if get_dashboard_url is None:
raise HttpError("not_supported", "This server doesn't have a user dashboard.") raise HttpError("not_supported", "This server doesn't have a user dashboard.")
@ -228,7 +229,7 @@ def dashboard_url(request: Request, user: User = Depends(get_authenticated_user)
return ret return ret
def signup_save(data: SignupIn, request: Request) -> User: def signup_save(data: SignupIn, request: Request) -> UserType:
user_data = data.user user_data = data.user
with transaction.atomic(): with transaction.atomic():
try: try:

@ -1,7 +1,6 @@
import typing as t import typing as t
from asgiref.sync import sync_to_async from asgiref.sync import sync_to_async
from django.contrib.auth import get_user_model
from django.core import exceptions as django_exceptions from django.core import exceptions as django_exceptions
from django.core.files.base import ContentFile from django.core.files.base import ContentFile
from django.db import transaction, IntegrityError from django.db import transaction, IntegrityError
@ -9,6 +8,7 @@ from django.db.models import Q, QuerySet
from fastapi import APIRouter, Depends, status, Request from fastapi import APIRouter, Depends, status, Request
from django_etebase import models from django_etebase import models
from myauth.models import UserType, get_typed_user_model
from .authentication import get_authenticated_user from .authentication import get_authenticated_user
from .exceptions import HttpError, transform_validation_error, PermissionDenied, ValidationError from .exceptions import HttpError, transform_validation_error, PermissionDenied, ValidationError
from .msgpack import MsgpackRoute from .msgpack import MsgpackRoute
@ -27,7 +27,7 @@ from .utils import (
from .dependencies import get_collection_queryset, get_item_queryset, get_collection from .dependencies import get_collection_queryset, get_item_queryset, get_collection
from .sendfile import sendfile from .sendfile import sendfile
User = get_user_model() User = get_typed_user_model
collection_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses) collection_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses)
item_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses) item_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses)
@ -36,11 +36,14 @@ class ListMulti(BaseModel):
collectionTypes: t.List[bytes] collectionTypes: t.List[bytes]
ChunkType = t.Tuple[str, t.Optional[bytes]]
class CollectionItemRevisionInOut(BaseModel): class CollectionItemRevisionInOut(BaseModel):
uid: str uid: str
meta: bytes meta: bytes
deleted: bool deleted: bool
chunks: t.List[t.Tuple[str, t.Optional[bytes]]] chunks: t.List[ChunkType]
class Config: class Config:
orm_mode = True orm_mode = True
@ -49,7 +52,7 @@ class CollectionItemRevisionInOut(BaseModel):
def from_orm_context( def from_orm_context(
cls: t.Type["CollectionItemRevisionInOut"], obj: models.CollectionItemRevision, context: Context cls: t.Type["CollectionItemRevisionInOut"], obj: models.CollectionItemRevision, context: Context
) -> "CollectionItemRevisionInOut": ) -> "CollectionItemRevisionInOut":
chunks = [] chunks: t.List[ChunkType] = []
for chunk_relation in obj.chunks_relation.all(): for chunk_relation in obj.chunks_relation.all():
chunk_obj = chunk_relation.chunk chunk_obj = chunk_relation.chunk
if context.prefetch == "auto": if context.prefetch == "auto":
@ -185,7 +188,7 @@ class ItemBatchIn(BaseModel):
@sync_to_async @sync_to_async
def collection_list_common( def collection_list_common(
queryset: QuerySet, queryset: QuerySet,
user: User, user: UserType,
stoken: t.Optional[str], stoken: t.Optional[str],
limit: int, limit: int,
prefetch: Prefetch, prefetch: Prefetch,
@ -210,7 +213,7 @@ def collection_list_common(
remed = remed_qs.values_list("collection__uid", flat=True) remed = remed_qs.values_list("collection__uid", flat=True)
if len(remed) > 0: if len(remed) > 0:
ret.removedMemberships = [{"uid": x} for x in remed] ret.removedMemberships = [RemovedMembershipOut(uid=x) for x in remed]
return ret return ret
@ -219,14 +222,14 @@ def collection_list_common(
def verify_collection_admin( def verify_collection_admin(
collection: models.Collection = Depends(get_collection), user: User = Depends(get_authenticated_user) collection: models.Collection = Depends(get_collection), user: UserType = Depends(get_authenticated_user)
): ):
if not is_collection_admin(collection, user): if not is_collection_admin(collection, user):
raise PermissionDenied("admin_access_required", "Only collection admins can perform this operation.") raise PermissionDenied("admin_access_required", "Only collection admins can perform this operation.")
def has_write_access( def has_write_access(
collection: models.Collection = Depends(get_collection), user: User = Depends(get_authenticated_user) collection: models.Collection = Depends(get_collection), user: UserType = Depends(get_authenticated_user)
): ):
member = collection.members.get(user=user) member = collection.members.get(user=user)
if member.accessLevel == models.AccessLevels.READ_ONLY: if member.accessLevel == models.AccessLevels.READ_ONLY:
@ -247,7 +250,7 @@ async def list_multi(
stoken: t.Optional[str] = None, stoken: t.Optional[str] = None,
limit: int = 50, limit: int = 50,
queryset: QuerySet = Depends(get_collection_queryset), queryset: QuerySet = Depends(get_collection_queryset),
user: User = Depends(get_authenticated_user), user: UserType = Depends(get_authenticated_user),
prefetch: Prefetch = PrefetchQuery, prefetch: Prefetch = PrefetchQuery,
): ):
# FIXME: Remove the isnull part once we attach collection types to all objects ("collection-type-migration") # FIXME: Remove the isnull part once we attach collection types to all objects ("collection-type-migration")
@ -263,7 +266,7 @@ async def collection_list(
stoken: t.Optional[str] = None, stoken: t.Optional[str] = None,
limit: int = 50, limit: int = 50,
prefetch: Prefetch = PrefetchQuery, prefetch: Prefetch = PrefetchQuery,
user: User = Depends(get_authenticated_user), user: UserType = Depends(get_authenticated_user),
queryset: QuerySet = Depends(get_collection_queryset), queryset: QuerySet = Depends(get_collection_queryset),
): ):
return await collection_list_common(queryset, user, stoken, limit, prefetch) return await collection_list_common(queryset, user, stoken, limit, prefetch)
@ -299,7 +302,7 @@ def process_revisions_for_item(item: models.CollectionItem, revision_data: Colle
return revision return revision
def _create(data: CollectionIn, user: User): def _create(data: CollectionIn, user: UserType):
with transaction.atomic(): with transaction.atomic():
if data.item.etag is not None: if data.item.etag is not None:
raise ValidationError("bad_etag", "etag is not null") raise ValidationError("bad_etag", "etag is not null")
@ -335,14 +338,14 @@ def _create(data: CollectionIn, user: User):
@collection_router.post("/", status_code=status.HTTP_201_CREATED, dependencies=PERMISSIONS_READWRITE) @collection_router.post("/", status_code=status.HTTP_201_CREATED, dependencies=PERMISSIONS_READWRITE)
async def create(data: CollectionIn, user: User = Depends(get_authenticated_user)): async def create(data: CollectionIn, user: UserType = Depends(get_authenticated_user)):
await sync_to_async(_create)(data, user) await sync_to_async(_create)(data, user)
@collection_router.get("/{collection_uid}/", response_model=CollectionOut, dependencies=PERMISSIONS_READ) @collection_router.get("/{collection_uid}/", response_model=CollectionOut, dependencies=PERMISSIONS_READ)
def collection_get( def collection_get(
obj: models.Collection = Depends(get_collection), obj: models.Collection = Depends(get_collection),
user: User = Depends(get_authenticated_user), user: UserType = Depends(get_authenticated_user),
prefetch: Prefetch = PrefetchQuery, prefetch: Prefetch = PrefetchQuery,
): ):
return CollectionOut.from_orm_context(obj, Context(user, prefetch)) return CollectionOut.from_orm_context(obj, Context(user, prefetch))
@ -393,7 +396,7 @@ def item_create(item_model: CollectionItemIn, collection: models.Collection, val
def item_get( def item_get(
item_uid: str, item_uid: str,
queryset: QuerySet = Depends(get_item_queryset), queryset: QuerySet = Depends(get_item_queryset),
user: User = Depends(get_authenticated_user), user: UserType = Depends(get_authenticated_user),
prefetch: Prefetch = PrefetchQuery, prefetch: Prefetch = PrefetchQuery,
): ):
obj = queryset.get(uid=item_uid) obj = queryset.get(uid=item_uid)
@ -403,7 +406,7 @@ def item_get(
@sync_to_async @sync_to_async
def item_list_common( def item_list_common(
queryset: QuerySet, queryset: QuerySet,
user: User, user: UserType,
stoken: t.Optional[str], stoken: t.Optional[str],
limit: int, limit: int,
prefetch: Prefetch, prefetch: Prefetch,
@ -424,7 +427,7 @@ async def item_list(
limit: int = 50, limit: int = 50,
prefetch: Prefetch = PrefetchQuery, prefetch: Prefetch = PrefetchQuery,
withCollection: bool = False, withCollection: bool = False,
user: User = Depends(get_authenticated_user), user: UserType = Depends(get_authenticated_user),
): ):
if not withCollection: if not withCollection:
queryset = queryset.filter(parent__isnull=True) queryset = queryset.filter(parent__isnull=True)
@ -433,7 +436,7 @@ async def item_list(
return response return response
def item_bulk_common(data: ItemBatchIn, user: User, stoken: t.Optional[str], uid: str, validate_etag: bool): def item_bulk_common(data: ItemBatchIn, user: UserType, stoken: t.Optional[str], uid: str, validate_etag: bool):
queryset = get_collection_queryset(user) queryset = get_collection_queryset(user)
with transaction.atomic(): # We need this for locking the collection object with transaction.atomic(): # We need this for locking the collection object
collection_object = queryset.select_for_update().get(uid=uid) collection_object = queryset.select_for_update().get(uid=uid)
@ -467,7 +470,7 @@ def item_revisions(
limit: int = 50, limit: int = 50,
iterator: t.Optional[str] = None, iterator: t.Optional[str] = None,
prefetch: Prefetch = PrefetchQuery, prefetch: Prefetch = PrefetchQuery,
user: User = Depends(get_authenticated_user), user: UserType = Depends(get_authenticated_user),
items: QuerySet = Depends(get_item_queryset), items: QuerySet = Depends(get_item_queryset),
): ):
item = get_object_or_404(items, uid=item_uid) item = get_object_or_404(items, uid=item_uid)
@ -501,7 +504,7 @@ def fetch_updates(
data: t.List[CollectionItemBulkGetIn], data: t.List[CollectionItemBulkGetIn],
stoken: t.Optional[str] = None, stoken: t.Optional[str] = None,
prefetch: Prefetch = PrefetchQuery, prefetch: Prefetch = PrefetchQuery,
user: User = Depends(get_authenticated_user), user: UserType = Depends(get_authenticated_user),
queryset: QuerySet = Depends(get_item_queryset), queryset: QuerySet = Depends(get_item_queryset),
): ):
# FIXME: make configurable? # FIXME: make configurable?
@ -531,14 +534,14 @@ def fetch_updates(
@item_router.post("/item/transaction/", dependencies=[Depends(has_write_access), *PERMISSIONS_READWRITE]) @item_router.post("/item/transaction/", dependencies=[Depends(has_write_access), *PERMISSIONS_READWRITE])
def item_transaction( def item_transaction(
collection_uid: str, data: ItemBatchIn, stoken: t.Optional[str] = None, user: User = Depends(get_authenticated_user) collection_uid: str, data: ItemBatchIn, stoken: t.Optional[str] = None, user: UserType = Depends(get_authenticated_user)
): ):
return item_bulk_common(data, user, stoken, collection_uid, validate_etag=True) return item_bulk_common(data, user, stoken, collection_uid, validate_etag=True)
@item_router.post("/item/batch/", dependencies=[Depends(has_write_access), *PERMISSIONS_READWRITE]) @item_router.post("/item/batch/", dependencies=[Depends(has_write_access), *PERMISSIONS_READWRITE])
def item_batch( def item_batch(
collection_uid: str, data: ItemBatchIn, stoken: t.Optional[str] = None, user: User = Depends(get_authenticated_user) collection_uid: str, data: ItemBatchIn, stoken: t.Optional[str] = None, user: UserType = Depends(get_authenticated_user)
): ):
return item_bulk_common(data, user, stoken, collection_uid, validate_etag=False) return item_bulk_common(data, user, stoken, collection_uid, validate_etag=False)

@ -3,17 +3,17 @@ import dataclasses
from fastapi import Depends from fastapi import Depends
from fastapi.security import APIKeyHeader from fastapi.security import APIKeyHeader
from django.contrib.auth import get_user_model
from django.utils import timezone from django.utils import timezone
from django.db.models import QuerySet from django.db.models import QuerySet
from django_etebase import models from django_etebase import models
from django_etebase.token_auth.models import AuthToken, get_default_expiry from django_etebase.token_auth.models import AuthToken, get_default_expiry
from myauth.models import UserType, get_typed_user_model
from .exceptions import AuthenticationFailed from .exceptions import AuthenticationFailed
from .utils import get_object_or_404 from .utils import get_object_or_404
User = get_user_model() User = get_typed_user_model()
token_scheme = APIKeyHeader(name="Authorization") token_scheme = APIKeyHeader(name="Authorization")
AUTO_REFRESH = True AUTO_REFRESH = True
MIN_REFRESH_INTERVAL = 60 MIN_REFRESH_INTERVAL = 60
@ -21,7 +21,7 @@ MIN_REFRESH_INTERVAL = 60
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class AuthData: class AuthData:
user: User user: UserType
token: AuthToken token: AuthToken
@ -60,12 +60,12 @@ def get_auth_data(api_token: str = Depends(token_scheme)) -> AuthData:
return AuthData(user, token) return AuthData(user, token)
def get_authenticated_user(api_token: str = Depends(token_scheme)) -> User: def get_authenticated_user(api_token: str = Depends(token_scheme)) -> UserType:
user, _ = __get_authenticated_user(api_token) user, _ = __get_authenticated_user(api_token)
return user return user
def get_collection_queryset(user: User = Depends(get_authenticated_user)) -> QuerySet: def get_collection_queryset(user: UserType = Depends(get_authenticated_user)) -> QuerySet:
default_queryset: QuerySet = models.Collection.objects.all() default_queryset: QuerySet = models.Collection.objects.all()
return default_queryset.filter(members__user=user) return default_queryset.filter(members__user=user)

@ -1,12 +1,12 @@
import typing as t import typing as t
from django.contrib.auth import get_user_model
from django.db import transaction, IntegrityError from django.db import transaction, IntegrityError
from django.db.models import QuerySet from django.db.models import QuerySet
from fastapi import APIRouter, Depends, status, Request from fastapi import APIRouter, Depends, status, Request
from django_etebase import models from django_etebase import models
from django_etebase.utils import get_user_queryset, CallbackContext from django_etebase.utils import get_user_queryset, CallbackContext
from myauth.models import UserType, get_typed_user_model
from .authentication import get_authenticated_user from .authentication import get_authenticated_user
from .exceptions import HttpError, PermissionDenied from .exceptions import HttpError, PermissionDenied
from .msgpack import MsgpackRoute from .msgpack import MsgpackRoute
@ -20,7 +20,7 @@ from .utils import (
PERMISSIONS_READWRITE, PERMISSIONS_READWRITE,
) )
User = get_user_model() User = get_typed_user_model()
invitation_incoming_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses) invitation_incoming_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses)
invitation_outgoing_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses) invitation_outgoing_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses)
default_queryset: QuerySet = models.CollectionInvitation.objects.all() default_queryset: QuerySet = models.CollectionInvitation.objects.all()
@ -53,7 +53,8 @@ class CollectionInvitationCommon(BaseModel):
class CollectionInvitationIn(CollectionInvitationCommon): class CollectionInvitationIn(CollectionInvitationCommon):
def validate_db(self, context: Context): def validate_db(self, context: Context):
if context.user.username == self.username.lower(): user = context.user
if user is not None and (user.username == self.username.lower()):
raise HttpError("no_self_invite", "Inviting yourself is not allowed") raise HttpError("no_self_invite", "Inviting yourself is not allowed")
@ -84,11 +85,11 @@ class InvitationListResponse(BaseModel):
done: bool done: bool
def get_incoming_queryset(user: User = Depends(get_authenticated_user)): def get_incoming_queryset(user: UserType = Depends(get_authenticated_user)):
return default_queryset.filter(user=user) return default_queryset.filter(user=user)
def get_outgoing_queryset(user: User = Depends(get_authenticated_user)): def get_outgoing_queryset(user: UserType = Depends(get_authenticated_user)):
return default_queryset.filter(fromMember__user=user) return default_queryset.filter(fromMember__user=user)
@ -183,7 +184,7 @@ def incoming_accept(
def outgoing_create( def outgoing_create(
data: CollectionInvitationIn, data: CollectionInvitationIn,
request: Request, request: Request,
user: User = Depends(get_authenticated_user), user: UserType = Depends(get_authenticated_user),
): ):
collection = get_object_or_404(models.Collection.objects, uid=data.collection) collection = get_object_or_404(models.Collection.objects, uid=data.collection)
to_user = get_object_or_404( to_user = get_object_or_404(
@ -231,7 +232,7 @@ def outgoing_delete(
def outgoing_fetch_user_profile( def outgoing_fetch_user_profile(
username: str, username: str,
request: Request, request: Request,
user: User = Depends(get_authenticated_user), user: UserType = Depends(get_authenticated_user),
): ):
kwargs = {User.USERNAME_FIELD: username.lower()} kwargs = {User.USERNAME_FIELD: username.lower()}
user = get_object_or_404(get_user_queryset(User.objects.all(), CallbackContext(request.path_params)), **kwargs) user = get_object_or_404(get_user_queryset(User.objects.all(), CallbackContext(request.path_params)), **kwargs)

@ -1,11 +1,11 @@
import typing as t import typing as t
from django.contrib.auth import get_user_model
from django.db import transaction from django.db import transaction
from django.db.models import QuerySet from django.db.models import QuerySet
from fastapi import APIRouter, Depends, status from fastapi import APIRouter, Depends, status
from django_etebase import models from django_etebase import models
from myauth.models import UserType, get_typed_user_model
from .authentication import get_authenticated_user from .authentication import get_authenticated_user
from .msgpack import MsgpackRoute from .msgpack import MsgpackRoute
from .utils import get_object_or_404, BaseModel, permission_responses, PERMISSIONS_READ, PERMISSIONS_READWRITE from .utils import get_object_or_404, BaseModel, permission_responses, PERMISSIONS_READ, PERMISSIONS_READWRITE
@ -13,7 +13,7 @@ from .stoken_handler import filter_by_stoken_and_limit
from .collection import get_collection, verify_collection_admin from .collection import get_collection, verify_collection_admin
User = get_user_model() User = get_typed_user_model()
member_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses) member_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses)
default_queryset: QuerySet = models.CollectionMember.objects.all() default_queryset: QuerySet = models.CollectionMember.objects.all()
@ -98,6 +98,8 @@ def member_patch(
@member_router.post("/member/leave/", status_code=status.HTTP_204_NO_CONTENT, dependencies=PERMISSIONS_READ) @member_router.post("/member/leave/", status_code=status.HTTP_204_NO_CONTENT, dependencies=PERMISSIONS_READ)
def member_leave(user: User = Depends(get_authenticated_user), collection: models.Collection = Depends(get_collection)): def member_leave(
user: UserType = Depends(get_authenticated_user), collection: models.Collection = Depends(get_collection)
):
obj = get_object_or_404(collection.members, user=user) obj = get_object_or_404(collection.members, user=user)
obj.revoke() obj.revoke()

@ -19,13 +19,15 @@ class MsgpackRequest(Request):
class MsgpackResponse(Response): class MsgpackResponse(Response):
media_type = "application/msgpack" media_type = "application/msgpack"
def render(self, content: t.Optional[t.Any]) -> t.Optional[bytes]: def render(self, content: t.Optional[t.Any]) -> bytes:
if content is None: if content is None:
return b"" return b""
if isinstance(content, BaseModel): if isinstance(content, BaseModel):
content = content.dict() content = content.dict()
return msgpack.packb(content, use_bin_type=True) ret = msgpack.packb(content, use_bin_type=True)
assert ret is not None
return ret
class MsgpackRoute(APIRoute): class MsgpackRoute(APIRoute):

@ -1,5 +1,4 @@
from django.conf import settings from django.conf import settings
from django.contrib.auth import get_user_model
from django.db import transaction from django.db import transaction
from django.shortcuts import get_object_or_404 from django.shortcuts import get_object_or_404
from fastapi import APIRouter, Request, status from fastapi import APIRouter, Request, status
@ -8,9 +7,10 @@ from django_etebase.utils import get_user_queryset, CallbackContext
from etebase_fastapi.authentication import SignupIn, signup_save from etebase_fastapi.authentication import SignupIn, signup_save
from etebase_fastapi.msgpack import MsgpackRoute from etebase_fastapi.msgpack import MsgpackRoute
from etebase_fastapi.exceptions import HttpError from etebase_fastapi.exceptions import HttpError
from myauth.models import get_typed_user_model
test_reset_view_router = APIRouter(route_class=MsgpackRoute, tags=["test helpers"]) test_reset_view_router = APIRouter(route_class=MsgpackRoute, tags=["test helpers"])
User = get_user_model() User = get_typed_user_model()
@test_reset_view_router.post("/reset/", status_code=status.HTTP_204_NO_CONTENT) @test_reset_view_router.post("/reset/", status_code=status.HTTP_204_NO_CONTENT)

@ -8,14 +8,14 @@ from pydantic import BaseModel as PyBaseModel
from django.db.models import QuerySet from django.db.models import QuerySet
from django.core.exceptions import ObjectDoesNotExist from django.core.exceptions import ObjectDoesNotExist
from django.contrib.auth import get_user_model
from django_etebase import app_settings from django_etebase import app_settings
from django_etebase.models import AccessLevels from django_etebase.models import AccessLevels
from myauth.models import UserType, get_typed_user_model
from .exceptions import HttpError, HttpErrorOut from .exceptions import HttpError, HttpErrorOut
User = get_user_model() User = get_typed_user_model()
Prefetch = t.Literal["auto", "medium"] Prefetch = t.Literal["auto", "medium"]
PrefetchQuery = Query(default="auto") PrefetchQuery = Query(default="auto")
@ -30,7 +30,7 @@ class BaseModel(PyBaseModel):
@dataclasses.dataclass @dataclasses.dataclass
class Context: class Context:
user: t.Optional[User] user: t.Optional[UserType]
prefetch: t.Optional[Prefetch] prefetch: t.Optional[Prefetch]

@ -1,8 +1,8 @@
from django import forms from django import forms
from django.contrib.auth import get_user_model
from django.contrib.auth.forms import UsernameField from django.contrib.auth.forms import UsernameField
from myauth.models import get_typed_user_model
User = get_user_model() User = get_typed_user_model()
class AdminUserCreationForm(forms.ModelForm): class AdminUserCreationForm(forms.ModelForm):

@ -1,3 +1,5 @@
import typing as t
from django.contrib.auth.models import AbstractUser, UserManager as DjangoUserManager from django.contrib.auth.models import AbstractUser, UserManager as DjangoUserManager
from django.core import validators from django.core import validators
from django.db import models from django.db import models
@ -28,9 +30,21 @@ class User(AbstractUser):
unique=True, unique=True,
help_text=_("Required. 150 characters or fewer. Letters, digits and ./-/_ only."), help_text=_("Required. 150 characters or fewer. Letters, digits and ./-/_ only."),
validators=[username_validator], validators=[username_validator],
error_messages={"unique": _("A user with that username already exists."),}, error_messages={
"unique": _("A user with that username already exists."),
},
) )
@classmethod @classmethod
def normalize_username(cls, username): def normalize_username(cls, username):
return super().normalize_username(username).lower() return super().normalize_username(username).lower()
UserType = t.Type[User]
def get_typed_user_model() -> UserType:
from django.contrib.auth import get_user_model
ret: t.Any = get_user_model()
return ret

Loading…
Cancel
Save