database: write more of the notification system

This commit is contained in:
Quentin Machu 2016-01-26 17:57:32 -05:00 committed by Jimmy Zelinskie
parent 90fe137de8
commit 8be18a0a01
10 changed files with 292 additions and 48 deletions

View File

@ -49,7 +49,7 @@ type Datastore interface {
// Notifications // Notifications
GetAvailableNotification(renotifyInterval time.Duration) (VulnerabilityNotification, error) // Does not fill old/new Vulnerabilities. GetAvailableNotification(renotifyInterval time.Duration) (VulnerabilityNotification, error) // Does not fill old/new Vulnerabilities.
GetNotification(name string, limit, page int) (VulnerabilityNotification, error) GetNotification(name string, limit int, page VulnerabilityNotificationPageNumber) (VulnerabilityNotification, VulnerabilityNotificationPageNumber, error)
SetNotificationNotified(name string) error SetNotificationNotified(name string) error
DeleteNotification(name string) error DeleteNotification(name string) error

View File

@ -22,7 +22,7 @@ import (
// ID is only meant to be used by database implementations and should never be used for anything else. // ID is only meant to be used by database implementations and should never be used for anything else.
type Model struct { type Model struct {
ID int `json:"-"` ID int
} }
type Layer struct { type Layer struct {
@ -74,6 +74,8 @@ type Vulnerability struct {
} }
type VulnerabilityNotification struct { type VulnerabilityNotification struct {
Model
Name string Name string
Created time.Time Created time.Time
@ -83,3 +85,12 @@ type VulnerabilityNotification struct {
OldVulnerability *Vulnerability OldVulnerability *Vulnerability
NewVulnerability Vulnerability NewVulnerability Vulnerability
} }
type VulnerabilityNotificationPageNumber struct {
// -1 means that we reached the end already.
OldVulnerability int
NewVulnerability int
}
var VulnerabilityNotificationFirstPage = VulnerabilityNotificationPageNumber{0, 0}
var NoVulnerabilityNotificationPage = VulnerabilityNotificationPageNumber{-1, -1}

View File

@ -36,7 +36,7 @@ const (
) )
func TestRaceAffects(t *testing.T) { func TestRaceAffects(t *testing.T) {
datastore, err := OpenForTest("TestRaceAffects", false) datastore, err := OpenForTest("RaceAffects", false)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return

View File

@ -151,7 +151,7 @@ CREATE TABLE IF NOT EXISTS Vulnerability_Notification (
created_at TIMESTAMP WITH TIME ZONE, created_at TIMESTAMP WITH TIME ZONE,
notified_at TIMESTAMP WITH TIME ZONE NULL, notified_at TIMESTAMP WITH TIME ZONE NULL,
deleted_at TIMESTAMP WITH TIME ZONE NULL, deleted_at TIMESTAMP WITH TIME ZONE NULL,
old_vulnerability TEXT, old_vulnerability TEXT NULL,
new_vulnerability TEXT); new_vulnerability TEXT);
CREATE INDEX ON Vulnerability_Notification (notified_at); CREATE INDEX ON Vulnerability_Notification (notified_at);

View File

@ -7,20 +7,27 @@ import (
"github.com/coreos/clair/database" "github.com/coreos/clair/database"
cerrors "github.com/coreos/clair/utils/errors" cerrors "github.com/coreos/clair/utils/errors"
"github.com/guregu/null/zero"
"github.com/pborman/uuid" "github.com/pborman/uuid"
) )
// do it in tx so we won't insert/update a vuln without notification and vice-versa. // do it in tx so we won't insert/update a vuln without notification and vice-versa.
// name and created doesn't matter. // name and created doesn't matter.
// Vuln ID must be filled in.
func (pgSQL *pgSQL) insertNotification(tx *sql.Tx, notification database.VulnerabilityNotification) error { func (pgSQL *pgSQL) insertNotification(tx *sql.Tx, notification database.VulnerabilityNotification) error {
defer observeQueryTime("insertNotification", "all", time.Now()) defer observeQueryTime("insertNotification", "all", time.Now())
// Marshal old and new Vulnerabilities. // Marshal old and new Vulnerabilities.
oldVulnerability, err := json.Marshal(notification.OldVulnerability) var oldVulnerability sql.NullString
if notification.OldVulnerability != nil {
oldVulnerabilityJSON, err := json.Marshal(notification.OldVulnerability)
if err != nil { if err != nil {
tx.Rollback() tx.Rollback()
return cerrors.NewBadRequestError("could not marshal old Vulnerability in insertNotification") return cerrors.NewBadRequestError("could not marshal old Vulnerability in insertNotification")
} }
oldVulnerability = sql.NullString{String: string(oldVulnerabilityJSON), Valid: true}
}
newVulnerability, err := json.Marshal(notification.NewVulnerability) newVulnerability, err := json.Marshal(notification.NewVulnerability)
if err != nil { if err != nil {
tx.Rollback() tx.Rollback()
@ -43,46 +50,127 @@ func (pgSQL *pgSQL) GetAvailableNotification(renotifyInterval time.Duration) (da
defer observeQueryTime("GetAvailableNotification", "all", time.Now()) defer observeQueryTime("GetAvailableNotification", "all", time.Now())
before := time.Now().Add(-renotifyInterval) before := time.Now().Add(-renotifyInterval)
row := pgSQL.QueryRow(getQuery("s_notification_available"), before)
notification, err := scanNotification(row, false)
var notification database.VulnerabilityNotification
err := pgSQL.QueryRow(getQuery("s_notification_available"), before).Scan(&notification.Name,
&notification.Created, &notification.Notified, &notification.Deleted)
if err != nil {
return notification, handleError("s_notification_available", err) return notification, handleError("s_notification_available", err)
} }
return notification, nil func (pgSQL *pgSQL) GetNotification(name string, limit int, page database.VulnerabilityNotificationPageNumber) (database.VulnerabilityNotification, database.VulnerabilityNotificationPageNumber, error) {
}
func (pgSQL *pgSQL) GetNotification(name string, limit, page int) (database.VulnerabilityNotification, error) {
defer observeQueryTime("GetNotification", "all", time.Now()) defer observeQueryTime("GetNotification", "all", time.Now())
// Get Notification. // Get Notification.
var notification database.VulnerabilityNotification notification, err := scanNotification(pgSQL.QueryRow(getQuery("s_notification"), name), true)
if err != nil {
return notification, page, handleError("s_notification", err)
}
// Load vulnerabilities' LayersIntroducingVulnerability.
page.OldVulnerability, err = pgSQL.loadLayerIntroducingVulnerability(
notification.OldVulnerability, limit, page.OldVulnerability)
if err != nil {
return notification, page, err
}
page.NewVulnerability, err = pgSQL.loadLayerIntroducingVulnerability(
&notification.NewVulnerability, limit, page.NewVulnerability)
if err != nil {
return notification, page, err
}
return notification, page, nil
}
func scanNotification(row *sql.Row, hasVulns bool) (notification database.VulnerabilityNotification, err error) {
var created zero.Time
var notified zero.Time
var deleted zero.Time
var oldVulnerability []byte var oldVulnerability []byte
var newVulnerability []byte var newVulnerability []byte
err := pgSQL.QueryRow(getQuery("s_notification"), name).Scan(&notification.Name, // Query notification.
&notification.Created, &notification.Notified, &notification.Deleted, &newVulnerability, if hasVulns {
&oldVulnerability) err = row.Scan(&notification.ID, &notification.Name, &created, &notified, &deleted,
&oldVulnerability, &newVulnerability)
} else {
err = row.Scan(&notification.ID, &notification.Name, &created, &notified, &deleted)
}
if err != nil { if err != nil {
return notification, handleError("s_notification", err) return
} }
notification.Created = created.Time
notification.Notified = notified.Time
notification.Deleted = deleted.Time
if hasVulns {
// Unmarshal old and new Vulnerabilities. // Unmarshal old and new Vulnerabilities.
err = json.Unmarshal(oldVulnerability, notification.OldVulnerability) err = json.Unmarshal(oldVulnerability, notification.OldVulnerability)
if err != nil { if err != nil {
return notification, cerrors.NewBadRequestError("could not unmarshal old Vulnerability in GetNotification") err = cerrors.NewBadRequestError("could not unmarshal old Vulnerability in GetNotification")
} }
err = json.Unmarshal(newVulnerability, &notification.NewVulnerability) err = json.Unmarshal(newVulnerability, &notification.NewVulnerability)
if err != nil { if err != nil {
return notification, cerrors.NewBadRequestError("could not unmarshal new Vulnerability in GetNotification") err = cerrors.NewBadRequestError("could not unmarshal new Vulnerability in GetNotification")
}
} }
// TODO(Quentin-M): Fill LayersIntroducingVulnerability. return
// And time it. }
return notification, nil // Fills Vulnerability.LayersIntroducingVulnerability.
// limit -1: won't do anything
// limit 0: will just get the startID of the second page
func (pgSQL *pgSQL) loadLayerIntroducingVulnerability(vulnerability *database.Vulnerability, limit, startID int) (int, error) {
tf := time.Now()
if vulnerability == nil {
return -1, nil
}
// A startID equals to -1 means that we reached the end already.
if startID == -1 || limit == -1 {
return -1, nil
}
// We do `defer observeQueryTime` here because we don't want to observe invalid calls.
defer observeQueryTime("loadLayerIntroducingVulnerability", "all", tf)
// Query with limit + 1, the last item will be used to know the next starting ID.
rows, err := pgSQL.Query(getQuery("s_notification_layer_introducing_vulnerability"),
vulnerability.ID, startID, limit+1)
if err != nil {
return 0, handleError("s_vulnerability_fixedin_feature", err)
}
defer rows.Close()
var layers []database.Layer
for rows.Next() {
var layer database.Layer
if err := rows.Scan(&layer.ID, &layer.Name); err != nil {
return -1, handleError("s_notification_layer_introducing_vulnerability.Scan()", err)
}
layers = append(layers, layer)
}
if err = rows.Err(); err != nil {
return -1, handleError("s_notification_layer_introducing_vulnerability.Rows()", err)
}
size := limit
if len(layers) < limit {
size = len(layers)
}
vulnerability.LayersIntroducingVulnerability = layers[:size]
nextID := -1
if len(layers) > limit {
nextID = layers[limit].ID
}
return nextID, nil
} }
func (pgSQL *pgSQL) SetNotificationNotified(name string) error { func (pgSQL *pgSQL) SetNotificationNotified(name string) error {

View File

@ -0,0 +1,123 @@
package pgsql
import (
"testing"
"time"
"fmt"
"github.com/coreos/clair/database"
cerrors "github.com/coreos/clair/utils/errors"
"github.com/coreos/clair/utils/types"
"github.com/stretchr/testify/assert"
)
func TestNotification(t *testing.T) {
datastore, err := OpenForTest("Notification", false)
if err != nil {
t.Error(err)
return
}
defer datastore.Close()
// Try to get a notification when there is none.
_, err = datastore.GetAvailableNotification(time.Second)
assert.Equal(t, cerrors.ErrNotFound, err)
// Create some data.
f1 := database.Feature{
Name: "TestNotificationFeature1",
Namespace: database.Namespace{Name: "TestNotificationNamespace1"},
}
l1 := database.Layer{
Name: "TestNotificationLayer1",
Features: []database.FeatureVersion{
database.FeatureVersion{
Feature: f1,
Version: types.NewVersionUnsafe("0.1"),
},
},
}
l2 := database.Layer{
Name: "TestNotificationLayer2",
Features: []database.FeatureVersion{
database.FeatureVersion{
Feature: f1,
Version: types.NewVersionUnsafe("0.2"),
},
},
}
l3 := database.Layer{
Name: "TestNotificationLayer3",
Features: []database.FeatureVersion{
database.FeatureVersion{
Feature: f1,
Version: types.NewVersionUnsafe("0.3"),
},
},
}
if assert.Nil(t, datastore.InsertLayer(l1)) && assert.Nil(t, datastore.InsertLayer(l2)) &&
assert.Nil(t, datastore.InsertLayer(l3)) {
// Insert a new vulnerability that is introduced by three layers.
v1 := database.Vulnerability{
Name: "TestNotificationVulnerability1",
Namespace: f1.Namespace,
Description: "TestNotificationDescription1",
Link: "TestNotificationLink1",
Severity: "Unknown",
FixedIn: []database.FeatureVersion{
database.FeatureVersion{
Feature: f1,
Version: types.NewVersionUnsafe("1.0"),
},
},
}
assert.Nil(t, datastore.insertVulnerability(v1))
// Get the notification associated to the previously inserted vulnerability.
notification, err := datastore.GetAvailableNotification(time.Second)
assert.Nil(t, err)
assert.NotEmpty(t, notification.Name)
// Verify the renotify behaviour.
if assert.Nil(t, datastore.SetNotificationNotified(notification.Name)) {
_, err := datastore.GetAvailableNotification(time.Second)
assert.Equal(t, cerrors.ErrNotFound, err)
time.Sleep(50 * time.Millisecond)
notificationB, err := datastore.GetAvailableNotification(20 * time.Millisecond)
assert.Nil(t, err)
assert.Equal(t, notification.Name, notificationB.Name)
datastore.SetNotificationNotified(notification.Name)
}
// Get notification.
filledNotification, nextPage, err := datastore.GetNotification(notification.Name, 2, database.VulnerabilityNotificationFirstPage)
assert.Nil(t, err)
assert.NotEqual(t, database.NoVulnerabilityNotificationPage, nextPage)
assert.Nil(t, filledNotification.OldVulnerability)
assert.Equal(t, v1.Name, filledNotification.NewVulnerability.Name)
assert.Len(t, filledNotification.NewVulnerability.LayersIntroducingVulnerability, 2)
// Get second page.
filledNotification, nextPage, err = datastore.GetNotification(notification.Name, 2, nextPage)
assert.Nil(t, err)
assert.Equal(t, database.NoVulnerabilityNotificationPage, nextPage)
assert.Nil(t, filledNotification.OldVulnerability)
assert.Equal(t, v1.Name, filledNotification.NewVulnerability.Name)
assert.Len(t, filledNotification.NewVulnerability.LayersIntroducingVulnerability, 1)
// Delete notification.
assert.Nil(t, datastore.DeleteNotification(notification.Name))
n, err := datastore.GetAvailableNotification(time.Millisecond)
assert.Equal(t, cerrors.ErrNotFound, err)
fmt.Println(n)
}
}

View File

@ -236,6 +236,10 @@ func OpenForTest(name string, withTestData bool) (*pgSQLTest, error) {
// handleError logs an error with an extra description and masks the error if it's an SQL one. // handleError logs an error with an extra description and masks the error if it's an SQL one.
// This ensures we never return plain SQL errors and leak anything. // This ensures we never return plain SQL errors and leak anything.
func handleError(desc string, err error) error { func handleError(desc string, err error) error {
if err == nil {
return nil
}
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return cerrors.ErrNotFound return cerrors.ErrNotFound
} }

View File

@ -190,21 +190,43 @@ func init() {
INSERT INTO Vulnerability_Notification(name, created_at, old_vulnerability, new_vulnerability) INSERT INTO Vulnerability_Notification(name, created_at, old_vulnerability, new_vulnerability)
VALUES($1, CURRENT_TIMESTAMP, $2, $3)` VALUES($1, CURRENT_TIMESTAMP, $2, $3)`
queries["r_notification"] = `UPDATE Vulnerability_Notification SET deleted_at = CURRENT_TIMESTAMP` queries["u_notification_notified"] = `
UPDATE Vulnerability_Notification
SET notified_at = CURRENT_TIMESTAMP
WHERE name = $1`
queries["r_notification"] = `
UPDATE Vulnerability_Notification
SET deleted_at = CURRENT_TIMESTAMP
WHERE name = $1`
queries["s_notification_available"] = ` queries["s_notification_available"] = `
SELECT name, created_at, notified_at, deleted_at SELECT id, name, created_at, notified_at, deleted_at
FROM Vulnerability_Notification FROM Vulnerability_Notification
WHERE notified_at = NULL OR notified_at < $1 WHERE (notified_at IS NULL OR notified_at < $1)
AND deleted_at IS NULL
AND name NOT IN (SELECT name FROM Lock) AND name NOT IN (SELECT name FROM Lock)
ORDER BY Random() ORDER BY Random()
LIMIT 1` LIMIT 1`
queries["s_notification"] = ` queries["s_notification"] = `
SELECT name, created_at, notified_at, deleted_at, old_vulnerability, new_vulnerability SELECT id, name, created_at, notified_at, deleted_at, old_vulnerability, new_vulnerability
FROM Vulnerability_Notification FROM Vulnerability_Notification
WHERE name = $1` WHERE name = $1`
queries["s_notification_layer_introducing_vulnerability"] = `
SELECT l.ID, l.name
FROM Vulnerability v, Vulnerability_Affects_FeatureVersion vafv, FeatureVersion fv, Layer_diff_FeatureVersion ldfv, Layer l
WHERE v.id = $1
AND v.id = vafv.vulnerability_id
AND vafv.featureversion_id = fv.id
AND fv.id = ldfv.featureversion_id
AND ldfv.modification = 'add'
AND ldfv.layer_id = l.id
AND l.id >= $2
ORDER BY l.ID
LIMIT $3`
// complex_test.go // complex_test.go
queries["s_complextest_featureversion_affects"] = ` queries["s_complextest_featureversion_affects"] = `
SELECT v.name SELECT v.name

View File

@ -223,17 +223,13 @@ func (pgSQL *pgSQL) insertVulnerability(vulnerability database.Vulnerability) er
} }
// Create notification. // Create notification.
var notification database.VulnerabilityNotification notification := database.VulnerabilityNotification{
if existingVulnerability.ID == 0 {
notification = database.VulnerabilityNotification{
NewVulnerability: vulnerability,
}
} else {
notification = database.VulnerabilityNotification{
OldVulnerability: &existingVulnerability,
NewVulnerability: vulnerability, NewVulnerability: vulnerability,
} }
if existingVulnerability.ID != 0 {
notification.OldVulnerability = &existingVulnerability
} }
if err := pgSQL.insertNotification(tx, notification); err != nil { if err := pgSQL.insertNotification(tx, notification); err != nil {
return err return err
} }

View File

@ -179,7 +179,7 @@ func (v Version) MarshalJSON() ([]byte, error) {
func (v *Version) UnmarshalJSON(b []byte) (err error) { func (v *Version) UnmarshalJSON(b []byte) (err error) {
var str string var str string
json.Unmarshal(b, &str) json.Unmarshal(b, &str)
vp, err := NewVersion(str) vp := NewVersionUnsafe(str)
*v = vp *v = vp
return return
} }