From 0b32b36cf7168eef2c005a3d7ec9c3a5996d910b Mon Sep 17 00:00:00 2001 From: Sida Chen Date: Wed, 6 Mar 2019 16:29:00 -0500 Subject: [PATCH] pgsql: Move detector to pgsql/detector module --- database/pgsql/detector.go | 196 ------------------ database/pgsql/detector/detector.go | 132 ++++++++++++ .../pgsql/{ => detector}/detector_test.go | 12 +- 3 files changed, 139 insertions(+), 201 deletions(-) delete mode 100644 database/pgsql/detector.go create mode 100644 database/pgsql/detector/detector.go rename database/pgsql/{ => detector}/detector_test.go (91%) diff --git a/database/pgsql/detector.go b/database/pgsql/detector.go deleted file mode 100644 index 3209d632..00000000 --- a/database/pgsql/detector.go +++ /dev/null @@ -1,196 +0,0 @@ -// Copyright 2018 clair authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package pgsql - -import ( - "database/sql" - - "github.com/deckarep/golang-set" - log "github.com/sirupsen/logrus" - - "github.com/coreos/clair/database" -) - -const ( - soiDetector = ` - INSERT INTO detector (name, version, dtype) - SELECT CAST ($1 AS TEXT), CAST ($2 AS TEXT), CAST ($3 AS detector_type ) - WHERE NOT EXISTS (SELECT id FROM detector WHERE name = $1 AND version = $2 AND dtype = $3);` - - selectAncestryDetectors = ` - SELECT d.name, d.version, d.dtype - FROM ancestry_detector, detector AS d - WHERE ancestry_detector.detector_id = d.id AND ancestry_detector.ancestry_id = $1;` - - selectLayerDetectors = ` - SELECT d.name, d.version, d.dtype - FROM layer_detector, detector AS d - WHERE layer_detector.detector_id = d.id AND layer_detector.layer_id = $1;` - - insertAncestryDetectors = ` - INSERT INTO ancestry_detector (ancestry_id, detector_id) - SELECT $1, $2 - WHERE NOT EXISTS (SELECT id FROM ancestry_detector WHERE ancestry_id = $1 AND detector_id = $2)` - - persistLayerDetector = ` - INSERT INTO layer_detector (layer_id, detector_id) - SELECT $1, $2 - WHERE NOT EXISTS (SELECT id FROM layer_detector WHERE layer_id = $1 AND detector_id = $2)` - - findDetectorID = `SELECT id FROM detector WHERE name = $1 AND version = $2 AND dtype = $3` - findAllDetectors = `SELECT id, name, version, dtype FROM detector` -) - -type detectorMap struct { - byID map[int64]database.Detector - byValue map[database.Detector]int64 -} - -func (tx *pgSession) PersistDetectors(detectors []database.Detector) error { - for _, d := range detectors { - if !d.Valid() { - log.WithField("detector", d).Debug("Invalid Detector") - return database.ErrInvalidParameters - } - - r, err := tx.Exec(soiDetector, d.Name, d.Version, d.DType) - if err != nil { - return handleError("soiDetector", err) - } - - count, err := r.RowsAffected() - if err != nil { - return handleError("soiDetector", err) - } - - if count == 0 { - log.Debug("detector already exists: ", d) - } - } - - return nil -} - -func (tx *pgSession) persistLayerDetector(layerID int64, detectorID int64) error { - if _, err := tx.Exec(persistLayerDetector, layerID, detectorID); err != nil { - return handleError("persistLayerDetector", err) - } - - return nil -} - -func (tx *pgSession) persistLayerDetectors(layerID int64, detectorIDs []int64) error { - alreadySaved := mapset.NewSet() - for _, id := range detectorIDs { - if alreadySaved.Contains(id) { - continue - } - - alreadySaved.Add(id) - if err := tx.persistLayerDetector(layerID, id); err != nil { - return err - } - } - - return nil -} - -func (tx *pgSession) insertAncestryDetectors(ancestryID int64, detectorIDs []int64) error { - for _, detectorID := range detectorIDs { - if _, err := tx.Exec(insertAncestryDetectors, ancestryID, detectorID); err != nil { - return handleError("insertAncestryDetectors", err) - } - } - - return nil -} - -func (tx *pgSession) findAncestryDetectors(id int64) ([]database.Detector, error) { - detectors, err := tx.getDetectors(selectAncestryDetectors, id) - return detectors, err -} - -func (tx *pgSession) findLayerDetectors(id int64) ([]database.Detector, error) { - detectors, err := tx.getDetectors(selectLayerDetectors, id) - return detectors, err -} - -// findDetectorIDs retrieve ids of the detectors from the database, if any is not -// found, return the error. -func (tx *pgSession) findDetectorIDs(detectors []database.Detector) ([]int64, error) { - ids := []int64{} - for _, d := range detectors { - id := sql.NullInt64{} - err := tx.QueryRow(findDetectorID, d.Name, d.Version, d.DType).Scan(&id) - if err != nil { - return nil, handleError("findDetectorID", err) - } - - if !id.Valid { - return nil, database.ErrInconsistent - } - - ids = append(ids, id.Int64) - } - - return ids, nil -} - -func (tx *pgSession) getDetectors(query string, id int64) ([]database.Detector, error) { - rows, err := tx.Query(query, id) - if err != nil { - return nil, handleError("getDetectors", err) - } - - detectors := []database.Detector{} - for rows.Next() { - d := database.Detector{} - err := rows.Scan(&d.Name, &d.Version, &d.DType) - if err != nil { - return nil, handleError("getDetectors", err) - } - - if !d.Valid() { - return nil, database.ErrInvalidDetector - } - - detectors = append(detectors, d) - } - - return detectors, nil -} - -func (tx *pgSession) findAllDetectors() (detectorMap, error) { - rows, err := tx.Query(findAllDetectors) - if err != nil { - return detectorMap{}, handleError("searchAllDetectors", err) - } - - detectors := detectorMap{byID: make(map[int64]database.Detector), byValue: make(map[database.Detector]int64)} - for rows.Next() { - var ( - id int64 - d database.Detector - ) - if err := rows.Scan(&id, &d.Name, &d.Version, &d.DType); err != nil { - return detectorMap{}, handleError("searchAllDetectors", err) - } - - detectors.byID[id] = d - detectors.byValue[d] = id - } - - return detectors, nil -} diff --git a/database/pgsql/detector/detector.go b/database/pgsql/detector/detector.go new file mode 100644 index 00000000..08643c90 --- /dev/null +++ b/database/pgsql/detector/detector.go @@ -0,0 +1,132 @@ +// Copyright 2018 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package detector + +import ( + "database/sql" + + log "github.com/sirupsen/logrus" + + "github.com/coreos/clair/database" + "github.com/coreos/clair/database/pgsql/util" +) + +const ( + soiDetector = ` + INSERT INTO detector (name, version, dtype) + SELECT CAST ($1 AS TEXT), CAST ($2 AS TEXT), CAST ($3 AS detector_type ) + WHERE NOT EXISTS (SELECT id FROM detector WHERE name = $1 AND version = $2 AND dtype = $3);` + + findDetectorID = `SELECT id FROM detector WHERE name = $1 AND version = $2 AND dtype = $3` + findAllDetectors = `SELECT id, name, version, dtype FROM detector` +) + +type DetectorMap struct { + ByID map[int64]database.Detector + ByValue map[database.Detector]int64 +} + +func PersistDetectors(tx *sql.Tx, detectors []database.Detector) error { + for _, d := range detectors { + if !d.Valid() { + log.WithField("detector", d).Debug("Invalid Detector") + return database.ErrInvalidParameters + } + + r, err := tx.Exec(soiDetector, d.Name, d.Version, d.DType) + if err != nil { + return util.HandleError("soiDetector", err) + } + + count, err := r.RowsAffected() + if err != nil { + return util.HandleError("soiDetector", err) + } + + if count == 0 { + log.Debug("detector already exists: ", d) + } + } + + return nil +} + +// findDetectorIDs retrieve ids of the detectors from the database, if any is not +// found, return the error. +func FindDetectorIDs(tx *sql.Tx, detectors []database.Detector) ([]int64, error) { + ids := []int64{} + for _, d := range detectors { + id := sql.NullInt64{} + err := tx.QueryRow(findDetectorID, d.Name, d.Version, d.DType).Scan(&id) + if err != nil { + return nil, util.HandleError("findDetectorID", err) + } + + if !id.Valid { + return nil, database.ErrInconsistent + } + + ids = append(ids, id.Int64) + } + + return ids, nil +} + +func GetDetectors(tx *sql.Tx, query string, id int64) ([]database.Detector, error) { + rows, err := tx.Query(query, id) + if err != nil { + return nil, util.HandleError("getDetectors", err) + } + + detectors := []database.Detector{} + for rows.Next() { + d := database.Detector{} + err := rows.Scan(&d.Name, &d.Version, &d.DType) + if err != nil { + return nil, util.HandleError("getDetectors", err) + } + + if !d.Valid() { + return nil, database.ErrInvalidDetector + } + + detectors = append(detectors, d) + } + + return detectors, nil +} + +func FindAllDetectors(tx *sql.Tx) (DetectorMap, error) { + rows, err := tx.Query(findAllDetectors) + if err != nil { + return DetectorMap{}, util.HandleError("searchAllDetectors", err) + } + + detectors := DetectorMap{ByID: make(map[int64]database.Detector), ByValue: make(map[database.Detector]int64)} + for rows.Next() { + var ( + id int64 + d database.Detector + ) + if err := rows.Scan(&id, &d.Name, &d.Version, &d.DType); err != nil { + return DetectorMap{}, util.HandleError("searchAllDetectors", err) + } + + detectors.ByID[id] = d + detectors.ByValue[d] = id + } + + return detectors, nil +} diff --git a/database/pgsql/detector_test.go b/database/pgsql/detector/detector_test.go similarity index 91% rename from database/pgsql/detector_test.go rename to database/pgsql/detector/detector_test.go index 582da60b..27e3fad5 100644 --- a/database/pgsql/detector_test.go +++ b/database/pgsql/detector/detector_test.go @@ -12,18 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -package pgsql +package detector import ( + "database/sql" "testing" "github.com/deckarep/golang-set" "github.com/stretchr/testify/require" "github.com/coreos/clair/database" + "github.com/coreos/clair/database/pgsql/testutil" ) -func testGetAllDetectors(tx *pgSession) []database.Detector { +func testGetAllDetectors(tx *sql.Tx) []database.Detector { query := `SELECT name, version, dtype FROM detector` rows, err := tx.Query(query) if err != nil { @@ -90,12 +92,12 @@ var persistDetectorTests = []struct { } func TestPersistDetector(t *testing.T) { - datastore, tx := openSessionForTest(t, "PersistDetector", true) - defer closeTest(t, datastore, tx) + tx, cleanup := testutil.CreateTestTxWithFixtures(t, "PersistDetector") + defer cleanup() for _, test := range persistDetectorTests { t.Run(test.title, func(t *testing.T) { - err := tx.PersistDetectors(test.in) + err := PersistDetectors(tx, test.in) if test.err != "" { require.EqualError(t, err, test.err) return