// Copyright 2017 clair authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package notification import ( "database/sql" "errors" "time" "github.com/guregu/null/zero" "github.com/coreos/clair/database" "github.com/coreos/clair/database/pgsql/util" "github.com/coreos/clair/database/pgsql/vulnerability" "github.com/coreos/clair/pkg/commonerr" "github.com/coreos/clair/pkg/pagination" ) const ( insertNotification = ` INSERT INTO Vulnerability_Notification(name, created_at, old_vulnerability_id, new_vulnerability_id) VALUES ($1, $2, $3, $4)` updatedNotificationAsRead = ` UPDATE Vulnerability_Notification SET notified_at = CURRENT_TIMESTAMP WHERE name = $1` removeNotification = ` UPDATE Vulnerability_Notification SET deleted_at = CURRENT_TIMESTAMP WHERE name = $1 AND deleted_at IS NULL` searchNotificationAvailable = ` SELECT name, created_at, notified_at, deleted_at FROM Vulnerability_Notification WHERE (notified_at IS NULL OR notified_at < $1) AND deleted_at IS NULL AND name NOT IN (SELECT name FROM Lock) ORDER BY Random() LIMIT 1` searchNotification = ` SELECT created_at, notified_at, deleted_at, old_vulnerability_id, new_vulnerability_id FROM Vulnerability_Notification WHERE name = $1` ) func queryInsertNotifications(count int) string { return util.QueryInsert(count, "vulnerability_notification", "name", "created_at", "old_vulnerability_id", "new_vulnerability_id", ) } var ( errNotificationNotFound = errors.New("requested notification is not found") errVulnerabilityNotFound = errors.New("vulnerability is not in database") ) func InsertVulnerabilityNotifications(tx *sql.Tx, notifications []database.VulnerabilityNotification) error { if len(notifications) == 0 { return nil } var ( newVulnIDMap = make(map[database.VulnerabilityID]sql.NullInt64) oldVulnIDMap = make(map[database.VulnerabilityID]sql.NullInt64) ) invalidCreationTime := time.Time{} for _, noti := range notifications { if noti.Name == "" { return commonerr.NewBadRequestError("notification should not have empty name") } if noti.Created == invalidCreationTime { return commonerr.NewBadRequestError("notification should not have empty created time") } if noti.New != nil { key := database.VulnerabilityID{ Name: noti.New.Name, Namespace: noti.New.Namespace.Name, } newVulnIDMap[key] = sql.NullInt64{} } if noti.Old != nil { key := database.VulnerabilityID{ Name: noti.Old.Name, Namespace: noti.Old.Namespace.Name, } oldVulnIDMap[key] = sql.NullInt64{} } } var ( newVulnIDs = make([]database.VulnerabilityID, 0, len(newVulnIDMap)) oldVulnIDs = make([]database.VulnerabilityID, 0, len(oldVulnIDMap)) ) for vulnID := range newVulnIDMap { newVulnIDs = append(newVulnIDs, vulnID) } for vulnID := range oldVulnIDMap { oldVulnIDs = append(oldVulnIDs, vulnID) } ids, err := vulnerability.FindNotDeletedVulnerabilityIDs(tx, newVulnIDs) if err != nil { return err } for i, id := range ids { if !id.Valid { return util.HandleError("findNotDeletedVulnerabilityIDs", errVulnerabilityNotFound) } newVulnIDMap[newVulnIDs[i]] = id } ids, err = vulnerability.FindLatestDeletedVulnerabilityIDs(tx, oldVulnIDs) if err != nil { return err } for i, id := range ids { if !id.Valid { return util.HandleError("findLatestDeletedVulnerabilityIDs", errVulnerabilityNotFound) } oldVulnIDMap[oldVulnIDs[i]] = id } var ( newVulnID sql.NullInt64 oldVulnID sql.NullInt64 ) keys := make([]interface{}, len(notifications)*4) for i, noti := range notifications { if noti.New != nil { newVulnID = newVulnIDMap[database.VulnerabilityID{ Name: noti.New.Name, Namespace: noti.New.Namespace.Name, }] } if noti.Old != nil { oldVulnID = oldVulnIDMap[database.VulnerabilityID{ Name: noti.Old.Name, Namespace: noti.Old.Namespace.Name, }] } keys[4*i] = noti.Name keys[4*i+1] = noti.Created keys[4*i+2] = oldVulnID keys[4*i+3] = newVulnID } // NOTE(Sida): The data is not sorted before inserting into database under // the fact that there's only one updater running at a time. If there are // multiple updaters, deadlock may happen. _, err = tx.Exec(queryInsertNotifications(len(notifications)), keys...) if err != nil { return util.HandleError("queryInsertNotifications", err) } return nil } func FindNewNotification(tx *sql.Tx, notifiedBefore time.Time) (database.NotificationHook, bool, error) { var ( notification database.NotificationHook created zero.Time notified zero.Time deleted zero.Time ) err := tx.QueryRow(searchNotificationAvailable, notifiedBefore).Scan(¬ification.Name, &created, ¬ified, &deleted) if err != nil { if err == sql.ErrNoRows { return notification, false, nil } return notification, false, util.HandleError("searchNotificationAvailable", err) } notification.Created = created.Time notification.Notified = notified.Time notification.Deleted = deleted.Time return notification, true, nil } func FindVulnerabilityNotification(tx *sql.Tx, name string, limit int, oldPageToken pagination.Token, newPageToken pagination.Token, key pagination.Key) ( database.VulnerabilityNotificationWithVulnerable, bool, error) { var ( noti database.VulnerabilityNotificationWithVulnerable oldVulnID sql.NullInt64 newVulnID sql.NullInt64 created zero.Time notified zero.Time deleted zero.Time ) if name == "" { return noti, false, commonerr.NewBadRequestError("Empty notification name is not allowed") } noti.Name = name err := tx.QueryRow(searchNotification, name).Scan(&created, ¬ified, &deleted, &oldVulnID, &newVulnID) if err != nil { if err == sql.ErrNoRows { return noti, false, nil } return noti, false, util.HandleError("searchNotification", err) } if created.Valid { noti.Created = created.Time } if notified.Valid { noti.Notified = notified.Time } if deleted.Valid { noti.Deleted = deleted.Time } if oldVulnID.Valid { page, err := vulnerability.FindPagedVulnerableAncestries(tx, oldVulnID.Int64, limit, oldPageToken, key) if err != nil { return noti, false, err } noti.Old = &page } if newVulnID.Valid { page, err := vulnerability.FindPagedVulnerableAncestries(tx, newVulnID.Int64, limit, newPageToken, key) if err != nil { return noti, false, err } noti.New = &page } return noti, true, nil } func MarkNotificationAsRead(tx *sql.Tx, name string) error { if name == "" { return commonerr.NewBadRequestError("Empty notification name is not allowed") } r, err := tx.Exec(updatedNotificationAsRead, name) if err != nil { return util.HandleError("updatedNotificationAsRead", err) } affected, err := r.RowsAffected() if err != nil { return util.HandleError("updatedNotificationAsRead", err) } if affected <= 0 { return util.HandleError("updatedNotificationAsRead", errNotificationNotFound) } return nil } func DeleteNotification(tx *sql.Tx, name string) error { if name == "" { return commonerr.NewBadRequestError("Empty notification name is not allowed") } result, err := tx.Exec(removeNotification, name) if err != nil { return util.HandleError("removeNotification", err) } affected, err := result.RowsAffected() if err != nil { return util.HandleError("removeNotification", err) } if affected <= 0 { return util.HandleError("removeNotification", commonerr.ErrNotFound) } return nil }