database: log and mask SQL errors
This commit is contained in:
parent
970756cd5a
commit
6a9cf21fd4
@ -4,13 +4,16 @@ import "errors"
|
|||||||
|
|
||||||
var (
|
var (
|
||||||
// ErrTransaction is an error that occurs when a database transaction fails.
|
// ErrTransaction is an error that occurs when a database transaction fails.
|
||||||
ErrTransaction = errors.New("database: transaction failed (concurrent modification?)")
|
// ErrTransaction = errors.New("database: transaction failed (concurrent modification?)")
|
||||||
|
|
||||||
// ErrBackendException is an error that occurs when the database backend does
|
// ErrBackendException is an error that occurs when the database backend does
|
||||||
// not work properly (ie. unreachable).
|
// not work properly (ie. unreachable).
|
||||||
ErrBackendException = errors.New("database: could not query backend")
|
ErrBackendException = errors.New("database: an error occured when querying the backend")
|
||||||
|
|
||||||
// ErrInconsistent is an error that occurs when a database consistency check
|
// ErrInconsistent is an error that occurs when a database consistency check
|
||||||
// fails (ie. when an entity which is supposed to be unique is detected twice)
|
// fails (ie. when an entity which is supposed to be unique is detected twice)
|
||||||
ErrInconsistent = errors.New("database: inconsistent database")
|
ErrInconsistent = errors.New("database: inconsistent database")
|
||||||
|
|
||||||
// ErrCantOpen is an error that occurs when the database could not be opened
|
// ErrCantOpen is an error that occurs when the database could not be opened
|
||||||
ErrCantOpen = errors.New("database: could not open database")
|
ErrCantOpen = errors.New("database: could not open database")
|
||||||
)
|
)
|
||||||
@ -26,8 +29,8 @@ type Datastore interface {
|
|||||||
// DeleteVulnerability(id string)
|
// DeleteVulnerability(id string)
|
||||||
|
|
||||||
// Notifications
|
// Notifications
|
||||||
// InsertNotifications([]*Notification) error
|
// InsertNotifications([]Notification) error
|
||||||
// FindNotificationToSend() (*Notification, error)
|
// FindNotificationToSend() (Notification, error)
|
||||||
// CountNotificationsToSend() (int, error)
|
// CountNotificationsToSend() (int, error)
|
||||||
// MarkNotificationAsSent(id string)
|
// MarkNotificationAsSent(id string)
|
||||||
|
|
||||||
|
@ -27,7 +27,7 @@ func (pgSQL *pgSQL) insertFeature(feature database.Feature) (int, error) {
|
|||||||
var id int
|
var id int
|
||||||
err = pgSQL.QueryRow(getQuery("soi_feature"), feature.Name, namespaceID).Scan(&id)
|
err = pgSQL.QueryRow(getQuery("soi_feature"), feature.Name, namespaceID).Scan(&id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, handleError("soi_feature", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if pgSQL.cache != nil {
|
if pgSQL.cache != nil {
|
||||||
@ -59,7 +59,7 @@ func (pgSQL *pgSQL) insertFeatureVersion(featureVersion database.FeatureVersion)
|
|||||||
tx, err := pgSQL.Begin()
|
tx, err := pgSQL.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
return 0, err
|
return 0, handleError("insertFeatureVersion.Begin()", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find or create FeatureVersion.
|
// Find or create FeatureVersion.
|
||||||
@ -68,7 +68,7 @@ func (pgSQL *pgSQL) insertFeatureVersion(featureVersion database.FeatureVersion)
|
|||||||
Scan(&newOrExisting, &featureVersion.ID)
|
Scan(&newOrExisting, &featureVersion.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
return 0, err
|
return 0, handleError("soi_featureversion", err)
|
||||||
}
|
}
|
||||||
if newOrExisting == "exi" {
|
if newOrExisting == "exi" {
|
||||||
// That featureVersion already exists, return its id.
|
// That featureVersion already exists, return its id.
|
||||||
@ -83,14 +83,14 @@ func (pgSQL *pgSQL) insertFeatureVersion(featureVersion database.FeatureVersion)
|
|||||||
_, err = tx.Exec(getQuery("l_share_vulnerability_fixedin_feature"))
|
_, err = tx.Exec(getQuery("l_share_vulnerability_fixedin_feature"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
return 0, err
|
return 0, handleError("l_share_vulnerability_fixedin_feature", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Select every vulnerability and the fixed version that affect this Feature.
|
// Select every vulnerability and the fixed version that affect this Feature.
|
||||||
rows, err := tx.Query(getQuery("s_vulnerability_fixedin_feature"), featureID)
|
rows, err := tx.Query(getQuery("s_vulnerability_fixedin_feature"), featureID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
return 0, err
|
return 0, handleError("s_vulnerability_fixedin_feature", err)
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
@ -100,7 +100,7 @@ func (pgSQL *pgSQL) insertFeatureVersion(featureVersion database.FeatureVersion)
|
|||||||
err := rows.Scan(&fixedInID, &vulnerabilityID, &fixedInVersion)
|
err := rows.Scan(&fixedInID, &vulnerabilityID, &fixedInVersion)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
return 0, err
|
return 0, handleError("s_vulnerability_fixedin_feature.Scan()", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if featureVersion.Version.Compare(fixedInVersion) < 0 {
|
if featureVersion.Version.Compare(fixedInVersion) < 0 {
|
||||||
@ -111,16 +111,19 @@ func (pgSQL *pgSQL) insertFeatureVersion(featureVersion database.FeatureVersion)
|
|||||||
featureVersion.ID, fixedInID)
|
featureVersion.ID, fixedInID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
return 0, err
|
return 0, handleError("i_vulnerability_affects_featureversion", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if err = rows.Err(); err != nil {
|
||||||
|
return 0, handleError("s_vulnerability_fixedin_feature.Rows()", err)
|
||||||
|
}
|
||||||
|
|
||||||
// Commit transaction.
|
// Commit transaction.
|
||||||
err = tx.Commit()
|
err = tx.Commit()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
return 0, err
|
return 0, handleError("insertFeatureVersion.Commit()", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if pgSQL.cache != nil {
|
if pgSQL.cache != nil {
|
||||||
|
@ -26,7 +26,7 @@ func (pgSQL *pgSQL) InsertKeyValue(key, value string) (err error) {
|
|||||||
// First, try to update.
|
// First, try to update.
|
||||||
r, err := pgSQL.Exec(getQuery("u_keyvalue"), value, key)
|
r, err := pgSQL.Exec(getQuery("u_keyvalue"), value, key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return handleError("u_keyvalue", err)
|
||||||
}
|
}
|
||||||
if n, _ := r.RowsAffected(); n > 0 {
|
if n, _ := r.RowsAffected(); n > 0 {
|
||||||
// Updated successfully.
|
// Updated successfully.
|
||||||
@ -41,7 +41,7 @@ func (pgSQL *pgSQL) InsertKeyValue(key, value string) (err error) {
|
|||||||
// Got unique constraint violation, retry.
|
// Got unique constraint violation, retry.
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
return err
|
return handleError("i_keyvalue", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -49,10 +49,16 @@ func (pgSQL *pgSQL) InsertKeyValue(key, value string) (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetValue reads a single key / value tuple and returns an empty string if the key doesn't exist.
|
// GetValue reads a single key / value tuple and returns an empty string if the key doesn't exist.
|
||||||
func (pgSQL *pgSQL) GetKeyValue(key string) (value string, err error) {
|
func (pgSQL *pgSQL) GetKeyValue(key string) (string, error) {
|
||||||
err = pgSQL.QueryRow(getQuery("s_keyvalue"), key).Scan(&value)
|
var value string
|
||||||
|
err := pgSQL.QueryRow(getQuery("s_keyvalue"), key).Scan(&value)
|
||||||
|
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
return
|
if err != nil {
|
||||||
|
return "", handleError("s_keyvalue", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return value, nil
|
||||||
}
|
}
|
||||||
|
@ -25,7 +25,7 @@ func (pgSQL *pgSQL) FindLayer(name string, withFeatures, withVulnerabilities boo
|
|||||||
return layer, cerrors.ErrNotFound
|
return layer, cerrors.ErrNotFound
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return layer, err
|
return layer, handleError("s_layer", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !parentID.IsZero() {
|
if !parentID.IsZero() {
|
||||||
@ -78,7 +78,7 @@ func (pgSQL *pgSQL) getLayerFeatureVersions(layerID int, idOnly bool) ([]databas
|
|||||||
// Query
|
// Query
|
||||||
rows, err := pgSQL.Query(query, layerID)
|
rows, err := pgSQL.Query(query, layerID)
|
||||||
if err != nil && err != sql.ErrNoRows {
|
if err != nil && err != sql.ErrNoRows {
|
||||||
return featureVersions, err
|
return featureVersions, handleError(query, err)
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
@ -91,14 +91,14 @@ func (pgSQL *pgSQL) getLayerFeatureVersions(layerID int, idOnly bool) ([]databas
|
|||||||
if idOnly {
|
if idOnly {
|
||||||
err = rows.Scan(&featureVersion.ID, &modification)
|
err = rows.Scan(&featureVersion.ID, &modification)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return featureVersions, err
|
return featureVersions, handleError(query+".Scan()", err)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
err = rows.Scan(&featureVersion.ID, &modification, &featureVersion.Feature.Namespace.ID,
|
err = rows.Scan(&featureVersion.ID, &modification, &featureVersion.Feature.Namespace.ID,
|
||||||
&featureVersion.Feature.Namespace.Name, &featureVersion.Feature.ID,
|
&featureVersion.Feature.Namespace.Name, &featureVersion.Feature.ID,
|
||||||
&featureVersion.Feature.Name, &featureVersion.ID, &featureVersion.Version)
|
&featureVersion.Feature.Name, &featureVersion.ID, &featureVersion.Version)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return featureVersions, err
|
return featureVersions, handleError(query+".Scan()", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -114,7 +114,7 @@ func (pgSQL *pgSQL) getLayerFeatureVersions(layerID int, idOnly bool) ([]databas
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err = rows.Err(); err != nil {
|
if err = rows.Err(); err != nil {
|
||||||
return featureVersions, err
|
return featureVersions, handleError(query+".Rows()", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build result by converting our map to a slice
|
// Build result by converting our map to a slice
|
||||||
@ -141,7 +141,7 @@ func (pgSQL *pgSQL) loadAffectedBy(featureVersions []database.FeatureVersion) er
|
|||||||
rows, err := pgSQL.Query(getQuery("s_featureversions_vulnerabilities"),
|
rows, err := pgSQL.Query(getQuery("s_featureversions_vulnerabilities"),
|
||||||
buildInputArray(featureVersionIDs))
|
buildInputArray(featureVersionIDs))
|
||||||
if err != nil && err != sql.ErrNoRows {
|
if err != nil && err != sql.ErrNoRows {
|
||||||
return err
|
return handleError("s_featureversions_vulnerabilities", err)
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
@ -153,12 +153,12 @@ func (pgSQL *pgSQL) loadAffectedBy(featureVersions []database.FeatureVersion) er
|
|||||||
&vulnerability.Description, &vulnerability.Link, &vulnerability.Severity,
|
&vulnerability.Description, &vulnerability.Link, &vulnerability.Severity,
|
||||||
&vulnerability.Namespace.Name, &vulnerability.FixedBy)
|
&vulnerability.Namespace.Name, &vulnerability.FixedBy)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return handleError("s_featureversions_vulnerabilities.Scan()", err)
|
||||||
}
|
}
|
||||||
vulnerabilities[featureversionID] = append(vulnerabilities[featureversionID], vulnerability)
|
vulnerabilities[featureversionID] = append(vulnerabilities[featureversionID], vulnerability)
|
||||||
}
|
}
|
||||||
if err = rows.Err(); err != nil {
|
if err = rows.Err(); err != nil {
|
||||||
return err
|
return handleError("s_featureversions_vulnerabilities.Rows()", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Assign vulnerabilities to every FeatureVersions
|
// Assign vulnerabilities to every FeatureVersions
|
||||||
@ -208,7 +208,7 @@ func (pgSQL *pgSQL) InsertLayer(layer database.Layer) error {
|
|||||||
tx, err := pgSQL.Begin()
|
tx, err := pgSQL.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
return err
|
return handleError("InsertLayer.Begin()", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find or insert namespace if provided.
|
// Find or insert namespace if provided.
|
||||||
@ -243,7 +243,7 @@ func (pgSQL *pgSQL) InsertLayer(layer database.Layer) error {
|
|||||||
Scan(&layer.ID)
|
Scan(&layer.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
return err
|
return handleError("i_layer", err)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if existingLayer.EngineVersion >= layer.EngineVersion {
|
if existingLayer.EngineVersion >= layer.EngineVersion {
|
||||||
@ -255,14 +255,14 @@ func (pgSQL *pgSQL) InsertLayer(layer database.Layer) error {
|
|||||||
_, err = tx.Exec(getQuery("u_layer"), layer.ID, layer.EngineVersion, namespaceID)
|
_, err = tx.Exec(getQuery("u_layer"), layer.ID, layer.EngineVersion, namespaceID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
return err
|
return handleError("u_layer", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove all existing Layer_diff_FeatureVersion.
|
// Remove all existing Layer_diff_FeatureVersion.
|
||||||
_, err = tx.Exec(getQuery("r_layer_diff_featureversion"), layer.ID)
|
_, err = tx.Exec(getQuery("r_layer_diff_featureversion"), layer.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
return err
|
return handleError("r_layer_diff_featureversion", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -276,7 +276,7 @@ func (pgSQL *pgSQL) InsertLayer(layer database.Layer) error {
|
|||||||
err = tx.Commit()
|
err = tx.Commit()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
return err
|
return handleError("InsertLayer.Commit()", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -324,13 +324,13 @@ func (pgSQL *pgSQL) updateDiffFeatureVersions(tx *sql.Tx, layer, existingLayer *
|
|||||||
if len(addIDs) > 0 {
|
if len(addIDs) > 0 {
|
||||||
_, err = tx.Exec(getQuery("i_layer_diff_featureversion"), layer.ID, "add", buildInputArray(addIDs))
|
_, err = tx.Exec(getQuery("i_layer_diff_featureversion"), layer.ID, "add", buildInputArray(addIDs))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return handleError("i_layer_diff_featureversion.Add", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(delIDs) > 0 {
|
if len(delIDs) > 0 {
|
||||||
_, err = tx.Exec(getQuery("i_layer_diff_featureversion"), layer.ID, "del", buildInputArray(delIDs))
|
_, err = tx.Exec(getQuery("i_layer_diff_featureversion"), layer.ID, "del", buildInputArray(delIDs))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return handleError("i_layer_diff_featureversion.Del", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@ import (
|
|||||||
cerrors "github.com/coreos/clair/utils/errors"
|
cerrors "github.com/coreos/clair/utils/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (pgSQL *pgSQL) insertNamespace(namespace database.Namespace) (id int, err error) {
|
func (pgSQL *pgSQL) insertNamespace(namespace database.Namespace) (int, error) {
|
||||||
if namespace.Name == "" {
|
if namespace.Name == "" {
|
||||||
return 0, cerrors.NewBadRequestError("could not find/insert invalid Namespace")
|
return 0, cerrors.NewBadRequestError("could not find/insert invalid Namespace")
|
||||||
}
|
}
|
||||||
@ -16,11 +16,15 @@ func (pgSQL *pgSQL) insertNamespace(namespace database.Namespace) (id int, err e
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = pgSQL.QueryRow(getQuery("soi_namespace"), namespace.Name).Scan(&id)
|
var id int
|
||||||
|
err := pgSQL.QueryRow(getQuery("soi_namespace"), namespace.Name).Scan(&id)
|
||||||
|
if err != nil {
|
||||||
|
return 0, handleError("soi_namespace", err)
|
||||||
|
}
|
||||||
|
|
||||||
if pgSQL.cache != nil {
|
if pgSQL.cache != nil {
|
||||||
pgSQL.cache.Add("namespace:"+namespace.Name, id)
|
pgSQL.cache.Add("namespace:"+namespace.Name, id)
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return id, nil
|
||||||
}
|
}
|
||||||
|
@ -11,6 +11,7 @@ import (
|
|||||||
"bitbucket.org/liamstask/goose/lib/goose"
|
"bitbucket.org/liamstask/goose/lib/goose"
|
||||||
"github.com/coreos/clair/config"
|
"github.com/coreos/clair/config"
|
||||||
"github.com/coreos/clair/database"
|
"github.com/coreos/clair/database"
|
||||||
|
cerrors "github.com/coreos/clair/utils/errors"
|
||||||
"github.com/coreos/pkg/capnslog"
|
"github.com/coreos/pkg/capnslog"
|
||||||
"github.com/hashicorp/golang-lru"
|
"github.com/hashicorp/golang-lru"
|
||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
@ -33,14 +34,16 @@ func (pgSQL *pgSQL) Close() {
|
|||||||
// It will run immediately every necessary migration on the database.
|
// It will run immediately every necessary migration on the database.
|
||||||
func Open(config *config.DatabaseConfig) (database.Datastore, error) {
|
func Open(config *config.DatabaseConfig) (database.Datastore, error) {
|
||||||
// Run migrations.
|
// Run migrations.
|
||||||
if err := Migrate(config.Source); err != nil {
|
if err := migrate(config.Source); err != nil {
|
||||||
return nil, fmt.Errorf("could not run database migration: %v", err)
|
log.Error(err)
|
||||||
|
return nil, database.ErrCantOpen
|
||||||
}
|
}
|
||||||
|
|
||||||
// Open database.
|
// Open database.
|
||||||
db, err := sql.Open("postgres", config.Source)
|
db, err := sql.Open("postgres", config.Source)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("could not open database (Open): %v", err)
|
log.Error(err)
|
||||||
|
return nil, database.ErrCantOpen
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize cache.
|
// Initialize cache.
|
||||||
@ -53,8 +56,8 @@ func Open(config *config.DatabaseConfig) (database.Datastore, error) {
|
|||||||
return &pgSQL{DB: db, cache: cache}, nil
|
return &pgSQL{DB: db, cache: cache}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Migrate runs all available migrations on a pgSQL database.
|
// migrate runs all available migrations on a pgSQL database.
|
||||||
func Migrate(dataSource string) error {
|
func migrate(dataSource string) error {
|
||||||
log.Info("running database migrations")
|
log.Info("running database migrations")
|
||||||
|
|
||||||
_, filename, _, _ := runtime.Caller(1)
|
_, filename, _, _ := runtime.Caller(1)
|
||||||
@ -85,9 +88,9 @@ func Migrate(dataSource string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateDatabase creates a new database.
|
// createDatabase creates a new database.
|
||||||
// The dataSource parameter should not contain a dbname.
|
// The dataSource parameter should not contain a dbname.
|
||||||
func CreateDatabase(dataSource, databaseName string) error {
|
func createDatabase(dataSource, databaseName string) error {
|
||||||
// Open database.
|
// Open database.
|
||||||
db, err := sql.Open("postgres", dataSource)
|
db, err := sql.Open("postgres", dataSource)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -104,9 +107,9 @@ func CreateDatabase(dataSource, databaseName string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropDatabase drops an existing database.
|
// dropDatabase drops an existing database.
|
||||||
// The dataSource parameter should not contain a dbname.
|
// The dataSource parameter should not contain a dbname.
|
||||||
func DropDatabase(dataSource, databaseName string) error {
|
func dropDatabase(dataSource, databaseName string) error {
|
||||||
// Open database.
|
// Open database.
|
||||||
db, err := sql.Open("postgres", dataSource)
|
db, err := sql.Open("postgres", dataSource)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -133,7 +136,7 @@ type pgSQLTest struct {
|
|||||||
|
|
||||||
func (pgSQL *pgSQLTest) Close() {
|
func (pgSQL *pgSQLTest) Close() {
|
||||||
pgSQL.DB.Close()
|
pgSQL.DB.Close()
|
||||||
DropDatabase(pgSQL.dataSource+"dbname=postgres", pgSQL.dbName)
|
dropDatabase(pgSQL.dataSource+"dbname=postgres", pgSQL.dbName)
|
||||||
}
|
}
|
||||||
|
|
||||||
// OpenForTest creates a test Datastore backed by a new PostgreSQL database.
|
// OpenForTest creates a test Datastore backed by a new PostgreSQL database.
|
||||||
@ -144,16 +147,18 @@ func OpenForTest(name string, withTestData bool) (*pgSQLTest, error) {
|
|||||||
dbName := "test_" + strings.ToLower(name) + "_" + strings.Replace(uuid.New(), "-", "_", -1)
|
dbName := "test_" + strings.ToLower(name) + "_" + strings.Replace(uuid.New(), "-", "_", -1)
|
||||||
|
|
||||||
// Create database.
|
// Create database.
|
||||||
err := CreateDatabase(dataSource+"dbname=postgres", dbName)
|
err := createDatabase(dataSource+"dbname=postgres", dbName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
log.Error(err)
|
||||||
|
return nil, database.ErrCantOpen
|
||||||
}
|
}
|
||||||
|
|
||||||
// Open database.
|
// Open database.
|
||||||
db, err := Open(&config.DatabaseConfig{Source: dataSource + "dbname=" + dbName, CacheSize: 0})
|
db, err := Open(&config.DatabaseConfig{Source: dataSource + "dbname=" + dbName, CacheSize: 0})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
DropDatabase(dataSource, dbName)
|
dropDatabase(dataSource, dbName)
|
||||||
return nil, err
|
log.Error(err)
|
||||||
|
return nil, database.ErrCantOpen
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load test data if specified.
|
// Load test data if specified.
|
||||||
@ -162,14 +167,31 @@ func OpenForTest(name string, withTestData bool) (*pgSQLTest, error) {
|
|||||||
d, _ := ioutil.ReadFile(path.Join(path.Dir(filename)) + "/testdata/data.sql")
|
d, _ := ioutil.ReadFile(path.Join(path.Dir(filename)) + "/testdata/data.sql")
|
||||||
_, err = db.(*pgSQL).Exec(string(d))
|
_, err = db.(*pgSQL).Exec(string(d))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
DropDatabase(dataSource, dbName)
|
dropDatabase(dataSource, dbName)
|
||||||
return nil, err
|
log.Error(err)
|
||||||
|
return nil, database.ErrCantOpen
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &pgSQLTest{pgSQL: db.(*pgSQL), dataSource: dataSource, dbName: dbName}, nil
|
return &pgSQLTest{pgSQL: db.(*pgSQL), dataSource: dataSource, dbName: dbName}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// handleError logs an error with an extra description and masks the error if it's an SQL one.
|
||||||
|
// This ensures we never return plain SQL errors and leak anything.
|
||||||
|
func handleError(desc string, err error) error {
|
||||||
|
log.Errorf("%s: %v", desc, err)
|
||||||
|
|
||||||
|
if _, ok := err.(*pq.Error); ok {
|
||||||
|
return database.ErrBackendException
|
||||||
|
} else if err == sql.ErrNoRows {
|
||||||
|
return cerrors.ErrNotFound
|
||||||
|
} else if err == sql.ErrTxDone {
|
||||||
|
return database.ErrBackendException
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// isErrUniqueViolation determines is the given error is a unique contraint violation.
|
// isErrUniqueViolation determines is the given error is a unique contraint violation.
|
||||||
func isErrUniqueViolation(err error) bool {
|
func isErrUniqueViolation(err error) bool {
|
||||||
pqErr, ok := err.(*pq.Error)
|
pqErr, ok := err.(*pq.Error)
|
||||||
|
Loading…
Reference in New Issue
Block a user