mirror of
https://github.com/etesync/server
synced 2025-01-01 04:00:55 +00:00
change response content to pydantic models and error handling
This commit is contained in:
parent
a0d1d23d2d
commit
31e0e0b832
@ -2,7 +2,6 @@ import dataclasses
|
||||
import typing as t
|
||||
from datetime import datetime
|
||||
from functools import cached_property
|
||||
from django.core import exceptions as django_exceptions
|
||||
|
||||
import nacl
|
||||
import nacl.encoding
|
||||
@ -12,6 +11,7 @@ import nacl.signing
|
||||
from asgiref.sync import sync_to_async
|
||||
from django.conf import settings
|
||||
from django.contrib.auth import get_user_model, user_logged_out, user_logged_in
|
||||
from django.core import exceptions as django_exceptions
|
||||
from django.db import transaction
|
||||
from django.utils import timezone
|
||||
from fastapi import APIRouter, Depends, status, Request, Response
|
||||
@ -21,7 +21,6 @@ from pydantic import BaseModel
|
||||
from django_etebase import app_settings, models
|
||||
from django_etebase.exceptions import EtebaseValidationError
|
||||
from django_etebase.models import UserInfo
|
||||
from django_etebase.serializers import UserSerializer
|
||||
from django_etebase.signals import user_signed_up
|
||||
from django_etebase.token_auth.models import AuthToken
|
||||
from django_etebase.token_auth.models import get_default_expiry
|
||||
@ -43,10 +42,16 @@ class AuthData:
|
||||
token: AuthToken
|
||||
|
||||
|
||||
class LoginChallengeData(BaseModel):
|
||||
class LoginChallengeIn(BaseModel):
|
||||
username: str
|
||||
|
||||
|
||||
class LoginChallengeOut(BaseModel):
|
||||
salt: bytes
|
||||
challenge: bytes
|
||||
version: int
|
||||
|
||||
|
||||
class LoginResponse(BaseModel):
|
||||
username: str
|
||||
challenge: bytes
|
||||
@ -54,6 +59,26 @@ class LoginResponse(BaseModel):
|
||||
action: t.Literal["login", "changePassword"]
|
||||
|
||||
|
||||
class UserOut(BaseModel):
|
||||
pubkey: bytes
|
||||
encryptedContent: bytes
|
||||
|
||||
@classmethod
|
||||
def from_orm(cls: t.Type["UserOut"], obj: User) -> "UserOut":
|
||||
return cls(pubkey=obj.userinfo.pubkey, encryptedContent=obj.userinfo.encryptedContent)
|
||||
|
||||
|
||||
class LoginOut(BaseModel):
|
||||
token: str
|
||||
user: UserOut
|
||||
|
||||
@classmethod
|
||||
def from_orm(cls: t.Type["LoginOut"], obj: User) -> "LoginOut":
|
||||
token = AuthToken.objects.create(user=obj).key
|
||||
user = UserOut.from_orm(obj)
|
||||
return cls(token=token, user=user)
|
||||
|
||||
|
||||
class Authentication(BaseModel):
|
||||
class Config:
|
||||
keep_untouched = (cached_property,)
|
||||
@ -145,7 +170,7 @@ def __get_login_user(username: str) -> User:
|
||||
raise AuthenticationFailed(code="user_not_found", detail="User not found")
|
||||
|
||||
|
||||
async def get_login_user(challenge: LoginChallengeData) -> User:
|
||||
async def get_login_user(challenge: LoginChallengeIn) -> User:
|
||||
user = await __get_login_user(challenge.username)
|
||||
return user
|
||||
|
||||
@ -161,7 +186,6 @@ def get_encryption_key(salt):
|
||||
)
|
||||
|
||||
|
||||
@sync_to_async
|
||||
def save_changed_password(data: ChangePassword, user: User):
|
||||
response_data = data.response_data
|
||||
user_info: UserInfo = user.userinfo
|
||||
@ -170,24 +194,6 @@ def save_changed_password(data: ChangePassword, user: User):
|
||||
user_info.save()
|
||||
|
||||
|
||||
@sync_to_async
|
||||
def login_response_data(user: User):
|
||||
return {
|
||||
"token": AuthToken.objects.create(user=user).key,
|
||||
"user": UserSerializer(user).data,
|
||||
}
|
||||
|
||||
|
||||
@sync_to_async
|
||||
def send_user_logged_in_async(user: User, request: Request):
|
||||
user_logged_in.send(sender=user.__class__, request=request, user=user)
|
||||
|
||||
|
||||
@sync_to_async
|
||||
def send_user_logged_out_async(user: User, request: Request):
|
||||
user_logged_out.send(sender=user.__class__, request=request, user=user)
|
||||
|
||||
|
||||
@sync_to_async
|
||||
def validate_login_request(
|
||||
validated_data: LoginResponse,
|
||||
@ -195,39 +201,26 @@ def validate_login_request(
|
||||
user: User,
|
||||
expected_action: str,
|
||||
host_from_request: str,
|
||||
) -> t.Optional[MsgpackResponse]:
|
||||
|
||||
):
|
||||
enc_key = get_encryption_key(bytes(user.userinfo.salt))
|
||||
box = nacl.secret.SecretBox(enc_key)
|
||||
challenge_data = msgpack_decode(box.decrypt(validated_data.challenge))
|
||||
now = int(datetime.now().timestamp())
|
||||
if validated_data.action != expected_action:
|
||||
content = {
|
||||
"code": "wrong_action",
|
||||
"detail": 'Expected "{}" but got something else'.format(challenge_sent_to_user.response),
|
||||
}
|
||||
return MsgpackResponse(content, status_code=status.HTTP_400_BAD_REQUEST)
|
||||
raise ValidationError("wrong_action", f'Expected "{challenge_sent_to_user.response}" but got something else')
|
||||
elif now - challenge_data["timestamp"] > app_settings.CHALLENGE_VALID_SECONDS:
|
||||
content = {"code": "challenge_expired", "detail": "Login challenge has expired"}
|
||||
return MsgpackResponse(content, status_code=status.HTTP_400_BAD_REQUEST)
|
||||
raise ValidationError("challenge_expired", "Login challenge has expired")
|
||||
elif challenge_data["userId"] != user.id:
|
||||
content = {"code": "wrong_user", "detail": "This challenge is for the wrong user"}
|
||||
return MsgpackResponse(content, status_code=status.HTTP_400_BAD_REQUEST)
|
||||
raise ValidationError("wrong_user", "This challenge is for the wrong user")
|
||||
elif not settings.DEBUG and validated_data.host.split(":", 1)[0] != host_from_request:
|
||||
detail = 'Found wrong host name. Got: "{}" expected: "{}"'.format(validated_data.host, host_from_request)
|
||||
content = {"code": "wrong_host", "detail": detail}
|
||||
return MsgpackResponse(content, status_code=status.HTTP_400_BAD_REQUEST)
|
||||
raise ValidationError(
|
||||
"wrong_host", f'Found wrong host name. Got: "{validated_data.host}" expected: "{host_from_request}"'
|
||||
)
|
||||
verify_key = nacl.signing.VerifyKey(bytes(user.userinfo.loginPubkey), encoder=nacl.encoding.RawEncoder)
|
||||
|
||||
try:
|
||||
verify_key.verify(challenge_sent_to_user.response, challenge_sent_to_user.signature)
|
||||
except nacl.exceptions.BadSignatureError:
|
||||
return MsgpackResponse(
|
||||
{"code": "login_bad_signature", "detail": "Wrong password for user."},
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
|
||||
return None
|
||||
raise ValidationError("login_bad_signature", "Wrong password for user.", status.HTTP_401_UNAUTHORIZED)
|
||||
|
||||
|
||||
@authentication_router.post("/login_challenge/")
|
||||
@ -239,35 +232,34 @@ async def login_challenge(user: User = Depends(get_login_user)):
|
||||
"userId": user.id,
|
||||
}
|
||||
challenge = bytes(box.encrypt(msgpack_encode(challenge_data), encoder=nacl.encoding.RawEncoder))
|
||||
return MsgpackResponse({"salt": user.userinfo.salt, "version": user.userinfo.version, "challenge": challenge})
|
||||
return MsgpackResponse(
|
||||
LoginChallengeOut(salt=user.userinfo.salt, challenge=challenge, version=user.userinfo.version)
|
||||
)
|
||||
|
||||
|
||||
@authentication_router.post("/login/")
|
||||
async def login(data: Login, request: Request):
|
||||
user = await get_login_user(LoginChallengeData(username=data.response_data.username))
|
||||
user = await get_login_user(LoginChallengeIn(username=data.response_data.username))
|
||||
host = request.headers.get("Host")
|
||||
bad_login_response = await validate_login_request(data.response_data, data, user, "login", host)
|
||||
if bad_login_response is not None:
|
||||
return bad_login_response
|
||||
data = await login_response_data(user)
|
||||
await send_user_logged_in_async(user, request)
|
||||
return MsgpackResponse(data, status_code=status.HTTP_200_OK)
|
||||
await validate_login_request(data.response_data, data, user, "login", host)
|
||||
data = await sync_to_async(LoginOut.from_orm)(user)
|
||||
await sync_to_async(user_logged_in.send)(sender=user.__class__, request=None, user=user)
|
||||
return MsgpackResponse(content=data, status_code=status.HTTP_200_OK)
|
||||
|
||||
|
||||
@authentication_router.post("/logout/")
|
||||
async def logout(request: Request, auth_data: AuthData = Depends(get_auth_data)):
|
||||
await sync_to_async(auth_data.token.delete)()
|
||||
await send_user_logged_out_async(auth_data.user, request)
|
||||
# XXX-TOM
|
||||
await sync_to_async(user_logged_out.send)(sender=auth_data.user.__class__, request=None, user=auth_data.user)
|
||||
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||
|
||||
|
||||
@authentication_router.post("/change_password/")
|
||||
async def change_password(data: ChangePassword, request: Request, user: User = Depends(get_authenticated_user)):
|
||||
host = request.headers.get("Host")
|
||||
bad_login_response = await validate_login_request(data.response_data, data, user, "changePassword", host)
|
||||
if bad_login_response is not None:
|
||||
return bad_login_response
|
||||
await save_changed_password(data, user)
|
||||
await validate_login_request(data.response_data, data, user, "changePassword", host)
|
||||
await sync_to_async(save_changed_password)(data, user)
|
||||
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||
|
||||
|
||||
@ -300,15 +292,10 @@ def signup_save(data: SignupIn) -> User:
|
||||
return instance
|
||||
|
||||
|
||||
@sync_to_async
|
||||
def send_user_signed_up_async(user: User, request):
|
||||
user_signed_up.send(sender=user.__class__, request=request, user=user)
|
||||
|
||||
|
||||
@authentication_router.post("/signup/")
|
||||
async def signup(data: SignupIn):
|
||||
user = await sync_to_async(signup_save)(data)
|
||||
# XXX-TOM
|
||||
data = await login_response_data(user)
|
||||
await send_user_signed_up_async(user, None)
|
||||
data = await sync_to_async(LoginOut.from_orm)(user)
|
||||
await sync_to_async(user_signed_up.send)(sender=user.__class__, request=None, user=user)
|
||||
return MsgpackResponse(content=data, status_code=status.HTTP_201_CREATED)
|
||||
|
Loading…
Reference in New Issue
Block a user