database: write more of the notification system
This commit is contained in:
parent
90fe137de8
commit
8be18a0a01
@ -49,7 +49,7 @@ type Datastore interface {
|
||||
|
||||
// Notifications
|
||||
GetAvailableNotification(renotifyInterval time.Duration) (VulnerabilityNotification, error) // Does not fill old/new Vulnerabilities.
|
||||
GetNotification(name string, limit, page int) (VulnerabilityNotification, error)
|
||||
GetNotification(name string, limit int, page VulnerabilityNotificationPageNumber) (VulnerabilityNotification, VulnerabilityNotificationPageNumber, error)
|
||||
SetNotificationNotified(name string) error
|
||||
DeleteNotification(name string) error
|
||||
|
||||
|
@ -22,7 +22,7 @@ import (
|
||||
|
||||
// ID is only meant to be used by database implementations and should never be used for anything else.
|
||||
type Model struct {
|
||||
ID int `json:"-"`
|
||||
ID int
|
||||
}
|
||||
|
||||
type Layer struct {
|
||||
@ -74,6 +74,8 @@ type Vulnerability struct {
|
||||
}
|
||||
|
||||
type VulnerabilityNotification struct {
|
||||
Model
|
||||
|
||||
Name string
|
||||
|
||||
Created time.Time
|
||||
@ -83,3 +85,12 @@ type VulnerabilityNotification struct {
|
||||
OldVulnerability *Vulnerability
|
||||
NewVulnerability Vulnerability
|
||||
}
|
||||
|
||||
type VulnerabilityNotificationPageNumber struct {
|
||||
// -1 means that we reached the end already.
|
||||
OldVulnerability int
|
||||
NewVulnerability int
|
||||
}
|
||||
|
||||
var VulnerabilityNotificationFirstPage = VulnerabilityNotificationPageNumber{0, 0}
|
||||
var NoVulnerabilityNotificationPage = VulnerabilityNotificationPageNumber{-1, -1}
|
||||
|
@ -36,7 +36,7 @@ const (
|
||||
)
|
||||
|
||||
func TestRaceAffects(t *testing.T) {
|
||||
datastore, err := OpenForTest("TestRaceAffects", false)
|
||||
datastore, err := OpenForTest("RaceAffects", false)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
|
@ -151,7 +151,7 @@ CREATE TABLE IF NOT EXISTS Vulnerability_Notification (
|
||||
created_at TIMESTAMP WITH TIME ZONE,
|
||||
notified_at TIMESTAMP WITH TIME ZONE NULL,
|
||||
deleted_at TIMESTAMP WITH TIME ZONE NULL,
|
||||
old_vulnerability TEXT,
|
||||
old_vulnerability TEXT NULL,
|
||||
new_vulnerability TEXT);
|
||||
|
||||
CREATE INDEX ON Vulnerability_Notification (notified_at);
|
||||
|
@ -7,20 +7,27 @@ import (
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
cerrors "github.com/coreos/clair/utils/errors"
|
||||
"github.com/guregu/null/zero"
|
||||
"github.com/pborman/uuid"
|
||||
)
|
||||
|
||||
// do it in tx so we won't insert/update a vuln without notification and vice-versa.
|
||||
// name and created doesn't matter.
|
||||
// Vuln ID must be filled in.
|
||||
func (pgSQL *pgSQL) insertNotification(tx *sql.Tx, notification database.VulnerabilityNotification) error {
|
||||
defer observeQueryTime("insertNotification", "all", time.Now())
|
||||
|
||||
// Marshal old and new Vulnerabilities.
|
||||
oldVulnerability, err := json.Marshal(notification.OldVulnerability)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return cerrors.NewBadRequestError("could not marshal old Vulnerability in insertNotification")
|
||||
var oldVulnerability sql.NullString
|
||||
if notification.OldVulnerability != nil {
|
||||
oldVulnerabilityJSON, err := json.Marshal(notification.OldVulnerability)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return cerrors.NewBadRequestError("could not marshal old Vulnerability in insertNotification")
|
||||
}
|
||||
oldVulnerability = sql.NullString{String: string(oldVulnerabilityJSON), Valid: true}
|
||||
}
|
||||
|
||||
newVulnerability, err := json.Marshal(notification.NewVulnerability)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
@ -43,46 +50,127 @@ func (pgSQL *pgSQL) GetAvailableNotification(renotifyInterval time.Duration) (da
|
||||
defer observeQueryTime("GetAvailableNotification", "all", time.Now())
|
||||
|
||||
before := time.Now().Add(-renotifyInterval)
|
||||
row := pgSQL.QueryRow(getQuery("s_notification_available"), before)
|
||||
notification, err := scanNotification(row, false)
|
||||
|
||||
var notification database.VulnerabilityNotification
|
||||
err := pgSQL.QueryRow(getQuery("s_notification_available"), before).Scan(¬ification.Name,
|
||||
¬ification.Created, ¬ification.Notified, ¬ification.Deleted)
|
||||
if err != nil {
|
||||
return notification, handleError("s_notification_available", err)
|
||||
}
|
||||
|
||||
return notification, nil
|
||||
return notification, handleError("s_notification_available", err)
|
||||
}
|
||||
|
||||
func (pgSQL *pgSQL) GetNotification(name string, limit, page int) (database.VulnerabilityNotification, error) {
|
||||
func (pgSQL *pgSQL) GetNotification(name string, limit int, page database.VulnerabilityNotificationPageNumber) (database.VulnerabilityNotification, database.VulnerabilityNotificationPageNumber, error) {
|
||||
defer observeQueryTime("GetNotification", "all", time.Now())
|
||||
|
||||
// Get Notification.
|
||||
var notification database.VulnerabilityNotification
|
||||
notification, err := scanNotification(pgSQL.QueryRow(getQuery("s_notification"), name), true)
|
||||
if err != nil {
|
||||
return notification, page, handleError("s_notification", err)
|
||||
}
|
||||
|
||||
// Load vulnerabilities' LayersIntroducingVulnerability.
|
||||
page.OldVulnerability, err = pgSQL.loadLayerIntroducingVulnerability(
|
||||
notification.OldVulnerability, limit, page.OldVulnerability)
|
||||
if err != nil {
|
||||
return notification, page, err
|
||||
}
|
||||
|
||||
page.NewVulnerability, err = pgSQL.loadLayerIntroducingVulnerability(
|
||||
¬ification.NewVulnerability, limit, page.NewVulnerability)
|
||||
if err != nil {
|
||||
return notification, page, err
|
||||
}
|
||||
|
||||
return notification, page, nil
|
||||
}
|
||||
|
||||
func scanNotification(row *sql.Row, hasVulns bool) (notification database.VulnerabilityNotification, err error) {
|
||||
var created zero.Time
|
||||
var notified zero.Time
|
||||
var deleted zero.Time
|
||||
var oldVulnerability []byte
|
||||
var newVulnerability []byte
|
||||
|
||||
err := pgSQL.QueryRow(getQuery("s_notification"), name).Scan(¬ification.Name,
|
||||
¬ification.Created, ¬ification.Notified, ¬ification.Deleted, &newVulnerability,
|
||||
&oldVulnerability)
|
||||
// Query notification.
|
||||
if hasVulns {
|
||||
err = row.Scan(¬ification.ID, ¬ification.Name, &created, ¬ified, &deleted,
|
||||
&oldVulnerability, &newVulnerability)
|
||||
} else {
|
||||
err = row.Scan(¬ification.ID, ¬ification.Name, &created, ¬ified, &deleted)
|
||||
}
|
||||
if err != nil {
|
||||
return notification, handleError("s_notification", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Unmarshal old and new Vulnerabilities.
|
||||
err = json.Unmarshal(oldVulnerability, notification.OldVulnerability)
|
||||
if err != nil {
|
||||
return notification, cerrors.NewBadRequestError("could not unmarshal old Vulnerability in GetNotification")
|
||||
}
|
||||
err = json.Unmarshal(newVulnerability, ¬ification.NewVulnerability)
|
||||
if err != nil {
|
||||
return notification, cerrors.NewBadRequestError("could not unmarshal new Vulnerability in GetNotification")
|
||||
notification.Created = created.Time
|
||||
notification.Notified = notified.Time
|
||||
notification.Deleted = deleted.Time
|
||||
|
||||
if hasVulns {
|
||||
// Unmarshal old and new Vulnerabilities.
|
||||
err = json.Unmarshal(oldVulnerability, notification.OldVulnerability)
|
||||
if err != nil {
|
||||
err = cerrors.NewBadRequestError("could not unmarshal old Vulnerability in GetNotification")
|
||||
}
|
||||
|
||||
err = json.Unmarshal(newVulnerability, ¬ification.NewVulnerability)
|
||||
if err != nil {
|
||||
err = cerrors.NewBadRequestError("could not unmarshal new Vulnerability in GetNotification")
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(Quentin-M): Fill LayersIntroducingVulnerability.
|
||||
// And time it.
|
||||
return
|
||||
}
|
||||
|
||||
return notification, nil
|
||||
// Fills Vulnerability.LayersIntroducingVulnerability.
|
||||
// limit -1: won't do anything
|
||||
// limit 0: will just get the startID of the second page
|
||||
func (pgSQL *pgSQL) loadLayerIntroducingVulnerability(vulnerability *database.Vulnerability, limit, startID int) (int, error) {
|
||||
tf := time.Now()
|
||||
|
||||
if vulnerability == nil {
|
||||
return -1, nil
|
||||
}
|
||||
|
||||
// A startID equals to -1 means that we reached the end already.
|
||||
if startID == -1 || limit == -1 {
|
||||
return -1, nil
|
||||
}
|
||||
|
||||
// We do `defer observeQueryTime` here because we don't want to observe invalid calls.
|
||||
defer observeQueryTime("loadLayerIntroducingVulnerability", "all", tf)
|
||||
|
||||
// Query with limit + 1, the last item will be used to know the next starting ID.
|
||||
rows, err := pgSQL.Query(getQuery("s_notification_layer_introducing_vulnerability"),
|
||||
vulnerability.ID, startID, limit+1)
|
||||
if err != nil {
|
||||
return 0, handleError("s_vulnerability_fixedin_feature", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var layers []database.Layer
|
||||
for rows.Next() {
|
||||
var layer database.Layer
|
||||
|
||||
if err := rows.Scan(&layer.ID, &layer.Name); err != nil {
|
||||
return -1, handleError("s_notification_layer_introducing_vulnerability.Scan()", err)
|
||||
}
|
||||
|
||||
layers = append(layers, layer)
|
||||
}
|
||||
if err = rows.Err(); err != nil {
|
||||
return -1, handleError("s_notification_layer_introducing_vulnerability.Rows()", err)
|
||||
}
|
||||
|
||||
size := limit
|
||||
if len(layers) < limit {
|
||||
size = len(layers)
|
||||
}
|
||||
vulnerability.LayersIntroducingVulnerability = layers[:size]
|
||||
|
||||
nextID := -1
|
||||
if len(layers) > limit {
|
||||
nextID = layers[limit].ID
|
||||
}
|
||||
|
||||
return nextID, nil
|
||||
}
|
||||
|
||||
func (pgSQL *pgSQL) SetNotificationNotified(name string) error {
|
||||
|
123
database/pgsql/notification_test.go
Normal file
123
database/pgsql/notification_test.go
Normal file
@ -0,0 +1,123 @@
|
||||
package pgsql
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"fmt"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
cerrors "github.com/coreos/clair/utils/errors"
|
||||
"github.com/coreos/clair/utils/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNotification(t *testing.T) {
|
||||
datastore, err := OpenForTest("Notification", false)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
defer datastore.Close()
|
||||
|
||||
// Try to get a notification when there is none.
|
||||
_, err = datastore.GetAvailableNotification(time.Second)
|
||||
assert.Equal(t, cerrors.ErrNotFound, err)
|
||||
|
||||
// Create some data.
|
||||
f1 := database.Feature{
|
||||
Name: "TestNotificationFeature1",
|
||||
Namespace: database.Namespace{Name: "TestNotificationNamespace1"},
|
||||
}
|
||||
|
||||
l1 := database.Layer{
|
||||
Name: "TestNotificationLayer1",
|
||||
Features: []database.FeatureVersion{
|
||||
database.FeatureVersion{
|
||||
Feature: f1,
|
||||
Version: types.NewVersionUnsafe("0.1"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
l2 := database.Layer{
|
||||
Name: "TestNotificationLayer2",
|
||||
Features: []database.FeatureVersion{
|
||||
database.FeatureVersion{
|
||||
Feature: f1,
|
||||
Version: types.NewVersionUnsafe("0.2"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
l3 := database.Layer{
|
||||
Name: "TestNotificationLayer3",
|
||||
Features: []database.FeatureVersion{
|
||||
database.FeatureVersion{
|
||||
Feature: f1,
|
||||
Version: types.NewVersionUnsafe("0.3"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if assert.Nil(t, datastore.InsertLayer(l1)) && assert.Nil(t, datastore.InsertLayer(l2)) &&
|
||||
assert.Nil(t, datastore.InsertLayer(l3)) {
|
||||
|
||||
// Insert a new vulnerability that is introduced by three layers.
|
||||
v1 := database.Vulnerability{
|
||||
Name: "TestNotificationVulnerability1",
|
||||
Namespace: f1.Namespace,
|
||||
Description: "TestNotificationDescription1",
|
||||
Link: "TestNotificationLink1",
|
||||
Severity: "Unknown",
|
||||
FixedIn: []database.FeatureVersion{
|
||||
database.FeatureVersion{
|
||||
Feature: f1,
|
||||
Version: types.NewVersionUnsafe("1.0"),
|
||||
},
|
||||
},
|
||||
}
|
||||
assert.Nil(t, datastore.insertVulnerability(v1))
|
||||
|
||||
// Get the notification associated to the previously inserted vulnerability.
|
||||
notification, err := datastore.GetAvailableNotification(time.Second)
|
||||
assert.Nil(t, err)
|
||||
assert.NotEmpty(t, notification.Name)
|
||||
|
||||
// Verify the renotify behaviour.
|
||||
if assert.Nil(t, datastore.SetNotificationNotified(notification.Name)) {
|
||||
_, err := datastore.GetAvailableNotification(time.Second)
|
||||
assert.Equal(t, cerrors.ErrNotFound, err)
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notificationB, err := datastore.GetAvailableNotification(20 * time.Millisecond)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, notification.Name, notificationB.Name)
|
||||
|
||||
datastore.SetNotificationNotified(notification.Name)
|
||||
}
|
||||
|
||||
// Get notification.
|
||||
filledNotification, nextPage, err := datastore.GetNotification(notification.Name, 2, database.VulnerabilityNotificationFirstPage)
|
||||
assert.Nil(t, err)
|
||||
assert.NotEqual(t, database.NoVulnerabilityNotificationPage, nextPage)
|
||||
assert.Nil(t, filledNotification.OldVulnerability)
|
||||
assert.Equal(t, v1.Name, filledNotification.NewVulnerability.Name)
|
||||
assert.Len(t, filledNotification.NewVulnerability.LayersIntroducingVulnerability, 2)
|
||||
|
||||
// Get second page.
|
||||
filledNotification, nextPage, err = datastore.GetNotification(notification.Name, 2, nextPage)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, database.NoVulnerabilityNotificationPage, nextPage)
|
||||
assert.Nil(t, filledNotification.OldVulnerability)
|
||||
assert.Equal(t, v1.Name, filledNotification.NewVulnerability.Name)
|
||||
assert.Len(t, filledNotification.NewVulnerability.LayersIntroducingVulnerability, 1)
|
||||
|
||||
// Delete notification.
|
||||
assert.Nil(t, datastore.DeleteNotification(notification.Name))
|
||||
|
||||
n, err := datastore.GetAvailableNotification(time.Millisecond)
|
||||
assert.Equal(t, cerrors.ErrNotFound, err)
|
||||
fmt.Println(n)
|
||||
}
|
||||
}
|
@ -236,6 +236,10 @@ func OpenForTest(name string, withTestData bool) (*pgSQLTest, error) {
|
||||
// handleError logs an error with an extra description and masks the error if it's an SQL one.
|
||||
// This ensures we never return plain SQL errors and leak anything.
|
||||
func handleError(desc string, err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
return cerrors.ErrNotFound
|
||||
}
|
||||
|
@ -190,21 +190,43 @@ func init() {
|
||||
INSERT INTO Vulnerability_Notification(name, created_at, old_vulnerability, new_vulnerability)
|
||||
VALUES($1, CURRENT_TIMESTAMP, $2, $3)`
|
||||
|
||||
queries["r_notification"] = `UPDATE Vulnerability_Notification SET deleted_at = CURRENT_TIMESTAMP`
|
||||
queries["u_notification_notified"] = `
|
||||
UPDATE Vulnerability_Notification
|
||||
SET notified_at = CURRENT_TIMESTAMP
|
||||
WHERE name = $1`
|
||||
|
||||
queries["r_notification"] = `
|
||||
UPDATE Vulnerability_Notification
|
||||
SET deleted_at = CURRENT_TIMESTAMP
|
||||
WHERE name = $1`
|
||||
|
||||
queries["s_notification_available"] = `
|
||||
SELECT name, created_at, notified_at, deleted_at
|
||||
SELECT id, name, created_at, notified_at, deleted_at
|
||||
FROM Vulnerability_Notification
|
||||
WHERE notified_at = NULL OR notified_at < $1
|
||||
WHERE (notified_at IS NULL OR notified_at < $1)
|
||||
AND deleted_at IS NULL
|
||||
AND name NOT IN (SELECT name FROM Lock)
|
||||
ORDER BY Random()
|
||||
LIMIT 1`
|
||||
|
||||
queries["s_notification"] = `
|
||||
SELECT name, created_at, notified_at, deleted_at, old_vulnerability, new_vulnerability
|
||||
SELECT id, name, created_at, notified_at, deleted_at, old_vulnerability, new_vulnerability
|
||||
FROM Vulnerability_Notification
|
||||
WHERE name = $1`
|
||||
|
||||
queries["s_notification_layer_introducing_vulnerability"] = `
|
||||
SELECT l.ID, l.name
|
||||
FROM Vulnerability v, Vulnerability_Affects_FeatureVersion vafv, FeatureVersion fv, Layer_diff_FeatureVersion ldfv, Layer l
|
||||
WHERE v.id = $1
|
||||
AND v.id = vafv.vulnerability_id
|
||||
AND vafv.featureversion_id = fv.id
|
||||
AND fv.id = ldfv.featureversion_id
|
||||
AND ldfv.modification = 'add'
|
||||
AND ldfv.layer_id = l.id
|
||||
AND l.id >= $2
|
||||
ORDER BY l.ID
|
||||
LIMIT $3`
|
||||
|
||||
// complex_test.go
|
||||
queries["s_complextest_featureversion_affects"] = `
|
||||
SELECT v.name
|
||||
|
@ -223,17 +223,13 @@ func (pgSQL *pgSQL) insertVulnerability(vulnerability database.Vulnerability) er
|
||||
}
|
||||
|
||||
// Create notification.
|
||||
var notification database.VulnerabilityNotification
|
||||
if existingVulnerability.ID == 0 {
|
||||
notification = database.VulnerabilityNotification{
|
||||
NewVulnerability: vulnerability,
|
||||
}
|
||||
} else {
|
||||
notification = database.VulnerabilityNotification{
|
||||
OldVulnerability: &existingVulnerability,
|
||||
NewVulnerability: vulnerability,
|
||||
}
|
||||
notification := database.VulnerabilityNotification{
|
||||
NewVulnerability: vulnerability,
|
||||
}
|
||||
if existingVulnerability.ID != 0 {
|
||||
notification.OldVulnerability = &existingVulnerability
|
||||
}
|
||||
|
||||
if err := pgSQL.insertNotification(tx, notification); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -179,7 +179,7 @@ func (v Version) MarshalJSON() ([]byte, error) {
|
||||
func (v *Version) UnmarshalJSON(b []byte) (err error) {
|
||||
var str string
|
||||
json.Unmarshal(b, &str)
|
||||
vp, err := NewVersion(str)
|
||||
vp := NewVersionUnsafe(str)
|
||||
*v = vp
|
||||
return
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user