diff --git a/database/database.go b/database/database.go index ff311ab7..5f55f70a 100644 --- a/database/database.go +++ b/database/database.go @@ -4,13 +4,16 @@ import "errors" var ( // ErrTransaction is an error that occurs when a database transaction fails. - ErrTransaction = errors.New("database: transaction failed (concurrent modification?)") + // ErrTransaction = errors.New("database: transaction failed (concurrent modification?)") + // ErrBackendException is an error that occurs when the database backend does // not work properly (ie. unreachable). - ErrBackendException = errors.New("database: could not query backend") + ErrBackendException = errors.New("database: an error occured when querying the backend") + // ErrInconsistent is an error that occurs when a database consistency check // fails (ie. when an entity which is supposed to be unique is detected twice) ErrInconsistent = errors.New("database: inconsistent database") + // ErrCantOpen is an error that occurs when the database could not be opened ErrCantOpen = errors.New("database: could not open database") ) @@ -26,8 +29,8 @@ type Datastore interface { // DeleteVulnerability(id string) // Notifications - // InsertNotifications([]*Notification) error - // FindNotificationToSend() (*Notification, error) + // InsertNotifications([]Notification) error + // FindNotificationToSend() (Notification, error) // CountNotificationsToSend() (int, error) // MarkNotificationAsSent(id string) diff --git a/database/pgsql/feature.go b/database/pgsql/feature.go index 6d779533..33d81ab8 100644 --- a/database/pgsql/feature.go +++ b/database/pgsql/feature.go @@ -27,7 +27,7 @@ func (pgSQL *pgSQL) insertFeature(feature database.Feature) (int, error) { var id int err = pgSQL.QueryRow(getQuery("soi_feature"), feature.Name, namespaceID).Scan(&id) if err != nil { - return 0, err + return 0, handleError("soi_feature", err) } if pgSQL.cache != nil { @@ -59,7 +59,7 @@ func (pgSQL *pgSQL) insertFeatureVersion(featureVersion database.FeatureVersion) tx, err := pgSQL.Begin() if err != nil { tx.Rollback() - return 0, err + return 0, handleError("insertFeatureVersion.Begin()", err) } // Find or create FeatureVersion. @@ -68,7 +68,7 @@ func (pgSQL *pgSQL) insertFeatureVersion(featureVersion database.FeatureVersion) Scan(&newOrExisting, &featureVersion.ID) if err != nil { tx.Rollback() - return 0, err + return 0, handleError("soi_featureversion", err) } if newOrExisting == "exi" { // That featureVersion already exists, return its id. @@ -83,14 +83,14 @@ func (pgSQL *pgSQL) insertFeatureVersion(featureVersion database.FeatureVersion) _, err = tx.Exec(getQuery("l_share_vulnerability_fixedin_feature")) if err != nil { tx.Rollback() - return 0, err + return 0, handleError("l_share_vulnerability_fixedin_feature", err) } // Select every vulnerability and the fixed version that affect this Feature. rows, err := tx.Query(getQuery("s_vulnerability_fixedin_feature"), featureID) if err != nil { tx.Rollback() - return 0, err + return 0, handleError("s_vulnerability_fixedin_feature", err) } defer rows.Close() @@ -100,7 +100,7 @@ func (pgSQL *pgSQL) insertFeatureVersion(featureVersion database.FeatureVersion) err := rows.Scan(&fixedInID, &vulnerabilityID, &fixedInVersion) if err != nil { tx.Rollback() - return 0, err + return 0, handleError("s_vulnerability_fixedin_feature.Scan()", err) } if featureVersion.Version.Compare(fixedInVersion) < 0 { @@ -111,16 +111,19 @@ func (pgSQL *pgSQL) insertFeatureVersion(featureVersion database.FeatureVersion) featureVersion.ID, fixedInID) if err != nil { tx.Rollback() - return 0, err + return 0, handleError("i_vulnerability_affects_featureversion", err) } } } + if err = rows.Err(); err != nil { + return 0, handleError("s_vulnerability_fixedin_feature.Rows()", err) + } // Commit transaction. err = tx.Commit() if err != nil { tx.Rollback() - return 0, err + return 0, handleError("insertFeatureVersion.Commit()", err) } if pgSQL.cache != nil { diff --git a/database/pgsql/keyvalue.go b/database/pgsql/keyvalue.go index bcdf9103..e0a3881f 100644 --- a/database/pgsql/keyvalue.go +++ b/database/pgsql/keyvalue.go @@ -26,7 +26,7 @@ func (pgSQL *pgSQL) InsertKeyValue(key, value string) (err error) { // First, try to update. r, err := pgSQL.Exec(getQuery("u_keyvalue"), value, key) if err != nil { - return err + return handleError("u_keyvalue", err) } if n, _ := r.RowsAffected(); n > 0 { // Updated successfully. @@ -41,7 +41,7 @@ func (pgSQL *pgSQL) InsertKeyValue(key, value string) (err error) { // Got unique constraint violation, retry. continue } - return err + return handleError("i_keyvalue", err) } return nil @@ -49,10 +49,16 @@ func (pgSQL *pgSQL) InsertKeyValue(key, value string) (err error) { } // GetValue reads a single key / value tuple and returns an empty string if the key doesn't exist. -func (pgSQL *pgSQL) GetKeyValue(key string) (value string, err error) { - err = pgSQL.QueryRow(getQuery("s_keyvalue"), key).Scan(&value) +func (pgSQL *pgSQL) GetKeyValue(key string) (string, error) { + var value string + err := pgSQL.QueryRow(getQuery("s_keyvalue"), key).Scan(&value) + if err == sql.ErrNoRows { return "", nil } - return + if err != nil { + return "", handleError("s_keyvalue", err) + } + + return value, nil } diff --git a/database/pgsql/layer.go b/database/pgsql/layer.go index 361c8681..9b4e2537 100644 --- a/database/pgsql/layer.go +++ b/database/pgsql/layer.go @@ -25,7 +25,7 @@ func (pgSQL *pgSQL) FindLayer(name string, withFeatures, withVulnerabilities boo return layer, cerrors.ErrNotFound } if err != nil { - return layer, err + return layer, handleError("s_layer", err) } if !parentID.IsZero() { @@ -78,7 +78,7 @@ func (pgSQL *pgSQL) getLayerFeatureVersions(layerID int, idOnly bool) ([]databas // Query rows, err := pgSQL.Query(query, layerID) if err != nil && err != sql.ErrNoRows { - return featureVersions, err + return featureVersions, handleError(query, err) } defer rows.Close() @@ -91,14 +91,14 @@ func (pgSQL *pgSQL) getLayerFeatureVersions(layerID int, idOnly bool) ([]databas if idOnly { err = rows.Scan(&featureVersion.ID, &modification) if err != nil { - return featureVersions, err + return featureVersions, handleError(query+".Scan()", err) } } else { err = rows.Scan(&featureVersion.ID, &modification, &featureVersion.Feature.Namespace.ID, &featureVersion.Feature.Namespace.Name, &featureVersion.Feature.ID, &featureVersion.Feature.Name, &featureVersion.ID, &featureVersion.Version) if err != nil { - return featureVersions, err + return featureVersions, handleError(query+".Scan()", err) } } @@ -114,7 +114,7 @@ func (pgSQL *pgSQL) getLayerFeatureVersions(layerID int, idOnly bool) ([]databas } } if err = rows.Err(); err != nil { - return featureVersions, err + return featureVersions, handleError(query+".Rows()", err) } // Build result by converting our map to a slice @@ -141,7 +141,7 @@ func (pgSQL *pgSQL) loadAffectedBy(featureVersions []database.FeatureVersion) er rows, err := pgSQL.Query(getQuery("s_featureversions_vulnerabilities"), buildInputArray(featureVersionIDs)) if err != nil && err != sql.ErrNoRows { - return err + return handleError("s_featureversions_vulnerabilities", err) } defer rows.Close() @@ -153,12 +153,12 @@ func (pgSQL *pgSQL) loadAffectedBy(featureVersions []database.FeatureVersion) er &vulnerability.Description, &vulnerability.Link, &vulnerability.Severity, &vulnerability.Namespace.Name, &vulnerability.FixedBy) if err != nil { - return err + return handleError("s_featureversions_vulnerabilities.Scan()", err) } vulnerabilities[featureversionID] = append(vulnerabilities[featureversionID], vulnerability) } if err = rows.Err(); err != nil { - return err + return handleError("s_featureversions_vulnerabilities.Rows()", err) } // Assign vulnerabilities to every FeatureVersions @@ -208,7 +208,7 @@ func (pgSQL *pgSQL) InsertLayer(layer database.Layer) error { tx, err := pgSQL.Begin() if err != nil { tx.Rollback() - return err + return handleError("InsertLayer.Begin()", err) } // Find or insert namespace if provided. @@ -243,7 +243,7 @@ func (pgSQL *pgSQL) InsertLayer(layer database.Layer) error { Scan(&layer.ID) if err != nil { tx.Rollback() - return err + return handleError("i_layer", err) } } else { if existingLayer.EngineVersion >= layer.EngineVersion { @@ -255,14 +255,14 @@ func (pgSQL *pgSQL) InsertLayer(layer database.Layer) error { _, err = tx.Exec(getQuery("u_layer"), layer.ID, layer.EngineVersion, namespaceID) if err != nil { tx.Rollback() - return err + return handleError("u_layer", err) } // Remove all existing Layer_diff_FeatureVersion. _, err = tx.Exec(getQuery("r_layer_diff_featureversion"), layer.ID) if err != nil { tx.Rollback() - return err + return handleError("r_layer_diff_featureversion", err) } } @@ -276,7 +276,7 @@ func (pgSQL *pgSQL) InsertLayer(layer database.Layer) error { err = tx.Commit() if err != nil { tx.Rollback() - return err + return handleError("InsertLayer.Commit()", err) } return nil @@ -324,13 +324,13 @@ func (pgSQL *pgSQL) updateDiffFeatureVersions(tx *sql.Tx, layer, existingLayer * if len(addIDs) > 0 { _, err = tx.Exec(getQuery("i_layer_diff_featureversion"), layer.ID, "add", buildInputArray(addIDs)) if err != nil { - return err + return handleError("i_layer_diff_featureversion.Add", err) } } if len(delIDs) > 0 { _, err = tx.Exec(getQuery("i_layer_diff_featureversion"), layer.ID, "del", buildInputArray(delIDs)) if err != nil { - return err + return handleError("i_layer_diff_featureversion.Del", err) } } diff --git a/database/pgsql/namespace.go b/database/pgsql/namespace.go index bbb25e2e..fe97b980 100644 --- a/database/pgsql/namespace.go +++ b/database/pgsql/namespace.go @@ -5,7 +5,7 @@ import ( cerrors "github.com/coreos/clair/utils/errors" ) -func (pgSQL *pgSQL) insertNamespace(namespace database.Namespace) (id int, err error) { +func (pgSQL *pgSQL) insertNamespace(namespace database.Namespace) (int, error) { if namespace.Name == "" { return 0, cerrors.NewBadRequestError("could not find/insert invalid Namespace") } @@ -16,11 +16,15 @@ func (pgSQL *pgSQL) insertNamespace(namespace database.Namespace) (id int, err e } } - err = pgSQL.QueryRow(getQuery("soi_namespace"), namespace.Name).Scan(&id) + var id int + err := pgSQL.QueryRow(getQuery("soi_namespace"), namespace.Name).Scan(&id) + if err != nil { + return 0, handleError("soi_namespace", err) + } if pgSQL.cache != nil { pgSQL.cache.Add("namespace:"+namespace.Name, id) } - return + return id, nil } diff --git a/database/pgsql/pgsql.go b/database/pgsql/pgsql.go index de98b3a9..476eca0f 100644 --- a/database/pgsql/pgsql.go +++ b/database/pgsql/pgsql.go @@ -11,6 +11,7 @@ import ( "bitbucket.org/liamstask/goose/lib/goose" "github.com/coreos/clair/config" "github.com/coreos/clair/database" + cerrors "github.com/coreos/clair/utils/errors" "github.com/coreos/pkg/capnslog" "github.com/hashicorp/golang-lru" "github.com/lib/pq" @@ -33,14 +34,16 @@ func (pgSQL *pgSQL) Close() { // It will run immediately every necessary migration on the database. func Open(config *config.DatabaseConfig) (database.Datastore, error) { // Run migrations. - if err := Migrate(config.Source); err != nil { - return nil, fmt.Errorf("could not run database migration: %v", err) + if err := migrate(config.Source); err != nil { + log.Error(err) + return nil, database.ErrCantOpen } // Open database. db, err := sql.Open("postgres", config.Source) if err != nil { - return nil, fmt.Errorf("could not open database (Open): %v", err) + log.Error(err) + return nil, database.ErrCantOpen } // Initialize cache. @@ -53,8 +56,8 @@ func Open(config *config.DatabaseConfig) (database.Datastore, error) { return &pgSQL{DB: db, cache: cache}, nil } -// Migrate runs all available migrations on a pgSQL database. -func Migrate(dataSource string) error { +// migrate runs all available migrations on a pgSQL database. +func migrate(dataSource string) error { log.Info("running database migrations") _, filename, _, _ := runtime.Caller(1) @@ -85,9 +88,9 @@ func Migrate(dataSource string) error { return nil } -// CreateDatabase creates a new database. +// createDatabase creates a new database. // The dataSource parameter should not contain a dbname. -func CreateDatabase(dataSource, databaseName string) error { +func createDatabase(dataSource, databaseName string) error { // Open database. db, err := sql.Open("postgres", dataSource) if err != nil { @@ -104,9 +107,9 @@ func CreateDatabase(dataSource, databaseName string) error { return nil } -// DropDatabase drops an existing database. +// dropDatabase drops an existing database. // The dataSource parameter should not contain a dbname. -func DropDatabase(dataSource, databaseName string) error { +func dropDatabase(dataSource, databaseName string) error { // Open database. db, err := sql.Open("postgres", dataSource) if err != nil { @@ -133,7 +136,7 @@ type pgSQLTest struct { func (pgSQL *pgSQLTest) Close() { pgSQL.DB.Close() - DropDatabase(pgSQL.dataSource+"dbname=postgres", pgSQL.dbName) + dropDatabase(pgSQL.dataSource+"dbname=postgres", pgSQL.dbName) } // OpenForTest creates a test Datastore backed by a new PostgreSQL database. @@ -144,16 +147,18 @@ func OpenForTest(name string, withTestData bool) (*pgSQLTest, error) { dbName := "test_" + strings.ToLower(name) + "_" + strings.Replace(uuid.New(), "-", "_", -1) // Create database. - err := CreateDatabase(dataSource+"dbname=postgres", dbName) + err := createDatabase(dataSource+"dbname=postgres", dbName) if err != nil { - return nil, err + log.Error(err) + return nil, database.ErrCantOpen } // Open database. db, err := Open(&config.DatabaseConfig{Source: dataSource + "dbname=" + dbName, CacheSize: 0}) if err != nil { - DropDatabase(dataSource, dbName) - return nil, err + dropDatabase(dataSource, dbName) + log.Error(err) + return nil, database.ErrCantOpen } // Load test data if specified. @@ -162,14 +167,31 @@ func OpenForTest(name string, withTestData bool) (*pgSQLTest, error) { d, _ := ioutil.ReadFile(path.Join(path.Dir(filename)) + "/testdata/data.sql") _, err = db.(*pgSQL).Exec(string(d)) if err != nil { - DropDatabase(dataSource, dbName) - return nil, err + dropDatabase(dataSource, dbName) + log.Error(err) + return nil, database.ErrCantOpen } } return &pgSQLTest{pgSQL: db.(*pgSQL), dataSource: dataSource, dbName: dbName}, nil } +// handleError logs an error with an extra description and masks the error if it's an SQL one. +// This ensures we never return plain SQL errors and leak anything. +func handleError(desc string, err error) error { + log.Errorf("%s: %v", desc, err) + + if _, ok := err.(*pq.Error); ok { + return database.ErrBackendException + } else if err == sql.ErrNoRows { + return cerrors.ErrNotFound + } else if err == sql.ErrTxDone { + return database.ErrBackendException + } + + return err +} + // isErrUniqueViolation determines is the given error is a unique contraint violation. func isErrUniqueViolation(err error) bool { pqErr, ok := err.(*pq.Error)