diff --git a/database/ancestry.go b/database/ancestry.go new file mode 100644 index 00000000..6ba8336c --- /dev/null +++ b/database/ancestry.go @@ -0,0 +1,96 @@ +// Copyright 2019 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package database + +// Ancestry is a manifest that keeps all layers in an image in order. +type Ancestry struct { + // Name is a globally unique value for a set of layers. This is often the + // sha256 digest of an OCI/Docker manifest. + Name string `json:"name"` + // By contains the processors that are used when computing the + // content of this ancestry. + By []Detector `json:"by"` + // Layers should be ordered and i_th layer is the parent of i+1_th layer in + // the slice. + Layers []AncestryLayer `json:"layers"` +} + +// Valid checks if the ancestry is compliant to spec. +func (a *Ancestry) Valid() bool { + if a == nil { + return false + } + + if a.Name == "" { + return false + } + + for _, d := range a.By { + if !d.Valid() { + return false + } + } + + for _, l := range a.Layers { + if !l.Valid() { + return false + } + } + + return true +} + +// AncestryLayer is a layer with all detected namespaced features. +type AncestryLayer struct { + // Hash is the sha-256 tarsum on the layer's blob content. + Hash string `json:"hash"` + // Features are the features introduced by this layer when it was + // processed. + Features []AncestryFeature `json:"features"` +} + +// Valid checks if the Ancestry Layer is compliant to the spec. +func (l *AncestryLayer) Valid() bool { + if l == nil { + return false + } + + if l.Hash == "" { + return false + } + + return true +} + +// GetFeatures returns the Ancestry's features. +func (l *AncestryLayer) GetFeatures() []NamespacedFeature { + nsf := make([]NamespacedFeature, 0, len(l.Features)) + for _, f := range l.Features { + nsf = append(nsf, f.NamespacedFeature) + } + + return nsf +} + +// AncestryFeature is a namespaced feature with the detectors used to +// find this feature. +type AncestryFeature struct { + NamespacedFeature `json:"namespacedFeature"` + + // FeatureBy is the detector that detected the feature. + FeatureBy Detector `json:"featureBy"` + // NamespaceBy is the detector that detected the namespace. + NamespaceBy Detector `json:"namespaceBy"` +} diff --git a/database/feature.go b/database/feature.go new file mode 100644 index 00000000..ac745b29 --- /dev/null +++ b/database/feature.go @@ -0,0 +1,96 @@ +// Copyright 2019 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package database + +// Feature represents a package detected in a layer but the namespace is not +// determined. +// +// e.g. Name: Libssl1.0, Version: 1.0, VersionFormat: dpkg, Type: binary +// dpkg is the version format of the installer package manager, which in this +// case could be dpkg or apk. +type Feature struct { + Name string `json:"name"` + Version string `json:"version"` + VersionFormat string `json:"versionFormat"` + Type FeatureType `json:"type"` +} + +// NamespacedFeature is a feature with determined namespace and can be affected +// by vulnerabilities. +// +// e.g. OpenSSL 1.0 dpkg Debian:7. +type NamespacedFeature struct { + Feature `json:"feature"` + + Namespace Namespace `json:"namespace"` +} + +// AffectedNamespacedFeature is a namespaced feature affected by the +// vulnerabilities with fixed-in versions for this feature. +type AffectedNamespacedFeature struct { + NamespacedFeature + + AffectedBy []VulnerabilityWithFixedIn +} + +// VulnerabilityWithFixedIn is used for AffectedNamespacedFeature to retrieve +// the affecting vulnerabilities and the fixed-in versions for the feature. +type VulnerabilityWithFixedIn struct { + Vulnerability + + FixedInVersion string +} + +// AffectedFeature is used to determine whether a namespaced feature is affected +// by a Vulnerability. Namespace and Feature Name is unique. Affected Feature is +// bound to vulnerability. +type AffectedFeature struct { + // FeatureType determines which type of package it affects. + FeatureType FeatureType + Namespace Namespace + FeatureName string + // FixedInVersion is known next feature version that's not affected by the + // vulnerability. Empty FixedInVersion means the unaffected version is + // unknown. + FixedInVersion string + // AffectedVersion contains the version range to determine whether or not a + // feature is affected. + AffectedVersion string +} + +// NullableAffectedNamespacedFeature is an affectednamespacedfeature with +// whether it's found in datastore. +type NullableAffectedNamespacedFeature struct { + AffectedNamespacedFeature + + Valid bool +} + +func NewFeature(name string, version string, versionFormat string, featureType FeatureType) *Feature { + return &Feature{name, version, versionFormat, featureType} +} + +func NewBinaryPackage(name string, version string, versionFormat string) *Feature { + return &Feature{name, version, versionFormat, BinaryPackage} +} + +func NewSourcePackage(name string, version string, versionFormat string) *Feature { + return &Feature{name, version, versionFormat, SourcePackage} +} + +func NewNamespacedFeature(namespace *Namespace, feature *Feature) *NamespacedFeature { + // TODO: namespaced feature should use pointer values + return &NamespacedFeature{*feature, *namespace} +} diff --git a/database/layer.go b/database/layer.go new file mode 100644 index 00000000..9ca3c410 --- /dev/null +++ b/database/layer.go @@ -0,0 +1,65 @@ +// Copyright 2019 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package database + +// Layer is a layer with all the detected features and namespaces. +type Layer struct { + // Hash is the sha-256 tarsum on the layer's blob content. + Hash string `json:"hash"` + // By contains a list of detectors scanned this Layer. + By []Detector `json:"by"` + Namespaces []LayerNamespace `json:"namespaces"` + Features []LayerFeature `json:"features"` +} + +func (l *Layer) GetFeatures() []Feature { + features := make([]Feature, 0, len(l.Features)) + for _, f := range l.Features { + features = append(features, f.Feature) + } + + return features +} + +func (l *Layer) GetNamespaces() []Namespace { + namespaces := make([]Namespace, 0, len(l.Namespaces)+len(l.Features)) + for _, ns := range l.Namespaces { + namespaces = append(namespaces, ns.Namespace) + } + for _, f := range l.Features { + if f.PotentialNamespace.Valid() { + namespaces = append(namespaces, f.PotentialNamespace) + } + } + + return namespaces +} + +// LayerNamespace is a namespace with detection information. +type LayerNamespace struct { + Namespace `json:"namespace"` + + // By is the detector found the namespace. + By Detector `json:"by"` +} + +// LayerFeature is a feature with detection information. +type LayerFeature struct { + Feature `json:"feature"` + + // By is the detector found the feature. + By Detector `json:"by"` + PotentialNamespace Namespace `json:"potentialNamespace"` +} diff --git a/database/metadata.go b/database/metadata.go new file mode 100644 index 00000000..44f588cf --- /dev/null +++ b/database/metadata.go @@ -0,0 +1,41 @@ +// Copyright 2019 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package database + +import ( + "database/sql/driver" + "encoding/json" +) + +// MetadataMap is for storing the metadata returned by vulnerability database. +type MetadataMap map[string]interface{} + +func (mm *MetadataMap) Scan(value interface{}) error { + if value == nil { + return nil + } + + // github.com/lib/pq decodes TEXT/VARCHAR fields into strings. + val, ok := value.(string) + if !ok { + panic("got type other than []byte from database") + } + return json.Unmarshal([]byte(val), mm) +} + +func (mm *MetadataMap) Value() (driver.Value, error) { + json, err := json.Marshal(*mm) + return string(json), err +} diff --git a/database/models.go b/database/models.go deleted file mode 100644 index 7dd36dfe..00000000 --- a/database/models.go +++ /dev/null @@ -1,363 +0,0 @@ -// Copyright 2017 clair authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package database - -import ( - "database/sql/driver" - "encoding/json" - "fmt" - "time" - - "github.com/coreos/clair/pkg/pagination" -) - -// Ancestry is a manifest that keeps all layers in an image in order. -type Ancestry struct { - // Name is a globally unique value for a set of layers. This is often the - // sha256 digest of an OCI/Docker manifest. - Name string `json:"name"` - // By contains the processors that are used when computing the - // content of this ancestry. - By []Detector `json:"by"` - // Layers should be ordered and i_th layer is the parent of i+1_th layer in - // the slice. - Layers []AncestryLayer `json:"layers"` -} - -// Valid checks if the ancestry is compliant to spec. -func (a *Ancestry) Valid() bool { - if a == nil { - return false - } - - if a.Name == "" { - return false - } - - for _, d := range a.By { - if !d.Valid() { - return false - } - } - - for _, l := range a.Layers { - if !l.Valid() { - return false - } - } - - return true -} - -// AncestryLayer is a layer with all detected namespaced features. -type AncestryLayer struct { - // Hash is the sha-256 tarsum on the layer's blob content. - Hash string `json:"hash"` - // Features are the features introduced by this layer when it was - // processed. - Features []AncestryFeature `json:"features"` -} - -// Valid checks if the Ancestry Layer is compliant to the spec. -func (l *AncestryLayer) Valid() bool { - if l == nil { - return false - } - - if l.Hash == "" { - return false - } - - return true -} - -// GetFeatures returns the Ancestry's features. -func (l *AncestryLayer) GetFeatures() []NamespacedFeature { - nsf := make([]NamespacedFeature, 0, len(l.Features)) - for _, f := range l.Features { - nsf = append(nsf, f.NamespacedFeature) - } - - return nsf -} - -// AncestryFeature is a namespaced feature with the detectors used to -// find this feature. -type AncestryFeature struct { - NamespacedFeature `json:"namespacedFeature"` - - // FeatureBy is the detector that detected the feature. - FeatureBy Detector `json:"featureBy"` - // NamespaceBy is the detector that detected the namespace. - NamespaceBy Detector `json:"namespaceBy"` -} - -// Layer is a layer with all the detected features and namespaces. -type Layer struct { - // Hash is the sha-256 tarsum on the layer's blob content. - Hash string `json:"hash"` - // By contains a list of detectors scanned this Layer. - By []Detector `json:"by"` - Namespaces []LayerNamespace `json:"namespaces"` - Features []LayerFeature `json:"features"` -} - -func (l *Layer) GetFeatures() []Feature { - features := make([]Feature, 0, len(l.Features)) - for _, f := range l.Features { - features = append(features, f.Feature) - } - - return features -} - -func (l *Layer) GetNamespaces() []Namespace { - namespaces := make([]Namespace, 0, len(l.Namespaces)+len(l.Features)) - for _, ns := range l.Namespaces { - namespaces = append(namespaces, ns.Namespace) - } - for _, f := range l.Features { - if f.PotentialNamespace.Valid() { - namespaces = append(namespaces, f.PotentialNamespace) - } - } - - return namespaces -} - -// LayerNamespace is a namespace with detection information. -type LayerNamespace struct { - Namespace `json:"namespace"` - - // By is the detector found the namespace. - By Detector `json:"by"` -} - -// LayerFeature is a feature with detection information. -type LayerFeature struct { - Feature `json:"feature"` - - // By is the detector found the feature. - By Detector `json:"by"` - PotentialNamespace Namespace `json:"potentialNamespace"` -} - -// Namespace is the contextual information around features. -// -// e.g. Debian:7, NodeJS. -type Namespace struct { - Name string `json:"name"` - VersionFormat string `json:"versionFormat"` -} - -func NewNamespace(name string, versionFormat string) *Namespace { - return &Namespace{name, versionFormat} -} - -func (ns *Namespace) Valid() bool { - if ns.Name == "" || ns.VersionFormat == "" { - return false - } - return true -} - -// Feature represents a package detected in a layer but the namespace is not -// determined. -// -// e.g. Name: Libssl1.0, Version: 1.0, VersionFormat: dpkg, Type: binary -// dpkg is the version format of the installer package manager, which in this -// case could be dpkg or apk. -type Feature struct { - Name string `json:"name"` - Version string `json:"version"` - VersionFormat string `json:"versionFormat"` - Type FeatureType `json:"type"` -} - -func NewFeature(name string, version string, versionFormat string, featureType FeatureType) *Feature { - return &Feature{name, version, versionFormat, featureType} -} - -func NewBinaryPackage(name string, version string, versionFormat string) *Feature { - return &Feature{name, version, versionFormat, BinaryPackage} -} - -func NewSourcePackage(name string, version string, versionFormat string) *Feature { - return &Feature{name, version, versionFormat, SourcePackage} -} - -// NamespacedFeature is a feature with determined namespace and can be affected -// by vulnerabilities. -// -// e.g. OpenSSL 1.0 dpkg Debian:7. -type NamespacedFeature struct { - Feature `json:"feature"` - - Namespace Namespace `json:"namespace"` -} - -func (nf *NamespacedFeature) Key() string { - return fmt.Sprintf("%s-%s-%s-%s-%s-%s", nf.Name, nf.Version, nf.VersionFormat, nf.Type, nf.Namespace.Name, nf.Namespace.VersionFormat) -} - -func NewNamespacedFeature(namespace *Namespace, feature *Feature) *NamespacedFeature { - // TODO: namespaced feature should use pointer values - return &NamespacedFeature{*feature, *namespace} -} - -// AffectedNamespacedFeature is a namespaced feature affected by the -// vulnerabilities with fixed-in versions for this feature. -type AffectedNamespacedFeature struct { - NamespacedFeature - - AffectedBy []VulnerabilityWithFixedIn -} - -// VulnerabilityWithFixedIn is used for AffectedNamespacedFeature to retrieve -// the affecting vulnerabilities and the fixed-in versions for the feature. -type VulnerabilityWithFixedIn struct { - Vulnerability - - FixedInVersion string -} - -// AffectedFeature is used to determine whether a namespaced feature is affected -// by a Vulnerability. Namespace and Feature Name is unique. Affected Feature is -// bound to vulnerability. -type AffectedFeature struct { - // FeatureType determines which type of package it affects. - FeatureType FeatureType - Namespace Namespace - FeatureName string - // FixedInVersion is known next feature version that's not affected by the - // vulnerability. Empty FixedInVersion means the unaffected version is - // unknown. - FixedInVersion string - // AffectedVersion contains the version range to determine whether or not a - // feature is affected. - AffectedVersion string -} - -// VulnerabilityID is an identifier for every vulnerability. Every vulnerability -// has unique namespace and name. -type VulnerabilityID struct { - Name string - Namespace string -} - -// Vulnerability represents CVE or similar vulnerability reports. -type Vulnerability struct { - Name string - Namespace Namespace - - Description string - Link string - Severity Severity - - Metadata MetadataMap -} - -// VulnerabilityWithAffected is a vulnerability with all known affected -// features. -type VulnerabilityWithAffected struct { - Vulnerability - - Affected []AffectedFeature -} - -// PagedVulnerableAncestries is a vulnerability with a page of affected -// ancestries each with a special index attached for streaming purpose. The -// current page number and next page number are for navigate. -type PagedVulnerableAncestries struct { - Vulnerability - - // Affected is a map of special indexes to Ancestries, which the pair - // should be unique in a stream. Every indexes in the map should be larger - // than previous page. - Affected map[int]string - - Limit int - Current pagination.Token - Next pagination.Token - - // End signals the end of the pages. - End bool -} - -// NotificationHook is a message sent to another service to inform of a change -// to a Vulnerability or the Ancestries affected by a Vulnerability. It contains -// the name of a notification that should be read and marked as read via the -// API. -type NotificationHook struct { - Name string - - Created time.Time - Notified time.Time - Deleted time.Time -} - -// VulnerabilityNotification is a notification for vulnerability changes. -type VulnerabilityNotification struct { - NotificationHook - - Old *Vulnerability - New *Vulnerability -} - -// VulnerabilityNotificationWithVulnerable is a notification for vulnerability -// changes with vulnerable ancestries. -type VulnerabilityNotificationWithVulnerable struct { - NotificationHook - - Old *PagedVulnerableAncestries - New *PagedVulnerableAncestries -} - -// MetadataMap is for storing the metadata returned by vulnerability database. -type MetadataMap map[string]interface{} - -// NullableAffectedNamespacedFeature is an affectednamespacedfeature with -// whether it's found in datastore. -type NullableAffectedNamespacedFeature struct { - AffectedNamespacedFeature - - Valid bool -} - -// NullableVulnerability is a vulnerability with whether the vulnerability is -// found in datastore. -type NullableVulnerability struct { - VulnerabilityWithAffected - - Valid bool -} - -func (mm *MetadataMap) Scan(value interface{}) error { - if value == nil { - return nil - } - - // github.com/lib/pq decodes TEXT/VARCHAR fields into strings. - val, ok := value.(string) - if !ok { - panic("got type other than []byte from database") - } - return json.Unmarshal([]byte(val), mm) -} - -func (mm *MetadataMap) Value() (driver.Value, error) { - json, err := json.Marshal(*mm) - return string(json), err -} diff --git a/database/namespace.go b/database/namespace.go new file mode 100644 index 00000000..33a9e2fb --- /dev/null +++ b/database/namespace.go @@ -0,0 +1,34 @@ +// Copyright 2019 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package database + +// Namespace is the contextual information around features. +// +// e.g. Debian:7, NodeJS. +type Namespace struct { + Name string `json:"name"` + VersionFormat string `json:"versionFormat"` +} + +func NewNamespace(name string, versionFormat string) *Namespace { + return &Namespace{name, versionFormat} +} + +func (ns *Namespace) Valid() bool { + if ns.Name == "" || ns.VersionFormat == "" { + return false + } + return true +} diff --git a/database/notification.go b/database/notification.go new file mode 100644 index 00000000..4b41ec25 --- /dev/null +++ b/database/notification.go @@ -0,0 +1,69 @@ +// Copyright 2019 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package database + +import ( + "time" + + "github.com/coreos/clair/pkg/pagination" +) + +// NotificationHook is a message sent to another service to inform of a change +// to a Vulnerability or the Ancestries affected by a Vulnerability. It contains +// the name of a notification that should be read and marked as read via the +// API. +type NotificationHook struct { + Name string + + Created time.Time + Notified time.Time + Deleted time.Time +} + +// VulnerabilityNotification is a notification for vulnerability changes. +type VulnerabilityNotification struct { + NotificationHook + + Old *Vulnerability + New *Vulnerability +} + +// VulnerabilityNotificationWithVulnerable is a notification for vulnerability +// changes with vulnerable ancestries. +type VulnerabilityNotificationWithVulnerable struct { + NotificationHook + + Old *PagedVulnerableAncestries + New *PagedVulnerableAncestries +} + +// PagedVulnerableAncestries is a vulnerability with a page of affected +// ancestries each with a special index attached for streaming purpose. The +// current page number and next page number are for navigate. +type PagedVulnerableAncestries struct { + Vulnerability + + // Affected is a map of special indexes to Ancestries, which the pair + // should be unique in a stream. Every indexes in the map should be larger + // than previous page. + Affected map[int]string + + Limit int + Current pagination.Token + Next pagination.Token + + // End signals the end of the pages. + End bool +} diff --git a/database/pgsql/ancestry.go b/database/pgsql/ancestry.go deleted file mode 100644 index 40a49e86..00000000 --- a/database/pgsql/ancestry.go +++ /dev/null @@ -1,375 +0,0 @@ -// Copyright 2019 clair authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package pgsql - -import ( - "database/sql" - "errors" - - log "github.com/sirupsen/logrus" - - "github.com/coreos/clair/database" - "github.com/coreos/clair/pkg/commonerr" -) - -const ( - insertAncestry = ` - INSERT INTO ancestry (name) VALUES ($1) RETURNING id` - - findAncestryLayerHashes = ` - SELECT layer.hash, ancestry_layer.ancestry_index - FROM layer, ancestry_layer - WHERE ancestry_layer.ancestry_id = $1 - AND ancestry_layer.layer_id = layer.id - ORDER BY ancestry_layer.ancestry_index ASC` - - findAncestryFeatures = ` - SELECT namespace.name, namespace.version_format, feature.name, - feature.version, feature.version_format, feature_type.name, ancestry_layer.ancestry_index, - ancestry_feature.feature_detector_id, ancestry_feature.namespace_detector_id - FROM namespace, feature, feature_type, namespaced_feature, ancestry_layer, ancestry_feature - WHERE ancestry_layer.ancestry_id = $1 - AND feature_type.id = feature.type - AND ancestry_feature.ancestry_layer_id = ancestry_layer.id - AND ancestry_feature.namespaced_feature_id = namespaced_feature.id - AND namespaced_feature.feature_id = feature.id - AND namespaced_feature.namespace_id = namespace.id` - - findAncestryID = `SELECT id FROM ancestry WHERE name = $1` - removeAncestry = `DELETE FROM ancestry WHERE name = $1` - insertAncestryLayers = ` - INSERT INTO ancestry_layer (ancestry_id, ancestry_index, layer_id) VALUES ($1, $2, $3) - RETURNING id` - insertAncestryFeatures = ` - INSERT INTO ancestry_feature - (ancestry_layer_id, namespaced_feature_id, feature_detector_id, namespace_detector_id) VALUES - ($1, $2, $3, $4)` -) - -func (tx *pgSession) FindAncestry(name string) (database.Ancestry, bool, error) { - var ( - ancestry = database.Ancestry{Name: name} - err error - ) - - id, ok, err := tx.findAncestryID(name) - if !ok || err != nil { - return ancestry, ok, err - } - - if ancestry.By, err = tx.findAncestryDetectors(id); err != nil { - return ancestry, false, err - } - - if ancestry.Layers, err = tx.findAncestryLayers(id); err != nil { - return ancestry, false, err - } - - return ancestry, true, nil -} - -func (tx *pgSession) UpsertAncestry(ancestry database.Ancestry) error { - if !ancestry.Valid() { - return database.ErrInvalidParameters - } - - if err := tx.removeAncestry(ancestry.Name); err != nil { - return err - } - - id, err := tx.insertAncestry(ancestry.Name) - if err != nil { - return err - } - - detectorIDs, err := tx.findDetectorIDs(ancestry.By) - if err != nil { - return err - } - - // insert ancestry metadata - if err := tx.insertAncestryDetectors(id, detectorIDs); err != nil { - return err - } - - layers := make([]string, 0, len(ancestry.Layers)) - for _, layer := range ancestry.Layers { - layers = append(layers, layer.Hash) - } - - layerIDs, ok, err := tx.findLayerIDs(layers) - if err != nil { - return err - } - - if !ok { - log.Error("layer cannot be found, this indicates that the internal logic of calling UpsertAncestry is wrong or the database is corrupted.") - return database.ErrMissingEntities - } - - ancestryLayerIDs, err := tx.insertAncestryLayers(id, layerIDs) - if err != nil { - return err - } - - for i, id := range ancestryLayerIDs { - if err := tx.insertAncestryFeatures(id, ancestry.Layers[i]); err != nil { - return err - } - } - - return nil -} - -func (tx *pgSession) insertAncestry(name string) (int64, error) { - var id int64 - err := tx.QueryRow(insertAncestry, name).Scan(&id) - if err != nil { - if isErrUniqueViolation(err) { - return 0, handleError("insertAncestry", errors.New("other Go-routine is processing this ancestry (skip)")) - } - - return 0, handleError("insertAncestry", err) - } - - return id, nil -} - -func (tx *pgSession) findAncestryID(name string) (int64, bool, error) { - var id sql.NullInt64 - if err := tx.QueryRow(findAncestryID, name).Scan(&id); err != nil { - if err == sql.ErrNoRows { - return 0, false, nil - } - - return 0, false, handleError("findAncestryID", err) - } - - return id.Int64, true, nil -} - -func (tx *pgSession) removeAncestry(name string) error { - result, err := tx.Exec(removeAncestry, name) - if err != nil { - return handleError("removeAncestry", err) - } - - affected, err := result.RowsAffected() - if err != nil { - return handleError("removeAncestry", err) - } - - if affected != 0 { - log.WithField("ancestry", name).Debug("removed ancestry") - } - - return nil -} - -func (tx *pgSession) findAncestryLayers(id int64) ([]database.AncestryLayer, error) { - detectors, err := tx.findAllDetectors() - if err != nil { - return nil, err - } - - layerMap, err := tx.findAncestryLayerHashes(id) - if err != nil { - return nil, err - } - - featureMap, err := tx.findAncestryFeatures(id, detectors) - if err != nil { - return nil, err - } - - layers := make([]database.AncestryLayer, len(layerMap)) - for index, layer := range layerMap { - // index MUST match the ancestry layer slice index. - if layers[index].Hash == "" && len(layers[index].Features) == 0 { - layers[index] = database.AncestryLayer{ - Hash: layer, - Features: featureMap[index], - } - } else { - log.WithFields(log.Fields{ - "ancestry ID": id, - "duplicated ancestry index": index, - }).WithError(database.ErrInconsistent).Error("ancestry layers with same ancestry_index is not allowed") - return nil, database.ErrInconsistent - } - } - - return layers, nil -} - -func (tx *pgSession) findAncestryLayerHashes(ancestryID int64) (map[int64]string, error) { - // retrieve layer indexes and hashes - rows, err := tx.Query(findAncestryLayerHashes, ancestryID) - if err != nil { - return nil, handleError("findAncestryLayerHashes", err) - } - - layerHashes := map[int64]string{} - for rows.Next() { - var ( - hash string - index int64 - ) - - if err = rows.Scan(&hash, &index); err != nil { - return nil, handleError("findAncestryLayerHashes", err) - } - - if _, ok := layerHashes[index]; ok { - // one ancestry index should correspond to only one layer - return nil, database.ErrInconsistent - } - - layerHashes[index] = hash - } - - return layerHashes, nil -} - -func (tx *pgSession) findAncestryFeatures(ancestryID int64, detectors detectorMap) (map[int64][]database.AncestryFeature, error) { - // ancestry_index -> ancestry features - featureMap := make(map[int64][]database.AncestryFeature) - // retrieve ancestry layer's namespaced features - rows, err := tx.Query(findAncestryFeatures, ancestryID) - if err != nil { - return nil, handleError("findAncestryFeatures", err) - } - - defer rows.Close() - - for rows.Next() { - var ( - featureDetectorID int64 - namespaceDetectorID int64 - feature database.NamespacedFeature - // index is used to determine which layer the feature belongs to. - index sql.NullInt64 - ) - - if err := rows.Scan( - &feature.Namespace.Name, - &feature.Namespace.VersionFormat, - &feature.Feature.Name, - &feature.Feature.Version, - &feature.Feature.VersionFormat, - &feature.Feature.Type, - &index, - &featureDetectorID, - &namespaceDetectorID, - ); err != nil { - return nil, handleError("findAncestryFeatures", err) - } - - if feature.Feature.VersionFormat != feature.Namespace.VersionFormat { - // Feature must have the same version format as the associated - // namespace version format. - return nil, database.ErrInconsistent - } - - fDetector, ok := detectors.byID[featureDetectorID] - if !ok { - return nil, database.ErrInconsistent - } - - nsDetector, ok := detectors.byID[namespaceDetectorID] - if !ok { - return nil, database.ErrInconsistent - } - - featureMap[index.Int64] = append(featureMap[index.Int64], database.AncestryFeature{ - NamespacedFeature: feature, - FeatureBy: fDetector, - NamespaceBy: nsDetector, - }) - } - - return featureMap, nil -} - -// insertAncestryLayers inserts the ancestry layers along with its content into -// the database. The layers are 0 based indexed in the original order. -func (tx *pgSession) insertAncestryLayers(ancestryID int64, layers []int64) ([]int64, error) { - stmt, err := tx.Prepare(insertAncestryLayers) - if err != nil { - return nil, handleError("insertAncestryLayers", err) - } - - ancestryLayerIDs := []int64{} - for index, layerID := range layers { - var ancestryLayerID sql.NullInt64 - if err := stmt.QueryRow(ancestryID, index, layerID).Scan(&ancestryLayerID); err != nil { - return nil, handleError("insertAncestryLayers", commonerr.CombineErrors(err, stmt.Close())) - } - - if !ancestryLayerID.Valid { - return nil, database.ErrInconsistent - } - - ancestryLayerIDs = append(ancestryLayerIDs, ancestryLayerID.Int64) - } - - if err := stmt.Close(); err != nil { - return nil, handleError("insertAncestryLayers", err) - } - - return ancestryLayerIDs, nil -} - -func (tx *pgSession) insertAncestryFeatures(ancestryLayerID int64, layer database.AncestryLayer) error { - detectors, err := tx.findAllDetectors() - if err != nil { - return err - } - - nsFeatureIDs, err := tx.findNamespacedFeatureIDs(layer.GetFeatures()) - if err != nil { - return err - } - - // find the detectors for each feature - stmt, err := tx.Prepare(insertAncestryFeatures) - if err != nil { - return handleError("insertAncestryFeatures", err) - } - - defer stmt.Close() - - for index, id := range nsFeatureIDs { - if !id.Valid { - return database.ErrMissingEntities - } - - namespaceDetectorID, ok := detectors.byValue[layer.Features[index].NamespaceBy] - if !ok { - return database.ErrMissingEntities - } - - featureDetectorID, ok := detectors.byValue[layer.Features[index].FeatureBy] - if !ok { - return database.ErrMissingEntities - } - - if _, err := stmt.Exec(ancestryLayerID, id, featureDetectorID, namespaceDetectorID); err != nil { - return handleError("insertAncestryFeatures", commonerr.CombineErrors(err, stmt.Close())) - } - } - - return nil -} diff --git a/database/pgsql/ancestry/ancestry.go b/database/pgsql/ancestry/ancestry.go new file mode 100644 index 00000000..74fe0fd6 --- /dev/null +++ b/database/pgsql/ancestry/ancestry.go @@ -0,0 +1,160 @@ +// Copyright 2019 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ancestry + +import ( + "database/sql" + "errors" + + log "github.com/sirupsen/logrus" + + "github.com/coreos/clair/database" + "github.com/coreos/clair/database/pgsql/detector" + "github.com/coreos/clair/database/pgsql/layer" + "github.com/coreos/clair/database/pgsql/util" +) + +const ( + insertAncestry = ` + INSERT INTO ancestry (name) VALUES ($1) RETURNING id` + + findAncestryID = `SELECT id FROM ancestry WHERE name = $1` + removeAncestry = `DELETE FROM ancestry WHERE name = $1` + + insertAncestryFeatures = ` + INSERT INTO ancestry_feature + (ancestry_layer_id, namespaced_feature_id, feature_detector_id, namespace_detector_id) VALUES + ($1, $2, $3, $4)` +) + +func FindAncestry(tx *sql.Tx, name string) (database.Ancestry, bool, error) { + var ( + ancestry = database.Ancestry{Name: name} + err error + ) + + id, ok, err := FindAncestryID(tx, name) + if !ok || err != nil { + return ancestry, ok, err + } + + if ancestry.By, err = FindAncestryDetectors(tx, id); err != nil { + return ancestry, false, err + } + + if ancestry.Layers, err = FindAncestryLayers(tx, id); err != nil { + return ancestry, false, err + } + + return ancestry, true, nil +} + +func UpsertAncestry(tx *sql.Tx, ancestry database.Ancestry) error { + if !ancestry.Valid() { + return database.ErrInvalidParameters + } + + if err := RemoveAncestry(tx, ancestry.Name); err != nil { + return err + } + + id, err := InsertAncestry(tx, ancestry.Name) + if err != nil { + return err + } + + detectorIDs, err := detector.FindDetectorIDs(tx, ancestry.By) + if err != nil { + return err + } + + // insert ancestry metadata + if err := InsertAncestryDetectors(tx, id, detectorIDs); err != nil { + return err + } + + layers := make([]string, 0, len(ancestry.Layers)) + for _, l := range ancestry.Layers { + layers = append(layers, l.Hash) + } + + layerIDs, ok, err := layer.FindLayerIDs(tx, layers) + if err != nil { + return err + } + + if !ok { + log.Error("layer cannot be found, this indicates that the internal logic of calling UpsertAncestry is wrong or the database is corrupted.") + return database.ErrMissingEntities + } + + ancestryLayerIDs, err := InsertAncestryLayers(tx, id, layerIDs) + if err != nil { + return err + } + + for i, id := range ancestryLayerIDs { + if err := InsertAncestryFeatures(tx, id, ancestry.Layers[i]); err != nil { + return err + } + } + + return nil +} + +func InsertAncestry(tx *sql.Tx, name string) (int64, error) { + var id int64 + err := tx.QueryRow(insertAncestry, name).Scan(&id) + if err != nil { + if util.IsErrUniqueViolation(err) { + return 0, util.HandleError("insertAncestry", errors.New("other Go-routine is processing this ancestry (skip)")) + } + + return 0, util.HandleError("insertAncestry", err) + } + + return id, nil +} + +func FindAncestryID(tx *sql.Tx, name string) (int64, bool, error) { + var id sql.NullInt64 + if err := tx.QueryRow(findAncestryID, name).Scan(&id); err != nil { + if err == sql.ErrNoRows { + return 0, false, nil + } + + return 0, false, util.HandleError("findAncestryID", err) + } + + return id.Int64, true, nil +} + +func RemoveAncestry(tx *sql.Tx, name string) error { + result, err := tx.Exec(removeAncestry, name) + if err != nil { + return util.HandleError("removeAncestry", err) + } + + affected, err := result.RowsAffected() + if err != nil { + return util.HandleError("removeAncestry", err) + } + + if affected != 0 { + log.WithField("ancestry", name).Debug("removed ancestry") + } + + return nil +} diff --git a/database/pgsql/ancestry/ancestry_detector.go b/database/pgsql/ancestry/ancestry_detector.go new file mode 100644 index 00000000..f51adf60 --- /dev/null +++ b/database/pgsql/ancestry/ancestry_detector.go @@ -0,0 +1,48 @@ +// Copyright 2019 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ancestry + +import ( + "database/sql" + + "github.com/coreos/clair/database" + "github.com/coreos/clair/database/pgsql/detector" + "github.com/coreos/clair/database/pgsql/util" +) + +var selectAncestryDetectors = ` +SELECT d.name, d.version, d.dtype + FROM ancestry_detector, detector AS d + WHERE ancestry_detector.detector_id = d.id AND ancestry_detector.ancestry_id = $1;` + +var insertAncestryDetectors = ` + INSERT INTO ancestry_detector (ancestry_id, detector_id) + SELECT $1, $2 + WHERE NOT EXISTS (SELECT id FROM ancestry_detector WHERE ancestry_id = $1 AND detector_id = $2)` + +func FindAncestryDetectors(tx *sql.Tx, id int64) ([]database.Detector, error) { + detectors, err := detector.GetDetectors(tx, selectAncestryDetectors, id) + return detectors, err +} + +func InsertAncestryDetectors(tx *sql.Tx, ancestryID int64, detectorIDs []int64) error { + for _, detectorID := range detectorIDs { + if _, err := tx.Exec(insertAncestryDetectors, ancestryID, detectorID); err != nil { + return util.HandleError("insertAncestryDetectors", err) + } + } + + return nil +} diff --git a/database/pgsql/ancestry/ancestry_feature.go b/database/pgsql/ancestry/ancestry_feature.go new file mode 100644 index 00000000..33096d20 --- /dev/null +++ b/database/pgsql/ancestry/ancestry_feature.go @@ -0,0 +1,139 @@ +// Copyright 2019 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ancestry + +import ( + "database/sql" + + "github.com/coreos/clair/database" + "github.com/coreos/clair/database/pgsql/detector" + "github.com/coreos/clair/database/pgsql/feature" + "github.com/coreos/clair/database/pgsql/util" + "github.com/coreos/clair/pkg/commonerr" +) + +const findAncestryFeatures = ` + SELECT namespace.name, namespace.version_format, feature.name, + feature.version, feature.version_format, feature_type.name, ancestry_layer.ancestry_index, + ancestry_feature.feature_detector_id, ancestry_feature.namespace_detector_id + FROM namespace, feature, feature_type, namespaced_feature, ancestry_layer, ancestry_feature + WHERE ancestry_layer.ancestry_id = $1 + AND feature_type.id = feature.type + AND ancestry_feature.ancestry_layer_id = ancestry_layer.id + AND ancestry_feature.namespaced_feature_id = namespaced_feature.id + AND namespaced_feature.feature_id = feature.id + AND namespaced_feature.namespace_id = namespace.id` + +func FindAncestryFeatures(tx *sql.Tx, ancestryID int64, detectors detector.DetectorMap) (map[int64][]database.AncestryFeature, error) { + // ancestry_index -> ancestry features + featureMap := make(map[int64][]database.AncestryFeature) + // retrieve ancestry layer's namespaced features + rows, err := tx.Query(findAncestryFeatures, ancestryID) + if err != nil { + return nil, util.HandleError("findAncestryFeatures", err) + } + + defer rows.Close() + + for rows.Next() { + var ( + featureDetectorID int64 + namespaceDetectorID int64 + feature database.NamespacedFeature + // index is used to determine which layer the feature belongs to. + index sql.NullInt64 + ) + + if err := rows.Scan( + &feature.Namespace.Name, + &feature.Namespace.VersionFormat, + &feature.Feature.Name, + &feature.Feature.Version, + &feature.Feature.VersionFormat, + &feature.Feature.Type, + &index, + &featureDetectorID, + &namespaceDetectorID, + ); err != nil { + return nil, util.HandleError("findAncestryFeatures", err) + } + + if feature.Feature.VersionFormat != feature.Namespace.VersionFormat { + // Feature must have the same version format as the associated + // namespace version format. + return nil, database.ErrInconsistent + } + + fDetector, ok := detectors.ByID[featureDetectorID] + if !ok { + return nil, database.ErrInconsistent + } + + nsDetector, ok := detectors.ByID[namespaceDetectorID] + if !ok { + return nil, database.ErrInconsistent + } + + featureMap[index.Int64] = append(featureMap[index.Int64], database.AncestryFeature{ + NamespacedFeature: feature, + FeatureBy: fDetector, + NamespaceBy: nsDetector, + }) + } + + return featureMap, nil +} + +func InsertAncestryFeatures(tx *sql.Tx, ancestryLayerID int64, layer database.AncestryLayer) error { + detectors, err := detector.FindAllDetectors(tx) + if err != nil { + return err + } + + nsFeatureIDs, err := feature.FindNamespacedFeatureIDs(tx, layer.GetFeatures()) + if err != nil { + return err + } + + // find the detectors for each feature + stmt, err := tx.Prepare(insertAncestryFeatures) + if err != nil { + return util.HandleError("insertAncestryFeatures", err) + } + + defer stmt.Close() + + for index, id := range nsFeatureIDs { + if !id.Valid { + return database.ErrMissingEntities + } + + namespaceDetectorID, ok := detectors.ByValue[layer.Features[index].NamespaceBy] + if !ok { + return database.ErrMissingEntities + } + + featureDetectorID, ok := detectors.ByValue[layer.Features[index].FeatureBy] + if !ok { + return database.ErrMissingEntities + } + + if _, err := stmt.Exec(ancestryLayerID, id, featureDetectorID, namespaceDetectorID); err != nil { + return util.HandleError("insertAncestryFeatures", commonerr.CombineErrors(err, stmt.Close())) + } + } + + return nil +} diff --git a/database/pgsql/ancestry/ancestry_layer.go b/database/pgsql/ancestry/ancestry_layer.go new file mode 100644 index 00000000..be053037 --- /dev/null +++ b/database/pgsql/ancestry/ancestry_layer.go @@ -0,0 +1,131 @@ +// Copyright 2019 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ancestry + +import ( + "database/sql" + + "github.com/coreos/clair/database" + "github.com/coreos/clair/database/pgsql/detector" + "github.com/coreos/clair/database/pgsql/util" + "github.com/coreos/clair/pkg/commonerr" + log "github.com/sirupsen/logrus" +) + +const ( + findAncestryLayerHashes = ` + SELECT layer.hash, ancestry_layer.ancestry_index + FROM layer, ancestry_layer + WHERE ancestry_layer.ancestry_id = $1 + AND ancestry_layer.layer_id = layer.id + ORDER BY ancestry_layer.ancestry_index ASC` + insertAncestryLayers = ` + INSERT INTO ancestry_layer (ancestry_id, ancestry_index, layer_id) VALUES ($1, $2, $3) + RETURNING id` +) + +func FindAncestryLayerHashes(tx *sql.Tx, ancestryID int64) (map[int64]string, error) { + // retrieve layer indexes and hashes + rows, err := tx.Query(findAncestryLayerHashes, ancestryID) + if err != nil { + return nil, util.HandleError("findAncestryLayerHashes", err) + } + + layerHashes := map[int64]string{} + for rows.Next() { + var ( + hash string + index int64 + ) + + if err = rows.Scan(&hash, &index); err != nil { + return nil, util.HandleError("findAncestryLayerHashes", err) + } + + if _, ok := layerHashes[index]; ok { + // one ancestry index should correspond to only one layer + return nil, database.ErrInconsistent + } + + layerHashes[index] = hash + } + + return layerHashes, nil +} + +// insertAncestryLayers inserts the ancestry layers along with its content into +// the database. The layers are 0 based indexed in the original order. +func InsertAncestryLayers(tx *sql.Tx, ancestryID int64, layers []int64) ([]int64, error) { + stmt, err := tx.Prepare(insertAncestryLayers) + if err != nil { + return nil, util.HandleError("insertAncestryLayers", err) + } + + ancestryLayerIDs := []int64{} + for index, layerID := range layers { + var ancestryLayerID sql.NullInt64 + if err := stmt.QueryRow(ancestryID, index, layerID).Scan(&ancestryLayerID); err != nil { + return nil, util.HandleError("insertAncestryLayers", commonerr.CombineErrors(err, stmt.Close())) + } + + if !ancestryLayerID.Valid { + return nil, database.ErrInconsistent + } + + ancestryLayerIDs = append(ancestryLayerIDs, ancestryLayerID.Int64) + } + + if err := stmt.Close(); err != nil { + return nil, util.HandleError("insertAncestryLayers", err) + } + + return ancestryLayerIDs, nil +} + +func FindAncestryLayers(tx *sql.Tx, id int64) ([]database.AncestryLayer, error) { + detectors, err := detector.FindAllDetectors(tx) + if err != nil { + return nil, err + } + + layerMap, err := FindAncestryLayerHashes(tx, id) + if err != nil { + return nil, err + } + + featureMap, err := FindAncestryFeatures(tx, id, detectors) + if err != nil { + return nil, err + } + + layers := make([]database.AncestryLayer, len(layerMap)) + for index, layer := range layerMap { + // index MUST match the ancestry layer slice index. + if layers[index].Hash == "" && len(layers[index].Features) == 0 { + layers[index] = database.AncestryLayer{ + Hash: layer, + Features: featureMap[index], + } + } else { + log.WithFields(log.Fields{ + "ancestry ID": id, + "duplicated ancestry index": index, + }).WithError(database.ErrInconsistent).Error("ancestry layers with same ancestry_index is not allowed") + return nil, database.ErrInconsistent + } + } + + return layers, nil +} diff --git a/database/pgsql/ancestry_test.go b/database/pgsql/ancestry/ancestry_test.go similarity index 74% rename from database/pgsql/ancestry_test.go rename to database/pgsql/ancestry/ancestry_test.go index 7f0a37f8..992c80bf 100644 --- a/database/pgsql/ancestry_test.go +++ b/database/pgsql/ancestry/ancestry_test.go @@ -1,4 +1,4 @@ -// Copyright 2017 clair authors +// Copyright 2019 clair authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package pgsql +package ancestry import ( "testing" @@ -20,6 +20,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/coreos/clair/database" + "github.com/coreos/clair/database/pgsql/testutil" ) var upsertAncestryTests = []struct { @@ -55,9 +56,9 @@ var upsertAncestryTests = []struct { title: "ancestry with invalid feature", in: &database.Ancestry{ Name: "a", - By: []database.Detector{realDetectors[1], realDetectors[2]}, + By: []database.Detector{testutil.RealDetectors[1], testutil.RealDetectors[2]}, Layers: []database.AncestryLayer{{Hash: "layer-1", Features: []database.AncestryFeature{ - {fakeNamespacedFeatures[1], fakeDetector[1], fakeDetector[2]}, + {testutil.FakeNamespacedFeatures[1], testutil.FakeDetector[1], testutil.FakeDetector[2]}, }}}, }, err: database.ErrMissingEntities.Error(), @@ -66,26 +67,27 @@ var upsertAncestryTests = []struct { title: "replace old ancestry", in: &database.Ancestry{ Name: "a", - By: []database.Detector{realDetectors[1], realDetectors[2]}, + By: []database.Detector{testutil.RealDetectors[1], testutil.RealDetectors[2]}, Layers: []database.AncestryLayer{ - {"layer-1", []database.AncestryFeature{{realNamespacedFeatures[1], realDetectors[2], realDetectors[1]}}}, + {"layer-1", []database.AncestryFeature{{testutil.RealNamespacedFeatures[1], testutil.RealDetectors[2], testutil.RealDetectors[1]}}}, }, }, }, } func TestUpsertAncestry(t *testing.T) { - store, tx := openSessionForTest(t, "UpsertAncestry", true) - defer closeTest(t, store, tx) + tx, cleanup := testutil.CreateTestTxWithFixtures(t, "TestUpsertAncestry") + defer cleanup() + for _, test := range upsertAncestryTests { t.Run(test.title, func(t *testing.T) { - err := tx.UpsertAncestry(*test.in) + err := UpsertAncestry(tx, *test.in) if test.err != "" { assert.EqualError(t, err, test.err, "unexpected error") return } assert.Nil(t, err) - actual, ok, err := tx.FindAncestry(test.in.Name) + actual, ok, err := FindAncestry(tx, test.in.Name) assert.Nil(t, err) assert.True(t, ok) database.AssertAncestryEqual(t, test.in, &actual) @@ -113,16 +115,17 @@ var findAncestryTests = []struct { in: "ancestry-2", err: "", ok: true, - ancestry: takeAncestryPointerFromMap(realAncestries, 2), + ancestry: testutil.TakeAncestryPointerFromMap(testutil.RealAncestries, 2), }, } func TestFindAncestry(t *testing.T) { - store, tx := openSessionForTest(t, "FindAncestry", true) - defer closeTest(t, store, tx) + tx, cleanup := testutil.CreateTestTxWithFixtures(t, "TestFindAncestry") + defer cleanup() + for _, test := range findAncestryTests { t.Run(test.title, func(t *testing.T) { - ancestry, ok, err := tx.FindAncestry(test.in) + ancestry, ok, err := FindAncestry(tx, test.in) if test.err != "" { assert.EqualError(t, err, test.err, "unexpected error") return diff --git a/database/pgsql/complex_test.go b/database/pgsql/complex_test.go deleted file mode 100644 index 5b42ccfa..00000000 --- a/database/pgsql/complex_test.go +++ /dev/null @@ -1,172 +0,0 @@ -// Copyright 2017 clair authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package pgsql - -import ( - "math/rand" - "runtime" - "strconv" - "sync" - "testing" - "time" - - "github.com/pborman/uuid" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/coreos/clair/database" - "github.com/coreos/clair/ext/versionfmt" - "github.com/coreos/clair/ext/versionfmt/dpkg" - "github.com/coreos/clair/pkg/strutil" -) - -const ( - numVulnerabilities = 100 - numFeatures = 100 -) - -func testGenRandomVulnerabilityAndNamespacedFeature(t *testing.T, store database.Datastore) ([]database.NamespacedFeature, []database.VulnerabilityWithAffected) { - tx, err := store.Begin() - if !assert.Nil(t, err) { - t.FailNow() - } - - featureName := "TestFeature" - featureVersionFormat := dpkg.ParserName - // Insert the namespace on which we'll work. - namespace := database.Namespace{ - Name: "TestRaceAffectsFeatureNamespace1", - VersionFormat: dpkg.ParserName, - } - - if !assert.Nil(t, tx.PersistNamespaces([]database.Namespace{namespace})) { - t.FailNow() - } - - // Initialize random generator and enforce max procs. - rand.Seed(time.Now().UnixNano()) - runtime.GOMAXPROCS(runtime.NumCPU()) - - // Generate Distinct random features - features := make([]database.Feature, numFeatures) - nsFeatures := make([]database.NamespacedFeature, numFeatures) - for i := 0; i < numFeatures; i++ { - version := rand.Intn(numFeatures) - - features[i] = *database.NewSourcePackage(featureName, strconv.Itoa(version), featureVersionFormat) - nsFeatures[i] = database.NamespacedFeature{ - Namespace: namespace, - Feature: features[i], - } - } - - if !assert.Nil(t, tx.PersistFeatures(features)) { - t.FailNow() - } - - // Generate vulnerabilities. - vulnerabilities := []database.VulnerabilityWithAffected{} - for i := 0; i < numVulnerabilities; i++ { - // any version less than this is vulnerable - version := rand.Intn(numFeatures) + 1 - - vulnerability := database.VulnerabilityWithAffected{ - Vulnerability: database.Vulnerability{ - Name: uuid.New(), - Namespace: namespace, - Severity: database.UnknownSeverity, - }, - Affected: []database.AffectedFeature{ - { - Namespace: namespace, - FeatureName: featureName, - FeatureType: database.SourcePackage, - AffectedVersion: strconv.Itoa(version), - FixedInVersion: strconv.Itoa(version), - }, - }, - } - - vulnerabilities = append(vulnerabilities, vulnerability) - } - tx.Commit() - - return nsFeatures, vulnerabilities -} - -func TestConcurrency(t *testing.T) { - store, cleanup := createTestPgSQL(t, "concurrency") - defer cleanup() - - var wg sync.WaitGroup - // there's a limit on the number of concurrent connections in the pool - wg.Add(30) - for i := 0; i < 30; i++ { - go func() { - defer wg.Done() - nsNamespaces := genRandomNamespaces(t, 100) - tx, err := store.Begin() - require.Nil(t, err) - require.Nil(t, tx.PersistNamespaces(nsNamespaces)) - require.Nil(t, tx.Commit()) - }() - } - - wg.Wait() -} - -func TestCaching(t *testing.T) { - store, cleanup := createTestPgSQL(t, "caching") - defer cleanup() - - nsFeatures, vulnerabilities := testGenRandomVulnerabilityAndNamespacedFeature(t, store) - tx, err := store.Begin() - require.Nil(t, err) - - require.Nil(t, tx.PersistNamespacedFeatures(nsFeatures)) - require.Nil(t, tx.Commit()) - - tx, err = store.Begin() - require.Nil(t, tx.InsertVulnerabilities(vulnerabilities)) - require.Nil(t, tx.Commit()) - - tx, err = store.Begin() - require.Nil(t, err) - defer tx.Rollback() - - affected, err := tx.FindAffectedNamespacedFeatures(nsFeatures) - require.Nil(t, err) - - for _, ansf := range affected { - require.True(t, ansf.Valid) - - expectedAffectedNames := []string{} - for _, vuln := range vulnerabilities { - if ok, err := versionfmt.InRange(dpkg.ParserName, ansf.Version, vuln.Affected[0].AffectedVersion); err == nil { - if ok { - expectedAffectedNames = append(expectedAffectedNames, vuln.Name) - } - } - } - - actualAffectedNames := []string{} - for _, s := range ansf.AffectedBy { - actualAffectedNames = append(actualAffectedNames, s.Name) - } - - require.Len(t, strutil.Difference(expectedAffectedNames, actualAffectedNames), 0, "\nvulns: %#v\nfeature:%#v\nexpected:%#v\nactual:%#v", vulnerabilities, ansf.NamespacedFeature, expectedAffectedNames, actualAffectedNames) - require.Len(t, strutil.Difference(actualAffectedNames, expectedAffectedNames), 0) - } -} diff --git a/database/pgsql/detector.go b/database/pgsql/detector.go deleted file mode 100644 index 3209d632..00000000 --- a/database/pgsql/detector.go +++ /dev/null @@ -1,196 +0,0 @@ -// Copyright 2018 clair authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package pgsql - -import ( - "database/sql" - - "github.com/deckarep/golang-set" - log "github.com/sirupsen/logrus" - - "github.com/coreos/clair/database" -) - -const ( - soiDetector = ` - INSERT INTO detector (name, version, dtype) - SELECT CAST ($1 AS TEXT), CAST ($2 AS TEXT), CAST ($3 AS detector_type ) - WHERE NOT EXISTS (SELECT id FROM detector WHERE name = $1 AND version = $2 AND dtype = $3);` - - selectAncestryDetectors = ` - SELECT d.name, d.version, d.dtype - FROM ancestry_detector, detector AS d - WHERE ancestry_detector.detector_id = d.id AND ancestry_detector.ancestry_id = $1;` - - selectLayerDetectors = ` - SELECT d.name, d.version, d.dtype - FROM layer_detector, detector AS d - WHERE layer_detector.detector_id = d.id AND layer_detector.layer_id = $1;` - - insertAncestryDetectors = ` - INSERT INTO ancestry_detector (ancestry_id, detector_id) - SELECT $1, $2 - WHERE NOT EXISTS (SELECT id FROM ancestry_detector WHERE ancestry_id = $1 AND detector_id = $2)` - - persistLayerDetector = ` - INSERT INTO layer_detector (layer_id, detector_id) - SELECT $1, $2 - WHERE NOT EXISTS (SELECT id FROM layer_detector WHERE layer_id = $1 AND detector_id = $2)` - - findDetectorID = `SELECT id FROM detector WHERE name = $1 AND version = $2 AND dtype = $3` - findAllDetectors = `SELECT id, name, version, dtype FROM detector` -) - -type detectorMap struct { - byID map[int64]database.Detector - byValue map[database.Detector]int64 -} - -func (tx *pgSession) PersistDetectors(detectors []database.Detector) error { - for _, d := range detectors { - if !d.Valid() { - log.WithField("detector", d).Debug("Invalid Detector") - return database.ErrInvalidParameters - } - - r, err := tx.Exec(soiDetector, d.Name, d.Version, d.DType) - if err != nil { - return handleError("soiDetector", err) - } - - count, err := r.RowsAffected() - if err != nil { - return handleError("soiDetector", err) - } - - if count == 0 { - log.Debug("detector already exists: ", d) - } - } - - return nil -} - -func (tx *pgSession) persistLayerDetector(layerID int64, detectorID int64) error { - if _, err := tx.Exec(persistLayerDetector, layerID, detectorID); err != nil { - return handleError("persistLayerDetector", err) - } - - return nil -} - -func (tx *pgSession) persistLayerDetectors(layerID int64, detectorIDs []int64) error { - alreadySaved := mapset.NewSet() - for _, id := range detectorIDs { - if alreadySaved.Contains(id) { - continue - } - - alreadySaved.Add(id) - if err := tx.persistLayerDetector(layerID, id); err != nil { - return err - } - } - - return nil -} - -func (tx *pgSession) insertAncestryDetectors(ancestryID int64, detectorIDs []int64) error { - for _, detectorID := range detectorIDs { - if _, err := tx.Exec(insertAncestryDetectors, ancestryID, detectorID); err != nil { - return handleError("insertAncestryDetectors", err) - } - } - - return nil -} - -func (tx *pgSession) findAncestryDetectors(id int64) ([]database.Detector, error) { - detectors, err := tx.getDetectors(selectAncestryDetectors, id) - return detectors, err -} - -func (tx *pgSession) findLayerDetectors(id int64) ([]database.Detector, error) { - detectors, err := tx.getDetectors(selectLayerDetectors, id) - return detectors, err -} - -// findDetectorIDs retrieve ids of the detectors from the database, if any is not -// found, return the error. -func (tx *pgSession) findDetectorIDs(detectors []database.Detector) ([]int64, error) { - ids := []int64{} - for _, d := range detectors { - id := sql.NullInt64{} - err := tx.QueryRow(findDetectorID, d.Name, d.Version, d.DType).Scan(&id) - if err != nil { - return nil, handleError("findDetectorID", err) - } - - if !id.Valid { - return nil, database.ErrInconsistent - } - - ids = append(ids, id.Int64) - } - - return ids, nil -} - -func (tx *pgSession) getDetectors(query string, id int64) ([]database.Detector, error) { - rows, err := tx.Query(query, id) - if err != nil { - return nil, handleError("getDetectors", err) - } - - detectors := []database.Detector{} - for rows.Next() { - d := database.Detector{} - err := rows.Scan(&d.Name, &d.Version, &d.DType) - if err != nil { - return nil, handleError("getDetectors", err) - } - - if !d.Valid() { - return nil, database.ErrInvalidDetector - } - - detectors = append(detectors, d) - } - - return detectors, nil -} - -func (tx *pgSession) findAllDetectors() (detectorMap, error) { - rows, err := tx.Query(findAllDetectors) - if err != nil { - return detectorMap{}, handleError("searchAllDetectors", err) - } - - detectors := detectorMap{byID: make(map[int64]database.Detector), byValue: make(map[database.Detector]int64)} - for rows.Next() { - var ( - id int64 - d database.Detector - ) - if err := rows.Scan(&id, &d.Name, &d.Version, &d.DType); err != nil { - return detectorMap{}, handleError("searchAllDetectors", err) - } - - detectors.byID[id] = d - detectors.byValue[d] = id - } - - return detectors, nil -} diff --git a/database/pgsql/detector/detector.go b/database/pgsql/detector/detector.go new file mode 100644 index 00000000..08643c90 --- /dev/null +++ b/database/pgsql/detector/detector.go @@ -0,0 +1,132 @@ +// Copyright 2018 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package detector + +import ( + "database/sql" + + log "github.com/sirupsen/logrus" + + "github.com/coreos/clair/database" + "github.com/coreos/clair/database/pgsql/util" +) + +const ( + soiDetector = ` + INSERT INTO detector (name, version, dtype) + SELECT CAST ($1 AS TEXT), CAST ($2 AS TEXT), CAST ($3 AS detector_type ) + WHERE NOT EXISTS (SELECT id FROM detector WHERE name = $1 AND version = $2 AND dtype = $3);` + + findDetectorID = `SELECT id FROM detector WHERE name = $1 AND version = $2 AND dtype = $3` + findAllDetectors = `SELECT id, name, version, dtype FROM detector` +) + +type DetectorMap struct { + ByID map[int64]database.Detector + ByValue map[database.Detector]int64 +} + +func PersistDetectors(tx *sql.Tx, detectors []database.Detector) error { + for _, d := range detectors { + if !d.Valid() { + log.WithField("detector", d).Debug("Invalid Detector") + return database.ErrInvalidParameters + } + + r, err := tx.Exec(soiDetector, d.Name, d.Version, d.DType) + if err != nil { + return util.HandleError("soiDetector", err) + } + + count, err := r.RowsAffected() + if err != nil { + return util.HandleError("soiDetector", err) + } + + if count == 0 { + log.Debug("detector already exists: ", d) + } + } + + return nil +} + +// findDetectorIDs retrieve ids of the detectors from the database, if any is not +// found, return the error. +func FindDetectorIDs(tx *sql.Tx, detectors []database.Detector) ([]int64, error) { + ids := []int64{} + for _, d := range detectors { + id := sql.NullInt64{} + err := tx.QueryRow(findDetectorID, d.Name, d.Version, d.DType).Scan(&id) + if err != nil { + return nil, util.HandleError("findDetectorID", err) + } + + if !id.Valid { + return nil, database.ErrInconsistent + } + + ids = append(ids, id.Int64) + } + + return ids, nil +} + +func GetDetectors(tx *sql.Tx, query string, id int64) ([]database.Detector, error) { + rows, err := tx.Query(query, id) + if err != nil { + return nil, util.HandleError("getDetectors", err) + } + + detectors := []database.Detector{} + for rows.Next() { + d := database.Detector{} + err := rows.Scan(&d.Name, &d.Version, &d.DType) + if err != nil { + return nil, util.HandleError("getDetectors", err) + } + + if !d.Valid() { + return nil, database.ErrInvalidDetector + } + + detectors = append(detectors, d) + } + + return detectors, nil +} + +func FindAllDetectors(tx *sql.Tx) (DetectorMap, error) { + rows, err := tx.Query(findAllDetectors) + if err != nil { + return DetectorMap{}, util.HandleError("searchAllDetectors", err) + } + + detectors := DetectorMap{ByID: make(map[int64]database.Detector), ByValue: make(map[database.Detector]int64)} + for rows.Next() { + var ( + id int64 + d database.Detector + ) + if err := rows.Scan(&id, &d.Name, &d.Version, &d.DType); err != nil { + return DetectorMap{}, util.HandleError("searchAllDetectors", err) + } + + detectors.ByID[id] = d + detectors.ByValue[d] = id + } + + return detectors, nil +} diff --git a/database/pgsql/detector_test.go b/database/pgsql/detector/detector_test.go similarity index 91% rename from database/pgsql/detector_test.go rename to database/pgsql/detector/detector_test.go index 582da60b..27e3fad5 100644 --- a/database/pgsql/detector_test.go +++ b/database/pgsql/detector/detector_test.go @@ -12,18 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -package pgsql +package detector import ( + "database/sql" "testing" "github.com/deckarep/golang-set" "github.com/stretchr/testify/require" "github.com/coreos/clair/database" + "github.com/coreos/clair/database/pgsql/testutil" ) -func testGetAllDetectors(tx *pgSession) []database.Detector { +func testGetAllDetectors(tx *sql.Tx) []database.Detector { query := `SELECT name, version, dtype FROM detector` rows, err := tx.Query(query) if err != nil { @@ -90,12 +92,12 @@ var persistDetectorTests = []struct { } func TestPersistDetector(t *testing.T) { - datastore, tx := openSessionForTest(t, "PersistDetector", true) - defer closeTest(t, datastore, tx) + tx, cleanup := testutil.CreateTestTxWithFixtures(t, "PersistDetector") + defer cleanup() for _, test := range persistDetectorTests { t.Run(test.title, func(t *testing.T) { - err := tx.PersistDetectors(test.in) + err := PersistDetectors(tx, test.in) if test.err != "" { require.EqualError(t, err, test.err) return diff --git a/database/pgsql/feature.go b/database/pgsql/feature.go deleted file mode 100644 index 66b47c50..00000000 --- a/database/pgsql/feature.go +++ /dev/null @@ -1,398 +0,0 @@ -// Copyright 2017 clair authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package pgsql - -import ( - "database/sql" - "sort" - - "github.com/lib/pq" - log "github.com/sirupsen/logrus" - - "github.com/coreos/clair/database" - "github.com/coreos/clair/ext/versionfmt" - "github.com/coreos/clair/pkg/commonerr" -) - -const ( - soiNamespacedFeature = ` - WITH new_feature_ns AS ( - INSERT INTO namespaced_feature(feature_id, namespace_id) - SELECT CAST ($1 AS INTEGER), CAST ($2 AS INTEGER) - WHERE NOT EXISTS ( SELECT id FROM namespaced_feature WHERE namespaced_feature.feature_id = $1 AND namespaced_feature.namespace_id = $2) - RETURNING id - ) - SELECT id FROM namespaced_feature WHERE namespaced_feature.feature_id = $1 AND namespaced_feature.namespace_id = $2 - UNION - SELECT id FROM new_feature_ns` - - searchPotentialAffectingVulneraibilities = ` - SELECT nf.id, v.id, vaf.affected_version, vaf.id - FROM vulnerability_affected_feature AS vaf, vulnerability AS v, - namespaced_feature AS nf, feature AS f - WHERE nf.id = ANY($1) - AND nf.feature_id = f.id - AND nf.namespace_id = v.namespace_id - AND vaf.feature_name = f.name - AND vaf.feature_type = f.type - AND vaf.vulnerability_id = v.id - AND v.deleted_at IS NULL` - - searchNamespacedFeaturesVulnerabilities = ` - SELECT vanf.namespaced_feature_id, v.name, v.description, v.link, - v.severity, v.metadata, vaf.fixedin, n.name, n.version_format - FROM vulnerability_affected_namespaced_feature AS vanf, - Vulnerability AS v, - vulnerability_affected_feature AS vaf, - namespace AS n - WHERE vanf.namespaced_feature_id = ANY($1) - AND vaf.id = vanf.added_by - AND v.id = vanf.vulnerability_id - AND n.id = v.namespace_id - AND v.deleted_at IS NULL` -) - -func (tx *pgSession) PersistFeatures(features []database.Feature) error { - if len(features) == 0 { - return nil - } - - types, err := tx.getFeatureTypeMap() - if err != nil { - return err - } - - // Sorting is needed before inserting into database to prevent deadlock. - sort.Slice(features, func(i, j int) bool { - return features[i].Name < features[j].Name || - features[i].Version < features[j].Version || - features[i].VersionFormat < features[j].VersionFormat - }) - - // TODO(Sida): A better interface for bulk insertion is needed. - keys := make([]interface{}, 0, len(features)*3) - for _, f := range features { - keys = append(keys, f.Name, f.Version, f.VersionFormat, types.byName[f.Type]) - if f.Name == "" || f.Version == "" || f.VersionFormat == "" { - return commonerr.NewBadRequestError("Empty feature name, version or version format is not allowed") - } - } - - _, err = tx.Exec(queryPersistFeature(len(features)), keys...) - return handleError("queryPersistFeature", err) -} - -type namespacedFeatureWithID struct { - database.NamespacedFeature - - ID int64 -} - -type vulnerabilityCache struct { - nsFeatureID int64 - vulnID int64 - vulnAffectingID int64 -} - -func (tx *pgSession) searchAffectingVulnerabilities(features []database.NamespacedFeature) ([]vulnerabilityCache, error) { - if len(features) == 0 { - return nil, nil - } - - ids, err := tx.findNamespacedFeatureIDs(features) - if err != nil { - return nil, err - } - - fMap := map[int64]database.NamespacedFeature{} - for i, f := range features { - if !ids[i].Valid { - return nil, database.ErrMissingEntities - } - fMap[ids[i].Int64] = f - } - - cacheTable := []vulnerabilityCache{} - rows, err := tx.Query(searchPotentialAffectingVulneraibilities, pq.Array(ids)) - if err != nil { - return nil, handleError("searchPotentialAffectingVulneraibilities", err) - } - - defer rows.Close() - for rows.Next() { - var ( - cache vulnerabilityCache - affected string - ) - - err := rows.Scan(&cache.nsFeatureID, &cache.vulnID, &affected, &cache.vulnAffectingID) - if err != nil { - return nil, err - } - - if ok, err := versionfmt.InRange(fMap[cache.nsFeatureID].VersionFormat, fMap[cache.nsFeatureID].Version, affected); err != nil { - return nil, err - } else if ok { - cacheTable = append(cacheTable, cache) - } - } - - return cacheTable, nil -} - -func (tx *pgSession) CacheAffectedNamespacedFeatures(features []database.NamespacedFeature) error { - if len(features) == 0 { - return nil - } - - _, err := tx.Exec(lockVulnerabilityAffects) - if err != nil { - return handleError("lockVulnerabilityAffects", err) - } - - cache, err := tx.searchAffectingVulnerabilities(features) - if err != nil { - return err - } - - keys := make([]interface{}, 0, len(cache)*3) - for _, c := range cache { - keys = append(keys, c.vulnID, c.nsFeatureID, c.vulnAffectingID) - } - - if len(cache) == 0 { - return nil - } - - affected, err := tx.Exec(queryPersistVulnerabilityAffectedNamespacedFeature(len(cache)), keys...) - if err != nil { - return handleError("persistVulnerabilityAffectedNamespacedFeature", err) - } - if count, err := affected.RowsAffected(); err != nil { - log.Debugf("Cached %d features in vulnerability_affected_namespaced_feature", count) - } - return nil -} - -func (tx *pgSession) PersistNamespacedFeatures(features []database.NamespacedFeature) error { - if len(features) == 0 { - return nil - } - - nsIDs := map[database.Namespace]sql.NullInt64{} - fIDs := map[database.Feature]sql.NullInt64{} - for _, f := range features { - nsIDs[f.Namespace] = sql.NullInt64{} - fIDs[f.Feature] = sql.NullInt64{} - } - - fToFind := []database.Feature{} - for f := range fIDs { - fToFind = append(fToFind, f) - } - - sort.Slice(fToFind, func(i, j int) bool { - return fToFind[i].Name < fToFind[j].Name || - fToFind[i].Version < fToFind[j].Version || - fToFind[i].VersionFormat < fToFind[j].VersionFormat - }) - - if ids, err := tx.findFeatureIDs(fToFind); err == nil { - for i, id := range ids { - if !id.Valid { - return database.ErrMissingEntities - } - fIDs[fToFind[i]] = id - } - } else { - return err - } - - nsToFind := []database.Namespace{} - for ns := range nsIDs { - nsToFind = append(nsToFind, ns) - } - - if ids, err := tx.findNamespaceIDs(nsToFind); err == nil { - for i, id := range ids { - if !id.Valid { - return database.ErrMissingEntities - } - nsIDs[nsToFind[i]] = id - } - } else { - return err - } - - keys := make([]interface{}, 0, len(features)*2) - for _, f := range features { - keys = append(keys, fIDs[f.Feature], nsIDs[f.Namespace]) - } - - _, err := tx.Exec(queryPersistNamespacedFeature(len(features)), keys...) - if err != nil { - return err - } - - return nil -} - -// FindAffectedNamespacedFeatures retrieves vulnerabilities associated with the -// feature. -func (tx *pgSession) FindAffectedNamespacedFeatures(features []database.NamespacedFeature) ([]database.NullableAffectedNamespacedFeature, error) { - if len(features) == 0 { - return nil, nil - } - - vulnerableFeatures := make([]database.NullableAffectedNamespacedFeature, len(features)) - featureIDs, err := tx.findNamespacedFeatureIDs(features) - if err != nil { - return nil, err - } - - for i, id := range featureIDs { - if id.Valid { - vulnerableFeatures[i].Valid = true - vulnerableFeatures[i].NamespacedFeature = features[i] - } - } - - rows, err := tx.Query(searchNamespacedFeaturesVulnerabilities, pq.Array(featureIDs)) - if err != nil { - return nil, handleError("searchNamespacedFeaturesVulnerabilities", err) - } - defer rows.Close() - - for rows.Next() { - var ( - featureID int64 - vuln database.VulnerabilityWithFixedIn - ) - - err := rows.Scan(&featureID, - &vuln.Name, - &vuln.Description, - &vuln.Link, - &vuln.Severity, - &vuln.Metadata, - &vuln.FixedInVersion, - &vuln.Namespace.Name, - &vuln.Namespace.VersionFormat, - ) - - if err != nil { - return nil, handleError("searchNamespacedFeaturesVulnerabilities", err) - } - - for i, id := range featureIDs { - if id.Valid && id.Int64 == featureID { - vulnerableFeatures[i].AffectedNamespacedFeature.AffectedBy = append(vulnerableFeatures[i].AffectedNamespacedFeature.AffectedBy, vuln) - } - } - } - - return vulnerableFeatures, nil -} - -func (tx *pgSession) findNamespacedFeatureIDs(nfs []database.NamespacedFeature) ([]sql.NullInt64, error) { - if len(nfs) == 0 { - return nil, nil - } - - nfsMap := map[database.NamespacedFeature]int64{} - keys := make([]interface{}, 0, len(nfs)*5) - for _, nf := range nfs { - keys = append(keys, nf.Name, nf.Version, nf.VersionFormat, nf.Type, nf.Namespace.Name) - } - - rows, err := tx.Query(querySearchNamespacedFeature(len(nfs)), keys...) - if err != nil { - return nil, handleError("searchNamespacedFeature", err) - } - - defer rows.Close() - var ( - id int64 - nf database.NamespacedFeature - ) - - for rows.Next() { - err := rows.Scan(&id, &nf.Name, &nf.Version, &nf.VersionFormat, &nf.Type, &nf.Namespace.Name) - nf.Namespace.VersionFormat = nf.VersionFormat - if err != nil { - return nil, handleError("searchNamespacedFeature", err) - } - nfsMap[nf] = id - } - - ids := make([]sql.NullInt64, len(nfs)) - for i, nf := range nfs { - if id, ok := nfsMap[nf]; ok { - ids[i] = sql.NullInt64{id, true} - } else { - ids[i] = sql.NullInt64{} - } - } - - return ids, nil -} - -func (tx *pgSession) findFeatureIDs(fs []database.Feature) ([]sql.NullInt64, error) { - if len(fs) == 0 { - return nil, nil - } - - types, err := tx.getFeatureTypeMap() - if err != nil { - return nil, err - } - - fMap := map[database.Feature]sql.NullInt64{} - - keys := make([]interface{}, 0, len(fs)*4) - for _, f := range fs { - typeID := types.byName[f.Type] - keys = append(keys, f.Name, f.Version, f.VersionFormat, typeID) - fMap[f] = sql.NullInt64{} - } - - rows, err := tx.Query(querySearchFeatureID(len(fs)), keys...) - if err != nil { - return nil, handleError("querySearchFeatureID", err) - } - defer rows.Close() - - var ( - id sql.NullInt64 - f database.Feature - ) - for rows.Next() { - var typeID int - err := rows.Scan(&id, &f.Name, &f.Version, &f.VersionFormat, &typeID) - if err != nil { - return nil, handleError("querySearchFeatureID", err) - } - - f.Type = types.byID[typeID] - fMap[f] = id - } - - ids := make([]sql.NullInt64, len(fs)) - for i, f := range fs { - ids[i] = fMap[f] - } - - return ids, nil -} diff --git a/database/pgsql/feature/feature.go b/database/pgsql/feature/feature.go new file mode 100644 index 00000000..b5fc989f --- /dev/null +++ b/database/pgsql/feature/feature.go @@ -0,0 +1,121 @@ +// Copyright 2017 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package feature + +import ( + "database/sql" + "fmt" + "sort" + + "github.com/coreos/clair/database" + "github.com/coreos/clair/database/pgsql/util" + "github.com/coreos/clair/pkg/commonerr" +) + +func queryPersistFeature(count int) string { + return util.QueryPersist(count, + "feature", + "feature_name_version_version_format_type_key", + "name", + "version", + "version_format", + "type") +} + +func querySearchFeatureID(featureCount int) string { + return fmt.Sprintf(` + SELECT id, name, version, version_format, type + FROM Feature WHERE (name, version, version_format, type) IN (%s)`, + util.QueryString(4, featureCount), + ) +} + +func PersistFeatures(tx *sql.Tx, features []database.Feature) error { + if len(features) == 0 { + return nil + } + + types, err := GetFeatureTypeMap(tx) + if err != nil { + return err + } + + // Sorting is needed before inserting into database to prevent deadlock. + sort.Slice(features, func(i, j int) bool { + return features[i].Name < features[j].Name || + features[i].Version < features[j].Version || + features[i].VersionFormat < features[j].VersionFormat + }) + + // TODO(Sida): A better interface for bulk insertion is needed. + keys := make([]interface{}, 0, len(features)*3) + for _, f := range features { + keys = append(keys, f.Name, f.Version, f.VersionFormat, types.ByName[f.Type]) + if f.Name == "" || f.Version == "" || f.VersionFormat == "" { + return commonerr.NewBadRequestError("Empty feature name, version or version format is not allowed") + } + } + + _, err = tx.Exec(queryPersistFeature(len(features)), keys...) + return util.HandleError("queryPersistFeature", err) +} + +func FindFeatureIDs(tx *sql.Tx, fs []database.Feature) ([]sql.NullInt64, error) { + if len(fs) == 0 { + return nil, nil + } + + types, err := GetFeatureTypeMap(tx) + if err != nil { + return nil, err + } + + fMap := map[database.Feature]sql.NullInt64{} + + keys := make([]interface{}, 0, len(fs)*4) + for _, f := range fs { + typeID := types.ByName[f.Type] + keys = append(keys, f.Name, f.Version, f.VersionFormat, typeID) + fMap[f] = sql.NullInt64{} + } + + rows, err := tx.Query(querySearchFeatureID(len(fs)), keys...) + if err != nil { + return nil, util.HandleError("querySearchFeatureID", err) + } + defer rows.Close() + + var ( + id sql.NullInt64 + f database.Feature + ) + for rows.Next() { + var typeID int + err := rows.Scan(&id, &f.Name, &f.Version, &f.VersionFormat, &typeID) + if err != nil { + return nil, util.HandleError("querySearchFeatureID", err) + } + + f.Type = types.ByID[typeID] + fMap[f] = id + } + + ids := make([]sql.NullInt64, len(fs)) + for i, f := range fs { + ids[i] = fMap[f] + } + + return ids, nil +} diff --git a/database/pgsql/feature_test.go b/database/pgsql/feature/feature_test.go similarity index 58% rename from database/pgsql/feature_test.go rename to database/pgsql/feature/feature_test.go index 574bfeab..fadf80ca 100644 --- a/database/pgsql/feature_test.go +++ b/database/pgsql/feature/feature_test.go @@ -12,36 +12,38 @@ // See the License for the specific language governing permissions and // limitations under the License. -package pgsql +package feature import ( + "database/sql" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/coreos/clair/database" + "github.com/coreos/clair/database/pgsql/testutil" ) func TestPersistFeatures(t *testing.T) { - tx, cleanup := createTestPgSession(t, "TestPersistFeatures") + tx, cleanup := testutil.CreateTestTx(t, "TestPersistFeatures") defer cleanup() invalid := database.Feature{} valid := *database.NewBinaryPackage("mount", "2.31.1-0.4ubuntu3.1", "dpkg") // invalid - require.NotNil(t, tx.PersistFeatures([]database.Feature{invalid})) + require.NotNil(t, PersistFeatures(tx, []database.Feature{invalid})) // existing - require.Nil(t, tx.PersistFeatures([]database.Feature{valid})) - require.Nil(t, tx.PersistFeatures([]database.Feature{valid})) + require.Nil(t, PersistFeatures(tx, []database.Feature{valid})) + require.Nil(t, PersistFeatures(tx, []database.Feature{valid})) features := selectAllFeatures(t, tx) assert.Equal(t, []database.Feature{valid}, features) } func TestPersistNamespacedFeatures(t *testing.T) { - tx, cleanup := createTestPgSessionWithFixtures(t, "TestPersistNamespacedFeatures") + tx, cleanup := testutil.CreateTestTxWithFixtures(t, "TestPersistNamespacedFeatures") defer cleanup() // existing features @@ -58,42 +60,17 @@ func TestPersistNamespacedFeatures(t *testing.T) { nf2 := database.NewNamespacedFeature(n2, f2) // namespaced features with namespaces or features not in the database will // generate error. - assert.Nil(t, tx.PersistNamespacedFeatures([]database.NamespacedFeature{})) - assert.NotNil(t, tx.PersistNamespacedFeatures([]database.NamespacedFeature{*nf1, *nf2})) + assert.Nil(t, PersistNamespacedFeatures(tx, []database.NamespacedFeature{})) + assert.NotNil(t, PersistNamespacedFeatures(tx, []database.NamespacedFeature{*nf1, *nf2})) // valid case: insert nf3 - assert.Nil(t, tx.PersistNamespacedFeatures([]database.NamespacedFeature{*nf1})) + assert.Nil(t, PersistNamespacedFeatures(tx, []database.NamespacedFeature{*nf1})) all := listNamespacedFeatures(t, tx) assert.Contains(t, all, *nf1) } -func TestFindAffectedNamespacedFeatures(t *testing.T) { - datastore, tx := openSessionForTest(t, "FindAffectedNamespacedFeatures", true) - defer closeTest(t, datastore, tx) - ns := database.NamespacedFeature{ - Feature: database.Feature{ - Name: "openssl", - Version: "1.0", - VersionFormat: "dpkg", - Type: database.SourcePackage, - }, - Namespace: database.Namespace{ - Name: "debian:7", - VersionFormat: "dpkg", - }, - } - - ans, err := tx.FindAffectedNamespacedFeatures([]database.NamespacedFeature{ns}) - if assert.Nil(t, err) && - assert.Len(t, ans, 1) && - assert.True(t, ans[0].Valid) && - assert.Len(t, ans[0].AffectedBy, 1) { - assert.Equal(t, "CVE-OPENSSL-1-DEB7", ans[0].AffectedBy[0].Name) - } -} - -func listNamespacedFeatures(t *testing.T, tx *pgSession) []database.NamespacedFeature { - types, err := tx.getFeatureTypeMap() +func listNamespacedFeatures(t *testing.T, tx *sql.Tx) []database.NamespacedFeature { + types, err := GetFeatureTypeMap(tx) if err != nil { panic(err) } @@ -114,15 +91,15 @@ func listNamespacedFeatures(t *testing.T, tx *pgSession) []database.NamespacedFe panic(err) } - f.Type = types.byID[typeID] + f.Type = types.ByID[typeID] nf = append(nf, f) } return nf } -func selectAllFeatures(t *testing.T, tx *pgSession) []database.Feature { - types, err := tx.getFeatureTypeMap() +func selectAllFeatures(t *testing.T, tx *sql.Tx) []database.Feature { + types, err := GetFeatureTypeMap(tx) if err != nil { panic(err) } @@ -137,7 +114,7 @@ func selectAllFeatures(t *testing.T, tx *pgSession) []database.Feature { f := database.Feature{} var typeID int err := rows.Scan(&f.Name, &f.Version, &f.VersionFormat, &typeID) - f.Type = types.byID[typeID] + f.Type = types.ByID[typeID] if err != nil { t.FailNow() } @@ -146,45 +123,24 @@ func selectAllFeatures(t *testing.T, tx *pgSession) []database.Feature { return fs } -func assertNamespacedFeatureEqual(t *testing.T, expected []database.NamespacedFeature, actual []database.NamespacedFeature) bool { - if assert.Len(t, actual, len(expected)) { - has := map[database.NamespacedFeature]bool{} - for _, nf := range expected { - has[nf] = false - } - - for _, nf := range actual { - has[nf] = true - } - - for nf, visited := range has { - if !assert.True(t, visited, nf.Namespace.Name+":"+nf.Name+" is expected") { - return false - } - } - return true - } - return false -} - func TestFindNamespacedFeatureIDs(t *testing.T) { - tx, cleanup := createTestPgSessionWithFixtures(t, "TestFindNamespacedFeatureIDs") + tx, cleanup := testutil.CreateTestTxWithFixtures(t, "TestFindNamespacedFeatureIDs") defer cleanup() features := []database.NamespacedFeature{} expectedIDs := []int{} - for id, feature := range realNamespacedFeatures { + for id, feature := range testutil.RealNamespacedFeatures { features = append(features, feature) expectedIDs = append(expectedIDs, id) } - features = append(features, realNamespacedFeatures[1]) // test duplicated + features = append(features, testutil.RealNamespacedFeatures[1]) // test duplicated expectedIDs = append(expectedIDs, 1) - namespace := realNamespaces[1] + namespace := testutil.RealNamespaces[1] features = append(features, *database.NewNamespacedFeature(&namespace, database.NewBinaryPackage("not-found", "1.0", "dpkg"))) // test not found feature - ids, err := tx.findNamespacedFeatureIDs(features) + ids, err := FindNamespacedFeatureIDs(tx, features) require.Nil(t, err) require.Len(t, ids, len(expectedIDs)+1) for i, id := range ids { diff --git a/database/pgsql/feature_type.go b/database/pgsql/feature/feature_type.go similarity index 71% rename from database/pgsql/feature_type.go rename to database/pgsql/feature/feature_type.go index bccf0cd8..7a290556 100644 --- a/database/pgsql/feature_type.go +++ b/database/pgsql/feature/feature_type.go @@ -12,24 +12,28 @@ // See the License for the specific language governing permissions and // limitations under the License. -package pgsql +package feature -import "github.com/coreos/clair/database" +import ( + "database/sql" + + "github.com/coreos/clair/database" +) const ( selectAllFeatureTypes = `SELECT id, name FROM feature_type` ) -type featureTypes struct { - byID map[int]database.FeatureType - byName map[database.FeatureType]int +type FeatureTypes struct { + ByID map[int]database.FeatureType + ByName map[database.FeatureType]int } -func newFeatureTypes() *featureTypes { - return &featureTypes{make(map[int]database.FeatureType), make(map[database.FeatureType]int)} +func newFeatureTypes() *FeatureTypes { + return &FeatureTypes{make(map[int]database.FeatureType), make(map[database.FeatureType]int)} } -func (tx *pgSession) getFeatureTypeMap() (*featureTypes, error) { +func GetFeatureTypeMap(tx *sql.Tx) (*FeatureTypes, error) { rows, err := tx.Query(selectAllFeatureTypes) if err != nil { return nil, err @@ -45,8 +49,8 @@ func (tx *pgSession) getFeatureTypeMap() (*featureTypes, error) { return nil, err } - types.byID[id] = name - types.byName[name] = id + types.ByID[id] = name + types.ByName[name] = id } return types, nil diff --git a/database/pgsql/feature_type_test.go b/database/pgsql/feature/feature_type_test.go similarity index 66% rename from database/pgsql/feature_type_test.go rename to database/pgsql/feature/feature_type_test.go index f8cbf732..bf27f90a 100644 --- a/database/pgsql/feature_type_test.go +++ b/database/pgsql/feature/feature_type_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package pgsql +package feature import ( "testing" @@ -20,19 +20,20 @@ import ( "github.com/stretchr/testify/require" "github.com/coreos/clair/database" + "github.com/coreos/clair/database/pgsql/testutil" ) func TestGetFeatureTypeMap(t *testing.T) { - tx, cleanup := createTestPgSession(t, "TestGetFeatureTypeMap") + tx, cleanup := testutil.CreateTestTx(t, "TestGetFeatureTypeMap") defer cleanup() - types, err := tx.getFeatureTypeMap() + types, err := GetFeatureTypeMap(tx) if err != nil { require.Nil(t, err, err.Error()) } - require.Equal(t, database.SourcePackage, types.byID[1]) - require.Equal(t, database.BinaryPackage, types.byID[2]) - require.Equal(t, 1, types.byName[database.SourcePackage]) - require.Equal(t, 2, types.byName[database.BinaryPackage]) + require.Equal(t, database.SourcePackage, types.ByID[1]) + require.Equal(t, database.BinaryPackage, types.ByID[2]) + require.Equal(t, 1, types.ByName[database.SourcePackage]) + require.Equal(t, 2, types.ByName[database.BinaryPackage]) } diff --git a/database/pgsql/feature/namespaced_feature.go b/database/pgsql/feature/namespaced_feature.go new file mode 100644 index 00000000..e3541c7e --- /dev/null +++ b/database/pgsql/feature/namespaced_feature.go @@ -0,0 +1,168 @@ +// Copyright 2019 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package feature + +import ( + "database/sql" + "fmt" + "sort" + + "github.com/coreos/clair/database" + "github.com/coreos/clair/database/pgsql/namespace" + "github.com/coreos/clair/database/pgsql/util" +) + +var soiNamespacedFeature = ` +WITH new_feature_ns AS ( + INSERT INTO namespaced_feature(feature_id, namespace_id) + SELECT CAST ($1 AS INTEGER), CAST ($2 AS INTEGER) + WHERE NOT EXISTS ( SELECT id FROM namespaced_feature WHERE namespaced_feature.feature_id = $1 AND namespaced_feature.namespace_id = $2) + RETURNING id +) +SELECT id FROM namespaced_feature WHERE namespaced_feature.feature_id = $1 AND namespaced_feature.namespace_id = $2 +UNION +SELECT id FROM new_feature_ns` + +func queryPersistNamespacedFeature(count int) string { + return util.QueryPersist(count, "namespaced_feature", + "namespaced_feature_namespace_id_feature_id_key", + "feature_id", + "namespace_id") +} + +func querySearchNamespacedFeature(nsfCount int) string { + return fmt.Sprintf(` + SELECT nf.id, f.name, f.version, f.version_format, t.name, n.name + FROM namespaced_feature AS nf, feature AS f, namespace AS n, feature_type AS t + WHERE nf.feature_id = f.id + AND nf.namespace_id = n.id + AND n.version_format = f.version_format + AND f.type = t.id + AND (f.name, f.version, f.version_format, t.name, n.name) IN (%s)`, + util.QueryString(5, nsfCount), + ) +} + +type namespacedFeatureWithID struct { + database.NamespacedFeature + + ID int64 +} + +func PersistNamespacedFeatures(tx *sql.Tx, features []database.NamespacedFeature) error { + if len(features) == 0 { + return nil + } + + nsIDs := map[database.Namespace]sql.NullInt64{} + fIDs := map[database.Feature]sql.NullInt64{} + for _, f := range features { + nsIDs[f.Namespace] = sql.NullInt64{} + fIDs[f.Feature] = sql.NullInt64{} + } + + fToFind := []database.Feature{} + for f := range fIDs { + fToFind = append(fToFind, f) + } + + sort.Slice(fToFind, func(i, j int) bool { + return fToFind[i].Name < fToFind[j].Name || + fToFind[i].Version < fToFind[j].Version || + fToFind[i].VersionFormat < fToFind[j].VersionFormat + }) + + if ids, err := FindFeatureIDs(tx, fToFind); err == nil { + for i, id := range ids { + if !id.Valid { + return database.ErrMissingEntities + } + fIDs[fToFind[i]] = id + } + } else { + return err + } + + nsToFind := []database.Namespace{} + for ns := range nsIDs { + nsToFind = append(nsToFind, ns) + } + + if ids, err := namespace.FindNamespaceIDs(tx, nsToFind); err == nil { + for i, id := range ids { + if !id.Valid { + return database.ErrMissingEntities + } + nsIDs[nsToFind[i]] = id + } + } else { + return err + } + + keys := make([]interface{}, 0, len(features)*2) + for _, f := range features { + keys = append(keys, fIDs[f.Feature], nsIDs[f.Namespace]) + } + + _, err := tx.Exec(queryPersistNamespacedFeature(len(features)), keys...) + if err != nil { + return err + } + + return nil +} + +func FindNamespacedFeatureIDs(tx *sql.Tx, nfs []database.NamespacedFeature) ([]sql.NullInt64, error) { + if len(nfs) == 0 { + return nil, nil + } + + nfsMap := map[database.NamespacedFeature]int64{} + keys := make([]interface{}, 0, len(nfs)*5) + for _, nf := range nfs { + keys = append(keys, nf.Name, nf.Version, nf.VersionFormat, nf.Type, nf.Namespace.Name) + } + + rows, err := tx.Query(querySearchNamespacedFeature(len(nfs)), keys...) + if err != nil { + return nil, util.HandleError("searchNamespacedFeature", err) + } + + defer rows.Close() + var ( + id int64 + nf database.NamespacedFeature + ) + + for rows.Next() { + err := rows.Scan(&id, &nf.Name, &nf.Version, &nf.VersionFormat, &nf.Type, &nf.Namespace.Name) + nf.Namespace.VersionFormat = nf.VersionFormat + if err != nil { + return nil, util.HandleError("searchNamespacedFeature", err) + } + nfsMap[nf] = id + } + + ids := make([]sql.NullInt64, len(nfs)) + for i, nf := range nfs { + if id, ok := nfsMap[nf]; ok { + ids[i] = sql.NullInt64{id, true} + } else { + ids[i] = sql.NullInt64{} + } + } + + return ids, nil +} diff --git a/database/pgsql/keyvalue.go b/database/pgsql/keyvalue/keyvalue.go similarity index 73% rename from database/pgsql/keyvalue.go rename to database/pgsql/keyvalue/keyvalue.go index 9c985279..fc1f5220 100644 --- a/database/pgsql/keyvalue.go +++ b/database/pgsql/keyvalue/keyvalue.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package pgsql +package keyvalue import ( "database/sql" @@ -20,6 +20,8 @@ import ( log "github.com/sirupsen/logrus" + "github.com/coreos/clair/database/pgsql/monitoring" + "github.com/coreos/clair/database/pgsql/util" "github.com/coreos/clair/pkg/commonerr" ) @@ -32,24 +34,24 @@ const ( DO UPDATE SET key=$1, value=$2` ) -func (tx *pgSession) UpdateKeyValue(key, value string) (err error) { +func UpdateKeyValue(tx *sql.Tx, key, value string) (err error) { if key == "" || value == "" { log.Warning("could not insert a flag which has an empty name or value") return commonerr.NewBadRequestError("could not insert a flag which has an empty name or value") } - defer observeQueryTime("PersistKeyValue", "all", time.Now()) + defer monitoring.ObserveQueryTime("PersistKeyValue", "all", time.Now()) _, err = tx.Exec(upsertKeyValue, key, value) if err != nil { - return handleError("insertKeyValue", err) + return util.HandleError("insertKeyValue", err) } return nil } -func (tx *pgSession) FindKeyValue(key string) (string, bool, error) { - defer observeQueryTime("FindKeyValue", "all", time.Now()) +func FindKeyValue(tx *sql.Tx, key string) (string, bool, error) { + defer monitoring.ObserveQueryTime("FindKeyValue", "all", time.Now()) var value string err := tx.QueryRow(searchKeyValue, key).Scan(&value) @@ -59,7 +61,7 @@ func (tx *pgSession) FindKeyValue(key string) (string, bool, error) { } if err != nil { - return "", false, handleError("searchKeyValue", err) + return "", false, util.HandleError("searchKeyValue", err) } return value, true, nil diff --git a/database/pgsql/keyvalue_test.go b/database/pgsql/keyvalue/keyvalue_test.go similarity index 65% rename from database/pgsql/keyvalue_test.go rename to database/pgsql/keyvalue/keyvalue_test.go index 9991bf48..165a75da 100644 --- a/database/pgsql/keyvalue_test.go +++ b/database/pgsql/keyvalue/keyvalue_test.go @@ -12,38 +12,39 @@ // See the License for the specific language governing permissions and // limitations under the License. -package pgsql +package keyvalue import ( "testing" + "github.com/coreos/clair/database/pgsql/testutil" "github.com/stretchr/testify/assert" ) func TestKeyValue(t *testing.T) { - datastore, tx := openSessionForTest(t, "KeyValue", true) - defer closeTest(t, datastore, tx) + tx, cleanup := testutil.CreateTestTxWithFixtures(t, "KeyValue") + defer cleanup() // Get non-existing key/value - f, ok, err := tx.FindKeyValue("test") + f, ok, err := FindKeyValue(tx, "test") assert.Nil(t, err) assert.False(t, ok) // Try to insert invalid key/value. - assert.Error(t, tx.UpdateKeyValue("test", "")) - assert.Error(t, tx.UpdateKeyValue("", "test")) - assert.Error(t, tx.UpdateKeyValue("", "")) + assert.Error(t, UpdateKeyValue(tx, "test", "")) + assert.Error(t, UpdateKeyValue(tx, "", "test")) + assert.Error(t, UpdateKeyValue(tx, "", "")) // Insert and verify. - assert.Nil(t, tx.UpdateKeyValue("test", "test1")) - f, ok, err = tx.FindKeyValue("test") + assert.Nil(t, UpdateKeyValue(tx, "test", "test1")) + f, ok, err = FindKeyValue(tx, "test") assert.Nil(t, err) assert.True(t, ok) assert.Equal(t, "test1", f) // Update and verify. - assert.Nil(t, tx.UpdateKeyValue("test", "test2")) - f, ok, err = tx.FindKeyValue("test") + assert.Nil(t, UpdateKeyValue(tx, "test", "test2")) + f, ok, err = FindKeyValue(tx, "test") assert.Nil(t, err) assert.True(t, ok) assert.Equal(t, "test2", f) diff --git a/database/pgsql/layer.go b/database/pgsql/layer.go deleted file mode 100644 index bce1ef06..00000000 --- a/database/pgsql/layer.go +++ /dev/null @@ -1,379 +0,0 @@ -// Copyright 2017 clair authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package pgsql - -import ( - "database/sql" - "sort" - - "github.com/deckarep/golang-set" - - "github.com/coreos/clair/database" - "github.com/coreos/clair/pkg/commonerr" -) - -const ( - soiLayer = ` - WITH new_layer AS ( - INSERT INTO layer (hash) - SELECT CAST ($1 AS VARCHAR) - WHERE NOT EXISTS (SELECT id FROM layer WHERE hash = $1) - RETURNING id - ) - SELECT id FROM new_Layer - UNION - SELECT id FROM layer WHERE hash = $1` - - findLayerFeatures = ` - SELECT - f.name, f.version, f.version_format, ft.name, lf.detector_id, ns.name, ns.version_format - FROM - layer_feature AS lf - LEFT JOIN feature f on f.id = lf.feature_id - LEFT JOIN feature_type ft on ft.id = f.type - LEFT JOIN namespace ns ON ns.id = lf.namespace_id - - WHERE lf.layer_id = $1` - - findLayerNamespaces = ` - SELECT ns.name, ns.version_format, ln.detector_id - FROM layer_namespace AS ln, namespace AS ns - WHERE ln.namespace_id = ns.id - AND ln.layer_id = $1` - - findLayerID = `SELECT id FROM layer WHERE hash = $1` -) - -// dbLayerNamespace represents the layer_namespace table. -type dbLayerNamespace struct { - layerID int64 - namespaceID int64 - detectorID int64 -} - -// dbLayerFeature represents the layer_feature table -type dbLayerFeature struct { - layerID int64 - featureID int64 - detectorID int64 - namespaceID sql.NullInt64 -} - -func (tx *pgSession) FindLayer(hash string) (database.Layer, bool, error) { - layer := database.Layer{Hash: hash} - if hash == "" { - return layer, false, commonerr.NewBadRequestError("non empty layer hash is expected.") - } - - layerID, ok, err := tx.findLayerID(hash) - if err != nil || !ok { - return layer, ok, err - } - - detectorMap, err := tx.findAllDetectors() - if err != nil { - return layer, false, err - } - - if layer.By, err = tx.findLayerDetectors(layerID); err != nil { - return layer, false, err - } - - if layer.Features, err = tx.findLayerFeatures(layerID, detectorMap); err != nil { - return layer, false, err - } - - if layer.Namespaces, err = tx.findLayerNamespaces(layerID, detectorMap); err != nil { - return layer, false, err - } - - return layer, true, nil -} - -func sanitizePersistLayerInput(hash string, features []database.LayerFeature, namespaces []database.LayerNamespace, detectedBy []database.Detector) error { - if hash == "" { - return commonerr.NewBadRequestError("expected non-empty layer hash") - } - - detectedBySet := mapset.NewSet() - for _, d := range detectedBy { - detectedBySet.Add(d) - } - - for _, f := range features { - if !detectedBySet.Contains(f.By) { - return database.ErrInvalidParameters - } - } - - for _, n := range namespaces { - if !detectedBySet.Contains(n.By) { - return database.ErrInvalidParameters - } - } - - return nil -} - -// PersistLayer saves the content of a layer to the database. -func (tx *pgSession) PersistLayer(hash string, features []database.LayerFeature, namespaces []database.LayerNamespace, detectedBy []database.Detector) error { - var ( - err error - id int64 - detectorIDs []int64 - ) - - if err = sanitizePersistLayerInput(hash, features, namespaces, detectedBy); err != nil { - return err - } - - if id, err = tx.soiLayer(hash); err != nil { - return err - } - - if detectorIDs, err = tx.findDetectorIDs(detectedBy); err != nil { - if err == commonerr.ErrNotFound { - return database.ErrMissingEntities - } - - return err - } - - if err = tx.persistLayerDetectors(id, detectorIDs); err != nil { - return err - } - - if err = tx.persistAllLayerFeatures(id, features); err != nil { - return err - } - - if err = tx.persistAllLayerNamespaces(id, namespaces); err != nil { - return err - } - - return nil -} - -func (tx *pgSession) persistAllLayerNamespaces(layerID int64, namespaces []database.LayerNamespace) error { - detectorMap, err := tx.findAllDetectors() - if err != nil { - return err - } - - // TODO(sidac): This kind of type conversion is very useless and wasteful, - // we need interfaces around the database models to reduce these kind of - // operations. - rawNamespaces := make([]database.Namespace, 0, len(namespaces)) - for _, ns := range namespaces { - rawNamespaces = append(rawNamespaces, ns.Namespace) - } - - rawNamespaceIDs, err := tx.findNamespaceIDs(rawNamespaces) - if err != nil { - return err - } - - dbLayerNamespaces := make([]dbLayerNamespace, 0, len(namespaces)) - for i, ns := range namespaces { - detectorID := detectorMap.byValue[ns.By] - namespaceID := rawNamespaceIDs[i].Int64 - if !rawNamespaceIDs[i].Valid { - return database.ErrMissingEntities - } - - dbLayerNamespaces = append(dbLayerNamespaces, dbLayerNamespace{layerID, namespaceID, detectorID}) - } - - return tx.persistLayerNamespaces(dbLayerNamespaces) -} - -func (tx *pgSession) persistAllLayerFeatures(layerID int64, features []database.LayerFeature) error { - detectorMap, err := tx.findAllDetectors() - if err != nil { - return err - } - var namespaces []database.Namespace - for _, feature := range features { - namespaces = append(namespaces, feature.PotentialNamespace) - } - nameSpaceIDs, _ := tx.findNamespaceIDs(namespaces) - featureNamespaceMap := map[database.Namespace]sql.NullInt64{} - rawFeatures := make([]database.Feature, 0, len(features)) - for i, f := range features { - rawFeatures = append(rawFeatures, f.Feature) - if f.PotentialNamespace.Valid() { - featureNamespaceMap[f.PotentialNamespace] = nameSpaceIDs[i] - } - } - - featureIDs, err := tx.findFeatureIDs(rawFeatures) - if err != nil { - return err - } - var namespaceID sql.NullInt64 - dbFeatures := make([]dbLayerFeature, 0, len(features)) - for i, f := range features { - detectorID := detectorMap.byValue[f.By] - if !featureIDs[i].Valid { - return database.ErrMissingEntities - } - featureID := featureIDs[i].Int64 - namespaceID = featureNamespaceMap[f.PotentialNamespace] - - dbFeatures = append(dbFeatures, dbLayerFeature{layerID, featureID, detectorID, namespaceID}) - } - - if err := tx.persistLayerFeatures(dbFeatures); err != nil { - return err - } - - return nil -} - -func (tx *pgSession) persistLayerFeatures(features []dbLayerFeature) error { - if len(features) == 0 { - return nil - } - - sort.Slice(features, func(i, j int) bool { - return features[i].featureID < features[j].featureID - }) - keys := make([]interface{}, 0, len(features)*4) - - for _, f := range features { - keys = append(keys, f.layerID, f.featureID, f.detectorID, f.namespaceID) - } - - _, err := tx.Exec(queryPersistLayerFeature(len(features)), keys...) - if err != nil { - return handleError("queryPersistLayerFeature", err) - } - return nil -} - -func (tx *pgSession) persistLayerNamespaces(namespaces []dbLayerNamespace) error { - if len(namespaces) == 0 { - return nil - } - - // for every bulk persist operation, the input data should be sorted. - sort.Slice(namespaces, func(i, j int) bool { - return namespaces[i].namespaceID < namespaces[j].namespaceID - }) - - keys := make([]interface{}, 0, len(namespaces)*3) - for _, row := range namespaces { - keys = append(keys, row.layerID, row.namespaceID, row.detectorID) - } - - _, err := tx.Exec(queryPersistLayerNamespace(len(namespaces)), keys...) - if err != nil { - return handleError("queryPersistLayerNamespace", err) - } - - return nil -} - -func (tx *pgSession) findLayerNamespaces(layerID int64, detectors detectorMap) ([]database.LayerNamespace, error) { - rows, err := tx.Query(findLayerNamespaces, layerID) - if err != nil { - return nil, handleError("findLayerNamespaces", err) - } - - namespaces := []database.LayerNamespace{} - for rows.Next() { - var ( - namespace database.LayerNamespace - detectorID int64 - ) - - if err := rows.Scan(&namespace.Name, &namespace.VersionFormat, &detectorID); err != nil { - return nil, err - } - - namespace.By = detectors.byID[detectorID] - namespaces = append(namespaces, namespace) - } - - return namespaces, nil -} - -func (tx *pgSession) findLayerFeatures(layerID int64, detectors detectorMap) ([]database.LayerFeature, error) { - rows, err := tx.Query(findLayerFeatures, layerID) - if err != nil { - return nil, handleError("findLayerFeatures", err) - } - defer rows.Close() - - features := []database.LayerFeature{} - for rows.Next() { - var ( - detectorID int64 - feature database.LayerFeature - ) - var namespaceName, namespaceVersion sql.NullString - if err := rows.Scan(&feature.Name, &feature.Version, &feature.VersionFormat, &feature.Type, &detectorID, &namespaceName, &namespaceVersion); err != nil { - return nil, handleError("findLayerFeatures", err) - } - feature.PotentialNamespace.Name = namespaceName.String - feature.PotentialNamespace.VersionFormat = namespaceVersion.String - - feature.By = detectors.byID[detectorID] - features = append(features, feature) - } - - return features, nil -} - -func (tx *pgSession) findLayerID(hash string) (int64, bool, error) { - var layerID int64 - err := tx.QueryRow(findLayerID, hash).Scan(&layerID) - if err != nil { - if err == sql.ErrNoRows { - return layerID, false, nil - } - - return layerID, false, handleError("findLayerID", err) - } - - return layerID, true, nil -} - -func (tx *pgSession) findLayerIDs(hashes []string) ([]int64, bool, error) { - layerIDs := make([]int64, 0, len(hashes)) - for _, hash := range hashes { - id, ok, err := tx.findLayerID(hash) - if !ok { - return nil, false, nil - } - - if err != nil { - return nil, false, err - } - - layerIDs = append(layerIDs, id) - } - - return layerIDs, true, nil -} - -func (tx *pgSession) soiLayer(hash string) (int64, error) { - var id int64 - if err := tx.QueryRow(soiLayer, hash).Scan(&id); err != nil { - return 0, handleError("soiLayer", err) - } - - return id, nil -} diff --git a/database/pgsql/layer/layer.go b/database/pgsql/layer/layer.go new file mode 100644 index 00000000..88fbfcb1 --- /dev/null +++ b/database/pgsql/layer/layer.go @@ -0,0 +1,177 @@ +// Copyright 2017 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package layer + +import ( + "database/sql" + + "github.com/deckarep/golang-set" + + "github.com/coreos/clair/database" + "github.com/coreos/clair/database/pgsql/detector" + "github.com/coreos/clair/database/pgsql/util" + "github.com/coreos/clair/pkg/commonerr" +) + +const ( + soiLayer = ` + WITH new_layer AS ( + INSERT INTO layer (hash) + SELECT CAST ($1 AS VARCHAR) + WHERE NOT EXISTS (SELECT id FROM layer WHERE hash = $1) + RETURNING id + ) + SELECT id FROM new_Layer + UNION + SELECT id FROM layer WHERE hash = $1` + + findLayerID = `SELECT id FROM layer WHERE hash = $1` +) + +func FindLayer(tx *sql.Tx, hash string) (database.Layer, bool, error) { + layer := database.Layer{Hash: hash} + if hash == "" { + return layer, false, commonerr.NewBadRequestError("non empty layer hash is expected.") + } + + layerID, ok, err := FindLayerID(tx, hash) + if err != nil || !ok { + return layer, ok, err + } + + detectorMap, err := detector.FindAllDetectors(tx) + if err != nil { + return layer, false, err + } + + if layer.By, err = FindLayerDetectors(tx, layerID); err != nil { + return layer, false, err + } + + if layer.Features, err = FindLayerFeatures(tx, layerID, detectorMap); err != nil { + return layer, false, err + } + + if layer.Namespaces, err = FindLayerNamespaces(tx, layerID, detectorMap); err != nil { + return layer, false, err + } + + return layer, true, nil +} + +func sanitizePersistLayerInput(hash string, features []database.LayerFeature, namespaces []database.LayerNamespace, detectedBy []database.Detector) error { + if hash == "" { + return commonerr.NewBadRequestError("expected non-empty layer hash") + } + + detectedBySet := mapset.NewSet() + for _, d := range detectedBy { + detectedBySet.Add(d) + } + + for _, f := range features { + if !detectedBySet.Contains(f.By) { + return database.ErrInvalidParameters + } + } + + for _, n := range namespaces { + if !detectedBySet.Contains(n.By) { + return database.ErrInvalidParameters + } + } + + return nil +} + +// PersistLayer saves the content of a layer to the database. +func PersistLayer(tx *sql.Tx, hash string, features []database.LayerFeature, namespaces []database.LayerNamespace, detectedBy []database.Detector) error { + var ( + err error + id int64 + detectorIDs []int64 + ) + + if err = sanitizePersistLayerInput(hash, features, namespaces, detectedBy); err != nil { + return err + } + + if id, err = SoiLayer(tx, hash); err != nil { + return err + } + + if detectorIDs, err = detector.FindDetectorIDs(tx, detectedBy); err != nil { + if err == commonerr.ErrNotFound { + return database.ErrMissingEntities + } + + return err + } + + if err = PersistLayerDetectors(tx, id, detectorIDs); err != nil { + return err + } + + if err = PersistAllLayerFeatures(tx, id, features); err != nil { + return err + } + + if err = PersistAllLayerNamespaces(tx, id, namespaces); err != nil { + return err + } + + return nil +} + +func FindLayerID(tx *sql.Tx, hash string) (int64, bool, error) { + var layerID int64 + err := tx.QueryRow(findLayerID, hash).Scan(&layerID) + if err != nil { + if err == sql.ErrNoRows { + return layerID, false, nil + } + + return layerID, false, util.HandleError("findLayerID", err) + } + + return layerID, true, nil +} + +func FindLayerIDs(tx *sql.Tx, hashes []string) ([]int64, bool, error) { + layerIDs := make([]int64, 0, len(hashes)) + for _, hash := range hashes { + id, ok, err := FindLayerID(tx, hash) + if !ok { + return nil, false, nil + } + + if err != nil { + return nil, false, err + } + + layerIDs = append(layerIDs, id) + } + + return layerIDs, true, nil +} + +func SoiLayer(tx *sql.Tx, hash string) (int64, error) { + var id int64 + if err := tx.QueryRow(soiLayer, hash).Scan(&id); err != nil { + return 0, util.HandleError("soiLayer", err) + } + + return id, nil +} diff --git a/database/pgsql/layer/layer_detector.go b/database/pgsql/layer/layer_detector.go new file mode 100644 index 00000000..78941ca6 --- /dev/null +++ b/database/pgsql/layer/layer_detector.go @@ -0,0 +1,66 @@ +// Copyright 2019 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package layer + +import ( + "database/sql" + + "github.com/deckarep/golang-set" + + "github.com/coreos/clair/database" + "github.com/coreos/clair/database/pgsql/detector" + "github.com/coreos/clair/database/pgsql/util" +) + +const ( + selectLayerDetectors = ` + SELECT d.name, d.version, d.dtype + FROM layer_detector, detector AS d + WHERE layer_detector.detector_id = d.id AND layer_detector.layer_id = $1;` + + persistLayerDetector = ` + INSERT INTO layer_detector (layer_id, detector_id) + SELECT $1, $2 + WHERE NOT EXISTS (SELECT id FROM layer_detector WHERE layer_id = $1 AND detector_id = $2)` +) + +func PersistLayerDetector(tx *sql.Tx, layerID int64, detectorID int64) error { + if _, err := tx.Exec(persistLayerDetector, layerID, detectorID); err != nil { + return util.HandleError("persistLayerDetector", err) + } + + return nil +} + +func PersistLayerDetectors(tx *sql.Tx, layerID int64, detectorIDs []int64) error { + alreadySaved := mapset.NewSet() + for _, id := range detectorIDs { + if alreadySaved.Contains(id) { + continue + } + + alreadySaved.Add(id) + if err := PersistLayerDetector(tx, layerID, id); err != nil { + return err + } + } + + return nil +} + +func FindLayerDetectors(tx *sql.Tx, id int64) ([]database.Detector, error) { + detectors, err := detector.GetDetectors(tx, selectLayerDetectors, id) + return detectors, err +} diff --git a/database/pgsql/layer/layer_feature.go b/database/pgsql/layer/layer_feature.go new file mode 100644 index 00000000..3aa50a42 --- /dev/null +++ b/database/pgsql/layer/layer_feature.go @@ -0,0 +1,147 @@ +// Copyright 2019 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package layer + +import ( + "database/sql" + "sort" + + "github.com/coreos/clair/database/pgsql/namespace" + + "github.com/coreos/clair/database" + "github.com/coreos/clair/database/pgsql/detector" + "github.com/coreos/clair/database/pgsql/feature" + "github.com/coreos/clair/database/pgsql/util" +) + +const findLayerFeatures = ` +SELECT + f.name, f.version, f.version_format, ft.name, lf.detector_id, ns.name, ns.version_format +FROM + layer_feature AS lf +LEFT JOIN feature f on f.id = lf.feature_id +LEFT JOIN feature_type ft on ft.id = f.type +LEFT JOIN namespace ns ON ns.id = lf.namespace_id + +WHERE lf.layer_id = $1` + +func queryPersistLayerFeature(count int) string { + return util.QueryPersist(count, + "layer_feature", + "layer_feature_layer_id_feature_id_namespace_id_key", + "layer_id", + "feature_id", + "detector_id", + "namespace_id") +} + +// dbLayerFeature represents the layer_feature table +type dbLayerFeature struct { + layerID int64 + featureID int64 + detectorID int64 + namespaceID sql.NullInt64 +} + +func FindLayerFeatures(tx *sql.Tx, layerID int64, detectors detector.DetectorMap) ([]database.LayerFeature, error) { + rows, err := tx.Query(findLayerFeatures, layerID) + if err != nil { + return nil, util.HandleError("findLayerFeatures", err) + } + defer rows.Close() + + features := []database.LayerFeature{} + for rows.Next() { + var ( + detectorID int64 + feature database.LayerFeature + ) + var namespaceName, namespaceVersion sql.NullString + if err := rows.Scan(&feature.Name, &feature.Version, &feature.VersionFormat, &feature.Type, &detectorID, &namespaceName, &namespaceVersion); err != nil { + return nil, util.HandleError("findLayerFeatures", err) + } + feature.PotentialNamespace.Name = namespaceName.String + feature.PotentialNamespace.VersionFormat = namespaceVersion.String + + feature.By = detectors.ByID[detectorID] + features = append(features, feature) + } + + return features, nil +} + +func PersistAllLayerFeatures(tx *sql.Tx, layerID int64, features []database.LayerFeature) error { + detectorMap, err := detector.FindAllDetectors(tx) + if err != nil { + return err + } + var namespaces []database.Namespace + for _, feature := range features { + namespaces = append(namespaces, feature.PotentialNamespace) + } + nameSpaceIDs, _ := namespace.FindNamespaceIDs(tx, namespaces) + featureNamespaceMap := map[database.Namespace]sql.NullInt64{} + rawFeatures := make([]database.Feature, 0, len(features)) + for i, f := range features { + rawFeatures = append(rawFeatures, f.Feature) + if f.PotentialNamespace.Valid() { + featureNamespaceMap[f.PotentialNamespace] = nameSpaceIDs[i] + } + } + + featureIDs, err := feature.FindFeatureIDs(tx, rawFeatures) + if err != nil { + return err + } + var namespaceID sql.NullInt64 + dbFeatures := make([]dbLayerFeature, 0, len(features)) + for i, f := range features { + detectorID := detectorMap.ByValue[f.By] + featureID := featureIDs[i].Int64 + if !featureIDs[i].Valid { + return database.ErrMissingEntities + } + namespaceID = featureNamespaceMap[f.PotentialNamespace] + + dbFeatures = append(dbFeatures, dbLayerFeature{layerID, featureID, detectorID, namespaceID}) + } + + if err := PersistLayerFeatures(tx, dbFeatures); err != nil { + return err + } + + return nil +} + +func PersistLayerFeatures(tx *sql.Tx, features []dbLayerFeature) error { + if len(features) == 0 { + return nil + } + + sort.Slice(features, func(i, j int) bool { + return features[i].featureID < features[j].featureID + }) + keys := make([]interface{}, 0, len(features)*4) + + for _, f := range features { + keys = append(keys, f.layerID, f.featureID, f.detectorID, f.namespaceID) + } + + _, err := tx.Exec(queryPersistLayerFeature(len(features)), keys...) + if err != nil { + return util.HandleError("queryPersistLayerFeature", err) + } + return nil +} diff --git a/database/pgsql/layer/layer_namespace.go b/database/pgsql/layer/layer_namespace.go new file mode 100644 index 00000000..8e65e126 --- /dev/null +++ b/database/pgsql/layer/layer_namespace.go @@ -0,0 +1,127 @@ +// Copyright 2019 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package layer + +import ( + "database/sql" + "sort" + + "github.com/coreos/clair/database" + "github.com/coreos/clair/database/pgsql/detector" + "github.com/coreos/clair/database/pgsql/namespace" + "github.com/coreos/clair/database/pgsql/util" +) + +const findLayerNamespaces = ` +SELECT ns.name, ns.version_format, ln.detector_id + FROM layer_namespace AS ln, namespace AS ns + WHERE ln.namespace_id = ns.id + AND ln.layer_id = $1` + +func queryPersistLayerNamespace(count int) string { + return util.QueryPersist(count, + "layer_namespace", + "layer_namespace_layer_id_namespace_id_key", + "layer_id", + "namespace_id", + "detector_id") +} + +// dbLayerNamespace represents the layer_namespace table. +type dbLayerNamespace struct { + layerID int64 + namespaceID int64 + detectorID int64 +} + +func FindLayerNamespaces(tx *sql.Tx, layerID int64, detectors detector.DetectorMap) ([]database.LayerNamespace, error) { + rows, err := tx.Query(findLayerNamespaces, layerID) + if err != nil { + return nil, util.HandleError("findLayerNamespaces", err) + } + + namespaces := []database.LayerNamespace{} + for rows.Next() { + var ( + namespace database.LayerNamespace + detectorID int64 + ) + + if err := rows.Scan(&namespace.Name, &namespace.VersionFormat, &detectorID); err != nil { + return nil, err + } + + namespace.By = detectors.ByID[detectorID] + namespaces = append(namespaces, namespace) + } + + return namespaces, nil +} + +func PersistAllLayerNamespaces(tx *sql.Tx, layerID int64, namespaces []database.LayerNamespace) error { + detectorMap, err := detector.FindAllDetectors(tx) + if err != nil { + return err + } + + // TODO(sidac): This kind of type conversion is very useless and wasteful, + // we need interfaces around the database models to reduce these kind of + // operations. + rawNamespaces := make([]database.Namespace, 0, len(namespaces)) + for _, ns := range namespaces { + rawNamespaces = append(rawNamespaces, ns.Namespace) + } + + rawNamespaceIDs, err := namespace.FindNamespaceIDs(tx, rawNamespaces) + if err != nil { + return err + } + + dbLayerNamespaces := make([]dbLayerNamespace, 0, len(namespaces)) + for i, ns := range namespaces { + detectorID := detectorMap.ByValue[ns.By] + namespaceID := rawNamespaceIDs[i].Int64 + if !rawNamespaceIDs[i].Valid { + return database.ErrMissingEntities + } + + dbLayerNamespaces = append(dbLayerNamespaces, dbLayerNamespace{layerID, namespaceID, detectorID}) + } + + return PersistLayerNamespaces(tx, dbLayerNamespaces) +} + +func PersistLayerNamespaces(tx *sql.Tx, namespaces []dbLayerNamespace) error { + if len(namespaces) == 0 { + return nil + } + + // for every bulk persist operation, the input data should be sorted. + sort.Slice(namespaces, func(i, j int) bool { + return namespaces[i].namespaceID < namespaces[j].namespaceID + }) + + keys := make([]interface{}, 0, len(namespaces)*3) + for _, row := range namespaces { + keys = append(keys, row.layerID, row.namespaceID, row.detectorID) + } + + _, err := tx.Exec(queryPersistLayerNamespace(len(namespaces)), keys...) + if err != nil { + return util.HandleError("queryPersistLayerNamespace", err) + } + + return nil +} diff --git a/database/pgsql/layer_test.go b/database/pgsql/layer/layer_test.go similarity index 58% rename from database/pgsql/layer_test.go rename to database/pgsql/layer/layer_test.go index 9b2be2bb..b310c041 100644 --- a/database/pgsql/layer_test.go +++ b/database/pgsql/layer/layer_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package pgsql +package layer import ( "testing" @@ -20,6 +20,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/coreos/clair/database" + "github.com/coreos/clair/database/pgsql/testutil" ) var persistLayerTests = []struct { @@ -39,9 +40,9 @@ var persistLayerTests = []struct { { title: "layer with inconsistent feature and detectors", name: "random-forest", - by: []database.Detector{realDetectors[2]}, + by: []database.Detector{testutil.RealDetectors[2]}, features: []database.LayerFeature{ - {realFeatures[1], realDetectors[1], database.Namespace{}}, + {testutil.RealFeatures[1], testutil.RealDetectors[1], database.Namespace{}}, }, err: "parameters are not valid", }, @@ -49,70 +50,71 @@ var persistLayerTests = []struct { title: "layer with non-existing feature", name: "random-forest", err: "associated immutable entities are missing in the database", - by: []database.Detector{realDetectors[2]}, + by: []database.Detector{testutil.RealDetectors[2]}, features: []database.LayerFeature{ - {fakeFeatures[1], realDetectors[2], database.Namespace{}}, + {testutil.FakeFeatures[1], testutil.RealDetectors[2], database.Namespace{}}, }, }, { title: "layer with non-existing namespace", name: "random-forest2", err: "associated immutable entities are missing in the database", - by: []database.Detector{realDetectors[1]}, + by: []database.Detector{testutil.RealDetectors[1]}, namespaces: []database.LayerNamespace{ - {fakeNamespaces[1], realDetectors[1]}, + {testutil.FakeNamespaces[1], testutil.RealDetectors[1]}, }, }, { title: "layer with non-existing detector", name: "random-forest3", err: "associated immutable entities are missing in the database", - by: []database.Detector{fakeDetector[1]}, + by: []database.Detector{testutil.FakeDetector[1]}, }, { + title: "valid layer", name: "hamsterhouse", - by: []database.Detector{realDetectors[1], realDetectors[2]}, + by: []database.Detector{testutil.RealDetectors[1], testutil.RealDetectors[2]}, features: []database.LayerFeature{ - {realFeatures[1], realDetectors[2], database.Namespace{}}, - {realFeatures[2], realDetectors[2], database.Namespace{}}, + {testutil.RealFeatures[1], testutil.RealDetectors[2], database.Namespace{}}, + {testutil.RealFeatures[2], testutil.RealDetectors[2], database.Namespace{}}, }, namespaces: []database.LayerNamespace{ - {realNamespaces[1], realDetectors[1]}, + {testutil.RealNamespaces[1], testutil.RealDetectors[1]}, }, layer: &database.Layer{ Hash: "hamsterhouse", - By: []database.Detector{realDetectors[1], realDetectors[2]}, + By: []database.Detector{testutil.RealDetectors[1], testutil.RealDetectors[2]}, Features: []database.LayerFeature{ - {realFeatures[1], realDetectors[2], database.Namespace{}}, - {realFeatures[2], realDetectors[2], database.Namespace{}}, + {testutil.RealFeatures[1], testutil.RealDetectors[2], database.Namespace{}}, + {testutil.RealFeatures[2], testutil.RealDetectors[2], database.Namespace{}}, }, Namespaces: []database.LayerNamespace{ - {realNamespaces[1], realDetectors[1]}, + {testutil.RealNamespaces[1], testutil.RealDetectors[1]}, }, }, }, { title: "update existing layer", name: "layer-1", - by: []database.Detector{realDetectors[3], realDetectors[4]}, + by: []database.Detector{testutil.RealDetectors[3], testutil.RealDetectors[4]}, features: []database.LayerFeature{ - {realFeatures[4], realDetectors[3], database.Namespace{}}, + {testutil.RealFeatures[4], testutil.RealDetectors[3], database.Namespace{}}, }, namespaces: []database.LayerNamespace{ - {realNamespaces[3], realDetectors[4]}, + {testutil.RealNamespaces[3], testutil.RealDetectors[4]}, }, layer: &database.Layer{ Hash: "layer-1", - By: []database.Detector{realDetectors[1], realDetectors[2], realDetectors[3], realDetectors[4]}, + By: []database.Detector{testutil.RealDetectors[1], testutil.RealDetectors[2], testutil.RealDetectors[3], testutil.RealDetectors[4]}, Features: []database.LayerFeature{ - {realFeatures[1], realDetectors[2], database.Namespace{}}, - {realFeatures[2], realDetectors[2], database.Namespace{}}, - {realFeatures[4], realDetectors[3], database.Namespace{}}, + {testutil.RealFeatures[1], testutil.RealDetectors[2], database.Namespace{}}, + {testutil.RealFeatures[2], testutil.RealDetectors[2], database.Namespace{}}, + {testutil.RealFeatures[4], testutil.RealDetectors[3], database.Namespace{}}, }, Namespaces: []database.LayerNamespace{ - {realNamespaces[1], realDetectors[1]}, - {realNamespaces[3], realDetectors[4]}, + {testutil.RealNamespaces[1], testutil.RealDetectors[1]}, + {testutil.RealNamespaces[3], testutil.RealDetectors[4]}, }, }, }, @@ -120,33 +122,33 @@ var persistLayerTests = []struct { { title: "layer with potential namespace", name: "layer-potential-namespace", - by: []database.Detector{realDetectors[3]}, + by: []database.Detector{testutil.RealDetectors[3]}, features: []database.LayerFeature{ - {realFeatures[4], realDetectors[3], realNamespaces[4]}, + {testutil.RealFeatures[4], testutil.RealDetectors[3], testutil.RealNamespaces[4]}, }, namespaces: []database.LayerNamespace{ - {realNamespaces[3], realDetectors[3]}, + {testutil.RealNamespaces[3], testutil.RealDetectors[3]}, }, layer: &database.Layer{ Hash: "layer-potential-namespace", - By: []database.Detector{realDetectors[3]}, + By: []database.Detector{testutil.RealDetectors[3]}, Features: []database.LayerFeature{ - {realFeatures[4], realDetectors[3], realNamespaces[4]}, + {testutil.RealFeatures[4], testutil.RealDetectors[3], testutil.RealNamespaces[4]}, }, Namespaces: []database.LayerNamespace{ - {realNamespaces[3], realDetectors[3]}, + {testutil.RealNamespaces[3], testutil.RealDetectors[3]}, }, }, }, } func TestPersistLayer(t *testing.T) { - datastore, tx := openSessionForTest(t, "PersistLayer", true) - defer closeTest(t, datastore, tx) + tx, cleanup := testutil.CreateTestTxWithFixtures(t, "PersistLayer") + defer cleanup() for _, test := range persistLayerTests { t.Run(test.title, func(t *testing.T) { - err := tx.PersistLayer(test.name, test.features, test.namespaces, test.by) + err := PersistLayer(tx, test.name, test.features, test.namespaces, test.by) if test.err != "" { assert.EqualError(t, err, test.err, "unexpected error") return @@ -154,7 +156,7 @@ func TestPersistLayer(t *testing.T) { assert.Nil(t, err) if test.layer != nil { - layer, ok, err := tx.FindLayer(test.name) + layer, ok, err := FindLayer(tx, test.name) assert.Nil(t, err) assert.True(t, ok) database.AssertLayerEqual(t, test.layer, &layer) @@ -186,17 +188,17 @@ var findLayerTests = []struct { title: "existing layer", in: "layer-4", ok: true, - out: takeLayerPointerFromMap(realLayers, 6), + out: testutil.TakeLayerPointerFromMap(testutil.RealLayers, 6), }, } func TestFindLayer(t *testing.T) { - datastore, tx := openSessionForTest(t, "FindLayer", true) - defer closeTest(t, datastore, tx) + tx, cleanup := testutil.CreateTestTxWithFixtures(t, "FindLayer") + defer cleanup() for _, test := range findLayerTests { t.Run(test.title, func(t *testing.T) { - layer, ok, err := tx.FindLayer(test.in) + layer, ok, err := FindLayer(tx, test.in) if test.err != "" { assert.EqualError(t, err, test.err, "unexpected error") return diff --git a/database/pgsql/lock.go b/database/pgsql/lock/lock.go similarity index 69% rename from database/pgsql/lock.go rename to database/pgsql/lock/lock.go index 0fd73f5b..29a4abe1 100644 --- a/database/pgsql/lock.go +++ b/database/pgsql/lock/lock.go @@ -12,11 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -package pgsql +package lock import ( + "database/sql" "time" + "github.com/coreos/clair/database/pgsql/monitoring" + "github.com/coreos/clair/database/pgsql/util" log "github.com/sirupsen/logrus" ) @@ -38,12 +41,12 @@ const ( SELECT owner, until FROM lock WHERE name = $1` ) -func (tx *pgSession) AcquireLock(lockName, whoami string, desiredDuration time.Duration) (bool, time.Time, error) { +func AcquireLock(tx *sql.Tx, lockName, whoami string, desiredDuration time.Duration) (bool, time.Time, error) { if lockName == "" || whoami == "" || desiredDuration == 0 { panic("invalid lock parameters") } - if err := tx.pruneLocks(); err != nil { + if err := PruneLocks(tx); err != nil { return false, time.Time{}, err } @@ -54,22 +57,22 @@ func (tx *pgSession) AcquireLock(lockName, whoami string, desiredDuration time.D lockOwner string ) - defer observeQueryTime("Lock", "soiLock", time.Now()) + defer monitoring.ObserveQueryTime("Lock", "soiLock", time.Now()) err := tx.QueryRow(soiLock, lockName, whoami, desiredLockedUntil).Scan(&lockOwner, &lockedUntil) - return lockOwner == whoami, lockedUntil, err + return lockOwner == whoami, lockedUntil, util.HandleError("AcquireLock", err) } -func (tx *pgSession) ExtendLock(lockName, whoami string, desiredDuration time.Duration) (bool, time.Time, error) { +func ExtendLock(tx *sql.Tx, lockName, whoami string, desiredDuration time.Duration) (bool, time.Time, error) { if lockName == "" || whoami == "" || desiredDuration == 0 { panic("invalid lock parameters") } desiredLockedUntil := time.Now().Add(desiredDuration) - defer observeQueryTime("Lock", "update", time.Now()) + defer monitoring.ObserveQueryTime("Lock", "update", time.Now()) result, err := tx.Exec(updateLock, lockName, whoami, desiredLockedUntil) if err != nil { - return false, time.Time{}, handleError("updateLock", err) + return false, time.Time{}, util.HandleError("updateLock", err) } if numRows, err := result.RowsAffected(); err == nil { @@ -77,27 +80,27 @@ func (tx *pgSession) ExtendLock(lockName, whoami string, desiredDuration time.Du return numRows > 0, desiredLockedUntil, nil } - return false, time.Time{}, handleError("updateLock", err) + return false, time.Time{}, util.HandleError("updateLock", err) } -func (tx *pgSession) ReleaseLock(name, owner string) error { +func ReleaseLock(tx *sql.Tx, name, owner string) error { if name == "" || owner == "" { panic("invalid lock parameters") } - defer observeQueryTime("Unlock", "all", time.Now()) + defer monitoring.ObserveQueryTime("Unlock", "all", time.Now()) _, err := tx.Exec(removeLock, name, owner) return err } // pruneLocks removes every expired locks from the database -func (tx *pgSession) pruneLocks() error { - defer observeQueryTime("pruneLocks", "all", time.Now()) +func PruneLocks(tx *sql.Tx) error { + defer monitoring.ObserveQueryTime("pruneLocks", "all", time.Now()) if r, err := tx.Exec(removeLockExpired, time.Now().UTC()); err != nil { - return handleError("removeLockExpired", err) + return util.HandleError("removeLockExpired", err) } else if affected, err := r.RowsAffected(); err != nil { - return handleError("removeLockExpired", err) + return util.HandleError("removeLockExpired", err) } else { log.Debugf("Pruned %d Locks", affected) } diff --git a/database/pgsql/lock/lock_test.go b/database/pgsql/lock/lock_test.go new file mode 100644 index 00000000..356825e3 --- /dev/null +++ b/database/pgsql/lock/lock_test.go @@ -0,0 +1,100 @@ +// Copyright 2019 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lock + +import ( + "testing" + "time" + + "github.com/coreos/clair/database/pgsql/testutil" + "github.com/stretchr/testify/require" +) + +func TestAcquireLockReturnsExistingLockDuration(t *testing.T) { + tx, cleanup := testutil.CreateTestTxWithFixtures(t, "Lock") + defer cleanup() + + acquired, originalExpiration, err := AcquireLock(tx, "test1", "owner1", time.Minute) + require.Nil(t, err) + require.True(t, acquired) + + acquired2, expiration, err := AcquireLock(tx, "test1", "owner2", time.Hour) + require.Nil(t, err) + require.False(t, acquired2) + require.Equal(t, expiration, originalExpiration) +} + +func TestLock(t *testing.T) { + db, cleanup := testutil.CreateTestDBWithFixture(t, "Lock") + defer cleanup() + + tx, err := db.Begin() + if err != nil { + panic(err) + } + + // Create a first lock. + l, _, err := AcquireLock(tx, "test1", "owner1", time.Minute) + require.Nil(t, err) + require.True(t, l) + tx = testutil.RestartTransaction(db, tx, true) + + // lock again by itself, the previous lock is not expired yet. + l, _, err = AcquireLock(tx, "test1", "owner1", time.Minute) + require.Nil(t, err) + require.True(t, l) + tx = testutil.RestartTransaction(db, tx, false) + + // Try to renew the same lock with another owner. + l, _, err = ExtendLock(tx, "test1", "owner2", time.Minute) + require.Nil(t, err) + require.False(t, l) + tx = testutil.RestartTransaction(db, tx, false) + + l, _, err = AcquireLock(tx, "test1", "owner2", time.Minute) + require.Nil(t, err) + require.False(t, l) + tx = testutil.RestartTransaction(db, tx, false) + + // Renew the lock. + l, _, err = ExtendLock(tx, "test1", "owner1", 2*time.Minute) + require.Nil(t, err) + require.True(t, l) + tx = testutil.RestartTransaction(db, tx, true) + + // Unlock and then relock by someone else. + err = ReleaseLock(tx, "test1", "owner1") + require.Nil(t, err) + tx = testutil.RestartTransaction(db, tx, true) + + l, _, err = AcquireLock(tx, "test1", "owner2", time.Minute) + require.Nil(t, err) + require.True(t, l) + tx = testutil.RestartTransaction(db, tx, true) + + // Create a second lock which is actually already expired ... + l, _, err = AcquireLock(tx, "test2", "owner1", -time.Minute) + require.Nil(t, err) + require.True(t, l) + tx = testutil.RestartTransaction(db, tx, true) + + // Take over the lock + l, _, err = AcquireLock(tx, "test2", "owner2", time.Minute) + require.Nil(t, err) + require.True(t, l) + tx = testutil.RestartTransaction(db, tx, true) + + require.Nil(t, tx.Rollback()) +} diff --git a/database/pgsql/lock_test.go b/database/pgsql/lock_test.go deleted file mode 100644 index 538961b6..00000000 --- a/database/pgsql/lock_test.go +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright 2019 clair authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package pgsql - -import ( - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestAcquireLockReturnsExistingLockDuration(t *testing.T) { - tx, cleanup := createTestPgSessionWithFixtures(t, "Lock") - defer cleanup() - - acquired, originalExpiration, err := tx.AcquireLock("test1", "owner1", time.Minute) - require.Nil(t, err) - require.True(t, acquired) - - acquired2, expiration, err := tx.AcquireLock("test1", "owner2", time.Hour) - require.Nil(t, err) - require.False(t, acquired2) - require.Equal(t, expiration, originalExpiration) -} - -func TestLock(t *testing.T) { - datastore, tx := openSessionForTest(t, "Lock", true) - defer datastore.Close() - - var l bool - - // Create a first lock. - l, _, err := tx.AcquireLock("test1", "owner1", time.Minute) - assert.Nil(t, err) - assert.True(t, l) - tx = restartSession(t, datastore, tx, true) - - // lock again by itself, the previous lock is not expired yet. - l, _, err = tx.AcquireLock("test1", "owner1", time.Minute) - assert.Nil(t, err) - assert.True(t, l) // acquire lock no-op when owner already has the lock. - tx = restartSession(t, datastore, tx, false) - - // Try to renew the same lock with another owner. - l, _, err = tx.ExtendLock("test1", "owner2", time.Minute) - assert.Nil(t, err) - assert.False(t, l) - tx = restartSession(t, datastore, tx, false) - - l, _, err = tx.AcquireLock("test1", "owner2", time.Minute) - assert.Nil(t, err) - assert.False(t, l) - tx = restartSession(t, datastore, tx, false) - - // Renew the lock. - l, _, err = tx.ExtendLock("test1", "owner1", 2*time.Minute) - assert.Nil(t, err) - assert.True(t, l) - tx = restartSession(t, datastore, tx, true) - - // Unlock and then relock by someone else. - err = tx.ReleaseLock("test1", "owner1") - assert.Nil(t, err) - tx = restartSession(t, datastore, tx, true) - - l, _, err = tx.AcquireLock("test1", "owner2", time.Minute) - assert.Nil(t, err) - assert.True(t, l) - tx = restartSession(t, datastore, tx, true) - - // Create a second lock which is actually already expired ... - l, _, err = tx.AcquireLock("test2", "owner1", -time.Minute) - assert.Nil(t, err) - assert.True(t, l) - tx = restartSession(t, datastore, tx, true) - - // Take over the lock - l, _, err = tx.AcquireLock("test2", "owner2", time.Minute) - assert.Nil(t, err) - assert.True(t, l) - tx = restartSession(t, datastore, tx, true) - - if !assert.Nil(t, tx.Rollback()) { - t.FailNow() - } -} diff --git a/database/pgsql/migrations_test.go b/database/pgsql/migrations/migrations_test.go similarity index 91% rename from database/pgsql/migrations_test.go rename to database/pgsql/migrations/migrations_test.go index e3b2eb30..a324f03f 100644 --- a/database/pgsql/migrations_test.go +++ b/database/pgsql/migrations/migrations_test.go @@ -12,22 +12,22 @@ // See the License for the specific language governing permissions and // limitations under the License. -package pgsql +package migrations_test import ( "testing" + "github.com/coreos/clair/database/pgsql/migrations" + "github.com/coreos/clair/database/pgsql/testutil" _ "github.com/lib/pq" "github.com/remind101/migrate" "github.com/stretchr/testify/require" - - "github.com/coreos/clair/database/pgsql/migrations" ) var userTableCount = `SELECT tablename FROM pg_catalog.pg_tables WHERE schemaname='public'` func TestMigration(t *testing.T) { - db, cleanup := createAndConnectTestDB(t, "TestMigration") + db, cleanup := testutil.CreateAndConnectTestDB(t, "TestMigration") defer cleanup() err := migrate.NewPostgresMigrator(db).Exec(migrate.Up, migrations.Migrations...) diff --git a/database/pgsql/monitoring/prometheus.go b/database/pgsql/monitoring/prometheus.go new file mode 100644 index 00000000..ccd1970b --- /dev/null +++ b/database/pgsql/monitoring/prometheus.go @@ -0,0 +1,67 @@ +// Copyright 2019 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package monitoring + +import ( + "time" + + "github.com/prometheus/client_golang/prometheus" +) + +var ( + PromErrorsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{ + Name: "clair_pgsql_errors_total", + Help: "Number of errors that PostgreSQL requests generated.", + }, []string{"request"}) + + PromCacheHitsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{ + Name: "clair_pgsql_cache_hits_total", + Help: "Number of cache hits that the PostgreSQL backend did.", + }, []string{"object"}) + + PromCacheQueriesTotal = prometheus.NewCounterVec(prometheus.CounterOpts{ + Name: "clair_pgsql_cache_queries_total", + Help: "Number of cache queries that the PostgreSQL backend did.", + }, []string{"object"}) + + PromQueryDurationMilliseconds = prometheus.NewHistogramVec(prometheus.HistogramOpts{ + Name: "clair_pgsql_query_duration_milliseconds", + Help: "Time it takes to execute the database query.", + }, []string{"query", "subquery"}) + + PromConcurrentLockVAFV = prometheus.NewGauge(prometheus.GaugeOpts{ + Name: "clair_pgsql_concurrent_lock_vafv_total", + Help: "Number of transactions trying to hold the exclusive Vulnerability_Affects_Feature lock.", + }) +) + +func init() { + prometheus.MustRegister(PromErrorsTotal) + prometheus.MustRegister(PromCacheHitsTotal) + prometheus.MustRegister(PromCacheQueriesTotal) + prometheus.MustRegister(PromQueryDurationMilliseconds) + prometheus.MustRegister(PromConcurrentLockVAFV) +} + +// monitoring.ObserveQueryTime computes the time elapsed since `start` to represent the +// query time. +// 1. `query` is a pgSession function name. +// 2. `subquery` is a specific query or a batched query. +// 3. `start` is the time right before query is executed. +func ObserveQueryTime(query, subquery string, start time.Time) { + PromQueryDurationMilliseconds. + WithLabelValues(query, subquery). + Observe(float64(time.Since(start).Nanoseconds()) / float64(time.Millisecond)) +} diff --git a/database/pgsql/namespace.go b/database/pgsql/namespace/namespace.go similarity index 73% rename from database/pgsql/namespace.go rename to database/pgsql/namespace/namespace.go index 87d25e33..57ba6ccc 100644 --- a/database/pgsql/namespace.go +++ b/database/pgsql/namespace/namespace.go @@ -12,13 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -package pgsql +package namespace import ( "database/sql" + "fmt" "sort" "github.com/coreos/clair/database" + "github.com/coreos/clair/database/pgsql/util" "github.com/coreos/clair/pkg/commonerr" ) @@ -26,8 +28,24 @@ const ( searchNamespaceID = `SELECT id FROM Namespace WHERE name = $1 AND version_format = $2` ) +func queryPersistNamespace(count int) string { + return util.QueryPersist(count, + "namespace", + "namespace_name_version_format_key", + "name", + "version_format") +} + +func querySearchNamespace(nsCount int) string { + return fmt.Sprintf( + `SELECT id, name, version_format + FROM namespace WHERE (name, version_format) IN (%s)`, + util.QueryString(2, nsCount), + ) +} + // PersistNamespaces soi namespaces into database. -func (tx *pgSession) PersistNamespaces(namespaces []database.Namespace) error { +func PersistNamespaces(tx *sql.Tx, namespaces []database.Namespace) error { if len(namespaces) == 0 { return nil } @@ -49,12 +67,12 @@ func (tx *pgSession) PersistNamespaces(namespaces []database.Namespace) error { _, err := tx.Exec(queryPersistNamespace(len(namespaces)), keys...) if err != nil { - return handleError("queryPersistNamespace", err) + return util.HandleError("queryPersistNamespace", err) } return nil } -func (tx *pgSession) findNamespaceIDs(namespaces []database.Namespace) ([]sql.NullInt64, error) { +func FindNamespaceIDs(tx *sql.Tx, namespaces []database.Namespace) ([]sql.NullInt64, error) { if len(namespaces) == 0 { return nil, nil } @@ -69,7 +87,7 @@ func (tx *pgSession) findNamespaceIDs(namespaces []database.Namespace) ([]sql.Nu rows, err := tx.Query(querySearchNamespace(len(namespaces)), keys...) if err != nil { - return nil, handleError("searchNamespace", err) + return nil, util.HandleError("searchNamespace", err) } defer rows.Close() @@ -81,7 +99,7 @@ func (tx *pgSession) findNamespaceIDs(namespaces []database.Namespace) ([]sql.Nu for rows.Next() { err := rows.Scan(&id, &ns.Name, &ns.VersionFormat) if err != nil { - return nil, handleError("searchNamespace", err) + return nil, util.HandleError("searchNamespace", err) } nsMap[ns] = id } diff --git a/database/pgsql/namespace_test.go b/database/pgsql/namespace/namespace_test.go similarity index 67% rename from database/pgsql/namespace_test.go rename to database/pgsql/namespace/namespace_test.go index 8f2af288..2cc9918e 100644 --- a/database/pgsql/namespace_test.go +++ b/database/pgsql/namespace/namespace_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package pgsql +package namespace import ( "testing" @@ -20,25 +20,26 @@ import ( "github.com/stretchr/testify/assert" "github.com/coreos/clair/database" + "github.com/coreos/clair/database/pgsql/testutil" ) func TestPersistNamespaces(t *testing.T) { - datastore, tx := openSessionForTest(t, "PersistNamespaces", false) - defer closeTest(t, datastore, tx) + tx, cleanup := testutil.CreateTestTx(t, "PersistNamespaces") + defer cleanup() ns1 := database.Namespace{} ns2 := database.Namespace{Name: "t", VersionFormat: "b"} // Empty Case - assert.Nil(t, tx.PersistNamespaces([]database.Namespace{})) + assert.Nil(t, PersistNamespaces(tx, []database.Namespace{})) // Invalid Case - assert.NotNil(t, tx.PersistNamespaces([]database.Namespace{ns1})) + assert.NotNil(t, PersistNamespaces(tx, []database.Namespace{ns1})) // Duplicated Case - assert.Nil(t, tx.PersistNamespaces([]database.Namespace{ns2, ns2})) + assert.Nil(t, PersistNamespaces(tx, []database.Namespace{ns2, ns2})) // Existing Case - assert.Nil(t, tx.PersistNamespaces([]database.Namespace{ns2})) + assert.Nil(t, PersistNamespaces(tx, []database.Namespace{ns2})) - nsList := listNamespaces(t, tx) + nsList := testutil.ListNamespaces(t, tx) assert.Len(t, nsList, 1) assert.Equal(t, ns2, nsList[0]) } diff --git a/database/pgsql/notification_test.go b/database/pgsql/notification/notification_test.go similarity index 61% rename from database/pgsql/notification_test.go rename to database/pgsql/notification/notification_test.go index da3b3248..bc9d2acc 100644 --- a/database/pgsql/notification_test.go +++ b/database/pgsql/notification/notification_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package pgsql +package notification import ( "testing" @@ -22,6 +22,8 @@ import ( "github.com/stretchr/testify/require" "github.com/coreos/clair/database" + "github.com/coreos/clair/database/pgsql/page" + "github.com/coreos/clair/database/pgsql/testutil" "github.com/coreos/clair/pkg/pagination" ) @@ -38,6 +40,8 @@ type findVulnerabilityNotificationOut struct { err string } +var testPaginationKey = pagination.Must(pagination.NewKey()) + var findVulnerabilityNotificationTests = []struct { title string in findVulnerabilityNotificationIn @@ -77,21 +81,21 @@ var findVulnerabilityNotificationTests = []struct { }, out: findVulnerabilityNotificationOut{ &database.VulnerabilityNotificationWithVulnerable{ - NotificationHook: realNotification[1].NotificationHook, + NotificationHook: testutil.RealNotification[1].NotificationHook, Old: &database.PagedVulnerableAncestries{ - Vulnerability: realVulnerability[2], + Vulnerability: testutil.RealVulnerability[2], Limit: 1, Affected: make(map[int]string), - Current: mustMarshalToken(testPaginationKey, Page{0}), - Next: mustMarshalToken(testPaginationKey, Page{0}), + Current: testutil.MustMarshalToken(testutil.TestPaginationKey, page.Page{0}), + Next: testutil.MustMarshalToken(testutil.TestPaginationKey, page.Page{0}), End: true, }, New: &database.PagedVulnerableAncestries{ - Vulnerability: realVulnerability[1], + Vulnerability: testutil.RealVulnerability[1], Limit: 1, Affected: map[int]string{3: "ancestry-3"}, - Current: mustMarshalToken(testPaginationKey, Page{0}), - Next: mustMarshalToken(testPaginationKey, Page{4}), + Current: testutil.MustMarshalToken(testutil.TestPaginationKey, page.Page{0}), + Next: testutil.MustMarshalToken(testutil.TestPaginationKey, page.Page{4}), End: false, }, }, @@ -100,32 +104,31 @@ var findVulnerabilityNotificationTests = []struct { "", }, }, - { title: "find existing notification of second page of new affected ancestry", in: findVulnerabilityNotificationIn{ notificationName: "test", pageSize: 1, oldAffectedAncestryPage: pagination.FirstPageToken, - newAffectedAncestryPage: mustMarshalToken(testPaginationKey, Page{4}), + newAffectedAncestryPage: testutil.MustMarshalToken(testutil.TestPaginationKey, page.Page{4}), }, out: findVulnerabilityNotificationOut{ &database.VulnerabilityNotificationWithVulnerable{ - NotificationHook: realNotification[1].NotificationHook, + NotificationHook: testutil.RealNotification[1].NotificationHook, Old: &database.PagedVulnerableAncestries{ - Vulnerability: realVulnerability[2], + Vulnerability: testutil.RealVulnerability[2], Limit: 1, Affected: make(map[int]string), - Current: mustMarshalToken(testPaginationKey, Page{0}), - Next: mustMarshalToken(testPaginationKey, Page{0}), + Current: testutil.MustMarshalToken(testutil.TestPaginationKey, page.Page{0}), + Next: testutil.MustMarshalToken(testutil.TestPaginationKey, page.Page{0}), End: true, }, New: &database.PagedVulnerableAncestries{ - Vulnerability: realVulnerability[1], + Vulnerability: testutil.RealVulnerability[1], Limit: 1, Affected: map[int]string{4: "ancestry-4"}, - Current: mustMarshalToken(testPaginationKey, Page{4}), - Next: mustMarshalToken(testPaginationKey, Page{0}), + Current: testutil.MustMarshalToken(testutil.TestPaginationKey, page.Page{4}), + Next: testutil.MustMarshalToken(testutil.TestPaginationKey, page.Page{0}), End: true, }, }, @@ -137,12 +140,12 @@ var findVulnerabilityNotificationTests = []struct { } func TestFindVulnerabilityNotification(t *testing.T) { - datastore, tx := openSessionForTest(t, "pagination", true) - defer closeTest(t, datastore, tx) + tx, cleanup := testutil.CreateTestTxWithFixtures(t, "pagination") + defer cleanup() for _, test := range findVulnerabilityNotificationTests { t.Run(test.title, func(t *testing.T) { - notification, ok, err := tx.FindVulnerabilityNotification(test.in.notificationName, test.in.pageSize, test.in.oldAffectedAncestryPage, test.in.newAffectedAncestryPage) + notification, ok, err := FindVulnerabilityNotification(tx, test.in.notificationName, test.in.pageSize, test.in.oldAffectedAncestryPage, test.in.newAffectedAncestryPage, testutil.TestPaginationKey) if test.out.err != "" { require.EqualError(t, err, test.out.err) return @@ -155,13 +158,14 @@ func TestFindVulnerabilityNotification(t *testing.T) { } require.True(t, ok) - assertVulnerabilityNotificationWithVulnerableEqual(t, testPaginationKey, test.out.notification, ¬ification) + testutil.AssertVulnerabilityNotificationWithVulnerableEqual(t, testutil.TestPaginationKey, test.out.notification, ¬ification) }) } } func TestInsertVulnerabilityNotifications(t *testing.T) { - datastore, tx := openSessionForTest(t, "InsertVulnerabilityNotifications", true) + datastore, cleanup := testutil.CreateTestDBWithFixture(t, "InsertVulnerabilityNotifications") + defer cleanup() n1 := database.VulnerabilityNotification{} n3 := database.VulnerabilityNotification{ @@ -187,34 +191,37 @@ func TestInsertVulnerabilityNotifications(t *testing.T) { }, } + tx, err := datastore.Begin() + require.Nil(t, err) + // invalid case - err := tx.InsertVulnerabilityNotifications([]database.VulnerabilityNotification{n1}) - assert.NotNil(t, err) + err = InsertVulnerabilityNotifications(tx, []database.VulnerabilityNotification{n1}) + require.NotNil(t, err) // invalid case: unknown vulnerability - err = tx.InsertVulnerabilityNotifications([]database.VulnerabilityNotification{n3}) - assert.NotNil(t, err) + err = InsertVulnerabilityNotifications(tx, []database.VulnerabilityNotification{n3}) + require.NotNil(t, err) // invalid case: duplicated input notification - err = tx.InsertVulnerabilityNotifications([]database.VulnerabilityNotification{n4, n4}) - assert.NotNil(t, err) - tx = restartSession(t, datastore, tx, false) + err = InsertVulnerabilityNotifications(tx, []database.VulnerabilityNotification{n4, n4}) + require.NotNil(t, err) + tx = testutil.RestartTransaction(datastore, tx, false) // valid case - err = tx.InsertVulnerabilityNotifications([]database.VulnerabilityNotification{n4}) - assert.Nil(t, err) + err = InsertVulnerabilityNotifications(tx, []database.VulnerabilityNotification{n4}) + require.Nil(t, err) // invalid case: notification is already in database - err = tx.InsertVulnerabilityNotifications([]database.VulnerabilityNotification{n4}) - assert.NotNil(t, err) + err = InsertVulnerabilityNotifications(tx, []database.VulnerabilityNotification{n4}) + require.NotNil(t, err) - closeTest(t, datastore, tx) + require.Nil(t, tx.Rollback()) } func TestFindNewNotification(t *testing.T) { - tx, cleanup := createTestPgSessionWithFixtures(t, "TestFindNewNotification") + tx, cleanup := testutil.CreateTestTxWithFixtures(t, "TestFindNewNotification") defer cleanup() - noti, ok, err := tx.FindNewNotification(time.Now()) + noti, ok, err := FindNewNotification(tx, time.Now()) if assert.Nil(t, err) && assert.True(t, ok) { assert.Equal(t, "test", noti.Name) assert.Equal(t, time.Time{}, noti.Notified) @@ -223,13 +230,13 @@ func TestFindNewNotification(t *testing.T) { } // can't find the notified - assert.Nil(t, tx.MarkNotificationAsRead("test")) + assert.Nil(t, MarkNotificationAsRead(tx, "test")) // if the notified time is before - noti, ok, err = tx.FindNewNotification(time.Now().Add(-time.Duration(10 * time.Second))) + noti, ok, err = FindNewNotification(tx, time.Now().Add(-time.Duration(10*time.Second))) assert.Nil(t, err) assert.False(t, ok) // can find the notified after a period of time - noti, ok, err = tx.FindNewNotification(time.Now().Add(time.Duration(10 * time.Second))) + noti, ok, err = FindNewNotification(tx, time.Now().Add(time.Duration(10*time.Second))) if assert.Nil(t, err) && assert.True(t, ok) { assert.Equal(t, "test", noti.Name) assert.NotEqual(t, time.Time{}, noti.Notified) @@ -237,37 +244,37 @@ func TestFindNewNotification(t *testing.T) { assert.Equal(t, time.Time{}, noti.Deleted) } - assert.Nil(t, tx.DeleteNotification("test")) + assert.Nil(t, DeleteNotification(tx, "test")) // can't find in any time - noti, ok, err = tx.FindNewNotification(time.Now().Add(-time.Duration(1000))) + noti, ok, err = FindNewNotification(tx, time.Now().Add(-time.Duration(1000))) assert.Nil(t, err) assert.False(t, ok) - noti, ok, err = tx.FindNewNotification(time.Now().Add(time.Duration(1000))) + noti, ok, err = FindNewNotification(tx, time.Now().Add(time.Duration(1000))) assert.Nil(t, err) assert.False(t, ok) } func TestMarkNotificationAsRead(t *testing.T) { - datastore, tx := openSessionForTest(t, "MarkNotificationAsRead", true) - defer closeTest(t, datastore, tx) + tx, cleanup := testutil.CreateTestTxWithFixtures(t, "MarkNotificationAsRead") + defer cleanup() // invalid case: notification doesn't exist - assert.NotNil(t, tx.MarkNotificationAsRead("non-existing")) + assert.NotNil(t, MarkNotificationAsRead(tx, "non-existing")) // valid case - assert.Nil(t, tx.MarkNotificationAsRead("test")) + assert.Nil(t, MarkNotificationAsRead(tx, "test")) // valid case - assert.Nil(t, tx.MarkNotificationAsRead("test")) + assert.Nil(t, MarkNotificationAsRead(tx, "test")) } func TestDeleteNotification(t *testing.T) { - datastore, tx := openSessionForTest(t, "DeleteNotification", true) - defer closeTest(t, datastore, tx) + tx, cleanup := testutil.CreateTestTxWithFixtures(t, "DeleteNotification") + defer cleanup() // invalid case: notification doesn't exist - assert.NotNil(t, tx.DeleteNotification("non-existing")) + assert.NotNil(t, DeleteNotification(tx, "non-existing")) // valid case - assert.Nil(t, tx.DeleteNotification("test")) + assert.Nil(t, DeleteNotification(tx, "test")) // invalid case: notification is already deleted - assert.NotNil(t, tx.DeleteNotification("test")) + assert.NotNil(t, DeleteNotification(tx, "test")) } diff --git a/database/pgsql/notification.go b/database/pgsql/notification/vulnerability_notification.go similarity index 60% rename from database/pgsql/notification.go rename to database/pgsql/notification/vulnerability_notification.go index 7d2b750d..e0ac3a1c 100644 --- a/database/pgsql/notification.go +++ b/database/pgsql/notification/vulnerability_notification.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package pgsql +package notification import ( "database/sql" @@ -22,6 +22,8 @@ import ( "github.com/guregu/null/zero" "github.com/coreos/clair/database" + "github.com/coreos/clair/database/pgsql/util" + "github.com/coreos/clair/database/pgsql/vulnerability" "github.com/coreos/clair/pkg/commonerr" "github.com/coreos/clair/pkg/pagination" ) @@ -54,26 +56,24 @@ const ( SELECT created_at, notified_at, deleted_at, old_vulnerability_id, new_vulnerability_id FROM Vulnerability_Notification WHERE name = $1` - - searchNotificationVulnerableAncestry = ` - SELECT DISTINCT ON (a.id) - a.id, a.name - FROM vulnerability_affected_namespaced_feature AS vanf, - ancestry_layer AS al, ancestry_feature AS af, ancestry AS a - WHERE vanf.vulnerability_id = $1 - AND a.id >= $2 - AND al.ancestry_id = a.id - AND al.id = af.ancestry_layer_id - AND af.namespaced_feature_id = vanf.namespaced_feature_id - ORDER BY a.id ASC - LIMIT $3;` ) +func queryInsertNotifications(count int) string { + return util.QueryInsert(count, + "vulnerability_notification", + "name", + "created_at", + "old_vulnerability_id", + "new_vulnerability_id", + ) +} + var ( - errNotificationNotFound = errors.New("requested notification is not found") + errNotificationNotFound = errors.New("requested notification is not found") + errVulnerabilityNotFound = errors.New("vulnerability is not in database") ) -func (tx *pgSession) InsertVulnerabilityNotifications(notifications []database.VulnerabilityNotification) error { +func InsertVulnerabilityNotifications(tx *sql.Tx, notifications []database.VulnerabilityNotification) error { if len(notifications) == 0 { return nil } @@ -122,26 +122,26 @@ func (tx *pgSession) InsertVulnerabilityNotifications(notifications []database.V oldVulnIDs = append(oldVulnIDs, vulnID) } - ids, err := tx.findNotDeletedVulnerabilityIDs(newVulnIDs) + ids, err := vulnerability.FindNotDeletedVulnerabilityIDs(tx, newVulnIDs) if err != nil { return err } for i, id := range ids { if !id.Valid { - return handleError("findNotDeletedVulnerabilityIDs", errVulnerabilityNotFound) + return util.HandleError("findNotDeletedVulnerabilityIDs", errVulnerabilityNotFound) } newVulnIDMap[newVulnIDs[i]] = id } - ids, err = tx.findLatestDeletedVulnerabilityIDs(oldVulnIDs) + ids, err = vulnerability.FindLatestDeletedVulnerabilityIDs(tx, oldVulnIDs) if err != nil { return err } for i, id := range ids { if !id.Valid { - return handleError("findLatestDeletedVulnerabilityIDs", errVulnerabilityNotFound) + return util.HandleError("findLatestDeletedVulnerabilityIDs", errVulnerabilityNotFound) } oldVulnIDMap[oldVulnIDs[i]] = id } @@ -178,13 +178,13 @@ func (tx *pgSession) InsertVulnerabilityNotifications(notifications []database.V // multiple updaters, deadlock may happen. _, err = tx.Exec(queryInsertNotifications(len(notifications)), keys...) if err != nil { - return handleError("queryInsertNotifications", err) + return util.HandleError("queryInsertNotifications", err) } return nil } -func (tx *pgSession) FindNewNotification(notifiedBefore time.Time) (database.NotificationHook, bool, error) { +func FindNewNotification(tx *sql.Tx, notifiedBefore time.Time) (database.NotificationHook, bool, error) { var ( notification database.NotificationHook created zero.Time @@ -197,7 +197,7 @@ func (tx *pgSession) FindNewNotification(notifiedBefore time.Time) (database.Not if err == sql.ErrNoRows { return notification, false, nil } - return notification, false, handleError("searchNotificationAvailable", err) + return notification, false, util.HandleError("searchNotificationAvailable", err) } notification.Created = created.Time @@ -207,71 +207,7 @@ func (tx *pgSession) FindNewNotification(notifiedBefore time.Time) (database.Not return notification, true, nil } -func (tx *pgSession) findPagedVulnerableAncestries(vulnID int64, limit int, currentToken pagination.Token) (database.PagedVulnerableAncestries, error) { - vulnPage := database.PagedVulnerableAncestries{Limit: limit} - currentPage := Page{0} - if currentToken != pagination.FirstPageToken { - if err := tx.key.UnmarshalToken(currentToken, ¤tPage); err != nil { - return vulnPage, err - } - } - - if err := tx.QueryRow(searchVulnerabilityByID, vulnID).Scan( - &vulnPage.Name, - &vulnPage.Description, - &vulnPage.Link, - &vulnPage.Severity, - &vulnPage.Metadata, - &vulnPage.Namespace.Name, - &vulnPage.Namespace.VersionFormat, - ); err != nil { - return vulnPage, handleError("searchVulnerabilityByID", err) - } - - // the last result is used for the next page's startID - rows, err := tx.Query(searchNotificationVulnerableAncestry, vulnID, currentPage.StartID, limit+1) - if err != nil { - return vulnPage, handleError("searchNotificationVulnerableAncestry", err) - } - defer rows.Close() - - ancestries := []affectedAncestry{} - for rows.Next() { - var ancestry affectedAncestry - err := rows.Scan(&ancestry.id, &ancestry.name) - if err != nil { - return vulnPage, handleError("searchNotificationVulnerableAncestry", err) - } - ancestries = append(ancestries, ancestry) - } - - lastIndex := 0 - if len(ancestries)-1 < limit { - lastIndex = len(ancestries) - vulnPage.End = true - } else { - // Use the last ancestry's ID as the next page. - lastIndex = len(ancestries) - 1 - vulnPage.Next, err = tx.key.MarshalToken(Page{ancestries[len(ancestries)-1].id}) - if err != nil { - return vulnPage, err - } - } - - vulnPage.Affected = map[int]string{} - for _, ancestry := range ancestries[0:lastIndex] { - vulnPage.Affected[int(ancestry.id)] = ancestry.name - } - - vulnPage.Current, err = tx.key.MarshalToken(currentPage) - if err != nil { - return vulnPage, err - } - - return vulnPage, nil -} - -func (tx *pgSession) FindVulnerabilityNotification(name string, limit int, oldPageToken pagination.Token, newPageToken pagination.Token) ( +func FindVulnerabilityNotification(tx *sql.Tx, name string, limit int, oldPageToken pagination.Token, newPageToken pagination.Token, key pagination.Key) ( database.VulnerabilityNotificationWithVulnerable, bool, error) { var ( noti database.VulnerabilityNotificationWithVulnerable @@ -294,7 +230,7 @@ func (tx *pgSession) FindVulnerabilityNotification(name string, limit int, oldPa if err == sql.ErrNoRows { return noti, false, nil } - return noti, false, handleError("searchNotification", err) + return noti, false, util.HandleError("searchNotification", err) } if created.Valid { @@ -310,7 +246,7 @@ func (tx *pgSession) FindVulnerabilityNotification(name string, limit int, oldPa } if oldVulnID.Valid { - page, err := tx.findPagedVulnerableAncestries(oldVulnID.Int64, limit, oldPageToken) + page, err := vulnerability.FindPagedVulnerableAncestries(tx, oldVulnID.Int64, limit, oldPageToken, key) if err != nil { return noti, false, err } @@ -318,7 +254,7 @@ func (tx *pgSession) FindVulnerabilityNotification(name string, limit int, oldPa } if newVulnID.Valid { - page, err := tx.findPagedVulnerableAncestries(newVulnID.Int64, limit, newPageToken) + page, err := vulnerability.FindPagedVulnerableAncestries(tx, newVulnID.Int64, limit, newPageToken, key) if err != nil { return noti, false, err } @@ -328,44 +264,44 @@ func (tx *pgSession) FindVulnerabilityNotification(name string, limit int, oldPa return noti, true, nil } -func (tx *pgSession) MarkNotificationAsRead(name string) error { +func MarkNotificationAsRead(tx *sql.Tx, name string) error { if name == "" { return commonerr.NewBadRequestError("Empty notification name is not allowed") } r, err := tx.Exec(updatedNotificationAsRead, name) if err != nil { - return handleError("updatedNotificationAsRead", err) + return util.HandleError("updatedNotificationAsRead", err) } affected, err := r.RowsAffected() if err != nil { - return handleError("updatedNotificationAsRead", err) + return util.HandleError("updatedNotificationAsRead", err) } if affected <= 0 { - return handleError("updatedNotificationAsRead", errNotificationNotFound) + return util.HandleError("updatedNotificationAsRead", errNotificationNotFound) } return nil } -func (tx *pgSession) DeleteNotification(name string) error { +func DeleteNotification(tx *sql.Tx, name string) error { if name == "" { return commonerr.NewBadRequestError("Empty notification name is not allowed") } result, err := tx.Exec(removeNotification, name) if err != nil { - return handleError("removeNotification", err) + return util.HandleError("removeNotification", err) } affected, err := result.RowsAffected() if err != nil { - return handleError("removeNotification", err) + return util.HandleError("removeNotification", err) } if affected <= 0 { - return handleError("removeNotification", commonerr.ErrNotFound) + return util.HandleError("removeNotification", commonerr.ErrNotFound) } return nil diff --git a/database/pgsql/page/page.go b/database/pgsql/page/page.go new file mode 100644 index 00000000..3ec17e8e --- /dev/null +++ b/database/pgsql/page/page.go @@ -0,0 +1,24 @@ +// Copyright 2019 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package page + +// Page is the representation of a page for the Postgres schema. +type Page struct { + // StartID is the ID being used as the basis for pagination across database + // results. It is used to search for an ancestry with ID >= StartID. + // + // StartID is required to be unique to every ancestry and always increasing. + StartID int64 +} diff --git a/database/pgsql/pgsession.go b/database/pgsql/pgsession.go new file mode 100644 index 00000000..646a768c --- /dev/null +++ b/database/pgsql/pgsession.go @@ -0,0 +1,134 @@ +// Copyright 2019 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pgsql + +import ( + "database/sql" + "time" + + "github.com/coreos/clair/database/pgsql/keyvalue" + "github.com/coreos/clair/database/pgsql/vulnerability" + + "github.com/coreos/clair/database" + "github.com/coreos/clair/database/pgsql/ancestry" + "github.com/coreos/clair/database/pgsql/detector" + "github.com/coreos/clair/database/pgsql/feature" + "github.com/coreos/clair/database/pgsql/layer" + "github.com/coreos/clair/database/pgsql/lock" + "github.com/coreos/clair/database/pgsql/namespace" + "github.com/coreos/clair/database/pgsql/notification" + "github.com/coreos/clair/pkg/pagination" +) + +// Enforce the interface at compile time. +var _ database.Session = &pgSession{} + +type pgSession struct { + *sql.Tx + + key pagination.Key +} + +func (tx *pgSession) UpsertAncestry(a database.Ancestry) error { + return ancestry.UpsertAncestry(tx.Tx, a) +} + +func (tx *pgSession) FindAncestry(name string) (database.Ancestry, bool, error) { + return ancestry.FindAncestry(tx.Tx, name) +} + +func (tx *pgSession) PersistDetectors(detectors []database.Detector) error { + return detector.PersistDetectors(tx.Tx, detectors) +} + +func (tx *pgSession) PersistFeatures(features []database.Feature) error { + return feature.PersistFeatures(tx.Tx, features) +} + +func (tx *pgSession) PersistNamespacedFeatures(features []database.NamespacedFeature) error { + return feature.PersistNamespacedFeatures(tx.Tx, features) +} + +func (tx *pgSession) CacheAffectedNamespacedFeatures(features []database.NamespacedFeature) error { + return vulnerability.CacheAffectedNamespacedFeatures(tx.Tx, features) +} + +func (tx *pgSession) FindAffectedNamespacedFeatures(features []database.NamespacedFeature) ([]database.NullableAffectedNamespacedFeature, error) { + return vulnerability.FindAffectedNamespacedFeatures(tx.Tx, features) +} + +func (tx *pgSession) PersistNamespaces(namespaces []database.Namespace) error { + return namespace.PersistNamespaces(tx.Tx, namespaces) +} + +func (tx *pgSession) PersistLayer(hash string, features []database.LayerFeature, namespaces []database.LayerNamespace, detectedBy []database.Detector) error { + return layer.PersistLayer(tx.Tx, hash, features, namespaces, detectedBy) +} + +func (tx *pgSession) FindLayer(hash string) (database.Layer, bool, error) { + return layer.FindLayer(tx.Tx, hash) +} + +func (tx *pgSession) InsertVulnerabilities(vulns []database.VulnerabilityWithAffected) error { + return vulnerability.InsertVulnerabilities(tx.Tx, vulns) +} + +func (tx *pgSession) FindVulnerabilities(ids []database.VulnerabilityID) ([]database.NullableVulnerability, error) { + return vulnerability.FindVulnerabilities(tx.Tx, ids) +} + +func (tx *pgSession) DeleteVulnerabilities(ids []database.VulnerabilityID) error { + return vulnerability.DeleteVulnerabilities(tx.Tx, ids) +} + +func (tx *pgSession) InsertVulnerabilityNotifications(notifications []database.VulnerabilityNotification) error { + return notification.InsertVulnerabilityNotifications(tx.Tx, notifications) +} + +func (tx *pgSession) FindNewNotification(notifiedBefore time.Time) (hook database.NotificationHook, found bool, err error) { + return notification.FindNewNotification(tx.Tx, notifiedBefore) +} + +func (tx *pgSession) FindVulnerabilityNotification(name string, limit int, oldVulnerabilityPage pagination.Token, newVulnerabilityPage pagination.Token) (noti database.VulnerabilityNotificationWithVulnerable, found bool, err error) { + return notification.FindVulnerabilityNotification(tx.Tx, name, limit, oldVulnerabilityPage, newVulnerabilityPage, tx.key) +} + +func (tx *pgSession) MarkNotificationAsRead(name string) error { + return notification.MarkNotificationAsRead(tx.Tx, name) +} + +func (tx *pgSession) DeleteNotification(name string) error { + return notification.DeleteNotification(tx.Tx, name) +} + +func (tx *pgSession) UpdateKeyValue(key, value string) error { + return keyvalue.UpdateKeyValue(tx.Tx, key, value) +} + +func (tx *pgSession) FindKeyValue(key string) (value string, found bool, err error) { + return keyvalue.FindKeyValue(tx.Tx, key) +} + +func (tx *pgSession) AcquireLock(name, owner string, duration time.Duration) (acquired bool, expiration time.Time, err error) { + return lock.AcquireLock(tx.Tx, name, owner, duration) +} + +func (tx *pgSession) ExtendLock(name, owner string, duration time.Duration) (extended bool, expiration time.Time, err error) { + return lock.ExtendLock(tx.Tx, name, owner, duration) +} + +func (tx *pgSession) ReleaseLock(name, owner string) error { + return lock.ReleaseLock(tx.Tx, name, owner) +} diff --git a/database/pgsql/pgsession_test.go b/database/pgsql/pgsession_test.go new file mode 100644 index 00000000..8f991882 --- /dev/null +++ b/database/pgsql/pgsession_test.go @@ -0,0 +1,56 @@ +// Copyright 2019 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pgsql + +import ( + "sync" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/coreos/clair/database/pgsql/namespace" + "github.com/coreos/clair/database/pgsql/testutil" +) + +const ( + numVulnerabilities = 100 + numFeatures = 100 +) + +func TestConcurrency(t *testing.T) { + db, cleanup := testutil.CreateTestDB(t, "concurrency") + defer cleanup() + + var wg sync.WaitGroup + // there's a limit on the number of concurrent connections in the pool + wg.Add(30) + for i := 0; i < 30; i++ { + go func() { + defer wg.Done() + nsNamespaces := testutil.GenRandomNamespaces(t, 100) + tx, err := db.Begin() + if err != nil { + panic(err) + } + + assert.Nil(t, namespace.PersistNamespaces(tx, nsNamespaces)) + if err := tx.Commit(); err != nil { + panic(err) + } + }() + } + + wg.Wait() +} diff --git a/database/pgsql/pgsql.go b/database/pgsql/pgsql.go index 4b23e014..0d0919f2 100644 --- a/database/pgsql/pgsql.go +++ b/database/pgsql/pgsql.go @@ -21,13 +21,10 @@ import ( "io/ioutil" "net/url" "strings" - "time" "gopkg.in/yaml.v2" "github.com/hashicorp/golang-lru" - "github.com/lib/pq" - "github.com/prometheus/client_golang/prometheus" "github.com/remind101/migrate" log "github.com/sirupsen/logrus" @@ -37,50 +34,10 @@ import ( "github.com/coreos/clair/pkg/pagination" ) -var ( - promErrorsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{ - Name: "clair_pgsql_errors_total", - Help: "Number of errors that PostgreSQL requests generated.", - }, []string{"request"}) - - promCacheHitsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{ - Name: "clair_pgsql_cache_hits_total", - Help: "Number of cache hits that the PostgreSQL backend did.", - }, []string{"object"}) - - promCacheQueriesTotal = prometheus.NewCounterVec(prometheus.CounterOpts{ - Name: "clair_pgsql_cache_queries_total", - Help: "Number of cache queries that the PostgreSQL backend did.", - }, []string{"object"}) - - promQueryDurationMilliseconds = prometheus.NewHistogramVec(prometheus.HistogramOpts{ - Name: "clair_pgsql_query_duration_milliseconds", - Help: "Time it takes to execute the database query.", - }, []string{"query", "subquery"}) - - promConcurrentLockVAFV = prometheus.NewGauge(prometheus.GaugeOpts{ - Name: "clair_pgsql_concurrent_lock_vafv_total", - Help: "Number of transactions trying to hold the exclusive Vulnerability_Affects_Feature lock.", - }) -) - func init() { - prometheus.MustRegister(promErrorsTotal) - prometheus.MustRegister(promCacheHitsTotal) - prometheus.MustRegister(promCacheQueriesTotal) - prometheus.MustRegister(promQueryDurationMilliseconds) - prometheus.MustRegister(promConcurrentLockVAFV) - database.Register("pgsql", openDatabase) } -// pgSessionCache is the session's cache, which holds the pgSQL's cache and the -// individual session's cache. Only when session.Commit is called, all the -// changes to pgSQL cache will be applied. -type pgSessionCache struct { - c *lru.ARCCache -} - type pgSQL struct { *sql.DB @@ -88,12 +45,6 @@ type pgSQL struct { config Config } -type pgSession struct { - *sql.Tx - - key pagination.Key -} - // Begin initiates a transaction to database. // // The expected transaction isolation level in this implementation is "Read @@ -109,10 +60,6 @@ func (pgSQL *pgSQL) Begin() (database.Session, error) { }, nil } -func (tx *pgSession) Commit() error { - return tx.Tx.Commit() -} - // Close closes the database and destroys if ManageDatabaseLifecycle has been specified in // the configuration. func (pgSQL *pgSQL) Close() { @@ -131,15 +78,6 @@ func (pgSQL *pgSQL) Ping() bool { return pgSQL.DB.Ping() == nil } -// Page is the representation of a page for the Postgres schema. -type Page struct { - // StartID is the ID being used as the basis for pagination across database - // results. It is used to search for an ancestry with ID >= StartID. - // - // StartID is required to be unique to every ancestry and always increasing. - StartID int64 -} - // Config is the configuration that is used by openDatabase. type Config struct { Source string @@ -313,42 +251,3 @@ func dropDatabase(source, dbName string) error { return nil } - -// handleError logs an error with an extra description and masks the error if it's an SQL one. -// The function ensures we never return plain SQL errors and leak anything. -// The function should be used for every database query error. -func handleError(desc string, err error) error { - if err == nil { - return nil - } - - if err == sql.ErrNoRows { - return commonerr.ErrNotFound - } - - log.WithError(err).WithField("Description", desc).Error("database: handled database error") - promErrorsTotal.WithLabelValues(desc).Inc() - - if _, o := err.(*pq.Error); o || err == sql.ErrTxDone || strings.HasPrefix(err.Error(), "sql:") { - return database.ErrBackendException - } - - return err -} - -// isErrUniqueViolation determines is the given error is a unique contraint violation. -func isErrUniqueViolation(err error) bool { - pqErr, ok := err.(*pq.Error) - return ok && pqErr.Code == "23505" -} - -// observeQueryTime computes the time elapsed since `start` to represent the -// query time. -// 1. `query` is a pgSession function name. -// 2. `subquery` is a specific query or a batched query. -// 3. `start` is the time right before query is executed. -func observeQueryTime(query, subquery string, start time.Time) { - promQueryDurationMilliseconds. - WithLabelValues(query, subquery). - Observe(float64(time.Since(start).Nanoseconds()) / float64(time.Millisecond)) -} diff --git a/database/pgsql/pgsql_test.go b/database/pgsql/pgsql_test.go deleted file mode 100644 index b79f0d98..00000000 --- a/database/pgsql/pgsql_test.go +++ /dev/null @@ -1,272 +0,0 @@ -// Copyright 2016 clair authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package pgsql - -import ( - "database/sql" - "fmt" - "io/ioutil" - "os" - "path/filepath" - "runtime" - "strings" - "testing" - - "github.com/pborman/uuid" - log "github.com/sirupsen/logrus" - "github.com/stretchr/testify/assert" - yaml "gopkg.in/yaml.v2" - - "github.com/coreos/clair/database" - "github.com/coreos/clair/pkg/pagination" -) - -var ( - withFixtureName, withoutFixtureName string -) - -var testPaginationKey = pagination.Must(pagination.NewKey()) - -func genTemplateDatabase(name string, loadFixture bool) (sourceURL string, dbName string) { - config := generateTestConfig(name, loadFixture, false) - source := config.Options["source"].(string) - name, url, err := parseConnectionString(source) - if err != nil { - panic(err) - } - - fixturePath := config.Options["fixturepath"].(string) - - if err := createDatabase(url, name); err != nil { - panic(err) - } - - // migration and fixture - db, err := sql.Open("postgres", source) - if err != nil { - panic(err) - } - - // Verify database state. - if err := db.Ping(); err != nil { - panic(err) - } - - // Run migrations. - if err := migrateDatabase(db); err != nil { - panic(err) - } - - if loadFixture { - d, err := ioutil.ReadFile(fixturePath) - if err != nil { - panic(err) - } - - _, err = db.Exec(string(d)) - if err != nil { - panic(err) - } - } - - db.Exec("UPDATE pg_database SET datistemplate=True WHERE datname=$1", name) - db.Close() - - log.Info("Generated Template database ", name) - return url, name -} - -func dropTemplateDatabase(url string, name string) { - db, err := sql.Open("postgres", url) - if err != nil { - panic(err) - } - - if _, err := db.Exec("UPDATE pg_database SET datistemplate=False WHERE datname=$1", name); err != nil { - panic(err) - } - - if err := db.Close(); err != nil { - panic(err) - } - - if err := dropDatabase(url, name); err != nil { - panic(err) - } - -} - -func TestMain(m *testing.M) { - fURL, fName := genTemplateDatabase("fixture", true) - nfURL, nfName := genTemplateDatabase("nonfixture", false) - - withFixtureName = fName - withoutFixtureName = nfName - - rc := m.Run() - - dropTemplateDatabase(fURL, fName) - dropTemplateDatabase(nfURL, nfName) - os.Exit(rc) -} - -func openCopiedDatabase(testConfig database.RegistrableComponentConfig, fixture bool) (database.Datastore, error) { - var fixtureName string - if fixture { - fixtureName = withFixtureName - } else { - fixtureName = withoutFixtureName - } - - // copy the database into new database - var pg pgSQL - // Parse configuration. - pg.config = Config{ - CacheSize: 16384, - } - - bytes, err := yaml.Marshal(testConfig.Options) - if err != nil { - return nil, fmt.Errorf("pgsql: could not load configuration: %v", err) - } - err = yaml.Unmarshal(bytes, &pg.config) - if err != nil { - return nil, fmt.Errorf("pgsql: could not load configuration: %v", err) - } - - dbName, pgSourceURL, err := parseConnectionString(pg.config.Source) - if err != nil { - return nil, err - } - - // Create database. - if pg.config.ManageDatabaseLifecycle { - if err = copyDatabase(pgSourceURL, dbName, fixtureName); err != nil { - return nil, err - } - } - - // Open database. - pg.DB, err = sql.Open("postgres", pg.config.Source) - fmt.Println("database", pg.config.Source) - if err != nil { - pg.Close() - return nil, fmt.Errorf("pgsql: could not open database: %v", err) - } - - return &pg, nil -} - -// copyDatabase creates a new database with -func copyDatabase(url, name string, templateName string) error { - // Open database. - db, err := sql.Open("postgres", url) - if err != nil { - return fmt.Errorf("pgsql: could not open 'postgres' database for creation: %v", err) - } - defer db.Close() - - // Create database with copy - _, err = db.Exec("CREATE DATABASE " + name + " WITH TEMPLATE " + templateName) - if err != nil { - return fmt.Errorf("pgsql: could not create database: %v", err) - } - - return nil -} - -func openDatabaseForTest(testName string, loadFixture bool) (*pgSQL, error) { - var ( - db database.Datastore - err error - testConfig = generateTestConfig(testName, loadFixture, true) - ) - - db, err = openCopiedDatabase(testConfig, loadFixture) - - if err != nil { - return nil, err - } - datastore := db.(*pgSQL) - return datastore, nil -} - -func generateTestConfig(testName string, loadFixture bool, manageLife bool) database.RegistrableComponentConfig { - dbName := "test_" + strings.ToLower(testName) + "_" + strings.Replace(uuid.New(), "-", "_", -1) - - var fixturePath string - if loadFixture { - _, filename, _, _ := runtime.Caller(0) - fixturePath = filepath.Join(filepath.Dir(filename)) + "/testdata/data.sql" - } - - source := fmt.Sprintf("postgresql://postgres@127.0.0.1:5432/%s?sslmode=disable", dbName) - if sourceEnv := os.Getenv("CLAIR_TEST_PGSQL"); sourceEnv != "" { - source = fmt.Sprintf(sourceEnv, dbName) - } - - return database.RegistrableComponentConfig{ - Options: map[string]interface{}{ - "source": source, - "cachesize": 0, - "managedatabaselifecycle": manageLife, - "fixturepath": fixturePath, - "paginationkey": testPaginationKey.String(), - }, - } -} - -func closeTest(t *testing.T, store database.Datastore, session database.Session) { - err := session.Rollback() - if err != nil { - t.Error(err) - t.FailNow() - } - - store.Close() -} - -func openSessionForTest(t *testing.T, name string, loadFixture bool) (*pgSQL, *pgSession) { - store, err := openDatabaseForTest(name, loadFixture) - if err != nil { - t.Error(err) - t.FailNow() - } - tx, err := store.Begin() - if err != nil { - t.Error(err) - t.FailNow() - } - - return store, tx.(*pgSession) -} - -func restartSession(t *testing.T, datastore *pgSQL, tx *pgSession, commit bool) *pgSession { - var err error - if !commit { - err = tx.Rollback() - } else { - err = tx.Commit() - } - - if assert.Nil(t, err) { - session, err := datastore.Begin() - if assert.Nil(t, err) { - return session.(*pgSession) - } - } - t.FailNow() - return nil -} diff --git a/database/pgsql/queries.go b/database/pgsql/queries.go deleted file mode 100644 index e19f466a..00000000 --- a/database/pgsql/queries.go +++ /dev/null @@ -1,187 +0,0 @@ -// Copyright 2015 clair authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package pgsql - -import ( - "fmt" - "strings" - - "github.com/lib/pq" -) - -// NOTE(Sida): Every search query can only have count less than postgres set -// stack depth. IN will be resolved to nested OR_s and the parser might exceed -// stack depth. TODO(Sida): Generate different queries for different count: if -// count < 5120, use IN; for count > 5120 and < 65536, use temporary table; for -// count > 65535, use is expected to split data into batches. -func querySearchLastDeletedVulnerabilityID(count int) string { - return fmt.Sprintf(` - SELECT vid, vname, nname FROM ( - SELECT v.id AS vid, v.name AS vname, n.name AS nname, - row_number() OVER ( - PARTITION by (v.name, n.name) - ORDER BY v.deleted_at DESC - ) AS rownum - FROM vulnerability AS v, namespace AS n - WHERE v.namespace_id = n.id - AND (v.name, n.name) IN ( %s ) - AND v.deleted_at IS NOT NULL - ) tmp WHERE rownum <= 1`, - queryString(2, count)) -} - -func querySearchNotDeletedVulnerabilityID(count int) string { - return fmt.Sprintf(` - SELECT v.id, v.name, n.name FROM vulnerability AS v, namespace AS n - WHERE v.namespace_id = n.id AND (v.name, n.name) IN (%s) - AND v.deleted_at IS NULL`, - queryString(2, count)) -} - -func querySearchFeatureID(featureCount int) string { - return fmt.Sprintf(` - SELECT id, name, version, version_format, type - FROM Feature WHERE (name, version, version_format, type) IN (%s)`, - queryString(4, featureCount), - ) -} - -func querySearchNamespacedFeature(nsfCount int) string { - return fmt.Sprintf(` - SELECT nf.id, f.name, f.version, f.version_format, t.name, n.name - FROM namespaced_feature AS nf, feature AS f, namespace AS n, feature_type AS t - WHERE nf.feature_id = f.id - AND nf.namespace_id = n.id - AND n.version_format = f.version_format - AND f.type = t.id - AND (f.name, f.version, f.version_format, t.name, n.name) IN (%s)`, - queryString(5, nsfCount), - ) -} - -func querySearchNamespace(nsCount int) string { - return fmt.Sprintf( - `SELECT id, name, version_format - FROM namespace WHERE (name, version_format) IN (%s)`, - queryString(2, nsCount), - ) -} - -func queryInsert(count int, table string, columns ...string) string { - base := `INSERT INTO %s (%s) VALUES %s` - t := pq.QuoteIdentifier(table) - cols := make([]string, len(columns)) - for i, c := range columns { - cols[i] = pq.QuoteIdentifier(c) - } - colsQuoted := strings.Join(cols, ",") - return fmt.Sprintf(base, t, colsQuoted, queryString(len(columns), count)) -} - -func queryPersist(count int, table, constraint string, columns ...string) string { - ct := "" - if constraint != "" { - ct = fmt.Sprintf("ON CONSTRAINT %s", constraint) - } - return fmt.Sprintf("%s ON CONFLICT %s DO NOTHING", queryInsert(count, table, columns...), ct) -} - -func queryInsertNotifications(count int) string { - return queryInsert(count, - "vulnerability_notification", - "name", - "created_at", - "old_vulnerability_id", - "new_vulnerability_id", - ) -} - -func queryPersistFeature(count int) string { - return queryPersist(count, - "feature", - "feature_name_version_version_format_type_key", - "name", - "version", - "version_format", - "type") -} - -func queryPersistLayerFeature(count int) string { - return queryPersist(count, - "layer_feature", - "layer_feature_layer_id_feature_id_namespace_id_key", - "layer_id", - "feature_id", - "detector_id", - "namespace_id") -} - -func queryPersistNamespace(count int) string { - return queryPersist(count, - "namespace", - "namespace_name_version_format_key", - "name", - "version_format") -} - -func queryPersistLayerNamespace(count int) string { - return queryPersist(count, - "layer_namespace", - "layer_namespace_layer_id_namespace_id_key", - "layer_id", - "namespace_id", - "detector_id") -} - -// size of key and array should be both greater than 0 -func queryString(keySize, arraySize int) string { - if arraySize <= 0 || keySize <= 0 { - panic("Bulk Query requires size of element tuple and number of elements to be greater than 0") - } - keys := make([]string, 0, arraySize) - for i := 0; i < arraySize; i++ { - key := make([]string, keySize) - for j := 0; j < keySize; j++ { - key[j] = fmt.Sprintf("$%d", i*keySize+j+1) - } - keys = append(keys, fmt.Sprintf("(%s)", strings.Join(key, ","))) - } - return strings.Join(keys, ",") -} - -func queryPersistNamespacedFeature(count int) string { - return queryPersist(count, "namespaced_feature", - "namespaced_feature_namespace_id_feature_id_key", - "feature_id", - "namespace_id") -} - -func queryPersistVulnerabilityAffectedNamespacedFeature(count int) string { - return queryPersist(count, "vulnerability_affected_namespaced_feature", - "vulnerability_affected_namesp_vulnerability_id_namespaced_f_key", - "vulnerability_id", - "namespaced_feature_id", - "added_by") -} - -func queryPersistLayer(count int) string { - return queryPersist(count, "layer", "", "hash") -} - -func queryInvalidateVulnerabilityCache(count int) string { - return fmt.Sprintf(`DELETE FROM vulnerability_affected_feature - WHERE vulnerability_id IN (%s)`, - queryString(1, count)) -} diff --git a/database/pgsql/testutil.go b/database/pgsql/testutil.go deleted file mode 100644 index b1bfabf7..00000000 --- a/database/pgsql/testutil.go +++ /dev/null @@ -1,424 +0,0 @@ -// Copyright 2018 clair authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package pgsql - -import ( - "database/sql" - "fmt" - "io/ioutil" - "math/rand" - "os" - "path/filepath" - "runtime" - "strings" - "testing" - "time" - - "github.com/remind101/migrate" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/coreos/clair/database" - "github.com/coreos/clair/database/pgsql/migrations" - "github.com/coreos/clair/pkg/pagination" -) - -// int keys must be the consistent with the database ID. -var ( - realFeatures = map[int]database.Feature{ - 1: {"ourchat", "0.5", "dpkg", "source"}, - 2: {"openssl", "1.0", "dpkg", "source"}, - 3: {"openssl", "2.0", "dpkg", "source"}, - 4: {"fake", "2.0", "rpm", "source"}, - 5: {"mount", "2.31.1-0.4ubuntu3.1", "dpkg", "binary"}, - } - - realNamespaces = map[int]database.Namespace{ - 1: {"debian:7", "dpkg"}, - 2: {"debian:8", "dpkg"}, - 3: {"fake:1.0", "rpm"}, - 4: {"cpe:/o:redhat:enterprise_linux:7::server", "rpm"}, - } - - realNamespacedFeatures = map[int]database.NamespacedFeature{ - 1: {realFeatures[1], realNamespaces[1]}, - 2: {realFeatures[2], realNamespaces[1]}, - 3: {realFeatures[2], realNamespaces[2]}, - 4: {realFeatures[3], realNamespaces[1]}, - } - - realDetectors = map[int]database.Detector{ - 1: database.NewNamespaceDetector("os-release", "1.0"), - 2: database.NewFeatureDetector("dpkg", "1.0"), - 3: database.NewFeatureDetector("rpm", "1.0"), - 4: database.NewNamespaceDetector("apt-sources", "1.0"), - } - - realLayers = map[int]database.Layer{ - 2: { - Hash: "layer-1", - By: []database.Detector{realDetectors[1], realDetectors[2]}, - Features: []database.LayerFeature{ - {realFeatures[1], realDetectors[2], database.Namespace{}}, - {realFeatures[2], realDetectors[2], database.Namespace{}}, - }, - Namespaces: []database.LayerNamespace{ - {realNamespaces[1], realDetectors[1]}, - }, - }, - 6: { - Hash: "layer-4", - By: []database.Detector{realDetectors[1], realDetectors[2], realDetectors[3], realDetectors[4]}, - Features: []database.LayerFeature{ - {realFeatures[4], realDetectors[3], database.Namespace{}}, - {realFeatures[3], realDetectors[2], database.Namespace{}}, - }, - Namespaces: []database.LayerNamespace{ - {realNamespaces[1], realDetectors[1]}, - {realNamespaces[3], realDetectors[4]}, - }, - }, - } - - realAncestries = map[int]database.Ancestry{ - 2: { - Name: "ancestry-2", - By: []database.Detector{realDetectors[2], realDetectors[1]}, - Layers: []database.AncestryLayer{ - { - "layer-0", - []database.AncestryFeature{}, - }, - { - "layer-1", - []database.AncestryFeature{}, - }, - { - "layer-2", - []database.AncestryFeature{ - { - realNamespacedFeatures[1], - realDetectors[2], - realDetectors[1], - }, - }, - }, - { - "layer-3b", - []database.AncestryFeature{ - { - realNamespacedFeatures[3], - realDetectors[2], - realDetectors[1], - }, - }, - }, - }, - }, - } - - realVulnerability = map[int]database.Vulnerability{ - 1: { - Name: "CVE-OPENSSL-1-DEB7", - Namespace: realNamespaces[1], - Description: "A vulnerability affecting OpenSSL < 2.0 on Debian 7.0", - Link: "http://google.com/#q=CVE-OPENSSL-1-DEB7", - Severity: database.HighSeverity, - }, - 2: { - Name: "CVE-NOPE", - Namespace: realNamespaces[1], - Description: "A vulnerability affecting nothing", - Severity: database.UnknownSeverity, - }, - } - - realNotification = map[int]database.VulnerabilityNotification{ - 1: { - NotificationHook: database.NotificationHook{ - Name: "test", - }, - Old: takeVulnerabilityPointerFromMap(realVulnerability, 2), - New: takeVulnerabilityPointerFromMap(realVulnerability, 1), - }, - } - - fakeFeatures = map[int]database.Feature{ - 1: { - Name: "ourchat", - Version: "0.6", - VersionFormat: "dpkg", - Type: "source", - }, - } - - fakeNamespaces = map[int]database.Namespace{ - 1: {"green hat", "rpm"}, - } - fakeNamespacedFeatures = map[int]database.NamespacedFeature{ - 1: { - Feature: fakeFeatures[0], - Namespace: realNamespaces[0], - }, - } - - fakeDetector = map[int]database.Detector{ - 1: { - Name: "fake", - Version: "1.0", - DType: database.FeatureDetectorType, - }, - 2: { - Name: "fake2", - Version: "2.0", - DType: database.NamespaceDetectorType, - }, - } -) - -func takeVulnerabilityPointerFromMap(m map[int]database.Vulnerability, id int) *database.Vulnerability { - x := m[id] - return &x -} - -func takeAncestryPointerFromMap(m map[int]database.Ancestry, id int) *database.Ancestry { - x := m[id] - return &x -} - -func takeLayerPointerFromMap(m map[int]database.Layer, id int) *database.Layer { - x := m[id] - return &x -} - -func listNamespaces(t *testing.T, tx *pgSession) []database.Namespace { - rows, err := tx.Query("SELECT name, version_format FROM namespace") - if err != nil { - t.FailNow() - } - defer rows.Close() - - namespaces := []database.Namespace{} - for rows.Next() { - var ns database.Namespace - err := rows.Scan(&ns.Name, &ns.VersionFormat) - if err != nil { - t.FailNow() - } - namespaces = append(namespaces, ns) - } - - return namespaces -} - -func assertVulnerabilityNotificationWithVulnerableEqual(t *testing.T, key pagination.Key, expected, actual *database.VulnerabilityNotificationWithVulnerable) bool { - if expected == actual { - return true - } - - if expected == nil || actual == nil { - return assert.Equal(t, expected, actual) - } - - return assert.Equal(t, expected.NotificationHook, actual.NotificationHook) && - AssertPagedVulnerableAncestriesEqual(t, key, expected.Old, actual.Old) && - AssertPagedVulnerableAncestriesEqual(t, key, expected.New, actual.New) -} - -func AssertPagedVulnerableAncestriesEqual(t *testing.T, key pagination.Key, expected, actual *database.PagedVulnerableAncestries) bool { - if expected == actual { - return true - } - - if expected == nil || actual == nil { - return assert.Equal(t, expected, actual) - } - - return database.AssertVulnerabilityEqual(t, &expected.Vulnerability, &actual.Vulnerability) && - assert.Equal(t, expected.Limit, actual.Limit) && - assert.Equal(t, mustUnmarshalToken(key, expected.Current), mustUnmarshalToken(key, actual.Current)) && - assert.Equal(t, mustUnmarshalToken(key, expected.Next), mustUnmarshalToken(key, actual.Next)) && - assert.Equal(t, expected.End, actual.End) && - database.AssertIntStringMapEqual(t, expected.Affected, actual.Affected) -} - -func mustUnmarshalToken(key pagination.Key, token pagination.Token) Page { - if token == pagination.FirstPageToken { - return Page{} - } - - p := Page{} - if err := key.UnmarshalToken(token, &p); err != nil { - panic(err) - } - - return p -} - -func mustMarshalToken(key pagination.Key, v interface{}) pagination.Token { - token, err := key.MarshalToken(v) - if err != nil { - panic(err) - } - - return token -} - -var userDBCount = `SELECT count(datname) FROM pg_database WHERE datistemplate = FALSE AND datname != 'postgres';` - -func createAndConnectTestDB(t *testing.T, testName string) (*sql.DB, func()) { - uri := "postgres@127.0.0.1:5432" - connectionTemplate := "postgresql://%s?sslmode=disable" - if envURI := os.Getenv("CLAIR_TEST_PGSQL"); envURI != "" { - uri = envURI - } - - db, err := sql.Open("postgres", fmt.Sprintf(connectionTemplate, uri)) - if err != nil { - panic(err) - } - - testName = strings.ToLower(testName) - dbName := fmt.Sprintf("test_%s_%s", testName, time.Now().UTC().Format("2006_01_02_15_04_05")) - t.Logf("creating temporary database name = %s", dbName) - _, err = db.Exec("CREATE DATABASE " + dbName) - if err != nil { - panic(err) - } - - testDB, err := sql.Open("postgres", fmt.Sprintf(connectionTemplate, uri+"/"+dbName)) - if err != nil { - panic(err) - } - - return testDB, func() { - t.Logf("cleaning up temporary database %s", dbName) - defer db.Close() - if err := testDB.Close(); err != nil { - panic(err) - } - - if _, err := db.Exec(`DROP DATABASE ` + dbName); err != nil { - panic(err) - } - - // ensure the database is cleaned up - var count int - if err := db.QueryRow(userDBCount).Scan(&count); err != nil { - panic(err) - } - } -} - -func createTestPgSQL(t *testing.T, testName string) (*pgSQL, func()) { - connection, cleanup := createAndConnectTestDB(t, testName) - err := migrate.NewPostgresMigrator(connection).Exec(migrate.Up, migrations.Migrations...) - if err != nil { - require.Nil(t, err, err.Error()) - } - - return &pgSQL{connection, nil, Config{PaginationKey: pagination.Must(pagination.NewKey()).String()}}, cleanup -} - -func createTestPgSQLWithFixtures(t *testing.T, testName string) (*pgSQL, func()) { - connection, cleanup := createTestPgSQL(t, testName) - session, err := connection.Begin() - if err != nil { - panic(err) - } - - defer session.Rollback() - - loadFixtures(session.(*pgSession)) - if err = session.Commit(); err != nil { - panic(err) - } - - return connection, cleanup -} - -func createTestPgSession(t *testing.T, testName string) (*pgSession, func()) { - connection, cleanup := createTestPgSQL(t, testName) - session, err := connection.Begin() - if err != nil { - panic(err) - } - - return session.(*pgSession), func() { - session.Rollback() - cleanup() - } -} - -func createTestPgSessionWithFixtures(t *testing.T, testName string) (*pgSession, func()) { - tx, cleanup := createTestPgSession(t, testName) - defer func() { - // ensure to cleanup when loadFixtures failed - if r := recover(); r != nil { - cleanup() - } - }() - - loadFixtures(tx) - return tx, cleanup -} - -func loadFixtures(tx *pgSession) { - _, filename, _, _ := runtime.Caller(0) - fixturePath := filepath.Join(filepath.Dir(filename)) + "/testdata/data.sql" - d, err := ioutil.ReadFile(fixturePath) - if err != nil { - panic(err) - } - - _, err = tx.Exec(string(d)) - if err != nil { - panic(err) - } -} - -func assertVulnerabilityWithAffectedEqual(t *testing.T, expected database.VulnerabilityWithAffected, actual database.VulnerabilityWithAffected) bool { - return assert.Equal(t, expected.Vulnerability, actual.Vulnerability) && assertAffectedFeaturesEqual(t, expected.Affected, actual.Affected) -} - -func assertAffectedFeaturesEqual(t *testing.T, expected []database.AffectedFeature, actual []database.AffectedFeature) bool { - if assert.Len(t, actual, len(expected)) { - has := map[database.AffectedFeature]bool{} - for _, i := range expected { - has[i] = false - } - for _, i := range actual { - if visited, ok := has[i]; !ok { - return false - } else if visited { - return false - } - has[i] = true - } - return true - } - return false -} - -func genRandomNamespaces(t *testing.T, count int) []database.Namespace { - r := make([]database.Namespace, count) - for i := 0; i < count; i++ { - r[i] = database.Namespace{ - Name: fmt.Sprint(rand.Int()), - VersionFormat: "dpkg", - } - } - return r -} diff --git a/database/pgsql/testutil/assertion.go b/database/pgsql/testutil/assertion.go new file mode 100644 index 00000000..b823c773 --- /dev/null +++ b/database/pgsql/testutil/assertion.go @@ -0,0 +1,98 @@ +// Copyright 2019 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testutil + +import ( + "testing" + + "github.com/coreos/clair/database" + "github.com/coreos/clair/pkg/pagination" + "github.com/stretchr/testify/assert" +) + +func AssertVulnerabilityNotificationWithVulnerableEqual(t *testing.T, key pagination.Key, expected, actual *database.VulnerabilityNotificationWithVulnerable) bool { + if expected == actual { + return true + } + + if expected == nil || actual == nil { + return assert.Equal(t, expected, actual) + } + + return assert.Equal(t, expected.NotificationHook, actual.NotificationHook) && + AssertPagedVulnerableAncestriesEqual(t, key, expected.Old, actual.Old) && + AssertPagedVulnerableAncestriesEqual(t, key, expected.New, actual.New) +} + +func AssertPagedVulnerableAncestriesEqual(t *testing.T, key pagination.Key, expected, actual *database.PagedVulnerableAncestries) bool { + if expected == actual { + return true + } + + if expected == nil || actual == nil { + return assert.Equal(t, expected, actual) + } + + return database.AssertVulnerabilityEqual(t, &expected.Vulnerability, &actual.Vulnerability) && + assert.Equal(t, expected.Limit, actual.Limit) && + assert.Equal(t, mustUnmarshalToken(key, expected.Current), mustUnmarshalToken(key, actual.Current)) && + assert.Equal(t, mustUnmarshalToken(key, expected.Next), mustUnmarshalToken(key, actual.Next)) && + assert.Equal(t, expected.End, actual.End) && + database.AssertIntStringMapEqual(t, expected.Affected, actual.Affected) +} + +func AssertVulnerabilityWithAffectedEqual(t *testing.T, expected database.VulnerabilityWithAffected, actual database.VulnerabilityWithAffected) bool { + return assert.Equal(t, expected.Vulnerability, actual.Vulnerability) && AssertAffectedFeaturesEqual(t, expected.Affected, actual.Affected) +} + +func AssertAffectedFeaturesEqual(t *testing.T, expected []database.AffectedFeature, actual []database.AffectedFeature) bool { + if assert.Len(t, actual, len(expected)) { + has := map[database.AffectedFeature]bool{} + for _, i := range expected { + has[i] = false + } + for _, i := range actual { + if visited, ok := has[i]; !ok { + return false + } else if visited { + return false + } + has[i] = true + } + return true + } + return false +} + +func AssertNamespacedFeatureEqual(t *testing.T, expected []database.NamespacedFeature, actual []database.NamespacedFeature) bool { + if assert.Len(t, actual, len(expected)) { + has := map[database.NamespacedFeature]bool{} + for _, nf := range expected { + has[nf] = false + } + + for _, nf := range actual { + has[nf] = true + } + + for nf, visited := range has { + if !assert.True(t, visited, nf.Namespace.Name+":"+nf.Name+" is expected") { + return false + } + } + return true + } + return false +} diff --git a/database/pgsql/testutil/data.go b/database/pgsql/testutil/data.go new file mode 100644 index 00000000..f6d9fe16 --- /dev/null +++ b/database/pgsql/testutil/data.go @@ -0,0 +1,171 @@ +// Copyright 2019 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testutil + +import "github.com/coreos/clair/database" + +// int keys must be the consistent with the database ID. +var ( + RealFeatures = map[int]database.Feature{ + 1: {"ourchat", "0.5", "dpkg", "source"}, + 2: {"openssl", "1.0", "dpkg", "source"}, + 3: {"openssl", "2.0", "dpkg", "source"}, + 4: {"fake", "2.0", "rpm", "source"}, + 5: {"mount", "2.31.1-0.4ubuntu3.1", "dpkg", "binary"}, + } + + RealNamespaces = map[int]database.Namespace{ + 1: {"debian:7", "dpkg"}, + 2: {"debian:8", "dpkg"}, + 3: {"fake:1.0", "rpm"}, + 4: {"cpe:/o:redhat:enterprise_linux:7::server", "rpm"}, + } + + RealNamespacedFeatures = map[int]database.NamespacedFeature{ + 1: {RealFeatures[1], RealNamespaces[1]}, + 2: {RealFeatures[2], RealNamespaces[1]}, + 3: {RealFeatures[2], RealNamespaces[2]}, + 4: {RealFeatures[3], RealNamespaces[1]}, + } + + RealDetectors = map[int]database.Detector{ + 1: database.NewNamespaceDetector("os-release", "1.0"), + 2: database.NewFeatureDetector("dpkg", "1.0"), + 3: database.NewFeatureDetector("rpm", "1.0"), + 4: database.NewNamespaceDetector("apt-sources", "1.0"), + } + + RealLayers = map[int]database.Layer{ + 2: { + Hash: "layer-1", + By: []database.Detector{RealDetectors[1], RealDetectors[2]}, + Features: []database.LayerFeature{ + {RealFeatures[1], RealDetectors[2], database.Namespace{}}, + {RealFeatures[2], RealDetectors[2], database.Namespace{}}, + }, + Namespaces: []database.LayerNamespace{ + {RealNamespaces[1], RealDetectors[1]}, + }, + }, + 6: { + Hash: "layer-4", + By: []database.Detector{RealDetectors[1], RealDetectors[2], RealDetectors[3], RealDetectors[4]}, + Features: []database.LayerFeature{ + {RealFeatures[4], RealDetectors[3], database.Namespace{}}, + {RealFeatures[3], RealDetectors[2], database.Namespace{}}, + }, + Namespaces: []database.LayerNamespace{ + {RealNamespaces[1], RealDetectors[1]}, + {RealNamespaces[3], RealDetectors[4]}, + }, + }, + } + + RealAncestries = map[int]database.Ancestry{ + 2: { + Name: "ancestry-2", + By: []database.Detector{RealDetectors[2], RealDetectors[1]}, + Layers: []database.AncestryLayer{ + { + "layer-0", + []database.AncestryFeature{}, + }, + { + "layer-1", + []database.AncestryFeature{}, + }, + { + "layer-2", + []database.AncestryFeature{ + { + RealNamespacedFeatures[1], + RealDetectors[2], + RealDetectors[1], + }, + }, + }, + { + "layer-3b", + []database.AncestryFeature{ + { + RealNamespacedFeatures[3], + RealDetectors[2], + RealDetectors[1], + }, + }, + }, + }, + }, + } + + RealVulnerability = map[int]database.Vulnerability{ + 1: { + Name: "CVE-OPENSSL-1-DEB7", + Namespace: RealNamespaces[1], + Description: "A vulnerability affecting OpenSSL < 2.0 on Debian 7.0", + Link: "http://google.com/#q=CVE-OPENSSL-1-DEB7", + Severity: database.HighSeverity, + }, + 2: { + Name: "CVE-NOPE", + Namespace: RealNamespaces[1], + Description: "A vulnerability affecting nothing", + Severity: database.UnknownSeverity, + }, + } + + RealNotification = map[int]database.VulnerabilityNotification{ + 1: { + NotificationHook: database.NotificationHook{ + Name: "test", + }, + Old: takeVulnerabilityPointerFromMap(RealVulnerability, 2), + New: takeVulnerabilityPointerFromMap(RealVulnerability, 1), + }, + } + + FakeFeatures = map[int]database.Feature{ + 1: { + Name: "ourchat", + Version: "0.6", + VersionFormat: "dpkg", + Type: "source", + }, + } + + FakeNamespaces = map[int]database.Namespace{ + 1: {"green hat", "rpm"}, + } + + FakeNamespacedFeatures = map[int]database.NamespacedFeature{ + 1: { + Feature: FakeFeatures[0], + Namespace: RealNamespaces[0], + }, + } + + FakeDetector = map[int]database.Detector{ + 1: { + Name: "fake", + Version: "1.0", + DType: database.FeatureDetectorType, + }, + 2: { + Name: "fake2", + Version: "2.0", + DType: database.NamespaceDetectorType, + }, + } +) diff --git a/database/pgsql/testdata/data.sql b/database/pgsql/testutil/data.sql similarity index 100% rename from database/pgsql/testdata/data.sql rename to database/pgsql/testutil/data.sql diff --git a/database/pgsql/testutil/testdb.go b/database/pgsql/testutil/testdb.go new file mode 100644 index 00000000..558a3dc3 --- /dev/null +++ b/database/pgsql/testutil/testdb.go @@ -0,0 +1,207 @@ +// Copyright 2019 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testutil + +import ( + "database/sql" + "fmt" + "io/ioutil" + "os" + "path/filepath" + "runtime" + "strings" + "testing" + "time" + + "github.com/coreos/clair/database/pgsql/migrations" + "github.com/coreos/clair/pkg/pagination" + "github.com/remind101/migrate" +) + +var TestPaginationKey = pagination.Must(pagination.NewKey()) + +var userDBCount = `SELECT count(datname) FROM pg_database WHERE datistemplate = FALSE AND datname != 'postgres';` + +func CreateAndConnectTestDB(t *testing.T, testName string) (*sql.DB, func()) { + uri := "postgres@127.0.0.1:5432" + connectionTemplate := "postgresql://%s?sslmode=disable" + if envURI := os.Getenv("CLAIR_TEST_PGSQL"); envURI != "" { + uri = envURI + } + + db, err := sql.Open("postgres", fmt.Sprintf(connectionTemplate, uri)) + if err != nil { + panic(err) + } + + testName = strings.ToLower(testName) + dbName := fmt.Sprintf("test_%s_%s", testName, time.Now().UTC().Format("2006_01_02_15_04_05")) + t.Logf("creating temporary database name = %s", dbName) + _, err = db.Exec("CREATE DATABASE " + dbName) + if err != nil { + panic(err) + } + + testDB, err := sql.Open("postgres", fmt.Sprintf(connectionTemplate, uri+"/"+dbName)) + if err != nil { + panic(err) + } + + return testDB, func() { + cleanupTestDB(t, dbName, db, testDB) + } +} + +func cleanupTestDB(t *testing.T, name string, db, testDB *sql.DB) { + t.Logf("cleaning up temporary database %s", name) + if db == nil { + panic("db is none") + } + + if testDB == nil { + panic("testDB is none") + } + + defer db.Close() + + if err := testDB.Close(); err != nil { + panic(err) + } + + // Kill any opened connection. + if _, err := db.Exec(` + SELECT pg_terminate_backend(pg_stat_activity.pid) + FROM pg_stat_activity + WHERE pg_stat_activity.datname = $1 + AND pid <> pg_backend_pid()`, name); err != nil { + panic(err) + } + + if _, err := db.Exec(`DROP DATABASE ` + name); err != nil { + panic(err) + } + + // ensure the database is cleaned up + var count int + if err := db.QueryRow(userDBCount).Scan(&count); err != nil { + panic(err) + } +} + +func CreateTestDB(t *testing.T, testName string) (*sql.DB, func()) { + connection, cleanup := CreateAndConnectTestDB(t, testName) + err := migrate.NewPostgresMigrator(connection).Exec(migrate.Up, migrations.Migrations...) + if err != nil { + panic(err) + } + + return connection, cleanup +} + +func CreateTestDBWithFixture(t *testing.T, testName string) (*sql.DB, func()) { + connection, cleanup := CreateTestDB(t, testName) + session, err := connection.Begin() + if err != nil { + panic(err) + } + + defer session.Rollback() + + loadFixtures(session) + if err = session.Commit(); err != nil { + panic(err) + } + + return connection, cleanup +} + +func CreateTestTx(t *testing.T, testName string) (*sql.Tx, func()) { + connection, cleanup := CreateTestDB(t, testName) + session, err := connection.Begin() + if session == nil { + panic("session is none") + } + + if err != nil { + panic(err) + } + + return session, func() { + session.Rollback() + cleanup() + } +} + +func CreateTestTxWithFixtures(t *testing.T, testName string) (*sql.Tx, func()) { + tx, cleanup := CreateTestTx(t, testName) + defer func() { + // ensure to cleanup when loadFixtures failed + if r := recover(); r != nil { + cleanup() + } + }() + + loadFixtures(tx) + return tx, cleanup +} + +func loadFixtures(tx *sql.Tx) { + _, filename, _, _ := runtime.Caller(0) + fixturePath := filepath.Join(filepath.Dir(filename), "data.sql") + d, err := ioutil.ReadFile(fixturePath) + if err != nil { + panic(err) + } + + _, err = tx.Exec(string(d)) + if err != nil { + panic(err) + } +} + +func OpenSessionForTest(t *testing.T, name string, loadFixture bool) (*sql.DB, *sql.Tx) { + var db *sql.DB + if loadFixture { + db, _ = CreateTestDB(t, name) + } else { + db, _ = CreateTestDBWithFixture(t, name) + } + + tx, err := db.Begin() + if err != nil { + panic(err) + } + + return db, tx +} + +func RestartTransaction(db *sql.DB, tx *sql.Tx, commit bool) *sql.Tx { + if !commit { + if err := tx.Rollback(); err != nil { + panic(err) + } + } else { + if err := tx.Commit(); err != nil { + panic(err) + } + } + + tx, err := db.Begin() + if err != nil { + panic(err) + } + + return tx +} diff --git a/database/pgsql/testutil/testutil.go b/database/pgsql/testutil/testutil.go new file mode 100644 index 00000000..03c81dd7 --- /dev/null +++ b/database/pgsql/testutil/testutil.go @@ -0,0 +1,94 @@ +// Copyright 2018 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testutil + +import ( + "database/sql" + "fmt" + "math/rand" + "testing" + + "github.com/coreos/clair/database" + "github.com/coreos/clair/database/pgsql/page" + "github.com/coreos/clair/pkg/pagination" +) + +func takeVulnerabilityPointerFromMap(m map[int]database.Vulnerability, id int) *database.Vulnerability { + x := m[id] + return &x +} + +func TakeAncestryPointerFromMap(m map[int]database.Ancestry, id int) *database.Ancestry { + x := m[id] + return &x +} + +func TakeLayerPointerFromMap(m map[int]database.Layer, id int) *database.Layer { + x := m[id] + return &x +} + +func ListNamespaces(t *testing.T, tx *sql.Tx) []database.Namespace { + rows, err := tx.Query("SELECT name, version_format FROM namespace") + if err != nil { + t.FailNow() + } + defer rows.Close() + + namespaces := []database.Namespace{} + for rows.Next() { + var ns database.Namespace + err := rows.Scan(&ns.Name, &ns.VersionFormat) + if err != nil { + t.FailNow() + } + namespaces = append(namespaces, ns) + } + + return namespaces +} + +func mustUnmarshalToken(key pagination.Key, token pagination.Token) page.Page { + if token == pagination.FirstPageToken { + return page.Page{} + } + + p := page.Page{} + if err := key.UnmarshalToken(token, &p); err != nil { + panic(err) + } + + return p +} + +func MustMarshalToken(key pagination.Key, v interface{}) pagination.Token { + token, err := key.MarshalToken(v) + if err != nil { + panic(err) + } + + return token +} + +func GenRandomNamespaces(t *testing.T, count int) []database.Namespace { + r := make([]database.Namespace, count) + for i := 0; i < count; i++ { + r[i] = database.Namespace{ + Name: fmt.Sprint(rand.Int()), + VersionFormat: "dpkg", + } + } + return r +} diff --git a/database/pgsql/util/error.go b/database/pgsql/util/error.go new file mode 100644 index 00000000..de724c93 --- /dev/null +++ b/database/pgsql/util/error.go @@ -0,0 +1,63 @@ +// Copyright 2019 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package util + +import ( + "database/sql" + "strings" + + "github.com/coreos/clair/database" + "github.com/coreos/clair/database/pgsql/monitoring" + "github.com/coreos/clair/pkg/commonerr" + "github.com/lib/pq" + "github.com/sirupsen/logrus" +) + +// IsErrUniqueViolation determines is the given error is a unique contraint violation. +func IsErrUniqueViolation(err error) bool { + pqErr, ok := err.(*pq.Error) + return ok && pqErr.Code == "23505" +} + +// HandleError logs an error with an extra description and masks the error if it's an SQL one. +// The function ensures we never return plain SQL errors and leak anything. +// The function should be used for every database query error. +func HandleError(desc string, err error) error { + if err == nil { + return nil + } + + if err == sql.ErrNoRows { + return commonerr.ErrNotFound + } + + if pqErr, ok := err.(*pq.Error); ok { + if pqErr.Fatal() { + panic(pqErr) + } + + if pqErr.Code == "42601" { + panic("invalid query: " + desc + ", info: " + err.Error()) + } + } + + logrus.WithError(err).WithField("Description", desc).Error("database: handled database error") + monitoring.PromErrorsTotal.WithLabelValues(desc).Inc() + if _, o := err.(*pq.Error); o || err == sql.ErrTxDone || strings.HasPrefix(err.Error(), "sql:") { + return database.ErrBackendException + } + + return err +} diff --git a/database/pgsql/util/queries.go b/database/pgsql/util/queries.go new file mode 100644 index 00000000..271845b0 --- /dev/null +++ b/database/pgsql/util/queries.go @@ -0,0 +1,57 @@ +// Copyright 2015 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package util + +import ( + "fmt" + "strings" + + "github.com/lib/pq" +) + +func QueryInsert(count int, table string, columns ...string) string { + base := `INSERT INTO %s (%s) VALUES %s` + t := pq.QuoteIdentifier(table) + cols := make([]string, len(columns)) + for i, c := range columns { + cols[i] = pq.QuoteIdentifier(c) + } + colsQuoted := strings.Join(cols, ",") + return fmt.Sprintf(base, t, colsQuoted, QueryString(len(columns), count)) +} + +func QueryPersist(count int, table, constraint string, columns ...string) string { + ct := "" + if constraint != "" { + ct = fmt.Sprintf("ON CONSTRAINT %s", constraint) + } + return fmt.Sprintf("%s ON CONFLICT %s DO NOTHING", QueryInsert(count, table, columns...), ct) +} + +// size of key and array should be both greater than 0 +func QueryString(keySize, arraySize int) string { + if arraySize <= 0 || keySize <= 0 { + panic("Bulk Query requires size of element tuple and number of elements to be greater than 0") + } + keys := make([]string, 0, arraySize) + for i := 0; i < arraySize; i++ { + key := make([]string, keySize) + for j := 0; j < keySize; j++ { + key[j] = fmt.Sprintf("$%d", i*keySize+j+1) + } + keys = append(keys, fmt.Sprintf("(%s)", strings.Join(key, ","))) + } + return strings.Join(keys, ",") +} diff --git a/database/pgsql/vulnerability.go b/database/pgsql/vulnerability/vulnerability.go similarity index 55% rename from database/pgsql/vulnerability.go rename to database/pgsql/vulnerability/vulnerability.go index e96d6d47..4245299a 100644 --- a/database/pgsql/vulnerability.go +++ b/database/pgsql/vulnerability/vulnerability.go @@ -12,23 +12,27 @@ // See the License for the specific language governing permissions and // limitations under the License. -package pgsql +package vulnerability import ( "database/sql" "errors" + "fmt" "time" "github.com/lib/pq" log "github.com/sirupsen/logrus" "github.com/coreos/clair/database" + "github.com/coreos/clair/database/pgsql/feature" + "github.com/coreos/clair/database/pgsql/monitoring" + "github.com/coreos/clair/database/pgsql/page" + "github.com/coreos/clair/database/pgsql/util" "github.com/coreos/clair/ext/versionfmt" + "github.com/coreos/clair/pkg/pagination" ) const ( - lockVulnerabilityAffects = `LOCK vulnerability_affected_namespaced_feature IN SHARE ROW EXCLUSIVE MODE` - searchVulnerability = ` SELECT v.id, v.description, v.link, v.severity, v.metadata, n.version_format FROM vulnerability AS v, namespace AS n @@ -38,45 +42,12 @@ const ( AND v.deleted_at IS NULL ` - insertVulnerabilityAffected = ` - INSERT INTO vulnerability_affected_feature(vulnerability_id, feature_name, affected_version, feature_type, fixedin) - VALUES ($1, $2, $3, $4, $5) - RETURNING ID - ` - - searchVulnerabilityAffected = ` - SELECT vulnerability_id, feature_name, affected_version, t.name, fixedin - FROM vulnerability_affected_feature AS vaf, feature_type AS t - WHERE t.id = vaf.feature_type AND vulnerability_id = ANY($1) - ` - searchVulnerabilityByID = ` SELECT v.name, v.description, v.link, v.severity, v.metadata, n.name, n.version_format FROM vulnerability AS v, namespace AS n WHERE v.namespace_id = n.id AND v.id = $1` - searchVulnerabilityPotentialAffected = ` - WITH req AS ( - SELECT vaf.id AS vaf_id, n.id AS n_id, vaf.feature_name AS name, vaf.feature_type AS type, v.id AS vulnerability_id - FROM vulnerability_affected_feature AS vaf, - vulnerability AS v, - namespace AS n - WHERE vaf.vulnerability_id = ANY($1) - AND v.id = vaf.vulnerability_id - AND n.id = v.namespace_id - ) - SELECT req.vulnerability_id, nf.id, f.version, req.vaf_id AS added_by - FROM feature AS f, namespaced_feature AS nf, req - WHERE f.name = req.name - AND f.type = req.type - AND nf.namespace_id = req.n_id - AND nf.feature_id = f.id` - - insertVulnerabilityAffectedNamespacedFeature = ` - INSERT INTO vulnerability_affected_namespaced_feature(vulnerability_id, namespaced_feature_id, added_by) - VALUES ($1, $2, $3)` - insertVulnerability = ` WITH ns AS ( SELECT id FROM namespace WHERE name = $6 AND version_format = $7 @@ -92,12 +63,56 @@ const ( AND name = $2 AND deleted_at IS NULL RETURNING id` -) -var ( - errVulnerabilityNotFound = errors.New("vulnerability is not in database") + searchNotificationVulnerableAncestry = ` + SELECT DISTINCT ON (a.id) + a.id, a.name + FROM vulnerability_affected_namespaced_feature AS vanf, + ancestry_layer AS al, ancestry_feature AS af, ancestry AS a + WHERE vanf.vulnerability_id = $1 + AND a.id >= $2 + AND al.ancestry_id = a.id + AND al.id = af.ancestry_layer_id + AND af.namespaced_feature_id = vanf.namespaced_feature_id + ORDER BY a.id ASC + LIMIT $3;` ) +func queryInvalidateVulnerabilityCache(count int) string { + return fmt.Sprintf(`DELETE FROM vulnerability_affected_feature + WHERE vulnerability_id IN (%s)`, + util.QueryString(1, count)) +} + +// NOTE(Sida): Every search query can only have count less than postgres set +// stack depth. IN will be resolved to nested OR_s and the parser might exceed +// stack depth. TODO(Sida): Generate different queries for different count: if +// count < 5120, use IN; for count > 5120 and < 65536, use temporary table; for +// count > 65535, use is expected to split data into batches. +func querySearchLastDeletedVulnerabilityID(count int) string { + return fmt.Sprintf(` + SELECT vid, vname, nname FROM ( + SELECT v.id AS vid, v.name AS vname, n.name AS nname, + row_number() OVER ( + PARTITION by (v.name, n.name) + ORDER BY v.deleted_at DESC + ) AS rownum + FROM vulnerability AS v, namespace AS n + WHERE v.namespace_id = n.id + AND (v.name, n.name) IN ( %s ) + AND v.deleted_at IS NOT NULL + ) tmp WHERE rownum <= 1`, + util.QueryString(2, count)) +} + +func querySearchNotDeletedVulnerabilityID(count int) string { + return fmt.Sprintf(` + SELECT v.id, v.name, n.name FROM vulnerability AS v, namespace AS n + WHERE v.namespace_id = n.id AND (v.name, n.name) IN (%s) + AND v.deleted_at IS NULL`, + util.QueryString(2, count)) +} + type affectedAncestry struct { name string id int64 @@ -113,8 +128,8 @@ type affectedFeatureRows struct { rows map[int64]database.AffectedFeature } -func (tx *pgSession) FindVulnerabilities(vulnerabilities []database.VulnerabilityID) ([]database.NullableVulnerability, error) { - defer observeQueryTime("findVulnerabilities", "", time.Now()) +func FindVulnerabilities(tx *sql.Tx, vulnerabilities []database.VulnerabilityID) ([]database.NullableVulnerability, error) { + defer monitoring.ObserveQueryTime("findVulnerabilities", "", time.Now()) resultVuln := make([]database.NullableVulnerability, len(vulnerabilities)) vulnIDMap := map[int64][]*database.NullableVulnerability{} @@ -151,7 +166,7 @@ func (tx *pgSession) FindVulnerabilities(vulnerabilities []database.Vulnerabilit if err != nil && err != sql.ErrNoRows { stmt.Close() - return nil, handleError("searchVulnerability", err) + return nil, util.HandleError("searchVulnerability", err) } vuln.Valid = id.Valid resultVuln[i] = vuln @@ -161,7 +176,7 @@ func (tx *pgSession) FindVulnerabilities(vulnerabilities []database.Vulnerabilit } if err := stmt.Close(); err != nil { - return nil, handleError("searchVulnerability", err) + return nil, util.HandleError("searchVulnerability", err) } toQuery := make([]int64, 0, len(vulnIDMap)) @@ -172,7 +187,7 @@ func (tx *pgSession) FindVulnerabilities(vulnerabilities []database.Vulnerabilit // load vulnerability affected features rows, err := tx.Query(searchVulnerabilityAffected, pq.Array(toQuery)) if err != nil { - return nil, handleError("searchVulnerabilityAffected", err) + return nil, util.HandleError("searchVulnerabilityAffected", err) } for rows.Next() { @@ -183,7 +198,7 @@ func (tx *pgSession) FindVulnerabilities(vulnerabilities []database.Vulnerabilit err := rows.Scan(&id, &f.FeatureName, &f.AffectedVersion, &f.FeatureType, &f.FixedInVersion) if err != nil { - return nil, handleError("searchVulnerabilityAffected", err) + return nil, util.HandleError("searchVulnerabilityAffected", err) } for _, vuln := range vulnIDMap[id] { @@ -195,41 +210,40 @@ func (tx *pgSession) FindVulnerabilities(vulnerabilities []database.Vulnerabilit return resultVuln, nil } -func (tx *pgSession) InsertVulnerabilities(vulnerabilities []database.VulnerabilityWithAffected) error { - defer observeQueryTime("insertVulnerabilities", "all", time.Now()) +func InsertVulnerabilities(tx *sql.Tx, vulnerabilities []database.VulnerabilityWithAffected) error { + defer monitoring.ObserveQueryTime("insertVulnerabilities", "all", time.Now()) // bulk insert vulnerabilities - vulnIDs, err := tx.insertVulnerabilities(vulnerabilities) + vulnIDs, err := insertVulnerabilities(tx, vulnerabilities) if err != nil { return err } // bulk insert vulnerability affected features - vulnFeatureMap, err := tx.insertVulnerabilityAffected(vulnIDs, vulnerabilities) + vulnFeatureMap, err := InsertVulnerabilityAffected(tx, vulnIDs, vulnerabilities) if err != nil { return err } - return tx.cacheVulnerabiltyAffectedNamespacedFeature(vulnFeatureMap) + return CacheVulnerabiltyAffectedNamespacedFeature(tx, vulnFeatureMap) } // insertVulnerabilityAffected inserts a set of vulnerability affected features for each vulnerability provided. // // i_th vulnerabilityIDs corresponds to i_th vulnerabilities provided. -func (tx *pgSession) insertVulnerabilityAffected(vulnerabilityIDs []int64, vulnerabilities []database.VulnerabilityWithAffected) (map[int64]affectedFeatureRows, error) { +func InsertVulnerabilityAffected(tx *sql.Tx, vulnerabilityIDs []int64, vulnerabilities []database.VulnerabilityWithAffected) (map[int64]affectedFeatureRows, error) { var ( vulnFeature = map[int64]affectedFeatureRows{} affectedID int64 ) - types, err := tx.getFeatureTypeMap() + types, err := feature.GetFeatureTypeMap(tx) if err != nil { return nil, err } - //TODO(Sida): Change to bulk insert. stmt, err := tx.Prepare(insertVulnerabilityAffected) if err != nil { - return nil, handleError("insertVulnerabilityAffected", err) + return nil, util.HandleError("insertVulnerabilityAffected", err) } defer stmt.Close() @@ -237,9 +251,9 @@ func (tx *pgSession) insertVulnerabilityAffected(vulnerabilityIDs []int64, vulne // affected feature row ID -> affected feature affectedFeatures := map[int64]database.AffectedFeature{} for _, f := range vuln.Affected { - err := stmt.QueryRow(vulnerabilityIDs[i], f.FeatureName, f.AffectedVersion, types.byName[f.FeatureType], f.FixedInVersion).Scan(&affectedID) + err := stmt.QueryRow(vulnerabilityIDs[i], f.FeatureName, f.AffectedVersion, types.ByName[f.FeatureType], f.FixedInVersion).Scan(&affectedID) if err != nil { - return nil, handleError("insertVulnerabilityAffected", err) + return nil, util.HandleError("insertVulnerabilityAffected", err) } affectedFeatures[affectedID] = f } @@ -251,7 +265,7 @@ func (tx *pgSession) insertVulnerabilityAffected(vulnerabilityIDs []int64, vulne // insertVulnerabilities inserts a set of unique vulnerabilities into database, // under the assumption that all vulnerabilities are valid. -func (tx *pgSession) insertVulnerabilities(vulnerabilities []database.VulnerabilityWithAffected) ([]int64, error) { +func insertVulnerabilities(tx *sql.Tx, vulnerabilities []database.VulnerabilityWithAffected) ([]int64, error) { var ( vulnID int64 vulnIDs = make([]int64, 0, len(vulnerabilities)) @@ -274,7 +288,7 @@ func (tx *pgSession) insertVulnerabilities(vulnerabilities []database.Vulnerabil //TODO(Sida): Change to bulk insert. stmt, err := tx.Prepare(insertVulnerability) if err != nil { - return nil, handleError("insertVulnerability", err) + return nil, util.HandleError("insertVulnerability", err) } defer stmt.Close() @@ -283,7 +297,7 @@ func (tx *pgSession) insertVulnerabilities(vulnerabilities []database.Vulnerabil vuln.Link, &vuln.Severity, &vuln.Metadata, vuln.Namespace.Name, vuln.Namespace.VersionFormat).Scan(&vulnID) if err != nil { - return nil, handleError("insertVulnerability", err) + return nil, util.HandleError("insertVulnerability", err) } vulnIDs = append(vulnIDs, vulnID) @@ -292,19 +306,19 @@ func (tx *pgSession) insertVulnerabilities(vulnerabilities []database.Vulnerabil return vulnIDs, nil } -func (tx *pgSession) lockFeatureVulnerabilityCache() error { +func LockFeatureVulnerabilityCache(tx *sql.Tx) error { _, err := tx.Exec(lockVulnerabilityAffects) if err != nil { - return handleError("lockVulnerabilityAffects", err) + return util.HandleError("lockVulnerabilityAffects", err) } return nil } // cacheVulnerabiltyAffectedNamespacedFeature takes in a map of vulnerability ID // to affected feature rows and caches them. -func (tx *pgSession) cacheVulnerabiltyAffectedNamespacedFeature(affected map[int64]affectedFeatureRows) error { +func CacheVulnerabiltyAffectedNamespacedFeature(tx *sql.Tx, affected map[int64]affectedFeatureRows) error { // Prevent InsertNamespacedFeatures to modify it. - err := tx.lockFeatureVulnerabilityCache() + err := LockFeatureVulnerabilityCache(tx) if err != nil { return err } @@ -316,7 +330,7 @@ func (tx *pgSession) cacheVulnerabiltyAffectedNamespacedFeature(affected map[int rows, err := tx.Query(searchVulnerabilityPotentialAffected, pq.Array(vulnIDs)) if err != nil { - return handleError("searchVulnerabilityPotentialAffected", err) + return util.HandleError("searchVulnerabilityPotentialAffected", err) } defer rows.Close() @@ -332,7 +346,7 @@ func (tx *pgSession) cacheVulnerabiltyAffectedNamespacedFeature(affected map[int err := rows.Scan(&vulnID, &nsfID, &fVersion, &addedBy) if err != nil { - return handleError("searchVulnerabilityPotentialAffected", err) + return util.HandleError("searchVulnerabilityPotentialAffected", err) } candidate, ok := affected[vulnID].rows[addedBy] @@ -361,7 +375,7 @@ func (tx *pgSession) cacheVulnerabiltyAffectedNamespacedFeature(affected map[int for _, r := range relation { result, err := tx.Exec(insertVulnerabilityAffectedNamespacedFeature, r.vulnerabilityID, r.namespacedFeatureID, r.addedBy) if err != nil { - return handleError("insertVulnerabilityAffectedNamespacedFeature", err) + return util.HandleError("insertVulnerabilityAffectedNamespacedFeature", err) } if num, err := result.RowsAffected(); err == nil { @@ -377,27 +391,27 @@ func (tx *pgSession) cacheVulnerabiltyAffectedNamespacedFeature(affected map[int return nil } -func (tx *pgSession) DeleteVulnerabilities(vulnerabilities []database.VulnerabilityID) error { - defer observeQueryTime("DeleteVulnerability", "all", time.Now()) +func DeleteVulnerabilities(tx *sql.Tx, vulnerabilities []database.VulnerabilityID) error { + defer monitoring.ObserveQueryTime("DeleteVulnerability", "all", time.Now()) - vulnIDs, err := tx.markVulnerabilitiesAsDeleted(vulnerabilities) + vulnIDs, err := MarkVulnerabilitiesAsDeleted(tx, vulnerabilities) if err != nil { return err } - if err := tx.invalidateVulnerabilityCache(vulnIDs); err != nil { + if err := InvalidateVulnerabilityCache(tx, vulnIDs); err != nil { return err } return nil } -func (tx *pgSession) invalidateVulnerabilityCache(vulnerabilityIDs []int64) error { +func InvalidateVulnerabilityCache(tx *sql.Tx, vulnerabilityIDs []int64) error { if len(vulnerabilityIDs) == 0 { return nil } // Prevent InsertNamespacedFeatures to modify it. - err := tx.lockFeatureVulnerabilityCache() + err := LockFeatureVulnerabilityCache(tx) if err != nil { return err } @@ -410,13 +424,13 @@ func (tx *pgSession) invalidateVulnerabilityCache(vulnerabilityIDs []int64) erro _, err = tx.Exec(queryInvalidateVulnerabilityCache(len(vulnerabilityIDs)), keys...) if err != nil { - return handleError("removeVulnerabilityAffectedFeature", err) + return util.HandleError("removeVulnerabilityAffectedFeature", err) } return nil } -func (tx *pgSession) markVulnerabilitiesAsDeleted(vulnerabilities []database.VulnerabilityID) ([]int64, error) { +func MarkVulnerabilitiesAsDeleted(tx *sql.Tx, vulnerabilities []database.VulnerabilityID) ([]int64, error) { var ( vulnID sql.NullInt64 vulnIDs []int64 @@ -425,17 +439,17 @@ func (tx *pgSession) markVulnerabilitiesAsDeleted(vulnerabilities []database.Vul // mark vulnerabilities deleted stmt, err := tx.Prepare(removeVulnerability) if err != nil { - return nil, handleError("removeVulnerability", err) + return nil, util.HandleError("removeVulnerability", err) } defer stmt.Close() for _, vuln := range vulnerabilities { err := stmt.QueryRow(vuln.Namespace, vuln.Name).Scan(&vulnID) if err != nil { - return nil, handleError("removeVulnerability", err) + return nil, util.HandleError("removeVulnerability", err) } if !vulnID.Valid { - return nil, handleError("removeVulnerability", errors.New("Vulnerability to be removed is not in database")) + return nil, util.HandleError("removeVulnerability", errors.New("Vulnerability to be removed is not in database")) } vulnIDs = append(vulnIDs, vulnID.Int64) } @@ -444,15 +458,15 @@ func (tx *pgSession) markVulnerabilitiesAsDeleted(vulnerabilities []database.Vul // findLatestDeletedVulnerabilityIDs requires all elements in vulnIDs are in // database and the order of output array is not guaranteed. -func (tx *pgSession) findLatestDeletedVulnerabilityIDs(vulnIDs []database.VulnerabilityID) ([]sql.NullInt64, error) { - return tx.findVulnerabilityIDs(vulnIDs, true) +func FindLatestDeletedVulnerabilityIDs(tx *sql.Tx, vulnIDs []database.VulnerabilityID) ([]sql.NullInt64, error) { + return FindVulnerabilityIDs(tx, vulnIDs, true) } -func (tx *pgSession) findNotDeletedVulnerabilityIDs(vulnIDs []database.VulnerabilityID) ([]sql.NullInt64, error) { - return tx.findVulnerabilityIDs(vulnIDs, false) +func FindNotDeletedVulnerabilityIDs(tx *sql.Tx, vulnIDs []database.VulnerabilityID) ([]sql.NullInt64, error) { + return FindVulnerabilityIDs(tx, vulnIDs, false) } -func (tx *pgSession) findVulnerabilityIDs(vulnIDs []database.VulnerabilityID, withLatestDeleted bool) ([]sql.NullInt64, error) { +func FindVulnerabilityIDs(tx *sql.Tx, vulnIDs []database.VulnerabilityID, withLatestDeleted bool) ([]sql.NullInt64, error) { if len(vulnIDs) == 0 { return nil, nil } @@ -474,7 +488,7 @@ func (tx *pgSession) findVulnerabilityIDs(vulnIDs []database.VulnerabilityID, wi rows, err := tx.Query(query, keys...) if err != nil { - return nil, handleError("querySearchVulnerabilityID.LatestDeleted.Query", err) + return nil, util.HandleError("querySearchVulnerabilityID.LatestDeleted.Query", err) } defer rows.Close() @@ -485,7 +499,7 @@ func (tx *pgSession) findVulnerabilityIDs(vulnIDs []database.VulnerabilityID, wi for rows.Next() { err := rows.Scan(&id, &vulnID.Name, &vulnID.Namespace) if err != nil { - return nil, handleError("querySearchVulnerabilityID.LatestDeleted.Scan", err) + return nil, util.HandleError("querySearchVulnerabilityID.LatestDeleted.Scan", err) } vulnIDMap[vulnID] = id } @@ -497,3 +511,67 @@ func (tx *pgSession) findVulnerabilityIDs(vulnIDs []database.VulnerabilityID, wi return ids, nil } + +func FindPagedVulnerableAncestries(tx *sql.Tx, vulnID int64, limit int, currentToken pagination.Token, key pagination.Key) (database.PagedVulnerableAncestries, error) { + vulnPage := database.PagedVulnerableAncestries{Limit: limit} + currentPage := page.Page{0} + if currentToken != pagination.FirstPageToken { + if err := key.UnmarshalToken(currentToken, ¤tPage); err != nil { + return vulnPage, err + } + } + + if err := tx.QueryRow(searchVulnerabilityByID, vulnID).Scan( + &vulnPage.Name, + &vulnPage.Description, + &vulnPage.Link, + &vulnPage.Severity, + &vulnPage.Metadata, + &vulnPage.Namespace.Name, + &vulnPage.Namespace.VersionFormat, + ); err != nil { + return vulnPage, util.HandleError("searchVulnerabilityByID", err) + } + + // the last result is used for the next page's startID + rows, err := tx.Query(searchNotificationVulnerableAncestry, vulnID, currentPage.StartID, limit+1) + if err != nil { + return vulnPage, util.HandleError("searchNotificationVulnerableAncestry", err) + } + defer rows.Close() + + ancestries := []affectedAncestry{} + for rows.Next() { + var ancestry affectedAncestry + err := rows.Scan(&ancestry.id, &ancestry.name) + if err != nil { + return vulnPage, util.HandleError("searchNotificationVulnerableAncestry", err) + } + ancestries = append(ancestries, ancestry) + } + + lastIndex := 0 + if len(ancestries)-1 < limit { + lastIndex = len(ancestries) + vulnPage.End = true + } else { + // Use the last ancestry's ID as the next page. + lastIndex = len(ancestries) - 1 + vulnPage.Next, err = key.MarshalToken(page.Page{ancestries[len(ancestries)-1].id}) + if err != nil { + return vulnPage, err + } + } + + vulnPage.Affected = map[int]string{} + for _, ancestry := range ancestries[0:lastIndex] { + vulnPage.Affected[int(ancestry.id)] = ancestry.name + } + + vulnPage.Current, err = key.MarshalToken(currentPage) + if err != nil { + return vulnPage, err + } + + return vulnPage, nil +} diff --git a/database/pgsql/vulnerability/vulnerability_affected_feature.go b/database/pgsql/vulnerability/vulnerability_affected_feature.go new file mode 100644 index 00000000..97716dd6 --- /dev/null +++ b/database/pgsql/vulnerability/vulnerability_affected_feature.go @@ -0,0 +1,118 @@ +// Copyright 2019 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vulnerability + +import ( + "database/sql" + + "github.com/coreos/clair/database" + "github.com/coreos/clair/database/pgsql/feature" + "github.com/coreos/clair/database/pgsql/util" + "github.com/coreos/clair/ext/versionfmt" + "github.com/lib/pq" +) + +const ( + searchPotentialAffectingVulneraibilities = ` + SELECT nf.id, v.id, vaf.affected_version, vaf.id + FROM vulnerability_affected_feature AS vaf, vulnerability AS v, + namespaced_feature AS nf, feature AS f + WHERE nf.id = ANY($1) + AND nf.feature_id = f.id + AND nf.namespace_id = v.namespace_id + AND vaf.feature_name = f.name + AND vaf.feature_type = f.type + AND vaf.vulnerability_id = v.id + AND v.deleted_at IS NULL` + insertVulnerabilityAffected = ` + INSERT INTO vulnerability_affected_feature(vulnerability_id, feature_name, affected_version, feature_type, fixedin) + VALUES ($1, $2, $3, $4, $5) + RETURNING ID + ` + searchVulnerabilityAffected = ` + SELECT vulnerability_id, feature_name, affected_version, t.name, fixedin + FROM vulnerability_affected_feature AS vaf, feature_type AS t + WHERE t.id = vaf.feature_type AND vulnerability_id = ANY($1) + ` + + searchVulnerabilityPotentialAffected = ` + WITH req AS ( + SELECT vaf.id AS vaf_id, n.id AS n_id, vaf.feature_name AS name, vaf.feature_type AS type, v.id AS vulnerability_id + FROM vulnerability_affected_feature AS vaf, + vulnerability AS v, + namespace AS n + WHERE vaf.vulnerability_id = ANY($1) + AND v.id = vaf.vulnerability_id + AND n.id = v.namespace_id + ) + SELECT req.vulnerability_id, nf.id, f.version, req.vaf_id AS added_by + FROM feature AS f, namespaced_feature AS nf, req + WHERE f.name = req.name + AND f.type = req.type + AND nf.namespace_id = req.n_id + AND nf.feature_id = f.id` +) + +type vulnerabilityCache struct { + nsFeatureID int64 + vulnID int64 + vulnAffectingID int64 +} + +func SearchAffectingVulnerabilities(tx *sql.Tx, features []database.NamespacedFeature) ([]vulnerabilityCache, error) { + if len(features) == 0 { + return nil, nil + } + + ids, err := feature.FindNamespacedFeatureIDs(tx, features) + if err != nil { + return nil, err + } + + fMap := map[int64]database.NamespacedFeature{} + for i, f := range features { + if !ids[i].Valid { + return nil, database.ErrMissingEntities + } + fMap[ids[i].Int64] = f + } + + cacheTable := []vulnerabilityCache{} + rows, err := tx.Query(searchPotentialAffectingVulneraibilities, pq.Array(ids)) + if err != nil { + return nil, util.HandleError("searchPotentialAffectingVulneraibilities", err) + } + + defer rows.Close() + for rows.Next() { + var ( + cache vulnerabilityCache + affected string + ) + + err := rows.Scan(&cache.nsFeatureID, &cache.vulnID, &affected, &cache.vulnAffectingID) + if err != nil { + return nil, err + } + + if ok, err := versionfmt.InRange(fMap[cache.nsFeatureID].VersionFormat, fMap[cache.nsFeatureID].Version, affected); err != nil { + return nil, err + } else if ok { + cacheTable = append(cacheTable, cache) + } + } + + return cacheTable, nil +} diff --git a/database/pgsql/vulnerability/vulnerability_affected_namespaced_feature.go b/database/pgsql/vulnerability/vulnerability_affected_namespaced_feature.go new file mode 100644 index 00000000..2a09fe5b --- /dev/null +++ b/database/pgsql/vulnerability/vulnerability_affected_namespaced_feature.go @@ -0,0 +1,142 @@ +// Copyright 2019 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vulnerability + +import ( + "database/sql" + + "github.com/coreos/clair/database" + "github.com/coreos/clair/database/pgsql/feature" + "github.com/coreos/clair/database/pgsql/util" + "github.com/lib/pq" + log "github.com/sirupsen/logrus" +) + +const ( + searchNamespacedFeaturesVulnerabilities = ` + SELECT vanf.namespaced_feature_id, v.name, v.description, v.link, + v.severity, v.metadata, vaf.fixedin, n.name, n.version_format + FROM vulnerability_affected_namespaced_feature AS vanf, + Vulnerability AS v, + vulnerability_affected_feature AS vaf, + namespace AS n + WHERE vanf.namespaced_feature_id = ANY($1) + AND vaf.id = vanf.added_by + AND v.id = vanf.vulnerability_id + AND n.id = v.namespace_id + AND v.deleted_at IS NULL` + + lockVulnerabilityAffects = `LOCK vulnerability_affected_namespaced_feature IN SHARE ROW EXCLUSIVE MODE` + + insertVulnerabilityAffectedNamespacedFeature = ` + INSERT INTO vulnerability_affected_namespaced_feature(vulnerability_id, namespaced_feature_id, added_by) + VALUES ($1, $2, $3)` +) + +func queryPersistVulnerabilityAffectedNamespacedFeature(count int) string { + return util.QueryPersist(count, "vulnerability_affected_namespaced_feature", + "vulnerability_affected_namesp_vulnerability_id_namespaced_f_key", + "vulnerability_id", + "namespaced_feature_id", + "added_by") +} + +// FindAffectedNamespacedFeatures retrieves vulnerabilities associated with the +// feature. +func FindAffectedNamespacedFeatures(tx *sql.Tx, features []database.NamespacedFeature) ([]database.NullableAffectedNamespacedFeature, error) { + if len(features) == 0 { + return nil, nil + } + + vulnerableFeatures := make([]database.NullableAffectedNamespacedFeature, len(features)) + featureIDs, err := feature.FindNamespacedFeatureIDs(tx, features) + if err != nil { + return nil, err + } + + for i, id := range featureIDs { + if id.Valid { + vulnerableFeatures[i].Valid = true + vulnerableFeatures[i].NamespacedFeature = features[i] + } + } + + rows, err := tx.Query(searchNamespacedFeaturesVulnerabilities, pq.Array(featureIDs)) + if err != nil { + return nil, util.HandleError("searchNamespacedFeaturesVulnerabilities", err) + } + defer rows.Close() + + for rows.Next() { + var ( + featureID int64 + vuln database.VulnerabilityWithFixedIn + ) + + err := rows.Scan(&featureID, + &vuln.Name, + &vuln.Description, + &vuln.Link, + &vuln.Severity, + &vuln.Metadata, + &vuln.FixedInVersion, + &vuln.Namespace.Name, + &vuln.Namespace.VersionFormat, + ) + + if err != nil { + return nil, util.HandleError("searchNamespacedFeaturesVulnerabilities", err) + } + + for i, id := range featureIDs { + if id.Valid && id.Int64 == featureID { + vulnerableFeatures[i].AffectedNamespacedFeature.AffectedBy = append(vulnerableFeatures[i].AffectedNamespacedFeature.AffectedBy, vuln) + } + } + } + + return vulnerableFeatures, nil +} + +func CacheAffectedNamespacedFeatures(tx *sql.Tx, features []database.NamespacedFeature) error { + if len(features) == 0 { + return nil + } + + _, err := tx.Exec(lockVulnerabilityAffects) + if err != nil { + return util.HandleError("lockVulnerabilityAffects", err) + } + + cache, err := SearchAffectingVulnerabilities(tx, features) + + keys := make([]interface{}, 0, len(cache)*3) + for _, c := range cache { + keys = append(keys, c.vulnID, c.nsFeatureID, c.vulnAffectingID) + } + + if len(cache) == 0 { + return nil + } + + affected, err := tx.Exec(queryPersistVulnerabilityAffectedNamespacedFeature(len(cache)), keys...) + if err != nil { + return util.HandleError("persistVulnerabilityAffectedNamespacedFeature", err) + } + if count, err := affected.RowsAffected(); err != nil { + log.Debugf("Cached %d features in vulnerability_affected_namespaced_feature", count) + } + return nil +} diff --git a/database/pgsql/vulnerability_test.go b/database/pgsql/vulnerability/vulnerability_test.go similarity index 50% rename from database/pgsql/vulnerability_test.go rename to database/pgsql/vulnerability/vulnerability_test.go index 759bfe2f..d911adc6 100644 --- a/database/pgsql/vulnerability_test.go +++ b/database/pgsql/vulnerability/vulnerability_test.go @@ -12,19 +12,30 @@ // See the License for the specific language governing permissions and // limitations under the License. -package pgsql +package vulnerability import ( + "database/sql" + "math/rand" + "strconv" "testing" + "github.com/pborman/uuid" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/coreos/clair/database" + "github.com/coreos/clair/database/pgsql/feature" + "github.com/coreos/clair/database/pgsql/namespace" + "github.com/coreos/clair/database/pgsql/testutil" + "github.com/coreos/clair/ext/versionfmt" "github.com/coreos/clair/ext/versionfmt/dpkg" + "github.com/coreos/clair/pkg/strutil" ) func TestInsertVulnerabilities(t *testing.T) { - store, tx := openSessionForTest(t, "InsertVulnerabilities", true) + store, cleanup := testutil.CreateTestDBWithFixture(t, "InsertVulnerabilities") + defer cleanup() ns1 := database.Namespace{ Name: "name", @@ -56,45 +67,48 @@ func TestInsertVulnerabilities(t *testing.T) { Vulnerability: v2, } + tx, err := store.Begin() + require.Nil(t, err) + // empty - err := tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{}) + err = InsertVulnerabilities(tx, []database.VulnerabilityWithAffected{}) assert.Nil(t, err) // invalid content: vwa1 is invalid - err = tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{vwa1, vwa2}) + err = InsertVulnerabilities(tx, []database.VulnerabilityWithAffected{vwa1, vwa2}) assert.NotNil(t, err) - tx = restartSession(t, store, tx, false) + tx = testutil.RestartTransaction(store, tx, false) // invalid content: duplicated input - err = tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{vwa2, vwa2}) + err = InsertVulnerabilities(tx, []database.VulnerabilityWithAffected{vwa2, vwa2}) assert.NotNil(t, err) - tx = restartSession(t, store, tx, false) + tx = testutil.RestartTransaction(store, tx, false) // valid content - err = tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{vwa2}) + err = InsertVulnerabilities(tx, []database.VulnerabilityWithAffected{vwa2}) assert.Nil(t, err) - tx = restartSession(t, store, tx, true) + tx = testutil.RestartTransaction(store, tx, true) // ensure the content is in database - vulns, err := tx.FindVulnerabilities([]database.VulnerabilityID{{Name: "valid", Namespace: "debian:7"}}) + vulns, err := FindVulnerabilities(tx, []database.VulnerabilityID{{Name: "valid", Namespace: "debian:7"}}) if assert.Nil(t, err) && assert.Len(t, vulns, 1) { assert.True(t, vulns[0].Valid) } - tx = restartSession(t, store, tx, false) + tx = testutil.RestartTransaction(store, tx, false) // valid content: vwa2 removed and inserted - err = tx.DeleteVulnerabilities([]database.VulnerabilityID{{Name: vwa2.Name, Namespace: vwa2.Namespace.Name}}) + err = DeleteVulnerabilities(tx, []database.VulnerabilityID{{Name: vwa2.Name, Namespace: vwa2.Namespace.Name}}) assert.Nil(t, err) - err = tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{vwa2}) + err = InsertVulnerabilities(tx, []database.VulnerabilityWithAffected{vwa2}) assert.Nil(t, err) - closeTest(t, store, tx) + require.Nil(t, tx.Rollback()) } func TestCachingVulnerable(t *testing.T) { - datastore, tx := openSessionForTest(t, "CachingVulnerable", true) - defer closeTest(t, datastore, tx) + tx, cleanup := testutil.CreateTestTxWithFixtures(t, "CachingVulnerable") + defer cleanup() ns := database.Namespace{ Name: "debian:8", @@ -163,11 +177,8 @@ func TestCachingVulnerable(t *testing.T) { FixedInVersion: "2.2", } - if !assert.Nil(t, tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{vuln, vuln2})) { - t.FailNow() - } - - r, err := tx.FindAffectedNamespacedFeatures([]database.NamespacedFeature{f}) + require.Nil(t, InsertVulnerabilities(tx, []database.VulnerabilityWithAffected{vuln, vuln2})) + r, err := FindAffectedNamespacedFeatures(tx, []database.NamespacedFeature{f}) assert.Nil(t, err) assert.Len(t, r, 1) for _, anf := range r { @@ -186,10 +197,10 @@ func TestCachingVulnerable(t *testing.T) { } func TestFindVulnerabilities(t *testing.T) { - datastore, tx := openSessionForTest(t, "FindVulnerabilities", true) - defer closeTest(t, datastore, tx) + tx, cleanup := testutil.CreateTestTxWithFixtures(t, "FindVulnerabilities") + defer cleanup() - vuln, err := tx.FindVulnerabilities([]database.VulnerabilityID{ + vuln, err := FindVulnerabilities(tx, []database.VulnerabilityID{ {Name: "CVE-OPENSSL-1-DEB7", Namespace: "debian:7"}, {Name: "CVE-NOPE", Namespace: "debian:7"}, {Name: "CVE-NOT HERE"}, @@ -255,7 +266,7 @@ func TestFindVulnerabilities(t *testing.T) { expected, ok := expectedExistingMap[key] if assert.True(t, ok, "vulnerability not found: "+key.Name+":"+key.Namespace) { - assertVulnerabilityWithAffectedEqual(t, expected, v.VulnerabilityWithAffected) + testutil.AssertVulnerabilityWithAffectedEqual(t, expected, v.VulnerabilityWithAffected) } } else if !assert.Equal(t, nonexisting, v.VulnerabilityWithAffected) { t.FailNow() @@ -264,7 +275,7 @@ func TestFindVulnerabilities(t *testing.T) { } // same vulnerability - r, err := tx.FindVulnerabilities([]database.VulnerabilityID{ + r, err := FindVulnerabilities(tx, []database.VulnerabilityID{ {Name: "CVE-OPENSSL-1-DEB7", Namespace: "debian:7"}, {Name: "CVE-OPENSSL-1-DEB7", Namespace: "debian:7"}, }) @@ -273,22 +284,22 @@ func TestFindVulnerabilities(t *testing.T) { for _, vuln := range r { if assert.True(t, vuln.Valid) { expected, _ := expectedExistingMap[database.VulnerabilityID{Name: "CVE-OPENSSL-1-DEB7", Namespace: "debian:7"}] - assertVulnerabilityWithAffectedEqual(t, expected, vuln.VulnerabilityWithAffected) + testutil.AssertVulnerabilityWithAffectedEqual(t, expected, vuln.VulnerabilityWithAffected) } } } } func TestDeleteVulnerabilities(t *testing.T) { - datastore, tx := openSessionForTest(t, "DeleteVulnerabilities", true) - defer closeTest(t, datastore, tx) + tx, cleanup := testutil.CreateTestTxWithFixtures(t, "DeleteVulnerabilities") + defer cleanup() remove := []database.VulnerabilityID{} // empty case - assert.Nil(t, tx.DeleteVulnerabilities(remove)) + assert.Nil(t, DeleteVulnerabilities(tx, remove)) // invalid case remove = append(remove, database.VulnerabilityID{}) - assert.NotNil(t, tx.DeleteVulnerabilities(remove)) + assert.NotNil(t, DeleteVulnerabilities(tx, remove)) // valid case validRemove := []database.VulnerabilityID{ @@ -296,8 +307,8 @@ func TestDeleteVulnerabilities(t *testing.T) { {Name: "CVE-NOPE", Namespace: "debian:7"}, } - assert.Nil(t, tx.DeleteVulnerabilities(validRemove)) - vuln, err := tx.FindVulnerabilities(validRemove) + assert.Nil(t, DeleteVulnerabilities(tx, validRemove)) + vuln, err := FindVulnerabilities(tx, validRemove) if assert.Nil(t, err) { for _, v := range vuln { assert.False(t, v.Valid) @@ -306,20 +317,158 @@ func TestDeleteVulnerabilities(t *testing.T) { } func TestFindVulnerabilityIDs(t *testing.T) { - store, tx := openSessionForTest(t, "FindVulnerabilityIDs", true) - defer closeTest(t, store, tx) + tx, cleanup := testutil.CreateTestTxWithFixtures(t, "FindVulnerabilityIDs") + defer cleanup() - ids, err := tx.findLatestDeletedVulnerabilityIDs([]database.VulnerabilityID{{Name: "CVE-DELETED", Namespace: "debian:7"}}) + ids, err := FindLatestDeletedVulnerabilityIDs(tx, []database.VulnerabilityID{{Name: "CVE-DELETED", Namespace: "debian:7"}}) if assert.Nil(t, err) { if !(assert.Len(t, ids, 1) && assert.True(t, ids[0].Valid) && assert.Equal(t, 3, int(ids[0].Int64))) { assert.Fail(t, "") } } - ids, err = tx.findNotDeletedVulnerabilityIDs([]database.VulnerabilityID{{Name: "CVE-NOPE", Namespace: "debian:7"}}) + ids, err = FindNotDeletedVulnerabilityIDs(tx, []database.VulnerabilityID{{Name: "CVE-NOPE", Namespace: "debian:7"}}) if assert.Nil(t, err) { if !(assert.Len(t, ids, 1) && assert.True(t, ids[0].Valid) && assert.Equal(t, 2, int(ids[0].Int64))) { assert.Fail(t, "") } } } + +func TestFindAffectedNamespacedFeatures(t *testing.T) { + tx, cleanup := testutil.CreateTestTxWithFixtures(t, "FindAffectedNamespacedFeatures") + defer cleanup() + + ns := database.NamespacedFeature{ + Feature: database.Feature{ + Name: "openssl", + Version: "1.0", + VersionFormat: "dpkg", + Type: database.SourcePackage, + }, + Namespace: database.Namespace{ + Name: "debian:7", + VersionFormat: "dpkg", + }, + } + + ans, err := FindAffectedNamespacedFeatures(tx, []database.NamespacedFeature{ns}) + if assert.Nil(t, err) && + assert.Len(t, ans, 1) && + assert.True(t, ans[0].Valid) && + assert.Len(t, ans[0].AffectedBy, 1) { + assert.Equal(t, "CVE-OPENSSL-1-DEB7", ans[0].AffectedBy[0].Name) + } +} + +func genRandomVulnerabilityAndNamespacedFeature(t *testing.T, store *sql.DB) ([]database.NamespacedFeature, []database.VulnerabilityWithAffected) { + tx, err := store.Begin() + if err != nil { + panic(err) + } + + numFeatures := 100 + numVulnerabilities := 100 + + featureName := "TestFeature" + featureVersionFormat := dpkg.ParserName + // Insert the namespace on which we'll work. + ns := database.Namespace{ + Name: "TestRaceAffectsFeatureNamespace1", + VersionFormat: dpkg.ParserName, + } + + if !assert.Nil(t, namespace.PersistNamespaces(tx, []database.Namespace{ns})) { + t.FailNow() + } + + // Generate Distinct random features + features := make([]database.Feature, numFeatures) + nsFeatures := make([]database.NamespacedFeature, numFeatures) + for i := 0; i < numFeatures; i++ { + version := rand.Intn(numFeatures) + + features[i] = *database.NewSourcePackage(featureName, strconv.Itoa(version), featureVersionFormat) + nsFeatures[i] = database.NamespacedFeature{ + Namespace: ns, + Feature: features[i], + } + } + + if !assert.Nil(t, feature.PersistFeatures(tx, features)) { + t.FailNow() + } + + // Generate vulnerabilities. + vulnerabilities := []database.VulnerabilityWithAffected{} + for i := 0; i < numVulnerabilities; i++ { + // any version less than this is vulnerable + version := rand.Intn(numFeatures) + 1 + + vulnerability := database.VulnerabilityWithAffected{ + Vulnerability: database.Vulnerability{ + Name: uuid.New(), + Namespace: ns, + Severity: database.UnknownSeverity, + }, + Affected: []database.AffectedFeature{ + { + Namespace: ns, + FeatureName: featureName, + FeatureType: database.SourcePackage, + AffectedVersion: strconv.Itoa(version), + FixedInVersion: strconv.Itoa(version), + }, + }, + } + + vulnerabilities = append(vulnerabilities, vulnerability) + } + tx.Commit() + + return nsFeatures, vulnerabilities +} + +func TestVulnChangeAffectsVulnerableFeatures(t *testing.T) { + db, cleanup := testutil.CreateTestDB(t, "caching") + defer cleanup() + + nsFeatures, vulnerabilities := genRandomVulnerabilityAndNamespacedFeature(t, db) + tx, err := db.Begin() + require.Nil(t, err) + + require.Nil(t, feature.PersistNamespacedFeatures(tx, nsFeatures)) + require.Nil(t, tx.Commit()) + + tx, err = db.Begin() + require.Nil(t, InsertVulnerabilities(tx, vulnerabilities)) + require.Nil(t, tx.Commit()) + + tx, err = db.Begin() + require.Nil(t, err) + defer tx.Rollback() + + affected, err := FindAffectedNamespacedFeatures(tx, nsFeatures) + require.Nil(t, err) + + for _, ansf := range affected { + require.True(t, ansf.Valid) + + expectedAffectedNames := []string{} + for _, vuln := range vulnerabilities { + if ok, err := versionfmt.InRange(dpkg.ParserName, ansf.Version, vuln.Affected[0].AffectedVersion); err == nil { + if ok { + expectedAffectedNames = append(expectedAffectedNames, vuln.Name) + } + } + } + + actualAffectedNames := []string{} + for _, s := range ansf.AffectedBy { + actualAffectedNames = append(actualAffectedNames, s.Name) + } + + require.Len(t, strutil.Difference(expectedAffectedNames, actualAffectedNames), 0, "\nvulns: %#v\nfeature:%#v\nexpected:%#v\nactual:%#v", vulnerabilities, ansf.NamespacedFeature, expectedAffectedNames, actualAffectedNames) + require.Len(t, strutil.Difference(actualAffectedNames, expectedAffectedNames), 0) + } +} diff --git a/database/vulnerability.go b/database/vulnerability.go new file mode 100644 index 00000000..4eb40c42 --- /dev/null +++ b/database/vulnerability.go @@ -0,0 +1,50 @@ +// Copyright 2019 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package database + +// VulnerabilityID is an identifier for every vulnerability. Every vulnerability +// has unique namespace and name. +type VulnerabilityID struct { + Name string + Namespace string +} + +// Vulnerability represents CVE or similar vulnerability reports. +type Vulnerability struct { + Name string + Namespace Namespace + + Description string + Link string + Severity Severity + + Metadata MetadataMap +} + +// VulnerabilityWithAffected is a vulnerability with all known affected +// features. +type VulnerabilityWithAffected struct { + Vulnerability + + Affected []AffectedFeature +} + +// NullableVulnerability is a vulnerability with whether the vulnerability is +// found in datastore. +type NullableVulnerability struct { + VulnerabilityWithAffected + + Valid bool +}