diff --git a/etebase_fastapi/authentication.py b/etebase_fastapi/authentication.py index 697f3f4..3ae6c4b 100644 --- a/etebase_fastapi/authentication.py +++ b/etebase_fastapi/authentication.py @@ -2,6 +2,7 @@ import dataclasses import typing as t from datetime import datetime from functools import cached_property +from django.core import exceptions as django_exceptions import nacl import nacl.encoding @@ -11,16 +12,19 @@ import nacl.signing from asgiref.sync import sync_to_async from django.conf import settings from django.contrib.auth import get_user_model, user_logged_out, user_logged_in +from django.db import transaction from django.utils import timezone from fastapi import APIRouter, Depends, status, Request, Response from fastapi.security import APIKeyHeader from pydantic import BaseModel from django_etebase import app_settings +from django_etebase.exceptions import EtebaseValidationError from django_etebase.models import UserInfo from django_etebase.serializers import UserSerializer from django_etebase.token_auth.models import AuthToken from django_etebase.token_auth.models import get_default_expiry +from django_etebase.utils import create_user from django_etebase.views import msgpack_encode, msgpack_decode from .execptions import AuthenticationFailed from .msgpack import MsgpackResponse, MsgpackRoute @@ -74,6 +78,19 @@ class ChangePassword(Authentication): return ChangePasswordResponse(**msgpack_decode(self.response)) +class UserSignup(BaseModel): + username: str + email: str + + +class SignupIn(BaseModel): + user: UserSignup + salt: bytes + loginPubkey: bytes + pubkey: bytes + encryptedContent: bytes + + def __renew_token(auth_token: AuthToken): current_expiry = auth_token.expiry new_expiry = get_default_expiry() @@ -252,3 +269,38 @@ async def change_password(data: ChangePassword, request: Request, user: User = D return bad_login_response await save_changed_password(data, user) return Response(status_code=status.HTTP_204_NO_CONTENT) + + +@sync_to_async +def signup_save(data: SignupIn): + user_data = data.user + with transaction.atomic(): + try: + # XXX-TOM + # view = self.context.get("view", None) + # user_queryset = get_user_queryset(User.objects.all(), view) + user_queryset = User.objects.all() + instance = user_queryset.get(**{User.USERNAME_FIELD: user_data.username.lower()}) + except User.DoesNotExist: + # Create the user and save the casing the user chose as the first name + try: + # XXX-TOM + instance = create_user(**user_data.dict(), password=None, first_name=user_data.username, view=None) + instance.full_clean() + except EtebaseValidationError as e: + raise e + except django_exceptions.ValidationError as e: + self.transform_validation_error("user", e) + except Exception as e: + raise EtebaseValidationError("generic", str(e)) + + if hasattr(instance, "userinfo"): + raise EtebaseValidationError("user_exists", "User already exists", status_code=status.HTTP_409_CONFLICT) + + models.UserInfo.objects.create(**validated_data, owner=instance) + return instance + + +@authentication_router.post("/signup/") +async def signup(data: SignupIn): + pass diff --git a/etebase_fastapi/execptions.py b/etebase_fastapi/execptions.py index 8808f5d..2b35634 100644 --- a/etebase_fastapi/execptions.py +++ b/etebase_fastapi/execptions.py @@ -1,5 +1,7 @@ from fastapi import status +from django_etebase.exceptions import EtebaseValidationError + class CustomHttpException(Exception): def __init__(self, code: str, detail: str, status_code: int = status.HTTP_400_BAD_REQUEST): @@ -40,3 +42,47 @@ class PermissionDenied(CustomHttpException): status_code: int = status.HTTP_403_FORBIDDEN, ): super().__init__(code=code, detail=detail, status_code=status_code) + + +class ValidationError(CustomHttpException): + def __init__(self, code: str, detail: str, status_code: int = status.HTTP_400_BAD_REQUEST): + super().__init__(code=code, detail=detail, status_code=status_code) + + +def flatten_errors(field_name, errors): + ret = [] + if isinstance(errors, dict): + for error_key in errors: + error = errors[error_key] + ret.extend(flatten_errors("{}.{}".format(field_name, error_key), error)) + else: + for error in errors: + if error.messages: + message = error.messages[0] + else: + message = str(error) + ret.append( + { + "field": field_name, + "code": error.code, + "detail": message, + } + ) + return ret + + +def transform_validation_error(prefix, err): + if hasattr(err, "error_dict"): + errors = flatten_errors(prefix, err.error_dict) + elif not hasattr(err, "message"): + errors = flatten_errors(prefix, err.error_list) + else: + raise EtebaseValidationError(err.code, err.message) + raise ValidationError(code="field_errors", detail="Field validations failed.") + raise serializers.ValidationError( + { + "code": "field_errors", + "detail": "Field validations failed.", + "errors": errors, + } + )