// Copyright 2017 clair authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package pgsql implements database.Datastore with PostgreSQL. package pgsql import ( "database/sql" "fmt" "io/ioutil" "net/url" "strings" "time" "os" "gopkg.in/yaml.v2" "github.com/hashicorp/golang-lru" "github.com/lib/pq" "github.com/prometheus/client_golang/prometheus" "github.com/remind101/migrate" log "github.com/sirupsen/logrus" "github.com/coreos/clair/database" "github.com/coreos/clair/database/pgsql/migrations" "github.com/coreos/clair/database/pgsql/token" "github.com/coreos/clair/pkg/commonerr" ) var ( promErrorsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{ Name: "clair_pgsql_errors_total", Help: "Number of errors that PostgreSQL requests generated.", }, []string{"request"}) promCacheHitsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{ Name: "clair_pgsql_cache_hits_total", Help: "Number of cache hits that the PostgreSQL backend did.", }, []string{"object"}) promCacheQueriesTotal = prometheus.NewCounterVec(prometheus.CounterOpts{ Name: "clair_pgsql_cache_queries_total", Help: "Number of cache queries that the PostgreSQL backend did.", }, []string{"object"}) promQueryDurationMilliseconds = prometheus.NewHistogramVec(prometheus.HistogramOpts{ Name: "clair_pgsql_query_duration_milliseconds", Help: "Time it takes to execute the database query.", }, []string{"query", "subquery"}) promConcurrentLockVAFV = prometheus.NewGauge(prometheus.GaugeOpts{ Name: "clair_pgsql_concurrent_lock_vafv_total", Help: "Number of transactions trying to hold the exclusive Vulnerability_Affects_Feature lock.", }) ) func init() { prometheus.MustRegister(promErrorsTotal) prometheus.MustRegister(promCacheHitsTotal) prometheus.MustRegister(promCacheQueriesTotal) prometheus.MustRegister(promQueryDurationMilliseconds) prometheus.MustRegister(promConcurrentLockVAFV) database.Register("pgsql", openDatabase) } // pgSessionCache is the session's cache, which holds the pgSQL's cache and the // individual session's cache. Only when session.Commit is called, all the // changes to pgSQL cache will be applied. type pgSessionCache struct { c *lru.ARCCache } type pgSQL struct { *sql.DB cache *lru.ARCCache config Config } type pgSession struct { *sql.Tx paginationKey string } type idPageNumber struct { // StartID is an implementation detail for paginating by an ID required to // be unique to every ancestry and always increasing. // // StartID is used to search for ancestry with ID >= StartID StartID int64 } func encryptPage(page idPageNumber, paginationKey string) (result database.PageNumber, err error) { resultBytes, err := token.Marshal(page, paginationKey) if err != nil { return result, err } result = database.PageNumber(resultBytes) return result, nil } func decryptPage(page database.PageNumber, paginationKey string) (result idPageNumber, err error) { err = token.Unmarshal(string(page), paginationKey, &result) return } // Begin initiates a transaction to database. The expected transaction isolation // level in this implementation is "Read Committed". func (pgSQL *pgSQL) Begin() (database.Session, error) { tx, err := pgSQL.DB.Begin() if err != nil { return nil, err } return &pgSession{ Tx: tx, paginationKey: pgSQL.config.PaginationKey, }, nil } func (tx *pgSession) Commit() error { return tx.Tx.Commit() } // Close closes the database and destroys if ManageDatabaseLifecycle has been specified in // the configuration. func (pgSQL *pgSQL) Close() { if pgSQL.DB != nil { pgSQL.DB.Close() } if pgSQL.config.ManageDatabaseLifecycle { dbName, pgSourceURL, _ := parseConnectionString(pgSQL.config.Source) dropDatabase(pgSourceURL, dbName) } } // Ping verifies that the database is accessible. func (pgSQL *pgSQL) Ping() bool { return pgSQL.DB.Ping() == nil } // Config is the configuration that is used by openDatabase. type Config struct { Source string CacheSize int ManageDatabaseLifecycle bool FixturePath string PaginationKey string } // openDatabase opens a PostgresSQL-backed Datastore using the given // configuration. // // It immediately runs all necessary migrations. If ManageDatabaseLifecycle is // specified, the database will be created first. If FixturePath is specified, // every SQL queries that are present insides will be executed. func openDatabase(registrableComponentConfig database.RegistrableComponentConfig) (database.Datastore, error) { var pg pgSQL var err error // Parse configuration. pg.config = Config{ CacheSize: 16384, } bytes, err := yaml.Marshal(registrableComponentConfig.Options) if err != nil { return nil, fmt.Errorf("pgsql: could not load configuration: %v", err) } err = yaml.Unmarshal(bytes, &pg.config) if err != nil { return nil, fmt.Errorf("pgsql: could not load configuration: %v", err) } if pg.config.PaginationKey == "" { panic("pagination key should be given") } dbName, pgSourceURL, err := parseConnectionString(pg.config.Source) if err != nil { return nil, err } // Create database. if pg.config.ManageDatabaseLifecycle { log.Info("pgsql: creating database") if err = createDatabase(pgSourceURL, dbName); err != nil { return nil, err } } // Open database. host := os.Getenv("POSTGRESQL_SERVICE_HOST") modifiedSource := strings.Replace(pg.config.Source, "$POSTGRESQL_SERVICE_HOST", host, -1) fmt.Println("postgresql hostname replaced: ", modifiedSource) pg.DB, err = sql.Open("postgres", modifiedSource) if err != nil { pg.Close() return nil, fmt.Errorf("pgsql: could not open database: %v", err) } // Verify database state. if err = pg.DB.Ping(); err != nil { pg.Close() return nil, fmt.Errorf("pgsql: could not open database: %v", err) } // Run migrations. if err = migrateDatabase(pg.DB); err != nil { pg.Close() return nil, err } // Load fixture data. if pg.config.FixturePath != "" { log.Info("pgsql: loading fixtures") d, err := ioutil.ReadFile(pg.config.FixturePath) if err != nil { pg.Close() return nil, fmt.Errorf("pgsql: could not open fixture file: %v", err) } _, err = pg.DB.Exec(string(d)) if err != nil { pg.Close() return nil, fmt.Errorf("pgsql: an error occurred while importing fixtures: %v", err) } } // Initialize cache. // TODO(Quentin-M): Benchmark with a simple LRU Cache. if pg.config.CacheSize > 0 { pg.cache, _ = lru.NewARC(pg.config.CacheSize) } return &pg, nil } func parseConnectionString(source string) (dbName string, pgSourceURL string, err error) { if source == "" { return "", "", commonerr.NewBadRequestError("pgsql: no database connection string specified") } sourceURL, err := url.Parse(source) if err != nil { return "", "", commonerr.NewBadRequestError("pgsql: database connection string is not a valid URL") } dbName = strings.TrimPrefix(sourceURL.Path, "/") pgSource := *sourceURL pgSource.Path = "/postgres" pgSourceURL = pgSource.String() return } // migrate runs all available migrations on a pgSQL database. func migrateDatabase(db *sql.DB) error { log.Info("running database migrations") err := migrate.NewPostgresMigrator(db).Exec(migrate.Up, migrations.Migrations...) if err != nil { return fmt.Errorf("pgsql: an error occurred while running migrations: %v", err) } log.Info("database migration ran successfully") return nil } // createDatabase creates a new database. // The source parameter should not contain a dbname. func createDatabase(source, dbName string) error { // Open database. db, err := sql.Open("postgres", source) if err != nil { return fmt.Errorf("pgsql: could not open 'postgres' database for creation: %v", err) } defer db.Close() // Create database. _, err = db.Exec("CREATE DATABASE " + dbName) if err != nil { return fmt.Errorf("pgsql: could not create database: %v", err) } return nil } // dropDatabase drops an existing database. // The source parameter should not contain a dbname. func dropDatabase(source, dbName string) error { // Open database. db, err := sql.Open("postgres", source) if err != nil { return fmt.Errorf("could not open database (DropDatabase): %v", err) } defer db.Close() // Kill any opened connection. if _, err = db.Exec(` SELECT pg_terminate_backend(pg_stat_activity.pid) FROM pg_stat_activity WHERE pg_stat_activity.datname = $1 AND pid <> pg_backend_pid()`, dbName); err != nil { return fmt.Errorf("could not drop database: %v", err) } // Drop database. if _, err = db.Exec("DROP DATABASE " + dbName); err != nil { return fmt.Errorf("could not drop database: %v", err) } return nil } // handleError logs an error with an extra description and masks the error if it's an SQL one. // The function ensures we never return plain SQL errors and leak anything. // The function should be used for every database query error. func handleError(desc string, err error) error { if err == nil { return nil } if err == sql.ErrNoRows { return commonerr.ErrNotFound } log.WithError(err).WithField("Description", desc).Error("Handled Database Error") promErrorsTotal.WithLabelValues(desc).Inc() if _, o := err.(*pq.Error); o || err == sql.ErrTxDone || strings.HasPrefix(err.Error(), "sql:") { return database.ErrBackendException } return err } // isErrUniqueViolation determines is the given error is a unique contraint violation. func isErrUniqueViolation(err error) bool { pqErr, ok := err.(*pq.Error) return ok && pqErr.Code == "23505" } // observeQueryTime computes the time elapsed since `start` to represent the // query time. // 1. `query` is a pgSession function name. // 2. `subquery` is a specific query or a batched query. // 3. `start` is the time right before query is executed. func observeQueryTime(query, subquery string, start time.Time) { promQueryDurationMilliseconds. WithLabelValues(query, subquery). Observe(float64(time.Since(start).Nanoseconds()) / float64(time.Millisecond)) }