// Copyright 2017 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package notification

import (
	"database/sql"
	"errors"
	"time"

	"github.com/guregu/null/zero"

	"github.com/coreos/clair/database"
	"github.com/coreos/clair/database/pgsql/util"
	"github.com/coreos/clair/database/pgsql/vulnerability"
	"github.com/coreos/clair/pkg/commonerr"
	"github.com/coreos/clair/pkg/pagination"
)

const (
	insertNotification = `
		INSERT INTO Vulnerability_Notification(name, created_at, old_vulnerability_id, new_vulnerability_id)
		VALUES ($1, $2, $3, $4)`

	updatedNotificationAsRead = `
		UPDATE Vulnerability_Notification
		SET notified_at = CURRENT_TIMESTAMP
		WHERE name = $1`

	removeNotification = `
		UPDATE Vulnerability_Notification
	  SET deleted_at = CURRENT_TIMESTAMP
	  WHERE name = $1 AND deleted_at IS NULL`

	searchNotificationAvailable = `
		SELECT name, created_at, notified_at, deleted_at
		FROM Vulnerability_Notification
		WHERE (notified_at IS NULL OR notified_at < $1)
					AND deleted_at IS NULL
					AND name NOT IN (SELECT name FROM Lock)
		ORDER BY Random()
		LIMIT 1`

	searchNotification = `
		SELECT created_at, notified_at, deleted_at, old_vulnerability_id, new_vulnerability_id
		FROM Vulnerability_Notification
		WHERE name = $1`
)

func queryInsertNotifications(count int) string {
	return util.QueryInsert(count,
		"vulnerability_notification",
		"name",
		"created_at",
		"old_vulnerability_id",
		"new_vulnerability_id",
	)
}

var (
	errNotificationNotFound  = errors.New("requested notification is not found")
	errVulnerabilityNotFound = errors.New("vulnerability is not in database")
)

func InsertVulnerabilityNotifications(tx *sql.Tx, notifications []database.VulnerabilityNotification) error {
	if len(notifications) == 0 {
		return nil
	}

	var (
		newVulnIDMap = make(map[database.VulnerabilityID]sql.NullInt64)
		oldVulnIDMap = make(map[database.VulnerabilityID]sql.NullInt64)
	)

	invalidCreationTime := time.Time{}
	for _, noti := range notifications {
		if noti.Name == "" {
			return commonerr.NewBadRequestError("notification should not have empty name")
		}
		if noti.Created == invalidCreationTime {
			return commonerr.NewBadRequestError("notification should not have empty created time")
		}

		if noti.New != nil {
			key := database.VulnerabilityID{
				Name:      noti.New.Name,
				Namespace: noti.New.Namespace.Name,
			}
			newVulnIDMap[key] = sql.NullInt64{}
		}

		if noti.Old != nil {
			key := database.VulnerabilityID{
				Name:      noti.Old.Name,
				Namespace: noti.Old.Namespace.Name,
			}
			oldVulnIDMap[key] = sql.NullInt64{}
		}
	}

	var (
		newVulnIDs = make([]database.VulnerabilityID, 0, len(newVulnIDMap))
		oldVulnIDs = make([]database.VulnerabilityID, 0, len(oldVulnIDMap))
	)

	for vulnID := range newVulnIDMap {
		newVulnIDs = append(newVulnIDs, vulnID)
	}

	for vulnID := range oldVulnIDMap {
		oldVulnIDs = append(oldVulnIDs, vulnID)
	}

	ids, err := vulnerability.FindNotDeletedVulnerabilityIDs(tx, newVulnIDs)
	if err != nil {
		return err
	}

	for i, id := range ids {
		if !id.Valid {
			return util.HandleError("findNotDeletedVulnerabilityIDs", errVulnerabilityNotFound)
		}
		newVulnIDMap[newVulnIDs[i]] = id
	}

	ids, err = vulnerability.FindLatestDeletedVulnerabilityIDs(tx, oldVulnIDs)
	if err != nil {
		return err
	}

	for i, id := range ids {
		if !id.Valid {
			return util.HandleError("findLatestDeletedVulnerabilityIDs", errVulnerabilityNotFound)
		}
		oldVulnIDMap[oldVulnIDs[i]] = id
	}

	var (
		newVulnID sql.NullInt64
		oldVulnID sql.NullInt64
	)

	keys := make([]interface{}, len(notifications)*4)
	for i, noti := range notifications {
		if noti.New != nil {
			newVulnID = newVulnIDMap[database.VulnerabilityID{
				Name:      noti.New.Name,
				Namespace: noti.New.Namespace.Name,
			}]
		}

		if noti.Old != nil {
			oldVulnID = oldVulnIDMap[database.VulnerabilityID{
				Name:      noti.Old.Name,
				Namespace: noti.Old.Namespace.Name,
			}]
		}

		keys[4*i] = noti.Name
		keys[4*i+1] = noti.Created
		keys[4*i+2] = oldVulnID
		keys[4*i+3] = newVulnID
	}

	// NOTE(Sida): The data is not sorted before inserting into database under
	// the fact that there's only one updater running at a time. If there are
	// multiple updaters, deadlock may happen.
	_, err = tx.Exec(queryInsertNotifications(len(notifications)), keys...)
	if err != nil {
		return util.HandleError("queryInsertNotifications", err)
	}

	return nil
}

func FindNewNotification(tx *sql.Tx, notifiedBefore time.Time) (database.NotificationHook, bool, error) {
	var (
		notification database.NotificationHook
		created      zero.Time
		notified     zero.Time
		deleted      zero.Time
	)

	err := tx.QueryRow(searchNotificationAvailable, notifiedBefore).Scan(&notification.Name, &created, &notified, &deleted)
	if err != nil {
		if err == sql.ErrNoRows {
			return notification, false, nil
		}
		return notification, false, util.HandleError("searchNotificationAvailable", err)
	}

	notification.Created = created.Time
	notification.Notified = notified.Time
	notification.Deleted = deleted.Time

	return notification, true, nil
}

func FindVulnerabilityNotification(tx *sql.Tx, name string, limit int, oldPageToken pagination.Token, newPageToken pagination.Token, key pagination.Key) (
	database.VulnerabilityNotificationWithVulnerable, bool, error) {
	var (
		noti      database.VulnerabilityNotificationWithVulnerable
		oldVulnID sql.NullInt64
		newVulnID sql.NullInt64
		created   zero.Time
		notified  zero.Time
		deleted   zero.Time
	)

	if name == "" {
		return noti, false, commonerr.NewBadRequestError("Empty notification name is not allowed")
	}

	noti.Name = name
	err := tx.QueryRow(searchNotification, name).Scan(&created, &notified,
		&deleted, &oldVulnID, &newVulnID)

	if err != nil {
		if err == sql.ErrNoRows {
			return noti, false, nil
		}
		return noti, false, util.HandleError("searchNotification", err)
	}

	if created.Valid {
		noti.Created = created.Time
	}

	if notified.Valid {
		noti.Notified = notified.Time
	}

	if deleted.Valid {
		noti.Deleted = deleted.Time
	}

	if oldVulnID.Valid {
		page, err := vulnerability.FindPagedVulnerableAncestries(tx, oldVulnID.Int64, limit, oldPageToken, key)
		if err != nil {
			return noti, false, err
		}
		noti.Old = &page
	}

	if newVulnID.Valid {
		page, err := vulnerability.FindPagedVulnerableAncestries(tx, newVulnID.Int64, limit, newPageToken, key)
		if err != nil {
			return noti, false, err
		}
		noti.New = &page
	}

	return noti, true, nil
}

func MarkNotificationAsRead(tx *sql.Tx, name string) error {
	if name == "" {
		return commonerr.NewBadRequestError("Empty notification name is not allowed")
	}

	r, err := tx.Exec(updatedNotificationAsRead, name)
	if err != nil {
		return util.HandleError("updatedNotificationAsRead", err)
	}

	affected, err := r.RowsAffected()
	if err != nil {
		return util.HandleError("updatedNotificationAsRead", err)
	}

	if affected <= 0 {
		return util.HandleError("updatedNotificationAsRead", errNotificationNotFound)
	}
	return nil
}

func DeleteNotification(tx *sql.Tx, name string) error {
	if name == "" {
		return commonerr.NewBadRequestError("Empty notification name is not allowed")
	}

	result, err := tx.Exec(removeNotification, name)
	if err != nil {
		return util.HandleError("removeNotification", err)
	}

	affected, err := result.RowsAffected()
	if err != nil {
		return util.HandleError("removeNotification", err)
	}

	if affected <= 0 {
		return util.HandleError("removeNotification", commonerr.ErrNotFound)
	}

	return nil
}