Merge pull request #721 from KeyboardNerd/cache

Restructure database folder
This commit is contained in:
Sida Chen 2019-03-13 16:56:28 -04:00 committed by GitHub
commit 2c7838eac7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
59 changed files with 3825 additions and 3430 deletions

96
database/ancestry.go Normal file
View File

@ -0,0 +1,96 @@
// Copyright 2019 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package database
// Ancestry is a manifest that keeps all layers in an image in order.
type Ancestry struct {
// Name is a globally unique value for a set of layers. This is often the
// sha256 digest of an OCI/Docker manifest.
Name string `json:"name"`
// By contains the processors that are used when computing the
// content of this ancestry.
By []Detector `json:"by"`
// Layers should be ordered and i_th layer is the parent of i+1_th layer in
// the slice.
Layers []AncestryLayer `json:"layers"`
}
// Valid checks if the ancestry is compliant to spec.
func (a *Ancestry) Valid() bool {
if a == nil {
return false
}
if a.Name == "" {
return false
}
for _, d := range a.By {
if !d.Valid() {
return false
}
}
for _, l := range a.Layers {
if !l.Valid() {
return false
}
}
return true
}
// AncestryLayer is a layer with all detected namespaced features.
type AncestryLayer struct {
// Hash is the sha-256 tarsum on the layer's blob content.
Hash string `json:"hash"`
// Features are the features introduced by this layer when it was
// processed.
Features []AncestryFeature `json:"features"`
}
// Valid checks if the Ancestry Layer is compliant to the spec.
func (l *AncestryLayer) Valid() bool {
if l == nil {
return false
}
if l.Hash == "" {
return false
}
return true
}
// GetFeatures returns the Ancestry's features.
func (l *AncestryLayer) GetFeatures() []NamespacedFeature {
nsf := make([]NamespacedFeature, 0, len(l.Features))
for _, f := range l.Features {
nsf = append(nsf, f.NamespacedFeature)
}
return nsf
}
// AncestryFeature is a namespaced feature with the detectors used to
// find this feature.
type AncestryFeature struct {
NamespacedFeature `json:"namespacedFeature"`
// FeatureBy is the detector that detected the feature.
FeatureBy Detector `json:"featureBy"`
// NamespaceBy is the detector that detected the namespace.
NamespaceBy Detector `json:"namespaceBy"`
}

96
database/feature.go Normal file
View File

@ -0,0 +1,96 @@
// Copyright 2019 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package database
// Feature represents a package detected in a layer but the namespace is not
// determined.
//
// e.g. Name: Libssl1.0, Version: 1.0, VersionFormat: dpkg, Type: binary
// dpkg is the version format of the installer package manager, which in this
// case could be dpkg or apk.
type Feature struct {
Name string `json:"name"`
Version string `json:"version"`
VersionFormat string `json:"versionFormat"`
Type FeatureType `json:"type"`
}
// NamespacedFeature is a feature with determined namespace and can be affected
// by vulnerabilities.
//
// e.g. OpenSSL 1.0 dpkg Debian:7.
type NamespacedFeature struct {
Feature `json:"feature"`
Namespace Namespace `json:"namespace"`
}
// AffectedNamespacedFeature is a namespaced feature affected by the
// vulnerabilities with fixed-in versions for this feature.
type AffectedNamespacedFeature struct {
NamespacedFeature
AffectedBy []VulnerabilityWithFixedIn
}
// VulnerabilityWithFixedIn is used for AffectedNamespacedFeature to retrieve
// the affecting vulnerabilities and the fixed-in versions for the feature.
type VulnerabilityWithFixedIn struct {
Vulnerability
FixedInVersion string
}
// AffectedFeature is used to determine whether a namespaced feature is affected
// by a Vulnerability. Namespace and Feature Name is unique. Affected Feature is
// bound to vulnerability.
type AffectedFeature struct {
// FeatureType determines which type of package it affects.
FeatureType FeatureType
Namespace Namespace
FeatureName string
// FixedInVersion is known next feature version that's not affected by the
// vulnerability. Empty FixedInVersion means the unaffected version is
// unknown.
FixedInVersion string
// AffectedVersion contains the version range to determine whether or not a
// feature is affected.
AffectedVersion string
}
// NullableAffectedNamespacedFeature is an affectednamespacedfeature with
// whether it's found in datastore.
type NullableAffectedNamespacedFeature struct {
AffectedNamespacedFeature
Valid bool
}
func NewFeature(name string, version string, versionFormat string, featureType FeatureType) *Feature {
return &Feature{name, version, versionFormat, featureType}
}
func NewBinaryPackage(name string, version string, versionFormat string) *Feature {
return &Feature{name, version, versionFormat, BinaryPackage}
}
func NewSourcePackage(name string, version string, versionFormat string) *Feature {
return &Feature{name, version, versionFormat, SourcePackage}
}
func NewNamespacedFeature(namespace *Namespace, feature *Feature) *NamespacedFeature {
// TODO: namespaced feature should use pointer values
return &NamespacedFeature{*feature, *namespace}
}

65
database/layer.go Normal file
View File

@ -0,0 +1,65 @@
// Copyright 2019 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package database
// Layer is a layer with all the detected features and namespaces.
type Layer struct {
// Hash is the sha-256 tarsum on the layer's blob content.
Hash string `json:"hash"`
// By contains a list of detectors scanned this Layer.
By []Detector `json:"by"`
Namespaces []LayerNamespace `json:"namespaces"`
Features []LayerFeature `json:"features"`
}
func (l *Layer) GetFeatures() []Feature {
features := make([]Feature, 0, len(l.Features))
for _, f := range l.Features {
features = append(features, f.Feature)
}
return features
}
func (l *Layer) GetNamespaces() []Namespace {
namespaces := make([]Namespace, 0, len(l.Namespaces)+len(l.Features))
for _, ns := range l.Namespaces {
namespaces = append(namespaces, ns.Namespace)
}
for _, f := range l.Features {
if f.PotentialNamespace.Valid() {
namespaces = append(namespaces, f.PotentialNamespace)
}
}
return namespaces
}
// LayerNamespace is a namespace with detection information.
type LayerNamespace struct {
Namespace `json:"namespace"`
// By is the detector found the namespace.
By Detector `json:"by"`
}
// LayerFeature is a feature with detection information.
type LayerFeature struct {
Feature `json:"feature"`
// By is the detector found the feature.
By Detector `json:"by"`
PotentialNamespace Namespace `json:"potentialNamespace"`
}

41
database/metadata.go Normal file
View File

@ -0,0 +1,41 @@
// Copyright 2019 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package database
import (
"database/sql/driver"
"encoding/json"
)
// MetadataMap is for storing the metadata returned by vulnerability database.
type MetadataMap map[string]interface{}
func (mm *MetadataMap) Scan(value interface{}) error {
if value == nil {
return nil
}
// github.com/lib/pq decodes TEXT/VARCHAR fields into strings.
val, ok := value.(string)
if !ok {
panic("got type other than []byte from database")
}
return json.Unmarshal([]byte(val), mm)
}
func (mm *MetadataMap) Value() (driver.Value, error) {
json, err := json.Marshal(*mm)
return string(json), err
}

View File

@ -1,363 +0,0 @@
// Copyright 2017 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package database
import (
"database/sql/driver"
"encoding/json"
"fmt"
"time"
"github.com/coreos/clair/pkg/pagination"
)
// Ancestry is a manifest that keeps all layers in an image in order.
type Ancestry struct {
// Name is a globally unique value for a set of layers. This is often the
// sha256 digest of an OCI/Docker manifest.
Name string `json:"name"`
// By contains the processors that are used when computing the
// content of this ancestry.
By []Detector `json:"by"`
// Layers should be ordered and i_th layer is the parent of i+1_th layer in
// the slice.
Layers []AncestryLayer `json:"layers"`
}
// Valid checks if the ancestry is compliant to spec.
func (a *Ancestry) Valid() bool {
if a == nil {
return false
}
if a.Name == "" {
return false
}
for _, d := range a.By {
if !d.Valid() {
return false
}
}
for _, l := range a.Layers {
if !l.Valid() {
return false
}
}
return true
}
// AncestryLayer is a layer with all detected namespaced features.
type AncestryLayer struct {
// Hash is the sha-256 tarsum on the layer's blob content.
Hash string `json:"hash"`
// Features are the features introduced by this layer when it was
// processed.
Features []AncestryFeature `json:"features"`
}
// Valid checks if the Ancestry Layer is compliant to the spec.
func (l *AncestryLayer) Valid() bool {
if l == nil {
return false
}
if l.Hash == "" {
return false
}
return true
}
// GetFeatures returns the Ancestry's features.
func (l *AncestryLayer) GetFeatures() []NamespacedFeature {
nsf := make([]NamespacedFeature, 0, len(l.Features))
for _, f := range l.Features {
nsf = append(nsf, f.NamespacedFeature)
}
return nsf
}
// AncestryFeature is a namespaced feature with the detectors used to
// find this feature.
type AncestryFeature struct {
NamespacedFeature `json:"namespacedFeature"`
// FeatureBy is the detector that detected the feature.
FeatureBy Detector `json:"featureBy"`
// NamespaceBy is the detector that detected the namespace.
NamespaceBy Detector `json:"namespaceBy"`
}
// Layer is a layer with all the detected features and namespaces.
type Layer struct {
// Hash is the sha-256 tarsum on the layer's blob content.
Hash string `json:"hash"`
// By contains a list of detectors scanned this Layer.
By []Detector `json:"by"`
Namespaces []LayerNamespace `json:"namespaces"`
Features []LayerFeature `json:"features"`
}
func (l *Layer) GetFeatures() []Feature {
features := make([]Feature, 0, len(l.Features))
for _, f := range l.Features {
features = append(features, f.Feature)
}
return features
}
func (l *Layer) GetNamespaces() []Namespace {
namespaces := make([]Namespace, 0, len(l.Namespaces)+len(l.Features))
for _, ns := range l.Namespaces {
namespaces = append(namespaces, ns.Namespace)
}
for _, f := range l.Features {
if f.PotentialNamespace.Valid() {
namespaces = append(namespaces, f.PotentialNamespace)
}
}
return namespaces
}
// LayerNamespace is a namespace with detection information.
type LayerNamespace struct {
Namespace `json:"namespace"`
// By is the detector found the namespace.
By Detector `json:"by"`
}
// LayerFeature is a feature with detection information.
type LayerFeature struct {
Feature `json:"feature"`
// By is the detector found the feature.
By Detector `json:"by"`
PotentialNamespace Namespace `json:"potentialNamespace"`
}
// Namespace is the contextual information around features.
//
// e.g. Debian:7, NodeJS.
type Namespace struct {
Name string `json:"name"`
VersionFormat string `json:"versionFormat"`
}
func NewNamespace(name string, versionFormat string) *Namespace {
return &Namespace{name, versionFormat}
}
func (ns *Namespace) Valid() bool {
if ns.Name == "" || ns.VersionFormat == "" {
return false
}
return true
}
// Feature represents a package detected in a layer but the namespace is not
// determined.
//
// e.g. Name: Libssl1.0, Version: 1.0, VersionFormat: dpkg, Type: binary
// dpkg is the version format of the installer package manager, which in this
// case could be dpkg or apk.
type Feature struct {
Name string `json:"name"`
Version string `json:"version"`
VersionFormat string `json:"versionFormat"`
Type FeatureType `json:"type"`
}
func NewFeature(name string, version string, versionFormat string, featureType FeatureType) *Feature {
return &Feature{name, version, versionFormat, featureType}
}
func NewBinaryPackage(name string, version string, versionFormat string) *Feature {
return &Feature{name, version, versionFormat, BinaryPackage}
}
func NewSourcePackage(name string, version string, versionFormat string) *Feature {
return &Feature{name, version, versionFormat, SourcePackage}
}
// NamespacedFeature is a feature with determined namespace and can be affected
// by vulnerabilities.
//
// e.g. OpenSSL 1.0 dpkg Debian:7.
type NamespacedFeature struct {
Feature `json:"feature"`
Namespace Namespace `json:"namespace"`
}
func (nf *NamespacedFeature) Key() string {
return fmt.Sprintf("%s-%s-%s-%s-%s-%s", nf.Name, nf.Version, nf.VersionFormat, nf.Type, nf.Namespace.Name, nf.Namespace.VersionFormat)
}
func NewNamespacedFeature(namespace *Namespace, feature *Feature) *NamespacedFeature {
// TODO: namespaced feature should use pointer values
return &NamespacedFeature{*feature, *namespace}
}
// AffectedNamespacedFeature is a namespaced feature affected by the
// vulnerabilities with fixed-in versions for this feature.
type AffectedNamespacedFeature struct {
NamespacedFeature
AffectedBy []VulnerabilityWithFixedIn
}
// VulnerabilityWithFixedIn is used for AffectedNamespacedFeature to retrieve
// the affecting vulnerabilities and the fixed-in versions for the feature.
type VulnerabilityWithFixedIn struct {
Vulnerability
FixedInVersion string
}
// AffectedFeature is used to determine whether a namespaced feature is affected
// by a Vulnerability. Namespace and Feature Name is unique. Affected Feature is
// bound to vulnerability.
type AffectedFeature struct {
// FeatureType determines which type of package it affects.
FeatureType FeatureType
Namespace Namespace
FeatureName string
// FixedInVersion is known next feature version that's not affected by the
// vulnerability. Empty FixedInVersion means the unaffected version is
// unknown.
FixedInVersion string
// AffectedVersion contains the version range to determine whether or not a
// feature is affected.
AffectedVersion string
}
// VulnerabilityID is an identifier for every vulnerability. Every vulnerability
// has unique namespace and name.
type VulnerabilityID struct {
Name string
Namespace string
}
// Vulnerability represents CVE or similar vulnerability reports.
type Vulnerability struct {
Name string
Namespace Namespace
Description string
Link string
Severity Severity
Metadata MetadataMap
}
// VulnerabilityWithAffected is a vulnerability with all known affected
// features.
type VulnerabilityWithAffected struct {
Vulnerability
Affected []AffectedFeature
}
// PagedVulnerableAncestries is a vulnerability with a page of affected
// ancestries each with a special index attached for streaming purpose. The
// current page number and next page number are for navigate.
type PagedVulnerableAncestries struct {
Vulnerability
// Affected is a map of special indexes to Ancestries, which the pair
// should be unique in a stream. Every indexes in the map should be larger
// than previous page.
Affected map[int]string
Limit int
Current pagination.Token
Next pagination.Token
// End signals the end of the pages.
End bool
}
// NotificationHook is a message sent to another service to inform of a change
// to a Vulnerability or the Ancestries affected by a Vulnerability. It contains
// the name of a notification that should be read and marked as read via the
// API.
type NotificationHook struct {
Name string
Created time.Time
Notified time.Time
Deleted time.Time
}
// VulnerabilityNotification is a notification for vulnerability changes.
type VulnerabilityNotification struct {
NotificationHook
Old *Vulnerability
New *Vulnerability
}
// VulnerabilityNotificationWithVulnerable is a notification for vulnerability
// changes with vulnerable ancestries.
type VulnerabilityNotificationWithVulnerable struct {
NotificationHook
Old *PagedVulnerableAncestries
New *PagedVulnerableAncestries
}
// MetadataMap is for storing the metadata returned by vulnerability database.
type MetadataMap map[string]interface{}
// NullableAffectedNamespacedFeature is an affectednamespacedfeature with
// whether it's found in datastore.
type NullableAffectedNamespacedFeature struct {
AffectedNamespacedFeature
Valid bool
}
// NullableVulnerability is a vulnerability with whether the vulnerability is
// found in datastore.
type NullableVulnerability struct {
VulnerabilityWithAffected
Valid bool
}
func (mm *MetadataMap) Scan(value interface{}) error {
if value == nil {
return nil
}
// github.com/lib/pq decodes TEXT/VARCHAR fields into strings.
val, ok := value.(string)
if !ok {
panic("got type other than []byte from database")
}
return json.Unmarshal([]byte(val), mm)
}
func (mm *MetadataMap) Value() (driver.Value, error) {
json, err := json.Marshal(*mm)
return string(json), err
}

34
database/namespace.go Normal file
View File

@ -0,0 +1,34 @@
// Copyright 2019 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package database
// Namespace is the contextual information around features.
//
// e.g. Debian:7, NodeJS.
type Namespace struct {
Name string `json:"name"`
VersionFormat string `json:"versionFormat"`
}
func NewNamespace(name string, versionFormat string) *Namespace {
return &Namespace{name, versionFormat}
}
func (ns *Namespace) Valid() bool {
if ns.Name == "" || ns.VersionFormat == "" {
return false
}
return true
}

69
database/notification.go Normal file
View File

@ -0,0 +1,69 @@
// Copyright 2019 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package database
import (
"time"
"github.com/coreos/clair/pkg/pagination"
)
// NotificationHook is a message sent to another service to inform of a change
// to a Vulnerability or the Ancestries affected by a Vulnerability. It contains
// the name of a notification that should be read and marked as read via the
// API.
type NotificationHook struct {
Name string
Created time.Time
Notified time.Time
Deleted time.Time
}
// VulnerabilityNotification is a notification for vulnerability changes.
type VulnerabilityNotification struct {
NotificationHook
Old *Vulnerability
New *Vulnerability
}
// VulnerabilityNotificationWithVulnerable is a notification for vulnerability
// changes with vulnerable ancestries.
type VulnerabilityNotificationWithVulnerable struct {
NotificationHook
Old *PagedVulnerableAncestries
New *PagedVulnerableAncestries
}
// PagedVulnerableAncestries is a vulnerability with a page of affected
// ancestries each with a special index attached for streaming purpose. The
// current page number and next page number are for navigate.
type PagedVulnerableAncestries struct {
Vulnerability
// Affected is a map of special indexes to Ancestries, which the pair
// should be unique in a stream. Every indexes in the map should be larger
// than previous page.
Affected map[int]string
Limit int
Current pagination.Token
Next pagination.Token
// End signals the end of the pages.
End bool
}

View File

@ -1,375 +0,0 @@
// Copyright 2019 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pgsql
import (
"database/sql"
"errors"
log "github.com/sirupsen/logrus"
"github.com/coreos/clair/database"
"github.com/coreos/clair/pkg/commonerr"
)
const (
insertAncestry = `
INSERT INTO ancestry (name) VALUES ($1) RETURNING id`
findAncestryLayerHashes = `
SELECT layer.hash, ancestry_layer.ancestry_index
FROM layer, ancestry_layer
WHERE ancestry_layer.ancestry_id = $1
AND ancestry_layer.layer_id = layer.id
ORDER BY ancestry_layer.ancestry_index ASC`
findAncestryFeatures = `
SELECT namespace.name, namespace.version_format, feature.name,
feature.version, feature.version_format, feature_type.name, ancestry_layer.ancestry_index,
ancestry_feature.feature_detector_id, ancestry_feature.namespace_detector_id
FROM namespace, feature, feature_type, namespaced_feature, ancestry_layer, ancestry_feature
WHERE ancestry_layer.ancestry_id = $1
AND feature_type.id = feature.type
AND ancestry_feature.ancestry_layer_id = ancestry_layer.id
AND ancestry_feature.namespaced_feature_id = namespaced_feature.id
AND namespaced_feature.feature_id = feature.id
AND namespaced_feature.namespace_id = namespace.id`
findAncestryID = `SELECT id FROM ancestry WHERE name = $1`
removeAncestry = `DELETE FROM ancestry WHERE name = $1`
insertAncestryLayers = `
INSERT INTO ancestry_layer (ancestry_id, ancestry_index, layer_id) VALUES ($1, $2, $3)
RETURNING id`
insertAncestryFeatures = `
INSERT INTO ancestry_feature
(ancestry_layer_id, namespaced_feature_id, feature_detector_id, namespace_detector_id) VALUES
($1, $2, $3, $4)`
)
func (tx *pgSession) FindAncestry(name string) (database.Ancestry, bool, error) {
var (
ancestry = database.Ancestry{Name: name}
err error
)
id, ok, err := tx.findAncestryID(name)
if !ok || err != nil {
return ancestry, ok, err
}
if ancestry.By, err = tx.findAncestryDetectors(id); err != nil {
return ancestry, false, err
}
if ancestry.Layers, err = tx.findAncestryLayers(id); err != nil {
return ancestry, false, err
}
return ancestry, true, nil
}
func (tx *pgSession) UpsertAncestry(ancestry database.Ancestry) error {
if !ancestry.Valid() {
return database.ErrInvalidParameters
}
if err := tx.removeAncestry(ancestry.Name); err != nil {
return err
}
id, err := tx.insertAncestry(ancestry.Name)
if err != nil {
return err
}
detectorIDs, err := tx.findDetectorIDs(ancestry.By)
if err != nil {
return err
}
// insert ancestry metadata
if err := tx.insertAncestryDetectors(id, detectorIDs); err != nil {
return err
}
layers := make([]string, 0, len(ancestry.Layers))
for _, layer := range ancestry.Layers {
layers = append(layers, layer.Hash)
}
layerIDs, ok, err := tx.findLayerIDs(layers)
if err != nil {
return err
}
if !ok {
log.Error("layer cannot be found, this indicates that the internal logic of calling UpsertAncestry is wrong or the database is corrupted.")
return database.ErrMissingEntities
}
ancestryLayerIDs, err := tx.insertAncestryLayers(id, layerIDs)
if err != nil {
return err
}
for i, id := range ancestryLayerIDs {
if err := tx.insertAncestryFeatures(id, ancestry.Layers[i]); err != nil {
return err
}
}
return nil
}
func (tx *pgSession) insertAncestry(name string) (int64, error) {
var id int64
err := tx.QueryRow(insertAncestry, name).Scan(&id)
if err != nil {
if isErrUniqueViolation(err) {
return 0, handleError("insertAncestry", errors.New("other Go-routine is processing this ancestry (skip)"))
}
return 0, handleError("insertAncestry", err)
}
return id, nil
}
func (tx *pgSession) findAncestryID(name string) (int64, bool, error) {
var id sql.NullInt64
if err := tx.QueryRow(findAncestryID, name).Scan(&id); err != nil {
if err == sql.ErrNoRows {
return 0, false, nil
}
return 0, false, handleError("findAncestryID", err)
}
return id.Int64, true, nil
}
func (tx *pgSession) removeAncestry(name string) error {
result, err := tx.Exec(removeAncestry, name)
if err != nil {
return handleError("removeAncestry", err)
}
affected, err := result.RowsAffected()
if err != nil {
return handleError("removeAncestry", err)
}
if affected != 0 {
log.WithField("ancestry", name).Debug("removed ancestry")
}
return nil
}
func (tx *pgSession) findAncestryLayers(id int64) ([]database.AncestryLayer, error) {
detectors, err := tx.findAllDetectors()
if err != nil {
return nil, err
}
layerMap, err := tx.findAncestryLayerHashes(id)
if err != nil {
return nil, err
}
featureMap, err := tx.findAncestryFeatures(id, detectors)
if err != nil {
return nil, err
}
layers := make([]database.AncestryLayer, len(layerMap))
for index, layer := range layerMap {
// index MUST match the ancestry layer slice index.
if layers[index].Hash == "" && len(layers[index].Features) == 0 {
layers[index] = database.AncestryLayer{
Hash: layer,
Features: featureMap[index],
}
} else {
log.WithFields(log.Fields{
"ancestry ID": id,
"duplicated ancestry index": index,
}).WithError(database.ErrInconsistent).Error("ancestry layers with same ancestry_index is not allowed")
return nil, database.ErrInconsistent
}
}
return layers, nil
}
func (tx *pgSession) findAncestryLayerHashes(ancestryID int64) (map[int64]string, error) {
// retrieve layer indexes and hashes
rows, err := tx.Query(findAncestryLayerHashes, ancestryID)
if err != nil {
return nil, handleError("findAncestryLayerHashes", err)
}
layerHashes := map[int64]string{}
for rows.Next() {
var (
hash string
index int64
)
if err = rows.Scan(&hash, &index); err != nil {
return nil, handleError("findAncestryLayerHashes", err)
}
if _, ok := layerHashes[index]; ok {
// one ancestry index should correspond to only one layer
return nil, database.ErrInconsistent
}
layerHashes[index] = hash
}
return layerHashes, nil
}
func (tx *pgSession) findAncestryFeatures(ancestryID int64, detectors detectorMap) (map[int64][]database.AncestryFeature, error) {
// ancestry_index -> ancestry features
featureMap := make(map[int64][]database.AncestryFeature)
// retrieve ancestry layer's namespaced features
rows, err := tx.Query(findAncestryFeatures, ancestryID)
if err != nil {
return nil, handleError("findAncestryFeatures", err)
}
defer rows.Close()
for rows.Next() {
var (
featureDetectorID int64
namespaceDetectorID int64
feature database.NamespacedFeature
// index is used to determine which layer the feature belongs to.
index sql.NullInt64
)
if err := rows.Scan(
&feature.Namespace.Name,
&feature.Namespace.VersionFormat,
&feature.Feature.Name,
&feature.Feature.Version,
&feature.Feature.VersionFormat,
&feature.Feature.Type,
&index,
&featureDetectorID,
&namespaceDetectorID,
); err != nil {
return nil, handleError("findAncestryFeatures", err)
}
if feature.Feature.VersionFormat != feature.Namespace.VersionFormat {
// Feature must have the same version format as the associated
// namespace version format.
return nil, database.ErrInconsistent
}
fDetector, ok := detectors.byID[featureDetectorID]
if !ok {
return nil, database.ErrInconsistent
}
nsDetector, ok := detectors.byID[namespaceDetectorID]
if !ok {
return nil, database.ErrInconsistent
}
featureMap[index.Int64] = append(featureMap[index.Int64], database.AncestryFeature{
NamespacedFeature: feature,
FeatureBy: fDetector,
NamespaceBy: nsDetector,
})
}
return featureMap, nil
}
// insertAncestryLayers inserts the ancestry layers along with its content into
// the database. The layers are 0 based indexed in the original order.
func (tx *pgSession) insertAncestryLayers(ancestryID int64, layers []int64) ([]int64, error) {
stmt, err := tx.Prepare(insertAncestryLayers)
if err != nil {
return nil, handleError("insertAncestryLayers", err)
}
ancestryLayerIDs := []int64{}
for index, layerID := range layers {
var ancestryLayerID sql.NullInt64
if err := stmt.QueryRow(ancestryID, index, layerID).Scan(&ancestryLayerID); err != nil {
return nil, handleError("insertAncestryLayers", commonerr.CombineErrors(err, stmt.Close()))
}
if !ancestryLayerID.Valid {
return nil, database.ErrInconsistent
}
ancestryLayerIDs = append(ancestryLayerIDs, ancestryLayerID.Int64)
}
if err := stmt.Close(); err != nil {
return nil, handleError("insertAncestryLayers", err)
}
return ancestryLayerIDs, nil
}
func (tx *pgSession) insertAncestryFeatures(ancestryLayerID int64, layer database.AncestryLayer) error {
detectors, err := tx.findAllDetectors()
if err != nil {
return err
}
nsFeatureIDs, err := tx.findNamespacedFeatureIDs(layer.GetFeatures())
if err != nil {
return err
}
// find the detectors for each feature
stmt, err := tx.Prepare(insertAncestryFeatures)
if err != nil {
return handleError("insertAncestryFeatures", err)
}
defer stmt.Close()
for index, id := range nsFeatureIDs {
if !id.Valid {
return database.ErrMissingEntities
}
namespaceDetectorID, ok := detectors.byValue[layer.Features[index].NamespaceBy]
if !ok {
return database.ErrMissingEntities
}
featureDetectorID, ok := detectors.byValue[layer.Features[index].FeatureBy]
if !ok {
return database.ErrMissingEntities
}
if _, err := stmt.Exec(ancestryLayerID, id, featureDetectorID, namespaceDetectorID); err != nil {
return handleError("insertAncestryFeatures", commonerr.CombineErrors(err, stmt.Close()))
}
}
return nil
}

View File

@ -0,0 +1,160 @@
// Copyright 2019 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ancestry
import (
"database/sql"
"errors"
log "github.com/sirupsen/logrus"
"github.com/coreos/clair/database"
"github.com/coreos/clair/database/pgsql/detector"
"github.com/coreos/clair/database/pgsql/layer"
"github.com/coreos/clair/database/pgsql/util"
)
const (
insertAncestry = `
INSERT INTO ancestry (name) VALUES ($1) RETURNING id`
findAncestryID = `SELECT id FROM ancestry WHERE name = $1`
removeAncestry = `DELETE FROM ancestry WHERE name = $1`
insertAncestryFeatures = `
INSERT INTO ancestry_feature
(ancestry_layer_id, namespaced_feature_id, feature_detector_id, namespace_detector_id) VALUES
($1, $2, $3, $4)`
)
func FindAncestry(tx *sql.Tx, name string) (database.Ancestry, bool, error) {
var (
ancestry = database.Ancestry{Name: name}
err error
)
id, ok, err := FindAncestryID(tx, name)
if !ok || err != nil {
return ancestry, ok, err
}
if ancestry.By, err = FindAncestryDetectors(tx, id); err != nil {
return ancestry, false, err
}
if ancestry.Layers, err = FindAncestryLayers(tx, id); err != nil {
return ancestry, false, err
}
return ancestry, true, nil
}
func UpsertAncestry(tx *sql.Tx, ancestry database.Ancestry) error {
if !ancestry.Valid() {
return database.ErrInvalidParameters
}
if err := RemoveAncestry(tx, ancestry.Name); err != nil {
return err
}
id, err := InsertAncestry(tx, ancestry.Name)
if err != nil {
return err
}
detectorIDs, err := detector.FindDetectorIDs(tx, ancestry.By)
if err != nil {
return err
}
// insert ancestry metadata
if err := InsertAncestryDetectors(tx, id, detectorIDs); err != nil {
return err
}
layers := make([]string, 0, len(ancestry.Layers))
for _, l := range ancestry.Layers {
layers = append(layers, l.Hash)
}
layerIDs, ok, err := layer.FindLayerIDs(tx, layers)
if err != nil {
return err
}
if !ok {
log.Error("layer cannot be found, this indicates that the internal logic of calling UpsertAncestry is wrong or the database is corrupted.")
return database.ErrMissingEntities
}
ancestryLayerIDs, err := InsertAncestryLayers(tx, id, layerIDs)
if err != nil {
return err
}
for i, id := range ancestryLayerIDs {
if err := InsertAncestryFeatures(tx, id, ancestry.Layers[i]); err != nil {
return err
}
}
return nil
}
func InsertAncestry(tx *sql.Tx, name string) (int64, error) {
var id int64
err := tx.QueryRow(insertAncestry, name).Scan(&id)
if err != nil {
if util.IsErrUniqueViolation(err) {
return 0, util.HandleError("insertAncestry", errors.New("other Go-routine is processing this ancestry (skip)"))
}
return 0, util.HandleError("insertAncestry", err)
}
return id, nil
}
func FindAncestryID(tx *sql.Tx, name string) (int64, bool, error) {
var id sql.NullInt64
if err := tx.QueryRow(findAncestryID, name).Scan(&id); err != nil {
if err == sql.ErrNoRows {
return 0, false, nil
}
return 0, false, util.HandleError("findAncestryID", err)
}
return id.Int64, true, nil
}
func RemoveAncestry(tx *sql.Tx, name string) error {
result, err := tx.Exec(removeAncestry, name)
if err != nil {
return util.HandleError("removeAncestry", err)
}
affected, err := result.RowsAffected()
if err != nil {
return util.HandleError("removeAncestry", err)
}
if affected != 0 {
log.WithField("ancestry", name).Debug("removed ancestry")
}
return nil
}

View File

@ -0,0 +1,48 @@
// Copyright 2019 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ancestry
import (
"database/sql"
"github.com/coreos/clair/database"
"github.com/coreos/clair/database/pgsql/detector"
"github.com/coreos/clair/database/pgsql/util"
)
var selectAncestryDetectors = `
SELECT d.name, d.version, d.dtype
FROM ancestry_detector, detector AS d
WHERE ancestry_detector.detector_id = d.id AND ancestry_detector.ancestry_id = $1;`
var insertAncestryDetectors = `
INSERT INTO ancestry_detector (ancestry_id, detector_id)
SELECT $1, $2
WHERE NOT EXISTS (SELECT id FROM ancestry_detector WHERE ancestry_id = $1 AND detector_id = $2)`
func FindAncestryDetectors(tx *sql.Tx, id int64) ([]database.Detector, error) {
detectors, err := detector.GetDetectors(tx, selectAncestryDetectors, id)
return detectors, err
}
func InsertAncestryDetectors(tx *sql.Tx, ancestryID int64, detectorIDs []int64) error {
for _, detectorID := range detectorIDs {
if _, err := tx.Exec(insertAncestryDetectors, ancestryID, detectorID); err != nil {
return util.HandleError("insertAncestryDetectors", err)
}
}
return nil
}

View File

@ -0,0 +1,139 @@
// Copyright 2019 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ancestry
import (
"database/sql"
"github.com/coreos/clair/database"
"github.com/coreos/clair/database/pgsql/detector"
"github.com/coreos/clair/database/pgsql/feature"
"github.com/coreos/clair/database/pgsql/util"
"github.com/coreos/clair/pkg/commonerr"
)
const findAncestryFeatures = `
SELECT namespace.name, namespace.version_format, feature.name,
feature.version, feature.version_format, feature_type.name, ancestry_layer.ancestry_index,
ancestry_feature.feature_detector_id, ancestry_feature.namespace_detector_id
FROM namespace, feature, feature_type, namespaced_feature, ancestry_layer, ancestry_feature
WHERE ancestry_layer.ancestry_id = $1
AND feature_type.id = feature.type
AND ancestry_feature.ancestry_layer_id = ancestry_layer.id
AND ancestry_feature.namespaced_feature_id = namespaced_feature.id
AND namespaced_feature.feature_id = feature.id
AND namespaced_feature.namespace_id = namespace.id`
func FindAncestryFeatures(tx *sql.Tx, ancestryID int64, detectors detector.DetectorMap) (map[int64][]database.AncestryFeature, error) {
// ancestry_index -> ancestry features
featureMap := make(map[int64][]database.AncestryFeature)
// retrieve ancestry layer's namespaced features
rows, err := tx.Query(findAncestryFeatures, ancestryID)
if err != nil {
return nil, util.HandleError("findAncestryFeatures", err)
}
defer rows.Close()
for rows.Next() {
var (
featureDetectorID int64
namespaceDetectorID int64
feature database.NamespacedFeature
// index is used to determine which layer the feature belongs to.
index sql.NullInt64
)
if err := rows.Scan(
&feature.Namespace.Name,
&feature.Namespace.VersionFormat,
&feature.Feature.Name,
&feature.Feature.Version,
&feature.Feature.VersionFormat,
&feature.Feature.Type,
&index,
&featureDetectorID,
&namespaceDetectorID,
); err != nil {
return nil, util.HandleError("findAncestryFeatures", err)
}
if feature.Feature.VersionFormat != feature.Namespace.VersionFormat {
// Feature must have the same version format as the associated
// namespace version format.
return nil, database.ErrInconsistent
}
fDetector, ok := detectors.ByID[featureDetectorID]
if !ok {
return nil, database.ErrInconsistent
}
nsDetector, ok := detectors.ByID[namespaceDetectorID]
if !ok {
return nil, database.ErrInconsistent
}
featureMap[index.Int64] = append(featureMap[index.Int64], database.AncestryFeature{
NamespacedFeature: feature,
FeatureBy: fDetector,
NamespaceBy: nsDetector,
})
}
return featureMap, nil
}
func InsertAncestryFeatures(tx *sql.Tx, ancestryLayerID int64, layer database.AncestryLayer) error {
detectors, err := detector.FindAllDetectors(tx)
if err != nil {
return err
}
nsFeatureIDs, err := feature.FindNamespacedFeatureIDs(tx, layer.GetFeatures())
if err != nil {
return err
}
// find the detectors for each feature
stmt, err := tx.Prepare(insertAncestryFeatures)
if err != nil {
return util.HandleError("insertAncestryFeatures", err)
}
defer stmt.Close()
for index, id := range nsFeatureIDs {
if !id.Valid {
return database.ErrMissingEntities
}
namespaceDetectorID, ok := detectors.ByValue[layer.Features[index].NamespaceBy]
if !ok {
return database.ErrMissingEntities
}
featureDetectorID, ok := detectors.ByValue[layer.Features[index].FeatureBy]
if !ok {
return database.ErrMissingEntities
}
if _, err := stmt.Exec(ancestryLayerID, id, featureDetectorID, namespaceDetectorID); err != nil {
return util.HandleError("insertAncestryFeatures", commonerr.CombineErrors(err, stmt.Close()))
}
}
return nil
}

View File

@ -0,0 +1,131 @@
// Copyright 2019 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ancestry
import (
"database/sql"
"github.com/coreos/clair/database"
"github.com/coreos/clair/database/pgsql/detector"
"github.com/coreos/clair/database/pgsql/util"
"github.com/coreos/clair/pkg/commonerr"
log "github.com/sirupsen/logrus"
)
const (
findAncestryLayerHashes = `
SELECT layer.hash, ancestry_layer.ancestry_index
FROM layer, ancestry_layer
WHERE ancestry_layer.ancestry_id = $1
AND ancestry_layer.layer_id = layer.id
ORDER BY ancestry_layer.ancestry_index ASC`
insertAncestryLayers = `
INSERT INTO ancestry_layer (ancestry_id, ancestry_index, layer_id) VALUES ($1, $2, $3)
RETURNING id`
)
func FindAncestryLayerHashes(tx *sql.Tx, ancestryID int64) (map[int64]string, error) {
// retrieve layer indexes and hashes
rows, err := tx.Query(findAncestryLayerHashes, ancestryID)
if err != nil {
return nil, util.HandleError("findAncestryLayerHashes", err)
}
layerHashes := map[int64]string{}
for rows.Next() {
var (
hash string
index int64
)
if err = rows.Scan(&hash, &index); err != nil {
return nil, util.HandleError("findAncestryLayerHashes", err)
}
if _, ok := layerHashes[index]; ok {
// one ancestry index should correspond to only one layer
return nil, database.ErrInconsistent
}
layerHashes[index] = hash
}
return layerHashes, nil
}
// insertAncestryLayers inserts the ancestry layers along with its content into
// the database. The layers are 0 based indexed in the original order.
func InsertAncestryLayers(tx *sql.Tx, ancestryID int64, layers []int64) ([]int64, error) {
stmt, err := tx.Prepare(insertAncestryLayers)
if err != nil {
return nil, util.HandleError("insertAncestryLayers", err)
}
ancestryLayerIDs := []int64{}
for index, layerID := range layers {
var ancestryLayerID sql.NullInt64
if err := stmt.QueryRow(ancestryID, index, layerID).Scan(&ancestryLayerID); err != nil {
return nil, util.HandleError("insertAncestryLayers", commonerr.CombineErrors(err, stmt.Close()))
}
if !ancestryLayerID.Valid {
return nil, database.ErrInconsistent
}
ancestryLayerIDs = append(ancestryLayerIDs, ancestryLayerID.Int64)
}
if err := stmt.Close(); err != nil {
return nil, util.HandleError("insertAncestryLayers", err)
}
return ancestryLayerIDs, nil
}
func FindAncestryLayers(tx *sql.Tx, id int64) ([]database.AncestryLayer, error) {
detectors, err := detector.FindAllDetectors(tx)
if err != nil {
return nil, err
}
layerMap, err := FindAncestryLayerHashes(tx, id)
if err != nil {
return nil, err
}
featureMap, err := FindAncestryFeatures(tx, id, detectors)
if err != nil {
return nil, err
}
layers := make([]database.AncestryLayer, len(layerMap))
for index, layer := range layerMap {
// index MUST match the ancestry layer slice index.
if layers[index].Hash == "" && len(layers[index].Features) == 0 {
layers[index] = database.AncestryLayer{
Hash: layer,
Features: featureMap[index],
}
} else {
log.WithFields(log.Fields{
"ancestry ID": id,
"duplicated ancestry index": index,
}).WithError(database.ErrInconsistent).Error("ancestry layers with same ancestry_index is not allowed")
return nil, database.ErrInconsistent
}
}
return layers, nil
}

View File

@ -1,4 +1,4 @@
// Copyright 2017 clair authors // Copyright 2019 clair authors
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package pgsql package ancestry
import ( import (
"testing" "testing"
@ -20,6 +20,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/coreos/clair/database" "github.com/coreos/clair/database"
"github.com/coreos/clair/database/pgsql/testutil"
) )
var upsertAncestryTests = []struct { var upsertAncestryTests = []struct {
@ -55,9 +56,9 @@ var upsertAncestryTests = []struct {
title: "ancestry with invalid feature", title: "ancestry with invalid feature",
in: &database.Ancestry{ in: &database.Ancestry{
Name: "a", Name: "a",
By: []database.Detector{realDetectors[1], realDetectors[2]}, By: []database.Detector{testutil.RealDetectors[1], testutil.RealDetectors[2]},
Layers: []database.AncestryLayer{{Hash: "layer-1", Features: []database.AncestryFeature{ Layers: []database.AncestryLayer{{Hash: "layer-1", Features: []database.AncestryFeature{
{fakeNamespacedFeatures[1], fakeDetector[1], fakeDetector[2]}, {testutil.FakeNamespacedFeatures[1], testutil.FakeDetector[1], testutil.FakeDetector[2]},
}}}, }}},
}, },
err: database.ErrMissingEntities.Error(), err: database.ErrMissingEntities.Error(),
@ -66,26 +67,27 @@ var upsertAncestryTests = []struct {
title: "replace old ancestry", title: "replace old ancestry",
in: &database.Ancestry{ in: &database.Ancestry{
Name: "a", Name: "a",
By: []database.Detector{realDetectors[1], realDetectors[2]}, By: []database.Detector{testutil.RealDetectors[1], testutil.RealDetectors[2]},
Layers: []database.AncestryLayer{ Layers: []database.AncestryLayer{
{"layer-1", []database.AncestryFeature{{realNamespacedFeatures[1], realDetectors[2], realDetectors[1]}}}, {"layer-1", []database.AncestryFeature{{testutil.RealNamespacedFeatures[1], testutil.RealDetectors[2], testutil.RealDetectors[1]}}},
}, },
}, },
}, },
} }
func TestUpsertAncestry(t *testing.T) { func TestUpsertAncestry(t *testing.T) {
store, tx := openSessionForTest(t, "UpsertAncestry", true) tx, cleanup := testutil.CreateTestTxWithFixtures(t, "TestUpsertAncestry")
defer closeTest(t, store, tx) defer cleanup()
for _, test := range upsertAncestryTests { for _, test := range upsertAncestryTests {
t.Run(test.title, func(t *testing.T) { t.Run(test.title, func(t *testing.T) {
err := tx.UpsertAncestry(*test.in) err := UpsertAncestry(tx, *test.in)
if test.err != "" { if test.err != "" {
assert.EqualError(t, err, test.err, "unexpected error") assert.EqualError(t, err, test.err, "unexpected error")
return return
} }
assert.Nil(t, err) assert.Nil(t, err)
actual, ok, err := tx.FindAncestry(test.in.Name) actual, ok, err := FindAncestry(tx, test.in.Name)
assert.Nil(t, err) assert.Nil(t, err)
assert.True(t, ok) assert.True(t, ok)
database.AssertAncestryEqual(t, test.in, &actual) database.AssertAncestryEqual(t, test.in, &actual)
@ -113,16 +115,17 @@ var findAncestryTests = []struct {
in: "ancestry-2", in: "ancestry-2",
err: "", err: "",
ok: true, ok: true,
ancestry: takeAncestryPointerFromMap(realAncestries, 2), ancestry: testutil.TakeAncestryPointerFromMap(testutil.RealAncestries, 2),
}, },
} }
func TestFindAncestry(t *testing.T) { func TestFindAncestry(t *testing.T) {
store, tx := openSessionForTest(t, "FindAncestry", true) tx, cleanup := testutil.CreateTestTxWithFixtures(t, "TestFindAncestry")
defer closeTest(t, store, tx) defer cleanup()
for _, test := range findAncestryTests { for _, test := range findAncestryTests {
t.Run(test.title, func(t *testing.T) { t.Run(test.title, func(t *testing.T) {
ancestry, ok, err := tx.FindAncestry(test.in) ancestry, ok, err := FindAncestry(tx, test.in)
if test.err != "" { if test.err != "" {
assert.EqualError(t, err, test.err, "unexpected error") assert.EqualError(t, err, test.err, "unexpected error")
return return

View File

@ -1,172 +0,0 @@
// Copyright 2017 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pgsql
import (
"math/rand"
"runtime"
"strconv"
"sync"
"testing"
"time"
"github.com/pborman/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coreos/clair/database"
"github.com/coreos/clair/ext/versionfmt"
"github.com/coreos/clair/ext/versionfmt/dpkg"
"github.com/coreos/clair/pkg/strutil"
)
const (
numVulnerabilities = 100
numFeatures = 100
)
func testGenRandomVulnerabilityAndNamespacedFeature(t *testing.T, store database.Datastore) ([]database.NamespacedFeature, []database.VulnerabilityWithAffected) {
tx, err := store.Begin()
if !assert.Nil(t, err) {
t.FailNow()
}
featureName := "TestFeature"
featureVersionFormat := dpkg.ParserName
// Insert the namespace on which we'll work.
namespace := database.Namespace{
Name: "TestRaceAffectsFeatureNamespace1",
VersionFormat: dpkg.ParserName,
}
if !assert.Nil(t, tx.PersistNamespaces([]database.Namespace{namespace})) {
t.FailNow()
}
// Initialize random generator and enforce max procs.
rand.Seed(time.Now().UnixNano())
runtime.GOMAXPROCS(runtime.NumCPU())
// Generate Distinct random features
features := make([]database.Feature, numFeatures)
nsFeatures := make([]database.NamespacedFeature, numFeatures)
for i := 0; i < numFeatures; i++ {
version := rand.Intn(numFeatures)
features[i] = *database.NewSourcePackage(featureName, strconv.Itoa(version), featureVersionFormat)
nsFeatures[i] = database.NamespacedFeature{
Namespace: namespace,
Feature: features[i],
}
}
if !assert.Nil(t, tx.PersistFeatures(features)) {
t.FailNow()
}
// Generate vulnerabilities.
vulnerabilities := []database.VulnerabilityWithAffected{}
for i := 0; i < numVulnerabilities; i++ {
// any version less than this is vulnerable
version := rand.Intn(numFeatures) + 1
vulnerability := database.VulnerabilityWithAffected{
Vulnerability: database.Vulnerability{
Name: uuid.New(),
Namespace: namespace,
Severity: database.UnknownSeverity,
},
Affected: []database.AffectedFeature{
{
Namespace: namespace,
FeatureName: featureName,
FeatureType: database.SourcePackage,
AffectedVersion: strconv.Itoa(version),
FixedInVersion: strconv.Itoa(version),
},
},
}
vulnerabilities = append(vulnerabilities, vulnerability)
}
tx.Commit()
return nsFeatures, vulnerabilities
}
func TestConcurrency(t *testing.T) {
store, cleanup := createTestPgSQL(t, "concurrency")
defer cleanup()
var wg sync.WaitGroup
// there's a limit on the number of concurrent connections in the pool
wg.Add(30)
for i := 0; i < 30; i++ {
go func() {
defer wg.Done()
nsNamespaces := genRandomNamespaces(t, 100)
tx, err := store.Begin()
require.Nil(t, err)
require.Nil(t, tx.PersistNamespaces(nsNamespaces))
require.Nil(t, tx.Commit())
}()
}
wg.Wait()
}
func TestCaching(t *testing.T) {
store, cleanup := createTestPgSQL(t, "caching")
defer cleanup()
nsFeatures, vulnerabilities := testGenRandomVulnerabilityAndNamespacedFeature(t, store)
tx, err := store.Begin()
require.Nil(t, err)
require.Nil(t, tx.PersistNamespacedFeatures(nsFeatures))
require.Nil(t, tx.Commit())
tx, err = store.Begin()
require.Nil(t, tx.InsertVulnerabilities(vulnerabilities))
require.Nil(t, tx.Commit())
tx, err = store.Begin()
require.Nil(t, err)
defer tx.Rollback()
affected, err := tx.FindAffectedNamespacedFeatures(nsFeatures)
require.Nil(t, err)
for _, ansf := range affected {
require.True(t, ansf.Valid)
expectedAffectedNames := []string{}
for _, vuln := range vulnerabilities {
if ok, err := versionfmt.InRange(dpkg.ParserName, ansf.Version, vuln.Affected[0].AffectedVersion); err == nil {
if ok {
expectedAffectedNames = append(expectedAffectedNames, vuln.Name)
}
}
}
actualAffectedNames := []string{}
for _, s := range ansf.AffectedBy {
actualAffectedNames = append(actualAffectedNames, s.Name)
}
require.Len(t, strutil.Difference(expectedAffectedNames, actualAffectedNames), 0, "\nvulns: %#v\nfeature:%#v\nexpected:%#v\nactual:%#v", vulnerabilities, ansf.NamespacedFeature, expectedAffectedNames, actualAffectedNames)
require.Len(t, strutil.Difference(actualAffectedNames, expectedAffectedNames), 0)
}
}

View File

@ -1,196 +0,0 @@
// Copyright 2018 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pgsql
import (
"database/sql"
"github.com/deckarep/golang-set"
log "github.com/sirupsen/logrus"
"github.com/coreos/clair/database"
)
const (
soiDetector = `
INSERT INTO detector (name, version, dtype)
SELECT CAST ($1 AS TEXT), CAST ($2 AS TEXT), CAST ($3 AS detector_type )
WHERE NOT EXISTS (SELECT id FROM detector WHERE name = $1 AND version = $2 AND dtype = $3);`
selectAncestryDetectors = `
SELECT d.name, d.version, d.dtype
FROM ancestry_detector, detector AS d
WHERE ancestry_detector.detector_id = d.id AND ancestry_detector.ancestry_id = $1;`
selectLayerDetectors = `
SELECT d.name, d.version, d.dtype
FROM layer_detector, detector AS d
WHERE layer_detector.detector_id = d.id AND layer_detector.layer_id = $1;`
insertAncestryDetectors = `
INSERT INTO ancestry_detector (ancestry_id, detector_id)
SELECT $1, $2
WHERE NOT EXISTS (SELECT id FROM ancestry_detector WHERE ancestry_id = $1 AND detector_id = $2)`
persistLayerDetector = `
INSERT INTO layer_detector (layer_id, detector_id)
SELECT $1, $2
WHERE NOT EXISTS (SELECT id FROM layer_detector WHERE layer_id = $1 AND detector_id = $2)`
findDetectorID = `SELECT id FROM detector WHERE name = $1 AND version = $2 AND dtype = $3`
findAllDetectors = `SELECT id, name, version, dtype FROM detector`
)
type detectorMap struct {
byID map[int64]database.Detector
byValue map[database.Detector]int64
}
func (tx *pgSession) PersistDetectors(detectors []database.Detector) error {
for _, d := range detectors {
if !d.Valid() {
log.WithField("detector", d).Debug("Invalid Detector")
return database.ErrInvalidParameters
}
r, err := tx.Exec(soiDetector, d.Name, d.Version, d.DType)
if err != nil {
return handleError("soiDetector", err)
}
count, err := r.RowsAffected()
if err != nil {
return handleError("soiDetector", err)
}
if count == 0 {
log.Debug("detector already exists: ", d)
}
}
return nil
}
func (tx *pgSession) persistLayerDetector(layerID int64, detectorID int64) error {
if _, err := tx.Exec(persistLayerDetector, layerID, detectorID); err != nil {
return handleError("persistLayerDetector", err)
}
return nil
}
func (tx *pgSession) persistLayerDetectors(layerID int64, detectorIDs []int64) error {
alreadySaved := mapset.NewSet()
for _, id := range detectorIDs {
if alreadySaved.Contains(id) {
continue
}
alreadySaved.Add(id)
if err := tx.persistLayerDetector(layerID, id); err != nil {
return err
}
}
return nil
}
func (tx *pgSession) insertAncestryDetectors(ancestryID int64, detectorIDs []int64) error {
for _, detectorID := range detectorIDs {
if _, err := tx.Exec(insertAncestryDetectors, ancestryID, detectorID); err != nil {
return handleError("insertAncestryDetectors", err)
}
}
return nil
}
func (tx *pgSession) findAncestryDetectors(id int64) ([]database.Detector, error) {
detectors, err := tx.getDetectors(selectAncestryDetectors, id)
return detectors, err
}
func (tx *pgSession) findLayerDetectors(id int64) ([]database.Detector, error) {
detectors, err := tx.getDetectors(selectLayerDetectors, id)
return detectors, err
}
// findDetectorIDs retrieve ids of the detectors from the database, if any is not
// found, return the error.
func (tx *pgSession) findDetectorIDs(detectors []database.Detector) ([]int64, error) {
ids := []int64{}
for _, d := range detectors {
id := sql.NullInt64{}
err := tx.QueryRow(findDetectorID, d.Name, d.Version, d.DType).Scan(&id)
if err != nil {
return nil, handleError("findDetectorID", err)
}
if !id.Valid {
return nil, database.ErrInconsistent
}
ids = append(ids, id.Int64)
}
return ids, nil
}
func (tx *pgSession) getDetectors(query string, id int64) ([]database.Detector, error) {
rows, err := tx.Query(query, id)
if err != nil {
return nil, handleError("getDetectors", err)
}
detectors := []database.Detector{}
for rows.Next() {
d := database.Detector{}
err := rows.Scan(&d.Name, &d.Version, &d.DType)
if err != nil {
return nil, handleError("getDetectors", err)
}
if !d.Valid() {
return nil, database.ErrInvalidDetector
}
detectors = append(detectors, d)
}
return detectors, nil
}
func (tx *pgSession) findAllDetectors() (detectorMap, error) {
rows, err := tx.Query(findAllDetectors)
if err != nil {
return detectorMap{}, handleError("searchAllDetectors", err)
}
detectors := detectorMap{byID: make(map[int64]database.Detector), byValue: make(map[database.Detector]int64)}
for rows.Next() {
var (
id int64
d database.Detector
)
if err := rows.Scan(&id, &d.Name, &d.Version, &d.DType); err != nil {
return detectorMap{}, handleError("searchAllDetectors", err)
}
detectors.byID[id] = d
detectors.byValue[d] = id
}
return detectors, nil
}

View File

@ -0,0 +1,132 @@
// Copyright 2018 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package detector
import (
"database/sql"
log "github.com/sirupsen/logrus"
"github.com/coreos/clair/database"
"github.com/coreos/clair/database/pgsql/util"
)
const (
soiDetector = `
INSERT INTO detector (name, version, dtype)
SELECT CAST ($1 AS TEXT), CAST ($2 AS TEXT), CAST ($3 AS detector_type )
WHERE NOT EXISTS (SELECT id FROM detector WHERE name = $1 AND version = $2 AND dtype = $3);`
findDetectorID = `SELECT id FROM detector WHERE name = $1 AND version = $2 AND dtype = $3`
findAllDetectors = `SELECT id, name, version, dtype FROM detector`
)
type DetectorMap struct {
ByID map[int64]database.Detector
ByValue map[database.Detector]int64
}
func PersistDetectors(tx *sql.Tx, detectors []database.Detector) error {
for _, d := range detectors {
if !d.Valid() {
log.WithField("detector", d).Debug("Invalid Detector")
return database.ErrInvalidParameters
}
r, err := tx.Exec(soiDetector, d.Name, d.Version, d.DType)
if err != nil {
return util.HandleError("soiDetector", err)
}
count, err := r.RowsAffected()
if err != nil {
return util.HandleError("soiDetector", err)
}
if count == 0 {
log.Debug("detector already exists: ", d)
}
}
return nil
}
// findDetectorIDs retrieve ids of the detectors from the database, if any is not
// found, return the error.
func FindDetectorIDs(tx *sql.Tx, detectors []database.Detector) ([]int64, error) {
ids := []int64{}
for _, d := range detectors {
id := sql.NullInt64{}
err := tx.QueryRow(findDetectorID, d.Name, d.Version, d.DType).Scan(&id)
if err != nil {
return nil, util.HandleError("findDetectorID", err)
}
if !id.Valid {
return nil, database.ErrInconsistent
}
ids = append(ids, id.Int64)
}
return ids, nil
}
func GetDetectors(tx *sql.Tx, query string, id int64) ([]database.Detector, error) {
rows, err := tx.Query(query, id)
if err != nil {
return nil, util.HandleError("getDetectors", err)
}
detectors := []database.Detector{}
for rows.Next() {
d := database.Detector{}
err := rows.Scan(&d.Name, &d.Version, &d.DType)
if err != nil {
return nil, util.HandleError("getDetectors", err)
}
if !d.Valid() {
return nil, database.ErrInvalidDetector
}
detectors = append(detectors, d)
}
return detectors, nil
}
func FindAllDetectors(tx *sql.Tx) (DetectorMap, error) {
rows, err := tx.Query(findAllDetectors)
if err != nil {
return DetectorMap{}, util.HandleError("searchAllDetectors", err)
}
detectors := DetectorMap{ByID: make(map[int64]database.Detector), ByValue: make(map[database.Detector]int64)}
for rows.Next() {
var (
id int64
d database.Detector
)
if err := rows.Scan(&id, &d.Name, &d.Version, &d.DType); err != nil {
return DetectorMap{}, util.HandleError("searchAllDetectors", err)
}
detectors.ByID[id] = d
detectors.ByValue[d] = id
}
return detectors, nil
}

View File

@ -12,18 +12,20 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package pgsql package detector
import ( import (
"database/sql"
"testing" "testing"
"github.com/deckarep/golang-set" "github.com/deckarep/golang-set"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/coreos/clair/database" "github.com/coreos/clair/database"
"github.com/coreos/clair/database/pgsql/testutil"
) )
func testGetAllDetectors(tx *pgSession) []database.Detector { func testGetAllDetectors(tx *sql.Tx) []database.Detector {
query := `SELECT name, version, dtype FROM detector` query := `SELECT name, version, dtype FROM detector`
rows, err := tx.Query(query) rows, err := tx.Query(query)
if err != nil { if err != nil {
@ -90,12 +92,12 @@ var persistDetectorTests = []struct {
} }
func TestPersistDetector(t *testing.T) { func TestPersistDetector(t *testing.T) {
datastore, tx := openSessionForTest(t, "PersistDetector", true) tx, cleanup := testutil.CreateTestTxWithFixtures(t, "PersistDetector")
defer closeTest(t, datastore, tx) defer cleanup()
for _, test := range persistDetectorTests { for _, test := range persistDetectorTests {
t.Run(test.title, func(t *testing.T) { t.Run(test.title, func(t *testing.T) {
err := tx.PersistDetectors(test.in) err := PersistDetectors(tx, test.in)
if test.err != "" { if test.err != "" {
require.EqualError(t, err, test.err) require.EqualError(t, err, test.err)
return return

View File

@ -1,398 +0,0 @@
// Copyright 2017 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pgsql
import (
"database/sql"
"sort"
"github.com/lib/pq"
log "github.com/sirupsen/logrus"
"github.com/coreos/clair/database"
"github.com/coreos/clair/ext/versionfmt"
"github.com/coreos/clair/pkg/commonerr"
)
const (
soiNamespacedFeature = `
WITH new_feature_ns AS (
INSERT INTO namespaced_feature(feature_id, namespace_id)
SELECT CAST ($1 AS INTEGER), CAST ($2 AS INTEGER)
WHERE NOT EXISTS ( SELECT id FROM namespaced_feature WHERE namespaced_feature.feature_id = $1 AND namespaced_feature.namespace_id = $2)
RETURNING id
)
SELECT id FROM namespaced_feature WHERE namespaced_feature.feature_id = $1 AND namespaced_feature.namespace_id = $2
UNION
SELECT id FROM new_feature_ns`
searchPotentialAffectingVulneraibilities = `
SELECT nf.id, v.id, vaf.affected_version, vaf.id
FROM vulnerability_affected_feature AS vaf, vulnerability AS v,
namespaced_feature AS nf, feature AS f
WHERE nf.id = ANY($1)
AND nf.feature_id = f.id
AND nf.namespace_id = v.namespace_id
AND vaf.feature_name = f.name
AND vaf.feature_type = f.type
AND vaf.vulnerability_id = v.id
AND v.deleted_at IS NULL`
searchNamespacedFeaturesVulnerabilities = `
SELECT vanf.namespaced_feature_id, v.name, v.description, v.link,
v.severity, v.metadata, vaf.fixedin, n.name, n.version_format
FROM vulnerability_affected_namespaced_feature AS vanf,
Vulnerability AS v,
vulnerability_affected_feature AS vaf,
namespace AS n
WHERE vanf.namespaced_feature_id = ANY($1)
AND vaf.id = vanf.added_by
AND v.id = vanf.vulnerability_id
AND n.id = v.namespace_id
AND v.deleted_at IS NULL`
)
func (tx *pgSession) PersistFeatures(features []database.Feature) error {
if len(features) == 0 {
return nil
}
types, err := tx.getFeatureTypeMap()
if err != nil {
return err
}
// Sorting is needed before inserting into database to prevent deadlock.
sort.Slice(features, func(i, j int) bool {
return features[i].Name < features[j].Name ||
features[i].Version < features[j].Version ||
features[i].VersionFormat < features[j].VersionFormat
})
// TODO(Sida): A better interface for bulk insertion is needed.
keys := make([]interface{}, 0, len(features)*3)
for _, f := range features {
keys = append(keys, f.Name, f.Version, f.VersionFormat, types.byName[f.Type])
if f.Name == "" || f.Version == "" || f.VersionFormat == "" {
return commonerr.NewBadRequestError("Empty feature name, version or version format is not allowed")
}
}
_, err = tx.Exec(queryPersistFeature(len(features)), keys...)
return handleError("queryPersistFeature", err)
}
type namespacedFeatureWithID struct {
database.NamespacedFeature
ID int64
}
type vulnerabilityCache struct {
nsFeatureID int64
vulnID int64
vulnAffectingID int64
}
func (tx *pgSession) searchAffectingVulnerabilities(features []database.NamespacedFeature) ([]vulnerabilityCache, error) {
if len(features) == 0 {
return nil, nil
}
ids, err := tx.findNamespacedFeatureIDs(features)
if err != nil {
return nil, err
}
fMap := map[int64]database.NamespacedFeature{}
for i, f := range features {
if !ids[i].Valid {
return nil, database.ErrMissingEntities
}
fMap[ids[i].Int64] = f
}
cacheTable := []vulnerabilityCache{}
rows, err := tx.Query(searchPotentialAffectingVulneraibilities, pq.Array(ids))
if err != nil {
return nil, handleError("searchPotentialAffectingVulneraibilities", err)
}
defer rows.Close()
for rows.Next() {
var (
cache vulnerabilityCache
affected string
)
err := rows.Scan(&cache.nsFeatureID, &cache.vulnID, &affected, &cache.vulnAffectingID)
if err != nil {
return nil, err
}
if ok, err := versionfmt.InRange(fMap[cache.nsFeatureID].VersionFormat, fMap[cache.nsFeatureID].Version, affected); err != nil {
return nil, err
} else if ok {
cacheTable = append(cacheTable, cache)
}
}
return cacheTable, nil
}
func (tx *pgSession) CacheAffectedNamespacedFeatures(features []database.NamespacedFeature) error {
if len(features) == 0 {
return nil
}
_, err := tx.Exec(lockVulnerabilityAffects)
if err != nil {
return handleError("lockVulnerabilityAffects", err)
}
cache, err := tx.searchAffectingVulnerabilities(features)
if err != nil {
return err
}
keys := make([]interface{}, 0, len(cache)*3)
for _, c := range cache {
keys = append(keys, c.vulnID, c.nsFeatureID, c.vulnAffectingID)
}
if len(cache) == 0 {
return nil
}
affected, err := tx.Exec(queryPersistVulnerabilityAffectedNamespacedFeature(len(cache)), keys...)
if err != nil {
return handleError("persistVulnerabilityAffectedNamespacedFeature", err)
}
if count, err := affected.RowsAffected(); err != nil {
log.Debugf("Cached %d features in vulnerability_affected_namespaced_feature", count)
}
return nil
}
func (tx *pgSession) PersistNamespacedFeatures(features []database.NamespacedFeature) error {
if len(features) == 0 {
return nil
}
nsIDs := map[database.Namespace]sql.NullInt64{}
fIDs := map[database.Feature]sql.NullInt64{}
for _, f := range features {
nsIDs[f.Namespace] = sql.NullInt64{}
fIDs[f.Feature] = sql.NullInt64{}
}
fToFind := []database.Feature{}
for f := range fIDs {
fToFind = append(fToFind, f)
}
sort.Slice(fToFind, func(i, j int) bool {
return fToFind[i].Name < fToFind[j].Name ||
fToFind[i].Version < fToFind[j].Version ||
fToFind[i].VersionFormat < fToFind[j].VersionFormat
})
if ids, err := tx.findFeatureIDs(fToFind); err == nil {
for i, id := range ids {
if !id.Valid {
return database.ErrMissingEntities
}
fIDs[fToFind[i]] = id
}
} else {
return err
}
nsToFind := []database.Namespace{}
for ns := range nsIDs {
nsToFind = append(nsToFind, ns)
}
if ids, err := tx.findNamespaceIDs(nsToFind); err == nil {
for i, id := range ids {
if !id.Valid {
return database.ErrMissingEntities
}
nsIDs[nsToFind[i]] = id
}
} else {
return err
}
keys := make([]interface{}, 0, len(features)*2)
for _, f := range features {
keys = append(keys, fIDs[f.Feature], nsIDs[f.Namespace])
}
_, err := tx.Exec(queryPersistNamespacedFeature(len(features)), keys...)
if err != nil {
return err
}
return nil
}
// FindAffectedNamespacedFeatures retrieves vulnerabilities associated with the
// feature.
func (tx *pgSession) FindAffectedNamespacedFeatures(features []database.NamespacedFeature) ([]database.NullableAffectedNamespacedFeature, error) {
if len(features) == 0 {
return nil, nil
}
vulnerableFeatures := make([]database.NullableAffectedNamespacedFeature, len(features))
featureIDs, err := tx.findNamespacedFeatureIDs(features)
if err != nil {
return nil, err
}
for i, id := range featureIDs {
if id.Valid {
vulnerableFeatures[i].Valid = true
vulnerableFeatures[i].NamespacedFeature = features[i]
}
}
rows, err := tx.Query(searchNamespacedFeaturesVulnerabilities, pq.Array(featureIDs))
if err != nil {
return nil, handleError("searchNamespacedFeaturesVulnerabilities", err)
}
defer rows.Close()
for rows.Next() {
var (
featureID int64
vuln database.VulnerabilityWithFixedIn
)
err := rows.Scan(&featureID,
&vuln.Name,
&vuln.Description,
&vuln.Link,
&vuln.Severity,
&vuln.Metadata,
&vuln.FixedInVersion,
&vuln.Namespace.Name,
&vuln.Namespace.VersionFormat,
)
if err != nil {
return nil, handleError("searchNamespacedFeaturesVulnerabilities", err)
}
for i, id := range featureIDs {
if id.Valid && id.Int64 == featureID {
vulnerableFeatures[i].AffectedNamespacedFeature.AffectedBy = append(vulnerableFeatures[i].AffectedNamespacedFeature.AffectedBy, vuln)
}
}
}
return vulnerableFeatures, nil
}
func (tx *pgSession) findNamespacedFeatureIDs(nfs []database.NamespacedFeature) ([]sql.NullInt64, error) {
if len(nfs) == 0 {
return nil, nil
}
nfsMap := map[database.NamespacedFeature]int64{}
keys := make([]interface{}, 0, len(nfs)*5)
for _, nf := range nfs {
keys = append(keys, nf.Name, nf.Version, nf.VersionFormat, nf.Type, nf.Namespace.Name)
}
rows, err := tx.Query(querySearchNamespacedFeature(len(nfs)), keys...)
if err != nil {
return nil, handleError("searchNamespacedFeature", err)
}
defer rows.Close()
var (
id int64
nf database.NamespacedFeature
)
for rows.Next() {
err := rows.Scan(&id, &nf.Name, &nf.Version, &nf.VersionFormat, &nf.Type, &nf.Namespace.Name)
nf.Namespace.VersionFormat = nf.VersionFormat
if err != nil {
return nil, handleError("searchNamespacedFeature", err)
}
nfsMap[nf] = id
}
ids := make([]sql.NullInt64, len(nfs))
for i, nf := range nfs {
if id, ok := nfsMap[nf]; ok {
ids[i] = sql.NullInt64{id, true}
} else {
ids[i] = sql.NullInt64{}
}
}
return ids, nil
}
func (tx *pgSession) findFeatureIDs(fs []database.Feature) ([]sql.NullInt64, error) {
if len(fs) == 0 {
return nil, nil
}
types, err := tx.getFeatureTypeMap()
if err != nil {
return nil, err
}
fMap := map[database.Feature]sql.NullInt64{}
keys := make([]interface{}, 0, len(fs)*4)
for _, f := range fs {
typeID := types.byName[f.Type]
keys = append(keys, f.Name, f.Version, f.VersionFormat, typeID)
fMap[f] = sql.NullInt64{}
}
rows, err := tx.Query(querySearchFeatureID(len(fs)), keys...)
if err != nil {
return nil, handleError("querySearchFeatureID", err)
}
defer rows.Close()
var (
id sql.NullInt64
f database.Feature
)
for rows.Next() {
var typeID int
err := rows.Scan(&id, &f.Name, &f.Version, &f.VersionFormat, &typeID)
if err != nil {
return nil, handleError("querySearchFeatureID", err)
}
f.Type = types.byID[typeID]
fMap[f] = id
}
ids := make([]sql.NullInt64, len(fs))
for i, f := range fs {
ids[i] = fMap[f]
}
return ids, nil
}

View File

@ -0,0 +1,121 @@
// Copyright 2017 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package feature
import (
"database/sql"
"fmt"
"sort"
"github.com/coreos/clair/database"
"github.com/coreos/clair/database/pgsql/util"
"github.com/coreos/clair/pkg/commonerr"
)
func queryPersistFeature(count int) string {
return util.QueryPersist(count,
"feature",
"feature_name_version_version_format_type_key",
"name",
"version",
"version_format",
"type")
}
func querySearchFeatureID(featureCount int) string {
return fmt.Sprintf(`
SELECT id, name, version, version_format, type
FROM Feature WHERE (name, version, version_format, type) IN (%s)`,
util.QueryString(4, featureCount),
)
}
func PersistFeatures(tx *sql.Tx, features []database.Feature) error {
if len(features) == 0 {
return nil
}
types, err := GetFeatureTypeMap(tx)
if err != nil {
return err
}
// Sorting is needed before inserting into database to prevent deadlock.
sort.Slice(features, func(i, j int) bool {
return features[i].Name < features[j].Name ||
features[i].Version < features[j].Version ||
features[i].VersionFormat < features[j].VersionFormat
})
// TODO(Sida): A better interface for bulk insertion is needed.
keys := make([]interface{}, 0, len(features)*3)
for _, f := range features {
keys = append(keys, f.Name, f.Version, f.VersionFormat, types.ByName[f.Type])
if f.Name == "" || f.Version == "" || f.VersionFormat == "" {
return commonerr.NewBadRequestError("Empty feature name, version or version format is not allowed")
}
}
_, err = tx.Exec(queryPersistFeature(len(features)), keys...)
return util.HandleError("queryPersistFeature", err)
}
func FindFeatureIDs(tx *sql.Tx, fs []database.Feature) ([]sql.NullInt64, error) {
if len(fs) == 0 {
return nil, nil
}
types, err := GetFeatureTypeMap(tx)
if err != nil {
return nil, err
}
fMap := map[database.Feature]sql.NullInt64{}
keys := make([]interface{}, 0, len(fs)*4)
for _, f := range fs {
typeID := types.ByName[f.Type]
keys = append(keys, f.Name, f.Version, f.VersionFormat, typeID)
fMap[f] = sql.NullInt64{}
}
rows, err := tx.Query(querySearchFeatureID(len(fs)), keys...)
if err != nil {
return nil, util.HandleError("querySearchFeatureID", err)
}
defer rows.Close()
var (
id sql.NullInt64
f database.Feature
)
for rows.Next() {
var typeID int
err := rows.Scan(&id, &f.Name, &f.Version, &f.VersionFormat, &typeID)
if err != nil {
return nil, util.HandleError("querySearchFeatureID", err)
}
f.Type = types.ByID[typeID]
fMap[f] = id
}
ids := make([]sql.NullInt64, len(fs))
for i, f := range fs {
ids[i] = fMap[f]
}
return ids, nil
}

View File

@ -12,36 +12,38 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package pgsql package feature
import ( import (
"database/sql"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/coreos/clair/database" "github.com/coreos/clair/database"
"github.com/coreos/clair/database/pgsql/testutil"
) )
func TestPersistFeatures(t *testing.T) { func TestPersistFeatures(t *testing.T) {
tx, cleanup := createTestPgSession(t, "TestPersistFeatures") tx, cleanup := testutil.CreateTestTx(t, "TestPersistFeatures")
defer cleanup() defer cleanup()
invalid := database.Feature{} invalid := database.Feature{}
valid := *database.NewBinaryPackage("mount", "2.31.1-0.4ubuntu3.1", "dpkg") valid := *database.NewBinaryPackage("mount", "2.31.1-0.4ubuntu3.1", "dpkg")
// invalid // invalid
require.NotNil(t, tx.PersistFeatures([]database.Feature{invalid})) require.NotNil(t, PersistFeatures(tx, []database.Feature{invalid}))
// existing // existing
require.Nil(t, tx.PersistFeatures([]database.Feature{valid})) require.Nil(t, PersistFeatures(tx, []database.Feature{valid}))
require.Nil(t, tx.PersistFeatures([]database.Feature{valid})) require.Nil(t, PersistFeatures(tx, []database.Feature{valid}))
features := selectAllFeatures(t, tx) features := selectAllFeatures(t, tx)
assert.Equal(t, []database.Feature{valid}, features) assert.Equal(t, []database.Feature{valid}, features)
} }
func TestPersistNamespacedFeatures(t *testing.T) { func TestPersistNamespacedFeatures(t *testing.T) {
tx, cleanup := createTestPgSessionWithFixtures(t, "TestPersistNamespacedFeatures") tx, cleanup := testutil.CreateTestTxWithFixtures(t, "TestPersistNamespacedFeatures")
defer cleanup() defer cleanup()
// existing features // existing features
@ -58,42 +60,17 @@ func TestPersistNamespacedFeatures(t *testing.T) {
nf2 := database.NewNamespacedFeature(n2, f2) nf2 := database.NewNamespacedFeature(n2, f2)
// namespaced features with namespaces or features not in the database will // namespaced features with namespaces or features not in the database will
// generate error. // generate error.
assert.Nil(t, tx.PersistNamespacedFeatures([]database.NamespacedFeature{})) assert.Nil(t, PersistNamespacedFeatures(tx, []database.NamespacedFeature{}))
assert.NotNil(t, tx.PersistNamespacedFeatures([]database.NamespacedFeature{*nf1, *nf2})) assert.NotNil(t, PersistNamespacedFeatures(tx, []database.NamespacedFeature{*nf1, *nf2}))
// valid case: insert nf3 // valid case: insert nf3
assert.Nil(t, tx.PersistNamespacedFeatures([]database.NamespacedFeature{*nf1})) assert.Nil(t, PersistNamespacedFeatures(tx, []database.NamespacedFeature{*nf1}))
all := listNamespacedFeatures(t, tx) all := listNamespacedFeatures(t, tx)
assert.Contains(t, all, *nf1) assert.Contains(t, all, *nf1)
} }
func TestFindAffectedNamespacedFeatures(t *testing.T) { func listNamespacedFeatures(t *testing.T, tx *sql.Tx) []database.NamespacedFeature {
datastore, tx := openSessionForTest(t, "FindAffectedNamespacedFeatures", true) types, err := GetFeatureTypeMap(tx)
defer closeTest(t, datastore, tx)
ns := database.NamespacedFeature{
Feature: database.Feature{
Name: "openssl",
Version: "1.0",
VersionFormat: "dpkg",
Type: database.SourcePackage,
},
Namespace: database.Namespace{
Name: "debian:7",
VersionFormat: "dpkg",
},
}
ans, err := tx.FindAffectedNamespacedFeatures([]database.NamespacedFeature{ns})
if assert.Nil(t, err) &&
assert.Len(t, ans, 1) &&
assert.True(t, ans[0].Valid) &&
assert.Len(t, ans[0].AffectedBy, 1) {
assert.Equal(t, "CVE-OPENSSL-1-DEB7", ans[0].AffectedBy[0].Name)
}
}
func listNamespacedFeatures(t *testing.T, tx *pgSession) []database.NamespacedFeature {
types, err := tx.getFeatureTypeMap()
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -114,15 +91,15 @@ func listNamespacedFeatures(t *testing.T, tx *pgSession) []database.NamespacedFe
panic(err) panic(err)
} }
f.Type = types.byID[typeID] f.Type = types.ByID[typeID]
nf = append(nf, f) nf = append(nf, f)
} }
return nf return nf
} }
func selectAllFeatures(t *testing.T, tx *pgSession) []database.Feature { func selectAllFeatures(t *testing.T, tx *sql.Tx) []database.Feature {
types, err := tx.getFeatureTypeMap() types, err := GetFeatureTypeMap(tx)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -137,7 +114,7 @@ func selectAllFeatures(t *testing.T, tx *pgSession) []database.Feature {
f := database.Feature{} f := database.Feature{}
var typeID int var typeID int
err := rows.Scan(&f.Name, &f.Version, &f.VersionFormat, &typeID) err := rows.Scan(&f.Name, &f.Version, &f.VersionFormat, &typeID)
f.Type = types.byID[typeID] f.Type = types.ByID[typeID]
if err != nil { if err != nil {
t.FailNow() t.FailNow()
} }
@ -146,45 +123,24 @@ func selectAllFeatures(t *testing.T, tx *pgSession) []database.Feature {
return fs return fs
} }
func assertNamespacedFeatureEqual(t *testing.T, expected []database.NamespacedFeature, actual []database.NamespacedFeature) bool {
if assert.Len(t, actual, len(expected)) {
has := map[database.NamespacedFeature]bool{}
for _, nf := range expected {
has[nf] = false
}
for _, nf := range actual {
has[nf] = true
}
for nf, visited := range has {
if !assert.True(t, visited, nf.Namespace.Name+":"+nf.Name+" is expected") {
return false
}
}
return true
}
return false
}
func TestFindNamespacedFeatureIDs(t *testing.T) { func TestFindNamespacedFeatureIDs(t *testing.T) {
tx, cleanup := createTestPgSessionWithFixtures(t, "TestFindNamespacedFeatureIDs") tx, cleanup := testutil.CreateTestTxWithFixtures(t, "TestFindNamespacedFeatureIDs")
defer cleanup() defer cleanup()
features := []database.NamespacedFeature{} features := []database.NamespacedFeature{}
expectedIDs := []int{} expectedIDs := []int{}
for id, feature := range realNamespacedFeatures { for id, feature := range testutil.RealNamespacedFeatures {
features = append(features, feature) features = append(features, feature)
expectedIDs = append(expectedIDs, id) expectedIDs = append(expectedIDs, id)
} }
features = append(features, realNamespacedFeatures[1]) // test duplicated features = append(features, testutil.RealNamespacedFeatures[1]) // test duplicated
expectedIDs = append(expectedIDs, 1) expectedIDs = append(expectedIDs, 1)
namespace := realNamespaces[1] namespace := testutil.RealNamespaces[1]
features = append(features, *database.NewNamespacedFeature(&namespace, database.NewBinaryPackage("not-found", "1.0", "dpkg"))) // test not found feature features = append(features, *database.NewNamespacedFeature(&namespace, database.NewBinaryPackage("not-found", "1.0", "dpkg"))) // test not found feature
ids, err := tx.findNamespacedFeatureIDs(features) ids, err := FindNamespacedFeatureIDs(tx, features)
require.Nil(t, err) require.Nil(t, err)
require.Len(t, ids, len(expectedIDs)+1) require.Len(t, ids, len(expectedIDs)+1)
for i, id := range ids { for i, id := range ids {

View File

@ -12,24 +12,28 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package pgsql package feature
import "github.com/coreos/clair/database" import (
"database/sql"
"github.com/coreos/clair/database"
)
const ( const (
selectAllFeatureTypes = `SELECT id, name FROM feature_type` selectAllFeatureTypes = `SELECT id, name FROM feature_type`
) )
type featureTypes struct { type FeatureTypes struct {
byID map[int]database.FeatureType ByID map[int]database.FeatureType
byName map[database.FeatureType]int ByName map[database.FeatureType]int
} }
func newFeatureTypes() *featureTypes { func newFeatureTypes() *FeatureTypes {
return &featureTypes{make(map[int]database.FeatureType), make(map[database.FeatureType]int)} return &FeatureTypes{make(map[int]database.FeatureType), make(map[database.FeatureType]int)}
} }
func (tx *pgSession) getFeatureTypeMap() (*featureTypes, error) { func GetFeatureTypeMap(tx *sql.Tx) (*FeatureTypes, error) {
rows, err := tx.Query(selectAllFeatureTypes) rows, err := tx.Query(selectAllFeatureTypes)
if err != nil { if err != nil {
return nil, err return nil, err
@ -45,8 +49,8 @@ func (tx *pgSession) getFeatureTypeMap() (*featureTypes, error) {
return nil, err return nil, err
} }
types.byID[id] = name types.ByID[id] = name
types.byName[name] = id types.ByName[name] = id
} }
return types, nil return types, nil

View File

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package pgsql package feature
import ( import (
"testing" "testing"
@ -20,19 +20,20 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/coreos/clair/database" "github.com/coreos/clair/database"
"github.com/coreos/clair/database/pgsql/testutil"
) )
func TestGetFeatureTypeMap(t *testing.T) { func TestGetFeatureTypeMap(t *testing.T) {
tx, cleanup := createTestPgSession(t, "TestGetFeatureTypeMap") tx, cleanup := testutil.CreateTestTx(t, "TestGetFeatureTypeMap")
defer cleanup() defer cleanup()
types, err := tx.getFeatureTypeMap() types, err := GetFeatureTypeMap(tx)
if err != nil { if err != nil {
require.Nil(t, err, err.Error()) require.Nil(t, err, err.Error())
} }
require.Equal(t, database.SourcePackage, types.byID[1]) require.Equal(t, database.SourcePackage, types.ByID[1])
require.Equal(t, database.BinaryPackage, types.byID[2]) require.Equal(t, database.BinaryPackage, types.ByID[2])
require.Equal(t, 1, types.byName[database.SourcePackage]) require.Equal(t, 1, types.ByName[database.SourcePackage])
require.Equal(t, 2, types.byName[database.BinaryPackage]) require.Equal(t, 2, types.ByName[database.BinaryPackage])
} }

View File

@ -0,0 +1,168 @@
// Copyright 2019 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package feature
import (
"database/sql"
"fmt"
"sort"
"github.com/coreos/clair/database"
"github.com/coreos/clair/database/pgsql/namespace"
"github.com/coreos/clair/database/pgsql/util"
)
var soiNamespacedFeature = `
WITH new_feature_ns AS (
INSERT INTO namespaced_feature(feature_id, namespace_id)
SELECT CAST ($1 AS INTEGER), CAST ($2 AS INTEGER)
WHERE NOT EXISTS ( SELECT id FROM namespaced_feature WHERE namespaced_feature.feature_id = $1 AND namespaced_feature.namespace_id = $2)
RETURNING id
)
SELECT id FROM namespaced_feature WHERE namespaced_feature.feature_id = $1 AND namespaced_feature.namespace_id = $2
UNION
SELECT id FROM new_feature_ns`
func queryPersistNamespacedFeature(count int) string {
return util.QueryPersist(count, "namespaced_feature",
"namespaced_feature_namespace_id_feature_id_key",
"feature_id",
"namespace_id")
}
func querySearchNamespacedFeature(nsfCount int) string {
return fmt.Sprintf(`
SELECT nf.id, f.name, f.version, f.version_format, t.name, n.name
FROM namespaced_feature AS nf, feature AS f, namespace AS n, feature_type AS t
WHERE nf.feature_id = f.id
AND nf.namespace_id = n.id
AND n.version_format = f.version_format
AND f.type = t.id
AND (f.name, f.version, f.version_format, t.name, n.name) IN (%s)`,
util.QueryString(5, nsfCount),
)
}
type namespacedFeatureWithID struct {
database.NamespacedFeature
ID int64
}
func PersistNamespacedFeatures(tx *sql.Tx, features []database.NamespacedFeature) error {
if len(features) == 0 {
return nil
}
nsIDs := map[database.Namespace]sql.NullInt64{}
fIDs := map[database.Feature]sql.NullInt64{}
for _, f := range features {
nsIDs[f.Namespace] = sql.NullInt64{}
fIDs[f.Feature] = sql.NullInt64{}
}
fToFind := []database.Feature{}
for f := range fIDs {
fToFind = append(fToFind, f)
}
sort.Slice(fToFind, func(i, j int) bool {
return fToFind[i].Name < fToFind[j].Name ||
fToFind[i].Version < fToFind[j].Version ||
fToFind[i].VersionFormat < fToFind[j].VersionFormat
})
if ids, err := FindFeatureIDs(tx, fToFind); err == nil {
for i, id := range ids {
if !id.Valid {
return database.ErrMissingEntities
}
fIDs[fToFind[i]] = id
}
} else {
return err
}
nsToFind := []database.Namespace{}
for ns := range nsIDs {
nsToFind = append(nsToFind, ns)
}
if ids, err := namespace.FindNamespaceIDs(tx, nsToFind); err == nil {
for i, id := range ids {
if !id.Valid {
return database.ErrMissingEntities
}
nsIDs[nsToFind[i]] = id
}
} else {
return err
}
keys := make([]interface{}, 0, len(features)*2)
for _, f := range features {
keys = append(keys, fIDs[f.Feature], nsIDs[f.Namespace])
}
_, err := tx.Exec(queryPersistNamespacedFeature(len(features)), keys...)
if err != nil {
return err
}
return nil
}
func FindNamespacedFeatureIDs(tx *sql.Tx, nfs []database.NamespacedFeature) ([]sql.NullInt64, error) {
if len(nfs) == 0 {
return nil, nil
}
nfsMap := map[database.NamespacedFeature]int64{}
keys := make([]interface{}, 0, len(nfs)*5)
for _, nf := range nfs {
keys = append(keys, nf.Name, nf.Version, nf.VersionFormat, nf.Type, nf.Namespace.Name)
}
rows, err := tx.Query(querySearchNamespacedFeature(len(nfs)), keys...)
if err != nil {
return nil, util.HandleError("searchNamespacedFeature", err)
}
defer rows.Close()
var (
id int64
nf database.NamespacedFeature
)
for rows.Next() {
err := rows.Scan(&id, &nf.Name, &nf.Version, &nf.VersionFormat, &nf.Type, &nf.Namespace.Name)
nf.Namespace.VersionFormat = nf.VersionFormat
if err != nil {
return nil, util.HandleError("searchNamespacedFeature", err)
}
nfsMap[nf] = id
}
ids := make([]sql.NullInt64, len(nfs))
for i, nf := range nfs {
if id, ok := nfsMap[nf]; ok {
ids[i] = sql.NullInt64{id, true}
} else {
ids[i] = sql.NullInt64{}
}
}
return ids, nil
}

View File

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package pgsql package keyvalue
import ( import (
"database/sql" "database/sql"
@ -20,6 +20,8 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/coreos/clair/database/pgsql/monitoring"
"github.com/coreos/clair/database/pgsql/util"
"github.com/coreos/clair/pkg/commonerr" "github.com/coreos/clair/pkg/commonerr"
) )
@ -32,24 +34,24 @@ const (
DO UPDATE SET key=$1, value=$2` DO UPDATE SET key=$1, value=$2`
) )
func (tx *pgSession) UpdateKeyValue(key, value string) (err error) { func UpdateKeyValue(tx *sql.Tx, key, value string) (err error) {
if key == "" || value == "" { if key == "" || value == "" {
log.Warning("could not insert a flag which has an empty name or value") log.Warning("could not insert a flag which has an empty name or value")
return commonerr.NewBadRequestError("could not insert a flag which has an empty name or value") return commonerr.NewBadRequestError("could not insert a flag which has an empty name or value")
} }
defer observeQueryTime("PersistKeyValue", "all", time.Now()) defer monitoring.ObserveQueryTime("PersistKeyValue", "all", time.Now())
_, err = tx.Exec(upsertKeyValue, key, value) _, err = tx.Exec(upsertKeyValue, key, value)
if err != nil { if err != nil {
return handleError("insertKeyValue", err) return util.HandleError("insertKeyValue", err)
} }
return nil return nil
} }
func (tx *pgSession) FindKeyValue(key string) (string, bool, error) { func FindKeyValue(tx *sql.Tx, key string) (string, bool, error) {
defer observeQueryTime("FindKeyValue", "all", time.Now()) defer monitoring.ObserveQueryTime("FindKeyValue", "all", time.Now())
var value string var value string
err := tx.QueryRow(searchKeyValue, key).Scan(&value) err := tx.QueryRow(searchKeyValue, key).Scan(&value)
@ -59,7 +61,7 @@ func (tx *pgSession) FindKeyValue(key string) (string, bool, error) {
} }
if err != nil { if err != nil {
return "", false, handleError("searchKeyValue", err) return "", false, util.HandleError("searchKeyValue", err)
} }
return value, true, nil return value, true, nil

View File

@ -12,38 +12,39 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package pgsql package keyvalue
import ( import (
"testing" "testing"
"github.com/coreos/clair/database/pgsql/testutil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestKeyValue(t *testing.T) { func TestKeyValue(t *testing.T) {
datastore, tx := openSessionForTest(t, "KeyValue", true) tx, cleanup := testutil.CreateTestTxWithFixtures(t, "KeyValue")
defer closeTest(t, datastore, tx) defer cleanup()
// Get non-existing key/value // Get non-existing key/value
f, ok, err := tx.FindKeyValue("test") f, ok, err := FindKeyValue(tx, "test")
assert.Nil(t, err) assert.Nil(t, err)
assert.False(t, ok) assert.False(t, ok)
// Try to insert invalid key/value. // Try to insert invalid key/value.
assert.Error(t, tx.UpdateKeyValue("test", "")) assert.Error(t, UpdateKeyValue(tx, "test", ""))
assert.Error(t, tx.UpdateKeyValue("", "test")) assert.Error(t, UpdateKeyValue(tx, "", "test"))
assert.Error(t, tx.UpdateKeyValue("", "")) assert.Error(t, UpdateKeyValue(tx, "", ""))
// Insert and verify. // Insert and verify.
assert.Nil(t, tx.UpdateKeyValue("test", "test1")) assert.Nil(t, UpdateKeyValue(tx, "test", "test1"))
f, ok, err = tx.FindKeyValue("test") f, ok, err = FindKeyValue(tx, "test")
assert.Nil(t, err) assert.Nil(t, err)
assert.True(t, ok) assert.True(t, ok)
assert.Equal(t, "test1", f) assert.Equal(t, "test1", f)
// Update and verify. // Update and verify.
assert.Nil(t, tx.UpdateKeyValue("test", "test2")) assert.Nil(t, UpdateKeyValue(tx, "test", "test2"))
f, ok, err = tx.FindKeyValue("test") f, ok, err = FindKeyValue(tx, "test")
assert.Nil(t, err) assert.Nil(t, err)
assert.True(t, ok) assert.True(t, ok)
assert.Equal(t, "test2", f) assert.Equal(t, "test2", f)

View File

@ -1,379 +0,0 @@
// Copyright 2017 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pgsql
import (
"database/sql"
"sort"
"github.com/deckarep/golang-set"
"github.com/coreos/clair/database"
"github.com/coreos/clair/pkg/commonerr"
)
const (
soiLayer = `
WITH new_layer AS (
INSERT INTO layer (hash)
SELECT CAST ($1 AS VARCHAR)
WHERE NOT EXISTS (SELECT id FROM layer WHERE hash = $1)
RETURNING id
)
SELECT id FROM new_Layer
UNION
SELECT id FROM layer WHERE hash = $1`
findLayerFeatures = `
SELECT
f.name, f.version, f.version_format, ft.name, lf.detector_id, ns.name, ns.version_format
FROM
layer_feature AS lf
LEFT JOIN feature f on f.id = lf.feature_id
LEFT JOIN feature_type ft on ft.id = f.type
LEFT JOIN namespace ns ON ns.id = lf.namespace_id
WHERE lf.layer_id = $1`
findLayerNamespaces = `
SELECT ns.name, ns.version_format, ln.detector_id
FROM layer_namespace AS ln, namespace AS ns
WHERE ln.namespace_id = ns.id
AND ln.layer_id = $1`
findLayerID = `SELECT id FROM layer WHERE hash = $1`
)
// dbLayerNamespace represents the layer_namespace table.
type dbLayerNamespace struct {
layerID int64
namespaceID int64
detectorID int64
}
// dbLayerFeature represents the layer_feature table
type dbLayerFeature struct {
layerID int64
featureID int64
detectorID int64
namespaceID sql.NullInt64
}
func (tx *pgSession) FindLayer(hash string) (database.Layer, bool, error) {
layer := database.Layer{Hash: hash}
if hash == "" {
return layer, false, commonerr.NewBadRequestError("non empty layer hash is expected.")
}
layerID, ok, err := tx.findLayerID(hash)
if err != nil || !ok {
return layer, ok, err
}
detectorMap, err := tx.findAllDetectors()
if err != nil {
return layer, false, err
}
if layer.By, err = tx.findLayerDetectors(layerID); err != nil {
return layer, false, err
}
if layer.Features, err = tx.findLayerFeatures(layerID, detectorMap); err != nil {
return layer, false, err
}
if layer.Namespaces, err = tx.findLayerNamespaces(layerID, detectorMap); err != nil {
return layer, false, err
}
return layer, true, nil
}
func sanitizePersistLayerInput(hash string, features []database.LayerFeature, namespaces []database.LayerNamespace, detectedBy []database.Detector) error {
if hash == "" {
return commonerr.NewBadRequestError("expected non-empty layer hash")
}
detectedBySet := mapset.NewSet()
for _, d := range detectedBy {
detectedBySet.Add(d)
}
for _, f := range features {
if !detectedBySet.Contains(f.By) {
return database.ErrInvalidParameters
}
}
for _, n := range namespaces {
if !detectedBySet.Contains(n.By) {
return database.ErrInvalidParameters
}
}
return nil
}
// PersistLayer saves the content of a layer to the database.
func (tx *pgSession) PersistLayer(hash string, features []database.LayerFeature, namespaces []database.LayerNamespace, detectedBy []database.Detector) error {
var (
err error
id int64
detectorIDs []int64
)
if err = sanitizePersistLayerInput(hash, features, namespaces, detectedBy); err != nil {
return err
}
if id, err = tx.soiLayer(hash); err != nil {
return err
}
if detectorIDs, err = tx.findDetectorIDs(detectedBy); err != nil {
if err == commonerr.ErrNotFound {
return database.ErrMissingEntities
}
return err
}
if err = tx.persistLayerDetectors(id, detectorIDs); err != nil {
return err
}
if err = tx.persistAllLayerFeatures(id, features); err != nil {
return err
}
if err = tx.persistAllLayerNamespaces(id, namespaces); err != nil {
return err
}
return nil
}
func (tx *pgSession) persistAllLayerNamespaces(layerID int64, namespaces []database.LayerNamespace) error {
detectorMap, err := tx.findAllDetectors()
if err != nil {
return err
}
// TODO(sidac): This kind of type conversion is very useless and wasteful,
// we need interfaces around the database models to reduce these kind of
// operations.
rawNamespaces := make([]database.Namespace, 0, len(namespaces))
for _, ns := range namespaces {
rawNamespaces = append(rawNamespaces, ns.Namespace)
}
rawNamespaceIDs, err := tx.findNamespaceIDs(rawNamespaces)
if err != nil {
return err
}
dbLayerNamespaces := make([]dbLayerNamespace, 0, len(namespaces))
for i, ns := range namespaces {
detectorID := detectorMap.byValue[ns.By]
namespaceID := rawNamespaceIDs[i].Int64
if !rawNamespaceIDs[i].Valid {
return database.ErrMissingEntities
}
dbLayerNamespaces = append(dbLayerNamespaces, dbLayerNamespace{layerID, namespaceID, detectorID})
}
return tx.persistLayerNamespaces(dbLayerNamespaces)
}
func (tx *pgSession) persistAllLayerFeatures(layerID int64, features []database.LayerFeature) error {
detectorMap, err := tx.findAllDetectors()
if err != nil {
return err
}
var namespaces []database.Namespace
for _, feature := range features {
namespaces = append(namespaces, feature.PotentialNamespace)
}
nameSpaceIDs, _ := tx.findNamespaceIDs(namespaces)
featureNamespaceMap := map[database.Namespace]sql.NullInt64{}
rawFeatures := make([]database.Feature, 0, len(features))
for i, f := range features {
rawFeatures = append(rawFeatures, f.Feature)
if f.PotentialNamespace.Valid() {
featureNamespaceMap[f.PotentialNamespace] = nameSpaceIDs[i]
}
}
featureIDs, err := tx.findFeatureIDs(rawFeatures)
if err != nil {
return err
}
var namespaceID sql.NullInt64
dbFeatures := make([]dbLayerFeature, 0, len(features))
for i, f := range features {
detectorID := detectorMap.byValue[f.By]
if !featureIDs[i].Valid {
return database.ErrMissingEntities
}
featureID := featureIDs[i].Int64
namespaceID = featureNamespaceMap[f.PotentialNamespace]
dbFeatures = append(dbFeatures, dbLayerFeature{layerID, featureID, detectorID, namespaceID})
}
if err := tx.persistLayerFeatures(dbFeatures); err != nil {
return err
}
return nil
}
func (tx *pgSession) persistLayerFeatures(features []dbLayerFeature) error {
if len(features) == 0 {
return nil
}
sort.Slice(features, func(i, j int) bool {
return features[i].featureID < features[j].featureID
})
keys := make([]interface{}, 0, len(features)*4)
for _, f := range features {
keys = append(keys, f.layerID, f.featureID, f.detectorID, f.namespaceID)
}
_, err := tx.Exec(queryPersistLayerFeature(len(features)), keys...)
if err != nil {
return handleError("queryPersistLayerFeature", err)
}
return nil
}
func (tx *pgSession) persistLayerNamespaces(namespaces []dbLayerNamespace) error {
if len(namespaces) == 0 {
return nil
}
// for every bulk persist operation, the input data should be sorted.
sort.Slice(namespaces, func(i, j int) bool {
return namespaces[i].namespaceID < namespaces[j].namespaceID
})
keys := make([]interface{}, 0, len(namespaces)*3)
for _, row := range namespaces {
keys = append(keys, row.layerID, row.namespaceID, row.detectorID)
}
_, err := tx.Exec(queryPersistLayerNamespace(len(namespaces)), keys...)
if err != nil {
return handleError("queryPersistLayerNamespace", err)
}
return nil
}
func (tx *pgSession) findLayerNamespaces(layerID int64, detectors detectorMap) ([]database.LayerNamespace, error) {
rows, err := tx.Query(findLayerNamespaces, layerID)
if err != nil {
return nil, handleError("findLayerNamespaces", err)
}
namespaces := []database.LayerNamespace{}
for rows.Next() {
var (
namespace database.LayerNamespace
detectorID int64
)
if err := rows.Scan(&namespace.Name, &namespace.VersionFormat, &detectorID); err != nil {
return nil, err
}
namespace.By = detectors.byID[detectorID]
namespaces = append(namespaces, namespace)
}
return namespaces, nil
}
func (tx *pgSession) findLayerFeatures(layerID int64, detectors detectorMap) ([]database.LayerFeature, error) {
rows, err := tx.Query(findLayerFeatures, layerID)
if err != nil {
return nil, handleError("findLayerFeatures", err)
}
defer rows.Close()
features := []database.LayerFeature{}
for rows.Next() {
var (
detectorID int64
feature database.LayerFeature
)
var namespaceName, namespaceVersion sql.NullString
if err := rows.Scan(&feature.Name, &feature.Version, &feature.VersionFormat, &feature.Type, &detectorID, &namespaceName, &namespaceVersion); err != nil {
return nil, handleError("findLayerFeatures", err)
}
feature.PotentialNamespace.Name = namespaceName.String
feature.PotentialNamespace.VersionFormat = namespaceVersion.String
feature.By = detectors.byID[detectorID]
features = append(features, feature)
}
return features, nil
}
func (tx *pgSession) findLayerID(hash string) (int64, bool, error) {
var layerID int64
err := tx.QueryRow(findLayerID, hash).Scan(&layerID)
if err != nil {
if err == sql.ErrNoRows {
return layerID, false, nil
}
return layerID, false, handleError("findLayerID", err)
}
return layerID, true, nil
}
func (tx *pgSession) findLayerIDs(hashes []string) ([]int64, bool, error) {
layerIDs := make([]int64, 0, len(hashes))
for _, hash := range hashes {
id, ok, err := tx.findLayerID(hash)
if !ok {
return nil, false, nil
}
if err != nil {
return nil, false, err
}
layerIDs = append(layerIDs, id)
}
return layerIDs, true, nil
}
func (tx *pgSession) soiLayer(hash string) (int64, error) {
var id int64
if err := tx.QueryRow(soiLayer, hash).Scan(&id); err != nil {
return 0, handleError("soiLayer", err)
}
return id, nil
}

View File

@ -0,0 +1,177 @@
// Copyright 2017 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package layer
import (
"database/sql"
"github.com/deckarep/golang-set"
"github.com/coreos/clair/database"
"github.com/coreos/clair/database/pgsql/detector"
"github.com/coreos/clair/database/pgsql/util"
"github.com/coreos/clair/pkg/commonerr"
)
const (
soiLayer = `
WITH new_layer AS (
INSERT INTO layer (hash)
SELECT CAST ($1 AS VARCHAR)
WHERE NOT EXISTS (SELECT id FROM layer WHERE hash = $1)
RETURNING id
)
SELECT id FROM new_Layer
UNION
SELECT id FROM layer WHERE hash = $1`
findLayerID = `SELECT id FROM layer WHERE hash = $1`
)
func FindLayer(tx *sql.Tx, hash string) (database.Layer, bool, error) {
layer := database.Layer{Hash: hash}
if hash == "" {
return layer, false, commonerr.NewBadRequestError("non empty layer hash is expected.")
}
layerID, ok, err := FindLayerID(tx, hash)
if err != nil || !ok {
return layer, ok, err
}
detectorMap, err := detector.FindAllDetectors(tx)
if err != nil {
return layer, false, err
}
if layer.By, err = FindLayerDetectors(tx, layerID); err != nil {
return layer, false, err
}
if layer.Features, err = FindLayerFeatures(tx, layerID, detectorMap); err != nil {
return layer, false, err
}
if layer.Namespaces, err = FindLayerNamespaces(tx, layerID, detectorMap); err != nil {
return layer, false, err
}
return layer, true, nil
}
func sanitizePersistLayerInput(hash string, features []database.LayerFeature, namespaces []database.LayerNamespace, detectedBy []database.Detector) error {
if hash == "" {
return commonerr.NewBadRequestError("expected non-empty layer hash")
}
detectedBySet := mapset.NewSet()
for _, d := range detectedBy {
detectedBySet.Add(d)
}
for _, f := range features {
if !detectedBySet.Contains(f.By) {
return database.ErrInvalidParameters
}
}
for _, n := range namespaces {
if !detectedBySet.Contains(n.By) {
return database.ErrInvalidParameters
}
}
return nil
}
// PersistLayer saves the content of a layer to the database.
func PersistLayer(tx *sql.Tx, hash string, features []database.LayerFeature, namespaces []database.LayerNamespace, detectedBy []database.Detector) error {
var (
err error
id int64
detectorIDs []int64
)
if err = sanitizePersistLayerInput(hash, features, namespaces, detectedBy); err != nil {
return err
}
if id, err = SoiLayer(tx, hash); err != nil {
return err
}
if detectorIDs, err = detector.FindDetectorIDs(tx, detectedBy); err != nil {
if err == commonerr.ErrNotFound {
return database.ErrMissingEntities
}
return err
}
if err = PersistLayerDetectors(tx, id, detectorIDs); err != nil {
return err
}
if err = PersistAllLayerFeatures(tx, id, features); err != nil {
return err
}
if err = PersistAllLayerNamespaces(tx, id, namespaces); err != nil {
return err
}
return nil
}
func FindLayerID(tx *sql.Tx, hash string) (int64, bool, error) {
var layerID int64
err := tx.QueryRow(findLayerID, hash).Scan(&layerID)
if err != nil {
if err == sql.ErrNoRows {
return layerID, false, nil
}
return layerID, false, util.HandleError("findLayerID", err)
}
return layerID, true, nil
}
func FindLayerIDs(tx *sql.Tx, hashes []string) ([]int64, bool, error) {
layerIDs := make([]int64, 0, len(hashes))
for _, hash := range hashes {
id, ok, err := FindLayerID(tx, hash)
if !ok {
return nil, false, nil
}
if err != nil {
return nil, false, err
}
layerIDs = append(layerIDs, id)
}
return layerIDs, true, nil
}
func SoiLayer(tx *sql.Tx, hash string) (int64, error) {
var id int64
if err := tx.QueryRow(soiLayer, hash).Scan(&id); err != nil {
return 0, util.HandleError("soiLayer", err)
}
return id, nil
}

View File

@ -0,0 +1,66 @@
// Copyright 2019 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package layer
import (
"database/sql"
"github.com/deckarep/golang-set"
"github.com/coreos/clair/database"
"github.com/coreos/clair/database/pgsql/detector"
"github.com/coreos/clair/database/pgsql/util"
)
const (
selectLayerDetectors = `
SELECT d.name, d.version, d.dtype
FROM layer_detector, detector AS d
WHERE layer_detector.detector_id = d.id AND layer_detector.layer_id = $1;`
persistLayerDetector = `
INSERT INTO layer_detector (layer_id, detector_id)
SELECT $1, $2
WHERE NOT EXISTS (SELECT id FROM layer_detector WHERE layer_id = $1 AND detector_id = $2)`
)
func PersistLayerDetector(tx *sql.Tx, layerID int64, detectorID int64) error {
if _, err := tx.Exec(persistLayerDetector, layerID, detectorID); err != nil {
return util.HandleError("persistLayerDetector", err)
}
return nil
}
func PersistLayerDetectors(tx *sql.Tx, layerID int64, detectorIDs []int64) error {
alreadySaved := mapset.NewSet()
for _, id := range detectorIDs {
if alreadySaved.Contains(id) {
continue
}
alreadySaved.Add(id)
if err := PersistLayerDetector(tx, layerID, id); err != nil {
return err
}
}
return nil
}
func FindLayerDetectors(tx *sql.Tx, id int64) ([]database.Detector, error) {
detectors, err := detector.GetDetectors(tx, selectLayerDetectors, id)
return detectors, err
}

View File

@ -0,0 +1,147 @@
// Copyright 2019 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package layer
import (
"database/sql"
"sort"
"github.com/coreos/clair/database/pgsql/namespace"
"github.com/coreos/clair/database"
"github.com/coreos/clair/database/pgsql/detector"
"github.com/coreos/clair/database/pgsql/feature"
"github.com/coreos/clair/database/pgsql/util"
)
const findLayerFeatures = `
SELECT
f.name, f.version, f.version_format, ft.name, lf.detector_id, ns.name, ns.version_format
FROM
layer_feature AS lf
LEFT JOIN feature f on f.id = lf.feature_id
LEFT JOIN feature_type ft on ft.id = f.type
LEFT JOIN namespace ns ON ns.id = lf.namespace_id
WHERE lf.layer_id = $1`
func queryPersistLayerFeature(count int) string {
return util.QueryPersist(count,
"layer_feature",
"layer_feature_layer_id_feature_id_namespace_id_key",
"layer_id",
"feature_id",
"detector_id",
"namespace_id")
}
// dbLayerFeature represents the layer_feature table
type dbLayerFeature struct {
layerID int64
featureID int64
detectorID int64
namespaceID sql.NullInt64
}
func FindLayerFeatures(tx *sql.Tx, layerID int64, detectors detector.DetectorMap) ([]database.LayerFeature, error) {
rows, err := tx.Query(findLayerFeatures, layerID)
if err != nil {
return nil, util.HandleError("findLayerFeatures", err)
}
defer rows.Close()
features := []database.LayerFeature{}
for rows.Next() {
var (
detectorID int64
feature database.LayerFeature
)
var namespaceName, namespaceVersion sql.NullString
if err := rows.Scan(&feature.Name, &feature.Version, &feature.VersionFormat, &feature.Type, &detectorID, &namespaceName, &namespaceVersion); err != nil {
return nil, util.HandleError("findLayerFeatures", err)
}
feature.PotentialNamespace.Name = namespaceName.String
feature.PotentialNamespace.VersionFormat = namespaceVersion.String
feature.By = detectors.ByID[detectorID]
features = append(features, feature)
}
return features, nil
}
func PersistAllLayerFeatures(tx *sql.Tx, layerID int64, features []database.LayerFeature) error {
detectorMap, err := detector.FindAllDetectors(tx)
if err != nil {
return err
}
var namespaces []database.Namespace
for _, feature := range features {
namespaces = append(namespaces, feature.PotentialNamespace)
}
nameSpaceIDs, _ := namespace.FindNamespaceIDs(tx, namespaces)
featureNamespaceMap := map[database.Namespace]sql.NullInt64{}
rawFeatures := make([]database.Feature, 0, len(features))
for i, f := range features {
rawFeatures = append(rawFeatures, f.Feature)
if f.PotentialNamespace.Valid() {
featureNamespaceMap[f.PotentialNamespace] = nameSpaceIDs[i]
}
}
featureIDs, err := feature.FindFeatureIDs(tx, rawFeatures)
if err != nil {
return err
}
var namespaceID sql.NullInt64
dbFeatures := make([]dbLayerFeature, 0, len(features))
for i, f := range features {
detectorID := detectorMap.ByValue[f.By]
featureID := featureIDs[i].Int64
if !featureIDs[i].Valid {
return database.ErrMissingEntities
}
namespaceID = featureNamespaceMap[f.PotentialNamespace]
dbFeatures = append(dbFeatures, dbLayerFeature{layerID, featureID, detectorID, namespaceID})
}
if err := PersistLayerFeatures(tx, dbFeatures); err != nil {
return err
}
return nil
}
func PersistLayerFeatures(tx *sql.Tx, features []dbLayerFeature) error {
if len(features) == 0 {
return nil
}
sort.Slice(features, func(i, j int) bool {
return features[i].featureID < features[j].featureID
})
keys := make([]interface{}, 0, len(features)*4)
for _, f := range features {
keys = append(keys, f.layerID, f.featureID, f.detectorID, f.namespaceID)
}
_, err := tx.Exec(queryPersistLayerFeature(len(features)), keys...)
if err != nil {
return util.HandleError("queryPersistLayerFeature", err)
}
return nil
}

View File

@ -0,0 +1,127 @@
// Copyright 2019 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package layer
import (
"database/sql"
"sort"
"github.com/coreos/clair/database"
"github.com/coreos/clair/database/pgsql/detector"
"github.com/coreos/clair/database/pgsql/namespace"
"github.com/coreos/clair/database/pgsql/util"
)
const findLayerNamespaces = `
SELECT ns.name, ns.version_format, ln.detector_id
FROM layer_namespace AS ln, namespace AS ns
WHERE ln.namespace_id = ns.id
AND ln.layer_id = $1`
func queryPersistLayerNamespace(count int) string {
return util.QueryPersist(count,
"layer_namespace",
"layer_namespace_layer_id_namespace_id_key",
"layer_id",
"namespace_id",
"detector_id")
}
// dbLayerNamespace represents the layer_namespace table.
type dbLayerNamespace struct {
layerID int64
namespaceID int64
detectorID int64
}
func FindLayerNamespaces(tx *sql.Tx, layerID int64, detectors detector.DetectorMap) ([]database.LayerNamespace, error) {
rows, err := tx.Query(findLayerNamespaces, layerID)
if err != nil {
return nil, util.HandleError("findLayerNamespaces", err)
}
namespaces := []database.LayerNamespace{}
for rows.Next() {
var (
namespace database.LayerNamespace
detectorID int64
)
if err := rows.Scan(&namespace.Name, &namespace.VersionFormat, &detectorID); err != nil {
return nil, err
}
namespace.By = detectors.ByID[detectorID]
namespaces = append(namespaces, namespace)
}
return namespaces, nil
}
func PersistAllLayerNamespaces(tx *sql.Tx, layerID int64, namespaces []database.LayerNamespace) error {
detectorMap, err := detector.FindAllDetectors(tx)
if err != nil {
return err
}
// TODO(sidac): This kind of type conversion is very useless and wasteful,
// we need interfaces around the database models to reduce these kind of
// operations.
rawNamespaces := make([]database.Namespace, 0, len(namespaces))
for _, ns := range namespaces {
rawNamespaces = append(rawNamespaces, ns.Namespace)
}
rawNamespaceIDs, err := namespace.FindNamespaceIDs(tx, rawNamespaces)
if err != nil {
return err
}
dbLayerNamespaces := make([]dbLayerNamespace, 0, len(namespaces))
for i, ns := range namespaces {
detectorID := detectorMap.ByValue[ns.By]
namespaceID := rawNamespaceIDs[i].Int64
if !rawNamespaceIDs[i].Valid {
return database.ErrMissingEntities
}
dbLayerNamespaces = append(dbLayerNamespaces, dbLayerNamespace{layerID, namespaceID, detectorID})
}
return PersistLayerNamespaces(tx, dbLayerNamespaces)
}
func PersistLayerNamespaces(tx *sql.Tx, namespaces []dbLayerNamespace) error {
if len(namespaces) == 0 {
return nil
}
// for every bulk persist operation, the input data should be sorted.
sort.Slice(namespaces, func(i, j int) bool {
return namespaces[i].namespaceID < namespaces[j].namespaceID
})
keys := make([]interface{}, 0, len(namespaces)*3)
for _, row := range namespaces {
keys = append(keys, row.layerID, row.namespaceID, row.detectorID)
}
_, err := tx.Exec(queryPersistLayerNamespace(len(namespaces)), keys...)
if err != nil {
return util.HandleError("queryPersistLayerNamespace", err)
}
return nil
}

View File

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package pgsql package layer
import ( import (
"testing" "testing"
@ -20,6 +20,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/coreos/clair/database" "github.com/coreos/clair/database"
"github.com/coreos/clair/database/pgsql/testutil"
) )
var persistLayerTests = []struct { var persistLayerTests = []struct {
@ -39,9 +40,9 @@ var persistLayerTests = []struct {
{ {
title: "layer with inconsistent feature and detectors", title: "layer with inconsistent feature and detectors",
name: "random-forest", name: "random-forest",
by: []database.Detector{realDetectors[2]}, by: []database.Detector{testutil.RealDetectors[2]},
features: []database.LayerFeature{ features: []database.LayerFeature{
{realFeatures[1], realDetectors[1], database.Namespace{}}, {testutil.RealFeatures[1], testutil.RealDetectors[1], database.Namespace{}},
}, },
err: "parameters are not valid", err: "parameters are not valid",
}, },
@ -49,70 +50,71 @@ var persistLayerTests = []struct {
title: "layer with non-existing feature", title: "layer with non-existing feature",
name: "random-forest", name: "random-forest",
err: "associated immutable entities are missing in the database", err: "associated immutable entities are missing in the database",
by: []database.Detector{realDetectors[2]}, by: []database.Detector{testutil.RealDetectors[2]},
features: []database.LayerFeature{ features: []database.LayerFeature{
{fakeFeatures[1], realDetectors[2], database.Namespace{}}, {testutil.FakeFeatures[1], testutil.RealDetectors[2], database.Namespace{}},
}, },
}, },
{ {
title: "layer with non-existing namespace", title: "layer with non-existing namespace",
name: "random-forest2", name: "random-forest2",
err: "associated immutable entities are missing in the database", err: "associated immutable entities are missing in the database",
by: []database.Detector{realDetectors[1]}, by: []database.Detector{testutil.RealDetectors[1]},
namespaces: []database.LayerNamespace{ namespaces: []database.LayerNamespace{
{fakeNamespaces[1], realDetectors[1]}, {testutil.FakeNamespaces[1], testutil.RealDetectors[1]},
}, },
}, },
{ {
title: "layer with non-existing detector", title: "layer with non-existing detector",
name: "random-forest3", name: "random-forest3",
err: "associated immutable entities are missing in the database", err: "associated immutable entities are missing in the database",
by: []database.Detector{fakeDetector[1]}, by: []database.Detector{testutil.FakeDetector[1]},
}, },
{ {
title: "valid layer", title: "valid layer",
name: "hamsterhouse", name: "hamsterhouse",
by: []database.Detector{realDetectors[1], realDetectors[2]}, by: []database.Detector{testutil.RealDetectors[1], testutil.RealDetectors[2]},
features: []database.LayerFeature{ features: []database.LayerFeature{
{realFeatures[1], realDetectors[2], database.Namespace{}}, {testutil.RealFeatures[1], testutil.RealDetectors[2], database.Namespace{}},
{realFeatures[2], realDetectors[2], database.Namespace{}}, {testutil.RealFeatures[2], testutil.RealDetectors[2], database.Namespace{}},
}, },
namespaces: []database.LayerNamespace{ namespaces: []database.LayerNamespace{
{realNamespaces[1], realDetectors[1]}, {testutil.RealNamespaces[1], testutil.RealDetectors[1]},
}, },
layer: &database.Layer{ layer: &database.Layer{
Hash: "hamsterhouse", Hash: "hamsterhouse",
By: []database.Detector{realDetectors[1], realDetectors[2]}, By: []database.Detector{testutil.RealDetectors[1], testutil.RealDetectors[2]},
Features: []database.LayerFeature{ Features: []database.LayerFeature{
{realFeatures[1], realDetectors[2], database.Namespace{}}, {testutil.RealFeatures[1], testutil.RealDetectors[2], database.Namespace{}},
{realFeatures[2], realDetectors[2], database.Namespace{}}, {testutil.RealFeatures[2], testutil.RealDetectors[2], database.Namespace{}},
}, },
Namespaces: []database.LayerNamespace{ Namespaces: []database.LayerNamespace{
{realNamespaces[1], realDetectors[1]}, {testutil.RealNamespaces[1], testutil.RealDetectors[1]},
}, },
}, },
}, },
{ {
title: "update existing layer", title: "update existing layer",
name: "layer-1", name: "layer-1",
by: []database.Detector{realDetectors[3], realDetectors[4]}, by: []database.Detector{testutil.RealDetectors[3], testutil.RealDetectors[4]},
features: []database.LayerFeature{ features: []database.LayerFeature{
{realFeatures[4], realDetectors[3], database.Namespace{}}, {testutil.RealFeatures[4], testutil.RealDetectors[3], database.Namespace{}},
}, },
namespaces: []database.LayerNamespace{ namespaces: []database.LayerNamespace{
{realNamespaces[3], realDetectors[4]}, {testutil.RealNamespaces[3], testutil.RealDetectors[4]},
}, },
layer: &database.Layer{ layer: &database.Layer{
Hash: "layer-1", Hash: "layer-1",
By: []database.Detector{realDetectors[1], realDetectors[2], realDetectors[3], realDetectors[4]}, By: []database.Detector{testutil.RealDetectors[1], testutil.RealDetectors[2], testutil.RealDetectors[3], testutil.RealDetectors[4]},
Features: []database.LayerFeature{ Features: []database.LayerFeature{
{realFeatures[1], realDetectors[2], database.Namespace{}}, {testutil.RealFeatures[1], testutil.RealDetectors[2], database.Namespace{}},
{realFeatures[2], realDetectors[2], database.Namespace{}}, {testutil.RealFeatures[2], testutil.RealDetectors[2], database.Namespace{}},
{realFeatures[4], realDetectors[3], database.Namespace{}}, {testutil.RealFeatures[4], testutil.RealDetectors[3], database.Namespace{}},
}, },
Namespaces: []database.LayerNamespace{ Namespaces: []database.LayerNamespace{
{realNamespaces[1], realDetectors[1]}, {testutil.RealNamespaces[1], testutil.RealDetectors[1]},
{realNamespaces[3], realDetectors[4]}, {testutil.RealNamespaces[3], testutil.RealDetectors[4]},
}, },
}, },
}, },
@ -120,33 +122,33 @@ var persistLayerTests = []struct {
{ {
title: "layer with potential namespace", title: "layer with potential namespace",
name: "layer-potential-namespace", name: "layer-potential-namespace",
by: []database.Detector{realDetectors[3]}, by: []database.Detector{testutil.RealDetectors[3]},
features: []database.LayerFeature{ features: []database.LayerFeature{
{realFeatures[4], realDetectors[3], realNamespaces[4]}, {testutil.RealFeatures[4], testutil.RealDetectors[3], testutil.RealNamespaces[4]},
}, },
namespaces: []database.LayerNamespace{ namespaces: []database.LayerNamespace{
{realNamespaces[3], realDetectors[3]}, {testutil.RealNamespaces[3], testutil.RealDetectors[3]},
}, },
layer: &database.Layer{ layer: &database.Layer{
Hash: "layer-potential-namespace", Hash: "layer-potential-namespace",
By: []database.Detector{realDetectors[3]}, By: []database.Detector{testutil.RealDetectors[3]},
Features: []database.LayerFeature{ Features: []database.LayerFeature{
{realFeatures[4], realDetectors[3], realNamespaces[4]}, {testutil.RealFeatures[4], testutil.RealDetectors[3], testutil.RealNamespaces[4]},
}, },
Namespaces: []database.LayerNamespace{ Namespaces: []database.LayerNamespace{
{realNamespaces[3], realDetectors[3]}, {testutil.RealNamespaces[3], testutil.RealDetectors[3]},
}, },
}, },
}, },
} }
func TestPersistLayer(t *testing.T) { func TestPersistLayer(t *testing.T) {
datastore, tx := openSessionForTest(t, "PersistLayer", true) tx, cleanup := testutil.CreateTestTxWithFixtures(t, "PersistLayer")
defer closeTest(t, datastore, tx) defer cleanup()
for _, test := range persistLayerTests { for _, test := range persistLayerTests {
t.Run(test.title, func(t *testing.T) { t.Run(test.title, func(t *testing.T) {
err := tx.PersistLayer(test.name, test.features, test.namespaces, test.by) err := PersistLayer(tx, test.name, test.features, test.namespaces, test.by)
if test.err != "" { if test.err != "" {
assert.EqualError(t, err, test.err, "unexpected error") assert.EqualError(t, err, test.err, "unexpected error")
return return
@ -154,7 +156,7 @@ func TestPersistLayer(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
if test.layer != nil { if test.layer != nil {
layer, ok, err := tx.FindLayer(test.name) layer, ok, err := FindLayer(tx, test.name)
assert.Nil(t, err) assert.Nil(t, err)
assert.True(t, ok) assert.True(t, ok)
database.AssertLayerEqual(t, test.layer, &layer) database.AssertLayerEqual(t, test.layer, &layer)
@ -186,17 +188,17 @@ var findLayerTests = []struct {
title: "existing layer", title: "existing layer",
in: "layer-4", in: "layer-4",
ok: true, ok: true,
out: takeLayerPointerFromMap(realLayers, 6), out: testutil.TakeLayerPointerFromMap(testutil.RealLayers, 6),
}, },
} }
func TestFindLayer(t *testing.T) { func TestFindLayer(t *testing.T) {
datastore, tx := openSessionForTest(t, "FindLayer", true) tx, cleanup := testutil.CreateTestTxWithFixtures(t, "FindLayer")
defer closeTest(t, datastore, tx) defer cleanup()
for _, test := range findLayerTests { for _, test := range findLayerTests {
t.Run(test.title, func(t *testing.T) { t.Run(test.title, func(t *testing.T) {
layer, ok, err := tx.FindLayer(test.in) layer, ok, err := FindLayer(tx, test.in)
if test.err != "" { if test.err != "" {
assert.EqualError(t, err, test.err, "unexpected error") assert.EqualError(t, err, test.err, "unexpected error")
return return

View File

@ -12,11 +12,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package pgsql package lock
import ( import (
"database/sql"
"time" "time"
"github.com/coreos/clair/database/pgsql/monitoring"
"github.com/coreos/clair/database/pgsql/util"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@ -38,12 +41,12 @@ const (
SELECT owner, until FROM lock WHERE name = $1` SELECT owner, until FROM lock WHERE name = $1`
) )
func (tx *pgSession) AcquireLock(lockName, whoami string, desiredDuration time.Duration) (bool, time.Time, error) { func AcquireLock(tx *sql.Tx, lockName, whoami string, desiredDuration time.Duration) (bool, time.Time, error) {
if lockName == "" || whoami == "" || desiredDuration == 0 { if lockName == "" || whoami == "" || desiredDuration == 0 {
panic("invalid lock parameters") panic("invalid lock parameters")
} }
if err := tx.pruneLocks(); err != nil { if err := PruneLocks(tx); err != nil {
return false, time.Time{}, err return false, time.Time{}, err
} }
@ -54,22 +57,22 @@ func (tx *pgSession) AcquireLock(lockName, whoami string, desiredDuration time.D
lockOwner string lockOwner string
) )
defer observeQueryTime("Lock", "soiLock", time.Now()) defer monitoring.ObserveQueryTime("Lock", "soiLock", time.Now())
err := tx.QueryRow(soiLock, lockName, whoami, desiredLockedUntil).Scan(&lockOwner, &lockedUntil) err := tx.QueryRow(soiLock, lockName, whoami, desiredLockedUntil).Scan(&lockOwner, &lockedUntil)
return lockOwner == whoami, lockedUntil, err return lockOwner == whoami, lockedUntil, util.HandleError("AcquireLock", err)
} }
func (tx *pgSession) ExtendLock(lockName, whoami string, desiredDuration time.Duration) (bool, time.Time, error) { func ExtendLock(tx *sql.Tx, lockName, whoami string, desiredDuration time.Duration) (bool, time.Time, error) {
if lockName == "" || whoami == "" || desiredDuration == 0 { if lockName == "" || whoami == "" || desiredDuration == 0 {
panic("invalid lock parameters") panic("invalid lock parameters")
} }
desiredLockedUntil := time.Now().Add(desiredDuration) desiredLockedUntil := time.Now().Add(desiredDuration)
defer observeQueryTime("Lock", "update", time.Now()) defer monitoring.ObserveQueryTime("Lock", "update", time.Now())
result, err := tx.Exec(updateLock, lockName, whoami, desiredLockedUntil) result, err := tx.Exec(updateLock, lockName, whoami, desiredLockedUntil)
if err != nil { if err != nil {
return false, time.Time{}, handleError("updateLock", err) return false, time.Time{}, util.HandleError("updateLock", err)
} }
if numRows, err := result.RowsAffected(); err == nil { if numRows, err := result.RowsAffected(); err == nil {
@ -77,27 +80,27 @@ func (tx *pgSession) ExtendLock(lockName, whoami string, desiredDuration time.Du
return numRows > 0, desiredLockedUntil, nil return numRows > 0, desiredLockedUntil, nil
} }
return false, time.Time{}, handleError("updateLock", err) return false, time.Time{}, util.HandleError("updateLock", err)
} }
func (tx *pgSession) ReleaseLock(name, owner string) error { func ReleaseLock(tx *sql.Tx, name, owner string) error {
if name == "" || owner == "" { if name == "" || owner == "" {
panic("invalid lock parameters") panic("invalid lock parameters")
} }
defer observeQueryTime("Unlock", "all", time.Now()) defer monitoring.ObserveQueryTime("Unlock", "all", time.Now())
_, err := tx.Exec(removeLock, name, owner) _, err := tx.Exec(removeLock, name, owner)
return err return err
} }
// pruneLocks removes every expired locks from the database // pruneLocks removes every expired locks from the database
func (tx *pgSession) pruneLocks() error { func PruneLocks(tx *sql.Tx) error {
defer observeQueryTime("pruneLocks", "all", time.Now()) defer monitoring.ObserveQueryTime("pruneLocks", "all", time.Now())
if r, err := tx.Exec(removeLockExpired, time.Now().UTC()); err != nil { if r, err := tx.Exec(removeLockExpired, time.Now().UTC()); err != nil {
return handleError("removeLockExpired", err) return util.HandleError("removeLockExpired", err)
} else if affected, err := r.RowsAffected(); err != nil { } else if affected, err := r.RowsAffected(); err != nil {
return handleError("removeLockExpired", err) return util.HandleError("removeLockExpired", err)
} else { } else {
log.Debugf("Pruned %d Locks", affected) log.Debugf("Pruned %d Locks", affected)
} }

View File

@ -0,0 +1,100 @@
// Copyright 2019 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package lock
import (
"testing"
"time"
"github.com/coreos/clair/database/pgsql/testutil"
"github.com/stretchr/testify/require"
)
func TestAcquireLockReturnsExistingLockDuration(t *testing.T) {
tx, cleanup := testutil.CreateTestTxWithFixtures(t, "Lock")
defer cleanup()
acquired, originalExpiration, err := AcquireLock(tx, "test1", "owner1", time.Minute)
require.Nil(t, err)
require.True(t, acquired)
acquired2, expiration, err := AcquireLock(tx, "test1", "owner2", time.Hour)
require.Nil(t, err)
require.False(t, acquired2)
require.Equal(t, expiration, originalExpiration)
}
func TestLock(t *testing.T) {
db, cleanup := testutil.CreateTestDBWithFixture(t, "Lock")
defer cleanup()
tx, err := db.Begin()
if err != nil {
panic(err)
}
// Create a first lock.
l, _, err := AcquireLock(tx, "test1", "owner1", time.Minute)
require.Nil(t, err)
require.True(t, l)
tx = testutil.RestartTransaction(db, tx, true)
// lock again by itself, the previous lock is not expired yet.
l, _, err = AcquireLock(tx, "test1", "owner1", time.Minute)
require.Nil(t, err)
require.True(t, l)
tx = testutil.RestartTransaction(db, tx, false)
// Try to renew the same lock with another owner.
l, _, err = ExtendLock(tx, "test1", "owner2", time.Minute)
require.Nil(t, err)
require.False(t, l)
tx = testutil.RestartTransaction(db, tx, false)
l, _, err = AcquireLock(tx, "test1", "owner2", time.Minute)
require.Nil(t, err)
require.False(t, l)
tx = testutil.RestartTransaction(db, tx, false)
// Renew the lock.
l, _, err = ExtendLock(tx, "test1", "owner1", 2*time.Minute)
require.Nil(t, err)
require.True(t, l)
tx = testutil.RestartTransaction(db, tx, true)
// Unlock and then relock by someone else.
err = ReleaseLock(tx, "test1", "owner1")
require.Nil(t, err)
tx = testutil.RestartTransaction(db, tx, true)
l, _, err = AcquireLock(tx, "test1", "owner2", time.Minute)
require.Nil(t, err)
require.True(t, l)
tx = testutil.RestartTransaction(db, tx, true)
// Create a second lock which is actually already expired ...
l, _, err = AcquireLock(tx, "test2", "owner1", -time.Minute)
require.Nil(t, err)
require.True(t, l)
tx = testutil.RestartTransaction(db, tx, true)
// Take over the lock
l, _, err = AcquireLock(tx, "test2", "owner2", time.Minute)
require.Nil(t, err)
require.True(t, l)
tx = testutil.RestartTransaction(db, tx, true)
require.Nil(t, tx.Rollback())
}

View File

@ -1,99 +0,0 @@
// Copyright 2019 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pgsql
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestAcquireLockReturnsExistingLockDuration(t *testing.T) {
tx, cleanup := createTestPgSessionWithFixtures(t, "Lock")
defer cleanup()
acquired, originalExpiration, err := tx.AcquireLock("test1", "owner1", time.Minute)
require.Nil(t, err)
require.True(t, acquired)
acquired2, expiration, err := tx.AcquireLock("test1", "owner2", time.Hour)
require.Nil(t, err)
require.False(t, acquired2)
require.Equal(t, expiration, originalExpiration)
}
func TestLock(t *testing.T) {
datastore, tx := openSessionForTest(t, "Lock", true)
defer datastore.Close()
var l bool
// Create a first lock.
l, _, err := tx.AcquireLock("test1", "owner1", time.Minute)
assert.Nil(t, err)
assert.True(t, l)
tx = restartSession(t, datastore, tx, true)
// lock again by itself, the previous lock is not expired yet.
l, _, err = tx.AcquireLock("test1", "owner1", time.Minute)
assert.Nil(t, err)
assert.True(t, l) // acquire lock no-op when owner already has the lock.
tx = restartSession(t, datastore, tx, false)
// Try to renew the same lock with another owner.
l, _, err = tx.ExtendLock("test1", "owner2", time.Minute)
assert.Nil(t, err)
assert.False(t, l)
tx = restartSession(t, datastore, tx, false)
l, _, err = tx.AcquireLock("test1", "owner2", time.Minute)
assert.Nil(t, err)
assert.False(t, l)
tx = restartSession(t, datastore, tx, false)
// Renew the lock.
l, _, err = tx.ExtendLock("test1", "owner1", 2*time.Minute)
assert.Nil(t, err)
assert.True(t, l)
tx = restartSession(t, datastore, tx, true)
// Unlock and then relock by someone else.
err = tx.ReleaseLock("test1", "owner1")
assert.Nil(t, err)
tx = restartSession(t, datastore, tx, true)
l, _, err = tx.AcquireLock("test1", "owner2", time.Minute)
assert.Nil(t, err)
assert.True(t, l)
tx = restartSession(t, datastore, tx, true)
// Create a second lock which is actually already expired ...
l, _, err = tx.AcquireLock("test2", "owner1", -time.Minute)
assert.Nil(t, err)
assert.True(t, l)
tx = restartSession(t, datastore, tx, true)
// Take over the lock
l, _, err = tx.AcquireLock("test2", "owner2", time.Minute)
assert.Nil(t, err)
assert.True(t, l)
tx = restartSession(t, datastore, tx, true)
if !assert.Nil(t, tx.Rollback()) {
t.FailNow()
}
}

View File

@ -12,22 +12,22 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package pgsql package migrations_test
import ( import (
"testing" "testing"
"github.com/coreos/clair/database/pgsql/migrations"
"github.com/coreos/clair/database/pgsql/testutil"
_ "github.com/lib/pq" _ "github.com/lib/pq"
"github.com/remind101/migrate" "github.com/remind101/migrate"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/coreos/clair/database/pgsql/migrations"
) )
var userTableCount = `SELECT tablename FROM pg_catalog.pg_tables WHERE schemaname='public'` var userTableCount = `SELECT tablename FROM pg_catalog.pg_tables WHERE schemaname='public'`
func TestMigration(t *testing.T) { func TestMigration(t *testing.T) {
db, cleanup := createAndConnectTestDB(t, "TestMigration") db, cleanup := testutil.CreateAndConnectTestDB(t, "TestMigration")
defer cleanup() defer cleanup()
err := migrate.NewPostgresMigrator(db).Exec(migrate.Up, migrations.Migrations...) err := migrate.NewPostgresMigrator(db).Exec(migrate.Up, migrations.Migrations...)

View File

@ -0,0 +1,67 @@
// Copyright 2019 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package monitoring
import (
"time"
"github.com/prometheus/client_golang/prometheus"
)
var (
PromErrorsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "clair_pgsql_errors_total",
Help: "Number of errors that PostgreSQL requests generated.",
}, []string{"request"})
PromCacheHitsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "clair_pgsql_cache_hits_total",
Help: "Number of cache hits that the PostgreSQL backend did.",
}, []string{"object"})
PromCacheQueriesTotal = prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "clair_pgsql_cache_queries_total",
Help: "Number of cache queries that the PostgreSQL backend did.",
}, []string{"object"})
PromQueryDurationMilliseconds = prometheus.NewHistogramVec(prometheus.HistogramOpts{
Name: "clair_pgsql_query_duration_milliseconds",
Help: "Time it takes to execute the database query.",
}, []string{"query", "subquery"})
PromConcurrentLockVAFV = prometheus.NewGauge(prometheus.GaugeOpts{
Name: "clair_pgsql_concurrent_lock_vafv_total",
Help: "Number of transactions trying to hold the exclusive Vulnerability_Affects_Feature lock.",
})
)
func init() {
prometheus.MustRegister(PromErrorsTotal)
prometheus.MustRegister(PromCacheHitsTotal)
prometheus.MustRegister(PromCacheQueriesTotal)
prometheus.MustRegister(PromQueryDurationMilliseconds)
prometheus.MustRegister(PromConcurrentLockVAFV)
}
// monitoring.ObserveQueryTime computes the time elapsed since `start` to represent the
// query time.
// 1. `query` is a pgSession function name.
// 2. `subquery` is a specific query or a batched query.
// 3. `start` is the time right before query is executed.
func ObserveQueryTime(query, subquery string, start time.Time) {
PromQueryDurationMilliseconds.
WithLabelValues(query, subquery).
Observe(float64(time.Since(start).Nanoseconds()) / float64(time.Millisecond))
}

View File

@ -12,13 +12,15 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package pgsql package namespace
import ( import (
"database/sql" "database/sql"
"fmt"
"sort" "sort"
"github.com/coreos/clair/database" "github.com/coreos/clair/database"
"github.com/coreos/clair/database/pgsql/util"
"github.com/coreos/clair/pkg/commonerr" "github.com/coreos/clair/pkg/commonerr"
) )
@ -26,8 +28,24 @@ const (
searchNamespaceID = `SELECT id FROM Namespace WHERE name = $1 AND version_format = $2` searchNamespaceID = `SELECT id FROM Namespace WHERE name = $1 AND version_format = $2`
) )
func queryPersistNamespace(count int) string {
return util.QueryPersist(count,
"namespace",
"namespace_name_version_format_key",
"name",
"version_format")
}
func querySearchNamespace(nsCount int) string {
return fmt.Sprintf(
`SELECT id, name, version_format
FROM namespace WHERE (name, version_format) IN (%s)`,
util.QueryString(2, nsCount),
)
}
// PersistNamespaces soi namespaces into database. // PersistNamespaces soi namespaces into database.
func (tx *pgSession) PersistNamespaces(namespaces []database.Namespace) error { func PersistNamespaces(tx *sql.Tx, namespaces []database.Namespace) error {
if len(namespaces) == 0 { if len(namespaces) == 0 {
return nil return nil
} }
@ -49,12 +67,12 @@ func (tx *pgSession) PersistNamespaces(namespaces []database.Namespace) error {
_, err := tx.Exec(queryPersistNamespace(len(namespaces)), keys...) _, err := tx.Exec(queryPersistNamespace(len(namespaces)), keys...)
if err != nil { if err != nil {
return handleError("queryPersistNamespace", err) return util.HandleError("queryPersistNamespace", err)
} }
return nil return nil
} }
func (tx *pgSession) findNamespaceIDs(namespaces []database.Namespace) ([]sql.NullInt64, error) { func FindNamespaceIDs(tx *sql.Tx, namespaces []database.Namespace) ([]sql.NullInt64, error) {
if len(namespaces) == 0 { if len(namespaces) == 0 {
return nil, nil return nil, nil
} }
@ -69,7 +87,7 @@ func (tx *pgSession) findNamespaceIDs(namespaces []database.Namespace) ([]sql.Nu
rows, err := tx.Query(querySearchNamespace(len(namespaces)), keys...) rows, err := tx.Query(querySearchNamespace(len(namespaces)), keys...)
if err != nil { if err != nil {
return nil, handleError("searchNamespace", err) return nil, util.HandleError("searchNamespace", err)
} }
defer rows.Close() defer rows.Close()
@ -81,7 +99,7 @@ func (tx *pgSession) findNamespaceIDs(namespaces []database.Namespace) ([]sql.Nu
for rows.Next() { for rows.Next() {
err := rows.Scan(&id, &ns.Name, &ns.VersionFormat) err := rows.Scan(&id, &ns.Name, &ns.VersionFormat)
if err != nil { if err != nil {
return nil, handleError("searchNamespace", err) return nil, util.HandleError("searchNamespace", err)
} }
nsMap[ns] = id nsMap[ns] = id
} }

View File

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package pgsql package namespace
import ( import (
"testing" "testing"
@ -20,25 +20,26 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/coreos/clair/database" "github.com/coreos/clair/database"
"github.com/coreos/clair/database/pgsql/testutil"
) )
func TestPersistNamespaces(t *testing.T) { func TestPersistNamespaces(t *testing.T) {
datastore, tx := openSessionForTest(t, "PersistNamespaces", false) tx, cleanup := testutil.CreateTestTx(t, "PersistNamespaces")
defer closeTest(t, datastore, tx) defer cleanup()
ns1 := database.Namespace{} ns1 := database.Namespace{}
ns2 := database.Namespace{Name: "t", VersionFormat: "b"} ns2 := database.Namespace{Name: "t", VersionFormat: "b"}
// Empty Case // Empty Case
assert.Nil(t, tx.PersistNamespaces([]database.Namespace{})) assert.Nil(t, PersistNamespaces(tx, []database.Namespace{}))
// Invalid Case // Invalid Case
assert.NotNil(t, tx.PersistNamespaces([]database.Namespace{ns1})) assert.NotNil(t, PersistNamespaces(tx, []database.Namespace{ns1}))
// Duplicated Case // Duplicated Case
assert.Nil(t, tx.PersistNamespaces([]database.Namespace{ns2, ns2})) assert.Nil(t, PersistNamespaces(tx, []database.Namespace{ns2, ns2}))
// Existing Case // Existing Case
assert.Nil(t, tx.PersistNamespaces([]database.Namespace{ns2})) assert.Nil(t, PersistNamespaces(tx, []database.Namespace{ns2}))
nsList := listNamespaces(t, tx) nsList := testutil.ListNamespaces(t, tx)
assert.Len(t, nsList, 1) assert.Len(t, nsList, 1)
assert.Equal(t, ns2, nsList[0]) assert.Equal(t, ns2, nsList[0])
} }

View File

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package pgsql package notification
import ( import (
"testing" "testing"
@ -22,6 +22,8 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/coreos/clair/database" "github.com/coreos/clair/database"
"github.com/coreos/clair/database/pgsql/page"
"github.com/coreos/clair/database/pgsql/testutil"
"github.com/coreos/clair/pkg/pagination" "github.com/coreos/clair/pkg/pagination"
) )
@ -38,6 +40,8 @@ type findVulnerabilityNotificationOut struct {
err string err string
} }
var testPaginationKey = pagination.Must(pagination.NewKey())
var findVulnerabilityNotificationTests = []struct { var findVulnerabilityNotificationTests = []struct {
title string title string
in findVulnerabilityNotificationIn in findVulnerabilityNotificationIn
@ -77,21 +81,21 @@ var findVulnerabilityNotificationTests = []struct {
}, },
out: findVulnerabilityNotificationOut{ out: findVulnerabilityNotificationOut{
&database.VulnerabilityNotificationWithVulnerable{ &database.VulnerabilityNotificationWithVulnerable{
NotificationHook: realNotification[1].NotificationHook, NotificationHook: testutil.RealNotification[1].NotificationHook,
Old: &database.PagedVulnerableAncestries{ Old: &database.PagedVulnerableAncestries{
Vulnerability: realVulnerability[2], Vulnerability: testutil.RealVulnerability[2],
Limit: 1, Limit: 1,
Affected: make(map[int]string), Affected: make(map[int]string),
Current: mustMarshalToken(testPaginationKey, Page{0}), Current: testutil.MustMarshalToken(testutil.TestPaginationKey, page.Page{0}),
Next: mustMarshalToken(testPaginationKey, Page{0}), Next: testutil.MustMarshalToken(testutil.TestPaginationKey, page.Page{0}),
End: true, End: true,
}, },
New: &database.PagedVulnerableAncestries{ New: &database.PagedVulnerableAncestries{
Vulnerability: realVulnerability[1], Vulnerability: testutil.RealVulnerability[1],
Limit: 1, Limit: 1,
Affected: map[int]string{3: "ancestry-3"}, Affected: map[int]string{3: "ancestry-3"},
Current: mustMarshalToken(testPaginationKey, Page{0}), Current: testutil.MustMarshalToken(testutil.TestPaginationKey, page.Page{0}),
Next: mustMarshalToken(testPaginationKey, Page{4}), Next: testutil.MustMarshalToken(testutil.TestPaginationKey, page.Page{4}),
End: false, End: false,
}, },
}, },
@ -100,32 +104,31 @@ var findVulnerabilityNotificationTests = []struct {
"", "",
}, },
}, },
{ {
title: "find existing notification of second page of new affected ancestry", title: "find existing notification of second page of new affected ancestry",
in: findVulnerabilityNotificationIn{ in: findVulnerabilityNotificationIn{
notificationName: "test", notificationName: "test",
pageSize: 1, pageSize: 1,
oldAffectedAncestryPage: pagination.FirstPageToken, oldAffectedAncestryPage: pagination.FirstPageToken,
newAffectedAncestryPage: mustMarshalToken(testPaginationKey, Page{4}), newAffectedAncestryPage: testutil.MustMarshalToken(testutil.TestPaginationKey, page.Page{4}),
}, },
out: findVulnerabilityNotificationOut{ out: findVulnerabilityNotificationOut{
&database.VulnerabilityNotificationWithVulnerable{ &database.VulnerabilityNotificationWithVulnerable{
NotificationHook: realNotification[1].NotificationHook, NotificationHook: testutil.RealNotification[1].NotificationHook,
Old: &database.PagedVulnerableAncestries{ Old: &database.PagedVulnerableAncestries{
Vulnerability: realVulnerability[2], Vulnerability: testutil.RealVulnerability[2],
Limit: 1, Limit: 1,
Affected: make(map[int]string), Affected: make(map[int]string),
Current: mustMarshalToken(testPaginationKey, Page{0}), Current: testutil.MustMarshalToken(testutil.TestPaginationKey, page.Page{0}),
Next: mustMarshalToken(testPaginationKey, Page{0}), Next: testutil.MustMarshalToken(testutil.TestPaginationKey, page.Page{0}),
End: true, End: true,
}, },
New: &database.PagedVulnerableAncestries{ New: &database.PagedVulnerableAncestries{
Vulnerability: realVulnerability[1], Vulnerability: testutil.RealVulnerability[1],
Limit: 1, Limit: 1,
Affected: map[int]string{4: "ancestry-4"}, Affected: map[int]string{4: "ancestry-4"},
Current: mustMarshalToken(testPaginationKey, Page{4}), Current: testutil.MustMarshalToken(testutil.TestPaginationKey, page.Page{4}),
Next: mustMarshalToken(testPaginationKey, Page{0}), Next: testutil.MustMarshalToken(testutil.TestPaginationKey, page.Page{0}),
End: true, End: true,
}, },
}, },
@ -137,12 +140,12 @@ var findVulnerabilityNotificationTests = []struct {
} }
func TestFindVulnerabilityNotification(t *testing.T) { func TestFindVulnerabilityNotification(t *testing.T) {
datastore, tx := openSessionForTest(t, "pagination", true) tx, cleanup := testutil.CreateTestTxWithFixtures(t, "pagination")
defer closeTest(t, datastore, tx) defer cleanup()
for _, test := range findVulnerabilityNotificationTests { for _, test := range findVulnerabilityNotificationTests {
t.Run(test.title, func(t *testing.T) { t.Run(test.title, func(t *testing.T) {
notification, ok, err := tx.FindVulnerabilityNotification(test.in.notificationName, test.in.pageSize, test.in.oldAffectedAncestryPage, test.in.newAffectedAncestryPage) notification, ok, err := FindVulnerabilityNotification(tx, test.in.notificationName, test.in.pageSize, test.in.oldAffectedAncestryPage, test.in.newAffectedAncestryPage, testutil.TestPaginationKey)
if test.out.err != "" { if test.out.err != "" {
require.EqualError(t, err, test.out.err) require.EqualError(t, err, test.out.err)
return return
@ -155,13 +158,14 @@ func TestFindVulnerabilityNotification(t *testing.T) {
} }
require.True(t, ok) require.True(t, ok)
assertVulnerabilityNotificationWithVulnerableEqual(t, testPaginationKey, test.out.notification, &notification) testutil.AssertVulnerabilityNotificationWithVulnerableEqual(t, testutil.TestPaginationKey, test.out.notification, &notification)
}) })
} }
} }
func TestInsertVulnerabilityNotifications(t *testing.T) { func TestInsertVulnerabilityNotifications(t *testing.T) {
datastore, tx := openSessionForTest(t, "InsertVulnerabilityNotifications", true) datastore, cleanup := testutil.CreateTestDBWithFixture(t, "InsertVulnerabilityNotifications")
defer cleanup()
n1 := database.VulnerabilityNotification{} n1 := database.VulnerabilityNotification{}
n3 := database.VulnerabilityNotification{ n3 := database.VulnerabilityNotification{
@ -187,34 +191,37 @@ func TestInsertVulnerabilityNotifications(t *testing.T) {
}, },
} }
tx, err := datastore.Begin()
require.Nil(t, err)
// invalid case // invalid case
err := tx.InsertVulnerabilityNotifications([]database.VulnerabilityNotification{n1}) err = InsertVulnerabilityNotifications(tx, []database.VulnerabilityNotification{n1})
assert.NotNil(t, err) require.NotNil(t, err)
// invalid case: unknown vulnerability // invalid case: unknown vulnerability
err = tx.InsertVulnerabilityNotifications([]database.VulnerabilityNotification{n3}) err = InsertVulnerabilityNotifications(tx, []database.VulnerabilityNotification{n3})
assert.NotNil(t, err) require.NotNil(t, err)
// invalid case: duplicated input notification // invalid case: duplicated input notification
err = tx.InsertVulnerabilityNotifications([]database.VulnerabilityNotification{n4, n4}) err = InsertVulnerabilityNotifications(tx, []database.VulnerabilityNotification{n4, n4})
assert.NotNil(t, err) require.NotNil(t, err)
tx = restartSession(t, datastore, tx, false) tx = testutil.RestartTransaction(datastore, tx, false)
// valid case // valid case
err = tx.InsertVulnerabilityNotifications([]database.VulnerabilityNotification{n4}) err = InsertVulnerabilityNotifications(tx, []database.VulnerabilityNotification{n4})
assert.Nil(t, err) require.Nil(t, err)
// invalid case: notification is already in database // invalid case: notification is already in database
err = tx.InsertVulnerabilityNotifications([]database.VulnerabilityNotification{n4}) err = InsertVulnerabilityNotifications(tx, []database.VulnerabilityNotification{n4})
assert.NotNil(t, err) require.NotNil(t, err)
closeTest(t, datastore, tx) require.Nil(t, tx.Rollback())
} }
func TestFindNewNotification(t *testing.T) { func TestFindNewNotification(t *testing.T) {
tx, cleanup := createTestPgSessionWithFixtures(t, "TestFindNewNotification") tx, cleanup := testutil.CreateTestTxWithFixtures(t, "TestFindNewNotification")
defer cleanup() defer cleanup()
noti, ok, err := tx.FindNewNotification(time.Now()) noti, ok, err := FindNewNotification(tx, time.Now())
if assert.Nil(t, err) && assert.True(t, ok) { if assert.Nil(t, err) && assert.True(t, ok) {
assert.Equal(t, "test", noti.Name) assert.Equal(t, "test", noti.Name)
assert.Equal(t, time.Time{}, noti.Notified) assert.Equal(t, time.Time{}, noti.Notified)
@ -223,13 +230,13 @@ func TestFindNewNotification(t *testing.T) {
} }
// can't find the notified // can't find the notified
assert.Nil(t, tx.MarkNotificationAsRead("test")) assert.Nil(t, MarkNotificationAsRead(tx, "test"))
// if the notified time is before // if the notified time is before
noti, ok, err = tx.FindNewNotification(time.Now().Add(-time.Duration(10 * time.Second))) noti, ok, err = FindNewNotification(tx, time.Now().Add(-time.Duration(10*time.Second)))
assert.Nil(t, err) assert.Nil(t, err)
assert.False(t, ok) assert.False(t, ok)
// can find the notified after a period of time // can find the notified after a period of time
noti, ok, err = tx.FindNewNotification(time.Now().Add(time.Duration(10 * time.Second))) noti, ok, err = FindNewNotification(tx, time.Now().Add(time.Duration(10*time.Second)))
if assert.Nil(t, err) && assert.True(t, ok) { if assert.Nil(t, err) && assert.True(t, ok) {
assert.Equal(t, "test", noti.Name) assert.Equal(t, "test", noti.Name)
assert.NotEqual(t, time.Time{}, noti.Notified) assert.NotEqual(t, time.Time{}, noti.Notified)
@ -237,37 +244,37 @@ func TestFindNewNotification(t *testing.T) {
assert.Equal(t, time.Time{}, noti.Deleted) assert.Equal(t, time.Time{}, noti.Deleted)
} }
assert.Nil(t, tx.DeleteNotification("test")) assert.Nil(t, DeleteNotification(tx, "test"))
// can't find in any time // can't find in any time
noti, ok, err = tx.FindNewNotification(time.Now().Add(-time.Duration(1000))) noti, ok, err = FindNewNotification(tx, time.Now().Add(-time.Duration(1000)))
assert.Nil(t, err) assert.Nil(t, err)
assert.False(t, ok) assert.False(t, ok)
noti, ok, err = tx.FindNewNotification(time.Now().Add(time.Duration(1000))) noti, ok, err = FindNewNotification(tx, time.Now().Add(time.Duration(1000)))
assert.Nil(t, err) assert.Nil(t, err)
assert.False(t, ok) assert.False(t, ok)
} }
func TestMarkNotificationAsRead(t *testing.T) { func TestMarkNotificationAsRead(t *testing.T) {
datastore, tx := openSessionForTest(t, "MarkNotificationAsRead", true) tx, cleanup := testutil.CreateTestTxWithFixtures(t, "MarkNotificationAsRead")
defer closeTest(t, datastore, tx) defer cleanup()
// invalid case: notification doesn't exist // invalid case: notification doesn't exist
assert.NotNil(t, tx.MarkNotificationAsRead("non-existing")) assert.NotNil(t, MarkNotificationAsRead(tx, "non-existing"))
// valid case // valid case
assert.Nil(t, tx.MarkNotificationAsRead("test")) assert.Nil(t, MarkNotificationAsRead(tx, "test"))
// valid case // valid case
assert.Nil(t, tx.MarkNotificationAsRead("test")) assert.Nil(t, MarkNotificationAsRead(tx, "test"))
} }
func TestDeleteNotification(t *testing.T) { func TestDeleteNotification(t *testing.T) {
datastore, tx := openSessionForTest(t, "DeleteNotification", true) tx, cleanup := testutil.CreateTestTxWithFixtures(t, "DeleteNotification")
defer closeTest(t, datastore, tx) defer cleanup()
// invalid case: notification doesn't exist // invalid case: notification doesn't exist
assert.NotNil(t, tx.DeleteNotification("non-existing")) assert.NotNil(t, DeleteNotification(tx, "non-existing"))
// valid case // valid case
assert.Nil(t, tx.DeleteNotification("test")) assert.Nil(t, DeleteNotification(tx, "test"))
// invalid case: notification is already deleted // invalid case: notification is already deleted
assert.NotNil(t, tx.DeleteNotification("test")) assert.NotNil(t, DeleteNotification(tx, "test"))
} }

View File

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package pgsql package notification
import ( import (
"database/sql" "database/sql"
@ -22,6 +22,8 @@ import (
"github.com/guregu/null/zero" "github.com/guregu/null/zero"
"github.com/coreos/clair/database" "github.com/coreos/clair/database"
"github.com/coreos/clair/database/pgsql/util"
"github.com/coreos/clair/database/pgsql/vulnerability"
"github.com/coreos/clair/pkg/commonerr" "github.com/coreos/clair/pkg/commonerr"
"github.com/coreos/clair/pkg/pagination" "github.com/coreos/clair/pkg/pagination"
) )
@ -54,26 +56,24 @@ const (
SELECT created_at, notified_at, deleted_at, old_vulnerability_id, new_vulnerability_id SELECT created_at, notified_at, deleted_at, old_vulnerability_id, new_vulnerability_id
FROM Vulnerability_Notification FROM Vulnerability_Notification
WHERE name = $1` WHERE name = $1`
searchNotificationVulnerableAncestry = `
SELECT DISTINCT ON (a.id)
a.id, a.name
FROM vulnerability_affected_namespaced_feature AS vanf,
ancestry_layer AS al, ancestry_feature AS af, ancestry AS a
WHERE vanf.vulnerability_id = $1
AND a.id >= $2
AND al.ancestry_id = a.id
AND al.id = af.ancestry_layer_id
AND af.namespaced_feature_id = vanf.namespaced_feature_id
ORDER BY a.id ASC
LIMIT $3;`
) )
func queryInsertNotifications(count int) string {
return util.QueryInsert(count,
"vulnerability_notification",
"name",
"created_at",
"old_vulnerability_id",
"new_vulnerability_id",
)
}
var ( var (
errNotificationNotFound = errors.New("requested notification is not found") errNotificationNotFound = errors.New("requested notification is not found")
errVulnerabilityNotFound = errors.New("vulnerability is not in database")
) )
func (tx *pgSession) InsertVulnerabilityNotifications(notifications []database.VulnerabilityNotification) error { func InsertVulnerabilityNotifications(tx *sql.Tx, notifications []database.VulnerabilityNotification) error {
if len(notifications) == 0 { if len(notifications) == 0 {
return nil return nil
} }
@ -122,26 +122,26 @@ func (tx *pgSession) InsertVulnerabilityNotifications(notifications []database.V
oldVulnIDs = append(oldVulnIDs, vulnID) oldVulnIDs = append(oldVulnIDs, vulnID)
} }
ids, err := tx.findNotDeletedVulnerabilityIDs(newVulnIDs) ids, err := vulnerability.FindNotDeletedVulnerabilityIDs(tx, newVulnIDs)
if err != nil { if err != nil {
return err return err
} }
for i, id := range ids { for i, id := range ids {
if !id.Valid { if !id.Valid {
return handleError("findNotDeletedVulnerabilityIDs", errVulnerabilityNotFound) return util.HandleError("findNotDeletedVulnerabilityIDs", errVulnerabilityNotFound)
} }
newVulnIDMap[newVulnIDs[i]] = id newVulnIDMap[newVulnIDs[i]] = id
} }
ids, err = tx.findLatestDeletedVulnerabilityIDs(oldVulnIDs) ids, err = vulnerability.FindLatestDeletedVulnerabilityIDs(tx, oldVulnIDs)
if err != nil { if err != nil {
return err return err
} }
for i, id := range ids { for i, id := range ids {
if !id.Valid { if !id.Valid {
return handleError("findLatestDeletedVulnerabilityIDs", errVulnerabilityNotFound) return util.HandleError("findLatestDeletedVulnerabilityIDs", errVulnerabilityNotFound)
} }
oldVulnIDMap[oldVulnIDs[i]] = id oldVulnIDMap[oldVulnIDs[i]] = id
} }
@ -178,13 +178,13 @@ func (tx *pgSession) InsertVulnerabilityNotifications(notifications []database.V
// multiple updaters, deadlock may happen. // multiple updaters, deadlock may happen.
_, err = tx.Exec(queryInsertNotifications(len(notifications)), keys...) _, err = tx.Exec(queryInsertNotifications(len(notifications)), keys...)
if err != nil { if err != nil {
return handleError("queryInsertNotifications", err) return util.HandleError("queryInsertNotifications", err)
} }
return nil return nil
} }
func (tx *pgSession) FindNewNotification(notifiedBefore time.Time) (database.NotificationHook, bool, error) { func FindNewNotification(tx *sql.Tx, notifiedBefore time.Time) (database.NotificationHook, bool, error) {
var ( var (
notification database.NotificationHook notification database.NotificationHook
created zero.Time created zero.Time
@ -197,7 +197,7 @@ func (tx *pgSession) FindNewNotification(notifiedBefore time.Time) (database.Not
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return notification, false, nil return notification, false, nil
} }
return notification, false, handleError("searchNotificationAvailable", err) return notification, false, util.HandleError("searchNotificationAvailable", err)
} }
notification.Created = created.Time notification.Created = created.Time
@ -207,71 +207,7 @@ func (tx *pgSession) FindNewNotification(notifiedBefore time.Time) (database.Not
return notification, true, nil return notification, true, nil
} }
func (tx *pgSession) findPagedVulnerableAncestries(vulnID int64, limit int, currentToken pagination.Token) (database.PagedVulnerableAncestries, error) { func FindVulnerabilityNotification(tx *sql.Tx, name string, limit int, oldPageToken pagination.Token, newPageToken pagination.Token, key pagination.Key) (
vulnPage := database.PagedVulnerableAncestries{Limit: limit}
currentPage := Page{0}
if currentToken != pagination.FirstPageToken {
if err := tx.key.UnmarshalToken(currentToken, &currentPage); err != nil {
return vulnPage, err
}
}
if err := tx.QueryRow(searchVulnerabilityByID, vulnID).Scan(
&vulnPage.Name,
&vulnPage.Description,
&vulnPage.Link,
&vulnPage.Severity,
&vulnPage.Metadata,
&vulnPage.Namespace.Name,
&vulnPage.Namespace.VersionFormat,
); err != nil {
return vulnPage, handleError("searchVulnerabilityByID", err)
}
// the last result is used for the next page's startID
rows, err := tx.Query(searchNotificationVulnerableAncestry, vulnID, currentPage.StartID, limit+1)
if err != nil {
return vulnPage, handleError("searchNotificationVulnerableAncestry", err)
}
defer rows.Close()
ancestries := []affectedAncestry{}
for rows.Next() {
var ancestry affectedAncestry
err := rows.Scan(&ancestry.id, &ancestry.name)
if err != nil {
return vulnPage, handleError("searchNotificationVulnerableAncestry", err)
}
ancestries = append(ancestries, ancestry)
}
lastIndex := 0
if len(ancestries)-1 < limit {
lastIndex = len(ancestries)
vulnPage.End = true
} else {
// Use the last ancestry's ID as the next page.
lastIndex = len(ancestries) - 1
vulnPage.Next, err = tx.key.MarshalToken(Page{ancestries[len(ancestries)-1].id})
if err != nil {
return vulnPage, err
}
}
vulnPage.Affected = map[int]string{}
for _, ancestry := range ancestries[0:lastIndex] {
vulnPage.Affected[int(ancestry.id)] = ancestry.name
}
vulnPage.Current, err = tx.key.MarshalToken(currentPage)
if err != nil {
return vulnPage, err
}
return vulnPage, nil
}
func (tx *pgSession) FindVulnerabilityNotification(name string, limit int, oldPageToken pagination.Token, newPageToken pagination.Token) (
database.VulnerabilityNotificationWithVulnerable, bool, error) { database.VulnerabilityNotificationWithVulnerable, bool, error) {
var ( var (
noti database.VulnerabilityNotificationWithVulnerable noti database.VulnerabilityNotificationWithVulnerable
@ -294,7 +230,7 @@ func (tx *pgSession) FindVulnerabilityNotification(name string, limit int, oldPa
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return noti, false, nil return noti, false, nil
} }
return noti, false, handleError("searchNotification", err) return noti, false, util.HandleError("searchNotification", err)
} }
if created.Valid { if created.Valid {
@ -310,7 +246,7 @@ func (tx *pgSession) FindVulnerabilityNotification(name string, limit int, oldPa
} }
if oldVulnID.Valid { if oldVulnID.Valid {
page, err := tx.findPagedVulnerableAncestries(oldVulnID.Int64, limit, oldPageToken) page, err := vulnerability.FindPagedVulnerableAncestries(tx, oldVulnID.Int64, limit, oldPageToken, key)
if err != nil { if err != nil {
return noti, false, err return noti, false, err
} }
@ -318,7 +254,7 @@ func (tx *pgSession) FindVulnerabilityNotification(name string, limit int, oldPa
} }
if newVulnID.Valid { if newVulnID.Valid {
page, err := tx.findPagedVulnerableAncestries(newVulnID.Int64, limit, newPageToken) page, err := vulnerability.FindPagedVulnerableAncestries(tx, newVulnID.Int64, limit, newPageToken, key)
if err != nil { if err != nil {
return noti, false, err return noti, false, err
} }
@ -328,44 +264,44 @@ func (tx *pgSession) FindVulnerabilityNotification(name string, limit int, oldPa
return noti, true, nil return noti, true, nil
} }
func (tx *pgSession) MarkNotificationAsRead(name string) error { func MarkNotificationAsRead(tx *sql.Tx, name string) error {
if name == "" { if name == "" {
return commonerr.NewBadRequestError("Empty notification name is not allowed") return commonerr.NewBadRequestError("Empty notification name is not allowed")
} }
r, err := tx.Exec(updatedNotificationAsRead, name) r, err := tx.Exec(updatedNotificationAsRead, name)
if err != nil { if err != nil {
return handleError("updatedNotificationAsRead", err) return util.HandleError("updatedNotificationAsRead", err)
} }
affected, err := r.RowsAffected() affected, err := r.RowsAffected()
if err != nil { if err != nil {
return handleError("updatedNotificationAsRead", err) return util.HandleError("updatedNotificationAsRead", err)
} }
if affected <= 0 { if affected <= 0 {
return handleError("updatedNotificationAsRead", errNotificationNotFound) return util.HandleError("updatedNotificationAsRead", errNotificationNotFound)
} }
return nil return nil
} }
func (tx *pgSession) DeleteNotification(name string) error { func DeleteNotification(tx *sql.Tx, name string) error {
if name == "" { if name == "" {
return commonerr.NewBadRequestError("Empty notification name is not allowed") return commonerr.NewBadRequestError("Empty notification name is not allowed")
} }
result, err := tx.Exec(removeNotification, name) result, err := tx.Exec(removeNotification, name)
if err != nil { if err != nil {
return handleError("removeNotification", err) return util.HandleError("removeNotification", err)
} }
affected, err := result.RowsAffected() affected, err := result.RowsAffected()
if err != nil { if err != nil {
return handleError("removeNotification", err) return util.HandleError("removeNotification", err)
} }
if affected <= 0 { if affected <= 0 {
return handleError("removeNotification", commonerr.ErrNotFound) return util.HandleError("removeNotification", commonerr.ErrNotFound)
} }
return nil return nil

View File

@ -0,0 +1,24 @@
// Copyright 2019 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package page
// Page is the representation of a page for the Postgres schema.
type Page struct {
// StartID is the ID being used as the basis for pagination across database
// results. It is used to search for an ancestry with ID >= StartID.
//
// StartID is required to be unique to every ancestry and always increasing.
StartID int64
}

134
database/pgsql/pgsession.go Normal file
View File

@ -0,0 +1,134 @@
// Copyright 2019 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pgsql
import (
"database/sql"
"time"
"github.com/coreos/clair/database/pgsql/keyvalue"
"github.com/coreos/clair/database/pgsql/vulnerability"
"github.com/coreos/clair/database"
"github.com/coreos/clair/database/pgsql/ancestry"
"github.com/coreos/clair/database/pgsql/detector"
"github.com/coreos/clair/database/pgsql/feature"
"github.com/coreos/clair/database/pgsql/layer"
"github.com/coreos/clair/database/pgsql/lock"
"github.com/coreos/clair/database/pgsql/namespace"
"github.com/coreos/clair/database/pgsql/notification"
"github.com/coreos/clair/pkg/pagination"
)
// Enforce the interface at compile time.
var _ database.Session = &pgSession{}
type pgSession struct {
*sql.Tx
key pagination.Key
}
func (tx *pgSession) UpsertAncestry(a database.Ancestry) error {
return ancestry.UpsertAncestry(tx.Tx, a)
}
func (tx *pgSession) FindAncestry(name string) (database.Ancestry, bool, error) {
return ancestry.FindAncestry(tx.Tx, name)
}
func (tx *pgSession) PersistDetectors(detectors []database.Detector) error {
return detector.PersistDetectors(tx.Tx, detectors)
}
func (tx *pgSession) PersistFeatures(features []database.Feature) error {
return feature.PersistFeatures(tx.Tx, features)
}
func (tx *pgSession) PersistNamespacedFeatures(features []database.NamespacedFeature) error {
return feature.PersistNamespacedFeatures(tx.Tx, features)
}
func (tx *pgSession) CacheAffectedNamespacedFeatures(features []database.NamespacedFeature) error {
return vulnerability.CacheAffectedNamespacedFeatures(tx.Tx, features)
}
func (tx *pgSession) FindAffectedNamespacedFeatures(features []database.NamespacedFeature) ([]database.NullableAffectedNamespacedFeature, error) {
return vulnerability.FindAffectedNamespacedFeatures(tx.Tx, features)
}
func (tx *pgSession) PersistNamespaces(namespaces []database.Namespace) error {
return namespace.PersistNamespaces(tx.Tx, namespaces)
}
func (tx *pgSession) PersistLayer(hash string, features []database.LayerFeature, namespaces []database.LayerNamespace, detectedBy []database.Detector) error {
return layer.PersistLayer(tx.Tx, hash, features, namespaces, detectedBy)
}
func (tx *pgSession) FindLayer(hash string) (database.Layer, bool, error) {
return layer.FindLayer(tx.Tx, hash)
}
func (tx *pgSession) InsertVulnerabilities(vulns []database.VulnerabilityWithAffected) error {
return vulnerability.InsertVulnerabilities(tx.Tx, vulns)
}
func (tx *pgSession) FindVulnerabilities(ids []database.VulnerabilityID) ([]database.NullableVulnerability, error) {
return vulnerability.FindVulnerabilities(tx.Tx, ids)
}
func (tx *pgSession) DeleteVulnerabilities(ids []database.VulnerabilityID) error {
return vulnerability.DeleteVulnerabilities(tx.Tx, ids)
}
func (tx *pgSession) InsertVulnerabilityNotifications(notifications []database.VulnerabilityNotification) error {
return notification.InsertVulnerabilityNotifications(tx.Tx, notifications)
}
func (tx *pgSession) FindNewNotification(notifiedBefore time.Time) (hook database.NotificationHook, found bool, err error) {
return notification.FindNewNotification(tx.Tx, notifiedBefore)
}
func (tx *pgSession) FindVulnerabilityNotification(name string, limit int, oldVulnerabilityPage pagination.Token, newVulnerabilityPage pagination.Token) (noti database.VulnerabilityNotificationWithVulnerable, found bool, err error) {
return notification.FindVulnerabilityNotification(tx.Tx, name, limit, oldVulnerabilityPage, newVulnerabilityPage, tx.key)
}
func (tx *pgSession) MarkNotificationAsRead(name string) error {
return notification.MarkNotificationAsRead(tx.Tx, name)
}
func (tx *pgSession) DeleteNotification(name string) error {
return notification.DeleteNotification(tx.Tx, name)
}
func (tx *pgSession) UpdateKeyValue(key, value string) error {
return keyvalue.UpdateKeyValue(tx.Tx, key, value)
}
func (tx *pgSession) FindKeyValue(key string) (value string, found bool, err error) {
return keyvalue.FindKeyValue(tx.Tx, key)
}
func (tx *pgSession) AcquireLock(name, owner string, duration time.Duration) (acquired bool, expiration time.Time, err error) {
return lock.AcquireLock(tx.Tx, name, owner, duration)
}
func (tx *pgSession) ExtendLock(name, owner string, duration time.Duration) (extended bool, expiration time.Time, err error) {
return lock.ExtendLock(tx.Tx, name, owner, duration)
}
func (tx *pgSession) ReleaseLock(name, owner string) error {
return lock.ReleaseLock(tx.Tx, name, owner)
}

View File

@ -0,0 +1,56 @@
// Copyright 2019 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pgsql
import (
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/coreos/clair/database/pgsql/namespace"
"github.com/coreos/clair/database/pgsql/testutil"
)
const (
numVulnerabilities = 100
numFeatures = 100
)
func TestConcurrency(t *testing.T) {
db, cleanup := testutil.CreateTestDB(t, "concurrency")
defer cleanup()
var wg sync.WaitGroup
// there's a limit on the number of concurrent connections in the pool
wg.Add(30)
for i := 0; i < 30; i++ {
go func() {
defer wg.Done()
nsNamespaces := testutil.GenRandomNamespaces(t, 100)
tx, err := db.Begin()
if err != nil {
panic(err)
}
assert.Nil(t, namespace.PersistNamespaces(tx, nsNamespaces))
if err := tx.Commit(); err != nil {
panic(err)
}
}()
}
wg.Wait()
}

View File

@ -21,13 +21,10 @@ import (
"io/ioutil" "io/ioutil"
"net/url" "net/url"
"strings" "strings"
"time"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
"github.com/hashicorp/golang-lru" "github.com/hashicorp/golang-lru"
"github.com/lib/pq"
"github.com/prometheus/client_golang/prometheus"
"github.com/remind101/migrate" "github.com/remind101/migrate"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -37,50 +34,10 @@ import (
"github.com/coreos/clair/pkg/pagination" "github.com/coreos/clair/pkg/pagination"
) )
var (
promErrorsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "clair_pgsql_errors_total",
Help: "Number of errors that PostgreSQL requests generated.",
}, []string{"request"})
promCacheHitsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "clair_pgsql_cache_hits_total",
Help: "Number of cache hits that the PostgreSQL backend did.",
}, []string{"object"})
promCacheQueriesTotal = prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "clair_pgsql_cache_queries_total",
Help: "Number of cache queries that the PostgreSQL backend did.",
}, []string{"object"})
promQueryDurationMilliseconds = prometheus.NewHistogramVec(prometheus.HistogramOpts{
Name: "clair_pgsql_query_duration_milliseconds",
Help: "Time it takes to execute the database query.",
}, []string{"query", "subquery"})
promConcurrentLockVAFV = prometheus.NewGauge(prometheus.GaugeOpts{
Name: "clair_pgsql_concurrent_lock_vafv_total",
Help: "Number of transactions trying to hold the exclusive Vulnerability_Affects_Feature lock.",
})
)
func init() { func init() {
prometheus.MustRegister(promErrorsTotal)
prometheus.MustRegister(promCacheHitsTotal)
prometheus.MustRegister(promCacheQueriesTotal)
prometheus.MustRegister(promQueryDurationMilliseconds)
prometheus.MustRegister(promConcurrentLockVAFV)
database.Register("pgsql", openDatabase) database.Register("pgsql", openDatabase)
} }
// pgSessionCache is the session's cache, which holds the pgSQL's cache and the
// individual session's cache. Only when session.Commit is called, all the
// changes to pgSQL cache will be applied.
type pgSessionCache struct {
c *lru.ARCCache
}
type pgSQL struct { type pgSQL struct {
*sql.DB *sql.DB
@ -88,12 +45,6 @@ type pgSQL struct {
config Config config Config
} }
type pgSession struct {
*sql.Tx
key pagination.Key
}
// Begin initiates a transaction to database. // Begin initiates a transaction to database.
// //
// The expected transaction isolation level in this implementation is "Read // The expected transaction isolation level in this implementation is "Read
@ -109,10 +60,6 @@ func (pgSQL *pgSQL) Begin() (database.Session, error) {
}, nil }, nil
} }
func (tx *pgSession) Commit() error {
return tx.Tx.Commit()
}
// Close closes the database and destroys if ManageDatabaseLifecycle has been specified in // Close closes the database and destroys if ManageDatabaseLifecycle has been specified in
// the configuration. // the configuration.
func (pgSQL *pgSQL) Close() { func (pgSQL *pgSQL) Close() {
@ -131,15 +78,6 @@ func (pgSQL *pgSQL) Ping() bool {
return pgSQL.DB.Ping() == nil return pgSQL.DB.Ping() == nil
} }
// Page is the representation of a page for the Postgres schema.
type Page struct {
// StartID is the ID being used as the basis for pagination across database
// results. It is used to search for an ancestry with ID >= StartID.
//
// StartID is required to be unique to every ancestry and always increasing.
StartID int64
}
// Config is the configuration that is used by openDatabase. // Config is the configuration that is used by openDatabase.
type Config struct { type Config struct {
Source string Source string
@ -313,42 +251,3 @@ func dropDatabase(source, dbName string) error {
return nil return nil
} }
// handleError logs an error with an extra description and masks the error if it's an SQL one.
// The function ensures we never return plain SQL errors and leak anything.
// The function should be used for every database query error.
func handleError(desc string, err error) error {
if err == nil {
return nil
}
if err == sql.ErrNoRows {
return commonerr.ErrNotFound
}
log.WithError(err).WithField("Description", desc).Error("database: handled database error")
promErrorsTotal.WithLabelValues(desc).Inc()
if _, o := err.(*pq.Error); o || err == sql.ErrTxDone || strings.HasPrefix(err.Error(), "sql:") {
return database.ErrBackendException
}
return err
}
// isErrUniqueViolation determines is the given error is a unique contraint violation.
func isErrUniqueViolation(err error) bool {
pqErr, ok := err.(*pq.Error)
return ok && pqErr.Code == "23505"
}
// observeQueryTime computes the time elapsed since `start` to represent the
// query time.
// 1. `query` is a pgSession function name.
// 2. `subquery` is a specific query or a batched query.
// 3. `start` is the time right before query is executed.
func observeQueryTime(query, subquery string, start time.Time) {
promQueryDurationMilliseconds.
WithLabelValues(query, subquery).
Observe(float64(time.Since(start).Nanoseconds()) / float64(time.Millisecond))
}

View File

@ -1,272 +0,0 @@
// Copyright 2016 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pgsql
import (
"database/sql"
"fmt"
"io/ioutil"
"os"
"path/filepath"
"runtime"
"strings"
"testing"
"github.com/pborman/uuid"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
yaml "gopkg.in/yaml.v2"
"github.com/coreos/clair/database"
"github.com/coreos/clair/pkg/pagination"
)
var (
withFixtureName, withoutFixtureName string
)
var testPaginationKey = pagination.Must(pagination.NewKey())
func genTemplateDatabase(name string, loadFixture bool) (sourceURL string, dbName string) {
config := generateTestConfig(name, loadFixture, false)
source := config.Options["source"].(string)
name, url, err := parseConnectionString(source)
if err != nil {
panic(err)
}
fixturePath := config.Options["fixturepath"].(string)
if err := createDatabase(url, name); err != nil {
panic(err)
}
// migration and fixture
db, err := sql.Open("postgres", source)
if err != nil {
panic(err)
}
// Verify database state.
if err := db.Ping(); err != nil {
panic(err)
}
// Run migrations.
if err := migrateDatabase(db); err != nil {
panic(err)
}
if loadFixture {
d, err := ioutil.ReadFile(fixturePath)
if err != nil {
panic(err)
}
_, err = db.Exec(string(d))
if err != nil {
panic(err)
}
}
db.Exec("UPDATE pg_database SET datistemplate=True WHERE datname=$1", name)
db.Close()
log.Info("Generated Template database ", name)
return url, name
}
func dropTemplateDatabase(url string, name string) {
db, err := sql.Open("postgres", url)
if err != nil {
panic(err)
}
if _, err := db.Exec("UPDATE pg_database SET datistemplate=False WHERE datname=$1", name); err != nil {
panic(err)
}
if err := db.Close(); err != nil {
panic(err)
}
if err := dropDatabase(url, name); err != nil {
panic(err)
}
}
func TestMain(m *testing.M) {
fURL, fName := genTemplateDatabase("fixture", true)
nfURL, nfName := genTemplateDatabase("nonfixture", false)
withFixtureName = fName
withoutFixtureName = nfName
rc := m.Run()
dropTemplateDatabase(fURL, fName)
dropTemplateDatabase(nfURL, nfName)
os.Exit(rc)
}
func openCopiedDatabase(testConfig database.RegistrableComponentConfig, fixture bool) (database.Datastore, error) {
var fixtureName string
if fixture {
fixtureName = withFixtureName
} else {
fixtureName = withoutFixtureName
}
// copy the database into new database
var pg pgSQL
// Parse configuration.
pg.config = Config{
CacheSize: 16384,
}
bytes, err := yaml.Marshal(testConfig.Options)
if err != nil {
return nil, fmt.Errorf("pgsql: could not load configuration: %v", err)
}
err = yaml.Unmarshal(bytes, &pg.config)
if err != nil {
return nil, fmt.Errorf("pgsql: could not load configuration: %v", err)
}
dbName, pgSourceURL, err := parseConnectionString(pg.config.Source)
if err != nil {
return nil, err
}
// Create database.
if pg.config.ManageDatabaseLifecycle {
if err = copyDatabase(pgSourceURL, dbName, fixtureName); err != nil {
return nil, err
}
}
// Open database.
pg.DB, err = sql.Open("postgres", pg.config.Source)
fmt.Println("database", pg.config.Source)
if err != nil {
pg.Close()
return nil, fmt.Errorf("pgsql: could not open database: %v", err)
}
return &pg, nil
}
// copyDatabase creates a new database with
func copyDatabase(url, name string, templateName string) error {
// Open database.
db, err := sql.Open("postgres", url)
if err != nil {
return fmt.Errorf("pgsql: could not open 'postgres' database for creation: %v", err)
}
defer db.Close()
// Create database with copy
_, err = db.Exec("CREATE DATABASE " + name + " WITH TEMPLATE " + templateName)
if err != nil {
return fmt.Errorf("pgsql: could not create database: %v", err)
}
return nil
}
func openDatabaseForTest(testName string, loadFixture bool) (*pgSQL, error) {
var (
db database.Datastore
err error
testConfig = generateTestConfig(testName, loadFixture, true)
)
db, err = openCopiedDatabase(testConfig, loadFixture)
if err != nil {
return nil, err
}
datastore := db.(*pgSQL)
return datastore, nil
}
func generateTestConfig(testName string, loadFixture bool, manageLife bool) database.RegistrableComponentConfig {
dbName := "test_" + strings.ToLower(testName) + "_" + strings.Replace(uuid.New(), "-", "_", -1)
var fixturePath string
if loadFixture {
_, filename, _, _ := runtime.Caller(0)
fixturePath = filepath.Join(filepath.Dir(filename)) + "/testdata/data.sql"
}
source := fmt.Sprintf("postgresql://postgres@127.0.0.1:5432/%s?sslmode=disable", dbName)
if sourceEnv := os.Getenv("CLAIR_TEST_PGSQL"); sourceEnv != "" {
source = fmt.Sprintf(sourceEnv, dbName)
}
return database.RegistrableComponentConfig{
Options: map[string]interface{}{
"source": source,
"cachesize": 0,
"managedatabaselifecycle": manageLife,
"fixturepath": fixturePath,
"paginationkey": testPaginationKey.String(),
},
}
}
func closeTest(t *testing.T, store database.Datastore, session database.Session) {
err := session.Rollback()
if err != nil {
t.Error(err)
t.FailNow()
}
store.Close()
}
func openSessionForTest(t *testing.T, name string, loadFixture bool) (*pgSQL, *pgSession) {
store, err := openDatabaseForTest(name, loadFixture)
if err != nil {
t.Error(err)
t.FailNow()
}
tx, err := store.Begin()
if err != nil {
t.Error(err)
t.FailNow()
}
return store, tx.(*pgSession)
}
func restartSession(t *testing.T, datastore *pgSQL, tx *pgSession, commit bool) *pgSession {
var err error
if !commit {
err = tx.Rollback()
} else {
err = tx.Commit()
}
if assert.Nil(t, err) {
session, err := datastore.Begin()
if assert.Nil(t, err) {
return session.(*pgSession)
}
}
t.FailNow()
return nil
}

View File

@ -1,187 +0,0 @@
// Copyright 2015 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pgsql
import (
"fmt"
"strings"
"github.com/lib/pq"
)
// NOTE(Sida): Every search query can only have count less than postgres set
// stack depth. IN will be resolved to nested OR_s and the parser might exceed
// stack depth. TODO(Sida): Generate different queries for different count: if
// count < 5120, use IN; for count > 5120 and < 65536, use temporary table; for
// count > 65535, use is expected to split data into batches.
func querySearchLastDeletedVulnerabilityID(count int) string {
return fmt.Sprintf(`
SELECT vid, vname, nname FROM (
SELECT v.id AS vid, v.name AS vname, n.name AS nname,
row_number() OVER (
PARTITION by (v.name, n.name)
ORDER BY v.deleted_at DESC
) AS rownum
FROM vulnerability AS v, namespace AS n
WHERE v.namespace_id = n.id
AND (v.name, n.name) IN ( %s )
AND v.deleted_at IS NOT NULL
) tmp WHERE rownum <= 1`,
queryString(2, count))
}
func querySearchNotDeletedVulnerabilityID(count int) string {
return fmt.Sprintf(`
SELECT v.id, v.name, n.name FROM vulnerability AS v, namespace AS n
WHERE v.namespace_id = n.id AND (v.name, n.name) IN (%s)
AND v.deleted_at IS NULL`,
queryString(2, count))
}
func querySearchFeatureID(featureCount int) string {
return fmt.Sprintf(`
SELECT id, name, version, version_format, type
FROM Feature WHERE (name, version, version_format, type) IN (%s)`,
queryString(4, featureCount),
)
}
func querySearchNamespacedFeature(nsfCount int) string {
return fmt.Sprintf(`
SELECT nf.id, f.name, f.version, f.version_format, t.name, n.name
FROM namespaced_feature AS nf, feature AS f, namespace AS n, feature_type AS t
WHERE nf.feature_id = f.id
AND nf.namespace_id = n.id
AND n.version_format = f.version_format
AND f.type = t.id
AND (f.name, f.version, f.version_format, t.name, n.name) IN (%s)`,
queryString(5, nsfCount),
)
}
func querySearchNamespace(nsCount int) string {
return fmt.Sprintf(
`SELECT id, name, version_format
FROM namespace WHERE (name, version_format) IN (%s)`,
queryString(2, nsCount),
)
}
func queryInsert(count int, table string, columns ...string) string {
base := `INSERT INTO %s (%s) VALUES %s`
t := pq.QuoteIdentifier(table)
cols := make([]string, len(columns))
for i, c := range columns {
cols[i] = pq.QuoteIdentifier(c)
}
colsQuoted := strings.Join(cols, ",")
return fmt.Sprintf(base, t, colsQuoted, queryString(len(columns), count))
}
func queryPersist(count int, table, constraint string, columns ...string) string {
ct := ""
if constraint != "" {
ct = fmt.Sprintf("ON CONSTRAINT %s", constraint)
}
return fmt.Sprintf("%s ON CONFLICT %s DO NOTHING", queryInsert(count, table, columns...), ct)
}
func queryInsertNotifications(count int) string {
return queryInsert(count,
"vulnerability_notification",
"name",
"created_at",
"old_vulnerability_id",
"new_vulnerability_id",
)
}
func queryPersistFeature(count int) string {
return queryPersist(count,
"feature",
"feature_name_version_version_format_type_key",
"name",
"version",
"version_format",
"type")
}
func queryPersistLayerFeature(count int) string {
return queryPersist(count,
"layer_feature",
"layer_feature_layer_id_feature_id_namespace_id_key",
"layer_id",
"feature_id",
"detector_id",
"namespace_id")
}
func queryPersistNamespace(count int) string {
return queryPersist(count,
"namespace",
"namespace_name_version_format_key",
"name",
"version_format")
}
func queryPersistLayerNamespace(count int) string {
return queryPersist(count,
"layer_namespace",
"layer_namespace_layer_id_namespace_id_key",
"layer_id",
"namespace_id",
"detector_id")
}
// size of key and array should be both greater than 0
func queryString(keySize, arraySize int) string {
if arraySize <= 0 || keySize <= 0 {
panic("Bulk Query requires size of element tuple and number of elements to be greater than 0")
}
keys := make([]string, 0, arraySize)
for i := 0; i < arraySize; i++ {
key := make([]string, keySize)
for j := 0; j < keySize; j++ {
key[j] = fmt.Sprintf("$%d", i*keySize+j+1)
}
keys = append(keys, fmt.Sprintf("(%s)", strings.Join(key, ",")))
}
return strings.Join(keys, ",")
}
func queryPersistNamespacedFeature(count int) string {
return queryPersist(count, "namespaced_feature",
"namespaced_feature_namespace_id_feature_id_key",
"feature_id",
"namespace_id")
}
func queryPersistVulnerabilityAffectedNamespacedFeature(count int) string {
return queryPersist(count, "vulnerability_affected_namespaced_feature",
"vulnerability_affected_namesp_vulnerability_id_namespaced_f_key",
"vulnerability_id",
"namespaced_feature_id",
"added_by")
}
func queryPersistLayer(count int) string {
return queryPersist(count, "layer", "", "hash")
}
func queryInvalidateVulnerabilityCache(count int) string {
return fmt.Sprintf(`DELETE FROM vulnerability_affected_feature
WHERE vulnerability_id IN (%s)`,
queryString(1, count))
}

View File

@ -1,424 +0,0 @@
// Copyright 2018 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pgsql
import (
"database/sql"
"fmt"
"io/ioutil"
"math/rand"
"os"
"path/filepath"
"runtime"
"strings"
"testing"
"time"
"github.com/remind101/migrate"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coreos/clair/database"
"github.com/coreos/clair/database/pgsql/migrations"
"github.com/coreos/clair/pkg/pagination"
)
// int keys must be the consistent with the database ID.
var (
realFeatures = map[int]database.Feature{
1: {"ourchat", "0.5", "dpkg", "source"},
2: {"openssl", "1.0", "dpkg", "source"},
3: {"openssl", "2.0", "dpkg", "source"},
4: {"fake", "2.0", "rpm", "source"},
5: {"mount", "2.31.1-0.4ubuntu3.1", "dpkg", "binary"},
}
realNamespaces = map[int]database.Namespace{
1: {"debian:7", "dpkg"},
2: {"debian:8", "dpkg"},
3: {"fake:1.0", "rpm"},
4: {"cpe:/o:redhat:enterprise_linux:7::server", "rpm"},
}
realNamespacedFeatures = map[int]database.NamespacedFeature{
1: {realFeatures[1], realNamespaces[1]},
2: {realFeatures[2], realNamespaces[1]},
3: {realFeatures[2], realNamespaces[2]},
4: {realFeatures[3], realNamespaces[1]},
}
realDetectors = map[int]database.Detector{
1: database.NewNamespaceDetector("os-release", "1.0"),
2: database.NewFeatureDetector("dpkg", "1.0"),
3: database.NewFeatureDetector("rpm", "1.0"),
4: database.NewNamespaceDetector("apt-sources", "1.0"),
}
realLayers = map[int]database.Layer{
2: {
Hash: "layer-1",
By: []database.Detector{realDetectors[1], realDetectors[2]},
Features: []database.LayerFeature{
{realFeatures[1], realDetectors[2], database.Namespace{}},
{realFeatures[2], realDetectors[2], database.Namespace{}},
},
Namespaces: []database.LayerNamespace{
{realNamespaces[1], realDetectors[1]},
},
},
6: {
Hash: "layer-4",
By: []database.Detector{realDetectors[1], realDetectors[2], realDetectors[3], realDetectors[4]},
Features: []database.LayerFeature{
{realFeatures[4], realDetectors[3], database.Namespace{}},
{realFeatures[3], realDetectors[2], database.Namespace{}},
},
Namespaces: []database.LayerNamespace{
{realNamespaces[1], realDetectors[1]},
{realNamespaces[3], realDetectors[4]},
},
},
}
realAncestries = map[int]database.Ancestry{
2: {
Name: "ancestry-2",
By: []database.Detector{realDetectors[2], realDetectors[1]},
Layers: []database.AncestryLayer{
{
"layer-0",
[]database.AncestryFeature{},
},
{
"layer-1",
[]database.AncestryFeature{},
},
{
"layer-2",
[]database.AncestryFeature{
{
realNamespacedFeatures[1],
realDetectors[2],
realDetectors[1],
},
},
},
{
"layer-3b",
[]database.AncestryFeature{
{
realNamespacedFeatures[3],
realDetectors[2],
realDetectors[1],
},
},
},
},
},
}
realVulnerability = map[int]database.Vulnerability{
1: {
Name: "CVE-OPENSSL-1-DEB7",
Namespace: realNamespaces[1],
Description: "A vulnerability affecting OpenSSL < 2.0 on Debian 7.0",
Link: "http://google.com/#q=CVE-OPENSSL-1-DEB7",
Severity: database.HighSeverity,
},
2: {
Name: "CVE-NOPE",
Namespace: realNamespaces[1],
Description: "A vulnerability affecting nothing",
Severity: database.UnknownSeverity,
},
}
realNotification = map[int]database.VulnerabilityNotification{
1: {
NotificationHook: database.NotificationHook{
Name: "test",
},
Old: takeVulnerabilityPointerFromMap(realVulnerability, 2),
New: takeVulnerabilityPointerFromMap(realVulnerability, 1),
},
}
fakeFeatures = map[int]database.Feature{
1: {
Name: "ourchat",
Version: "0.6",
VersionFormat: "dpkg",
Type: "source",
},
}
fakeNamespaces = map[int]database.Namespace{
1: {"green hat", "rpm"},
}
fakeNamespacedFeatures = map[int]database.NamespacedFeature{
1: {
Feature: fakeFeatures[0],
Namespace: realNamespaces[0],
},
}
fakeDetector = map[int]database.Detector{
1: {
Name: "fake",
Version: "1.0",
DType: database.FeatureDetectorType,
},
2: {
Name: "fake2",
Version: "2.0",
DType: database.NamespaceDetectorType,
},
}
)
func takeVulnerabilityPointerFromMap(m map[int]database.Vulnerability, id int) *database.Vulnerability {
x := m[id]
return &x
}
func takeAncestryPointerFromMap(m map[int]database.Ancestry, id int) *database.Ancestry {
x := m[id]
return &x
}
func takeLayerPointerFromMap(m map[int]database.Layer, id int) *database.Layer {
x := m[id]
return &x
}
func listNamespaces(t *testing.T, tx *pgSession) []database.Namespace {
rows, err := tx.Query("SELECT name, version_format FROM namespace")
if err != nil {
t.FailNow()
}
defer rows.Close()
namespaces := []database.Namespace{}
for rows.Next() {
var ns database.Namespace
err := rows.Scan(&ns.Name, &ns.VersionFormat)
if err != nil {
t.FailNow()
}
namespaces = append(namespaces, ns)
}
return namespaces
}
func assertVulnerabilityNotificationWithVulnerableEqual(t *testing.T, key pagination.Key, expected, actual *database.VulnerabilityNotificationWithVulnerable) bool {
if expected == actual {
return true
}
if expected == nil || actual == nil {
return assert.Equal(t, expected, actual)
}
return assert.Equal(t, expected.NotificationHook, actual.NotificationHook) &&
AssertPagedVulnerableAncestriesEqual(t, key, expected.Old, actual.Old) &&
AssertPagedVulnerableAncestriesEqual(t, key, expected.New, actual.New)
}
func AssertPagedVulnerableAncestriesEqual(t *testing.T, key pagination.Key, expected, actual *database.PagedVulnerableAncestries) bool {
if expected == actual {
return true
}
if expected == nil || actual == nil {
return assert.Equal(t, expected, actual)
}
return database.AssertVulnerabilityEqual(t, &expected.Vulnerability, &actual.Vulnerability) &&
assert.Equal(t, expected.Limit, actual.Limit) &&
assert.Equal(t, mustUnmarshalToken(key, expected.Current), mustUnmarshalToken(key, actual.Current)) &&
assert.Equal(t, mustUnmarshalToken(key, expected.Next), mustUnmarshalToken(key, actual.Next)) &&
assert.Equal(t, expected.End, actual.End) &&
database.AssertIntStringMapEqual(t, expected.Affected, actual.Affected)
}
func mustUnmarshalToken(key pagination.Key, token pagination.Token) Page {
if token == pagination.FirstPageToken {
return Page{}
}
p := Page{}
if err := key.UnmarshalToken(token, &p); err != nil {
panic(err)
}
return p
}
func mustMarshalToken(key pagination.Key, v interface{}) pagination.Token {
token, err := key.MarshalToken(v)
if err != nil {
panic(err)
}
return token
}
var userDBCount = `SELECT count(datname) FROM pg_database WHERE datistemplate = FALSE AND datname != 'postgres';`
func createAndConnectTestDB(t *testing.T, testName string) (*sql.DB, func()) {
uri := "postgres@127.0.0.1:5432"
connectionTemplate := "postgresql://%s?sslmode=disable"
if envURI := os.Getenv("CLAIR_TEST_PGSQL"); envURI != "" {
uri = envURI
}
db, err := sql.Open("postgres", fmt.Sprintf(connectionTemplate, uri))
if err != nil {
panic(err)
}
testName = strings.ToLower(testName)
dbName := fmt.Sprintf("test_%s_%s", testName, time.Now().UTC().Format("2006_01_02_15_04_05"))
t.Logf("creating temporary database name = %s", dbName)
_, err = db.Exec("CREATE DATABASE " + dbName)
if err != nil {
panic(err)
}
testDB, err := sql.Open("postgres", fmt.Sprintf(connectionTemplate, uri+"/"+dbName))
if err != nil {
panic(err)
}
return testDB, func() {
t.Logf("cleaning up temporary database %s", dbName)
defer db.Close()
if err := testDB.Close(); err != nil {
panic(err)
}
if _, err := db.Exec(`DROP DATABASE ` + dbName); err != nil {
panic(err)
}
// ensure the database is cleaned up
var count int
if err := db.QueryRow(userDBCount).Scan(&count); err != nil {
panic(err)
}
}
}
func createTestPgSQL(t *testing.T, testName string) (*pgSQL, func()) {
connection, cleanup := createAndConnectTestDB(t, testName)
err := migrate.NewPostgresMigrator(connection).Exec(migrate.Up, migrations.Migrations...)
if err != nil {
require.Nil(t, err, err.Error())
}
return &pgSQL{connection, nil, Config{PaginationKey: pagination.Must(pagination.NewKey()).String()}}, cleanup
}
func createTestPgSQLWithFixtures(t *testing.T, testName string) (*pgSQL, func()) {
connection, cleanup := createTestPgSQL(t, testName)
session, err := connection.Begin()
if err != nil {
panic(err)
}
defer session.Rollback()
loadFixtures(session.(*pgSession))
if err = session.Commit(); err != nil {
panic(err)
}
return connection, cleanup
}
func createTestPgSession(t *testing.T, testName string) (*pgSession, func()) {
connection, cleanup := createTestPgSQL(t, testName)
session, err := connection.Begin()
if err != nil {
panic(err)
}
return session.(*pgSession), func() {
session.Rollback()
cleanup()
}
}
func createTestPgSessionWithFixtures(t *testing.T, testName string) (*pgSession, func()) {
tx, cleanup := createTestPgSession(t, testName)
defer func() {
// ensure to cleanup when loadFixtures failed
if r := recover(); r != nil {
cleanup()
}
}()
loadFixtures(tx)
return tx, cleanup
}
func loadFixtures(tx *pgSession) {
_, filename, _, _ := runtime.Caller(0)
fixturePath := filepath.Join(filepath.Dir(filename)) + "/testdata/data.sql"
d, err := ioutil.ReadFile(fixturePath)
if err != nil {
panic(err)
}
_, err = tx.Exec(string(d))
if err != nil {
panic(err)
}
}
func assertVulnerabilityWithAffectedEqual(t *testing.T, expected database.VulnerabilityWithAffected, actual database.VulnerabilityWithAffected) bool {
return assert.Equal(t, expected.Vulnerability, actual.Vulnerability) && assertAffectedFeaturesEqual(t, expected.Affected, actual.Affected)
}
func assertAffectedFeaturesEqual(t *testing.T, expected []database.AffectedFeature, actual []database.AffectedFeature) bool {
if assert.Len(t, actual, len(expected)) {
has := map[database.AffectedFeature]bool{}
for _, i := range expected {
has[i] = false
}
for _, i := range actual {
if visited, ok := has[i]; !ok {
return false
} else if visited {
return false
}
has[i] = true
}
return true
}
return false
}
func genRandomNamespaces(t *testing.T, count int) []database.Namespace {
r := make([]database.Namespace, count)
for i := 0; i < count; i++ {
r[i] = database.Namespace{
Name: fmt.Sprint(rand.Int()),
VersionFormat: "dpkg",
}
}
return r
}

View File

@ -0,0 +1,98 @@
// Copyright 2019 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package testutil
import (
"testing"
"github.com/coreos/clair/database"
"github.com/coreos/clair/pkg/pagination"
"github.com/stretchr/testify/assert"
)
func AssertVulnerabilityNotificationWithVulnerableEqual(t *testing.T, key pagination.Key, expected, actual *database.VulnerabilityNotificationWithVulnerable) bool {
if expected == actual {
return true
}
if expected == nil || actual == nil {
return assert.Equal(t, expected, actual)
}
return assert.Equal(t, expected.NotificationHook, actual.NotificationHook) &&
AssertPagedVulnerableAncestriesEqual(t, key, expected.Old, actual.Old) &&
AssertPagedVulnerableAncestriesEqual(t, key, expected.New, actual.New)
}
func AssertPagedVulnerableAncestriesEqual(t *testing.T, key pagination.Key, expected, actual *database.PagedVulnerableAncestries) bool {
if expected == actual {
return true
}
if expected == nil || actual == nil {
return assert.Equal(t, expected, actual)
}
return database.AssertVulnerabilityEqual(t, &expected.Vulnerability, &actual.Vulnerability) &&
assert.Equal(t, expected.Limit, actual.Limit) &&
assert.Equal(t, mustUnmarshalToken(key, expected.Current), mustUnmarshalToken(key, actual.Current)) &&
assert.Equal(t, mustUnmarshalToken(key, expected.Next), mustUnmarshalToken(key, actual.Next)) &&
assert.Equal(t, expected.End, actual.End) &&
database.AssertIntStringMapEqual(t, expected.Affected, actual.Affected)
}
func AssertVulnerabilityWithAffectedEqual(t *testing.T, expected database.VulnerabilityWithAffected, actual database.VulnerabilityWithAffected) bool {
return assert.Equal(t, expected.Vulnerability, actual.Vulnerability) && AssertAffectedFeaturesEqual(t, expected.Affected, actual.Affected)
}
func AssertAffectedFeaturesEqual(t *testing.T, expected []database.AffectedFeature, actual []database.AffectedFeature) bool {
if assert.Len(t, actual, len(expected)) {
has := map[database.AffectedFeature]bool{}
for _, i := range expected {
has[i] = false
}
for _, i := range actual {
if visited, ok := has[i]; !ok {
return false
} else if visited {
return false
}
has[i] = true
}
return true
}
return false
}
func AssertNamespacedFeatureEqual(t *testing.T, expected []database.NamespacedFeature, actual []database.NamespacedFeature) bool {
if assert.Len(t, actual, len(expected)) {
has := map[database.NamespacedFeature]bool{}
for _, nf := range expected {
has[nf] = false
}
for _, nf := range actual {
has[nf] = true
}
for nf, visited := range has {
if !assert.True(t, visited, nf.Namespace.Name+":"+nf.Name+" is expected") {
return false
}
}
return true
}
return false
}

View File

@ -0,0 +1,171 @@
// Copyright 2019 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package testutil
import "github.com/coreos/clair/database"
// int keys must be the consistent with the database ID.
var (
RealFeatures = map[int]database.Feature{
1: {"ourchat", "0.5", "dpkg", "source"},
2: {"openssl", "1.0", "dpkg", "source"},
3: {"openssl", "2.0", "dpkg", "source"},
4: {"fake", "2.0", "rpm", "source"},
5: {"mount", "2.31.1-0.4ubuntu3.1", "dpkg", "binary"},
}
RealNamespaces = map[int]database.Namespace{
1: {"debian:7", "dpkg"},
2: {"debian:8", "dpkg"},
3: {"fake:1.0", "rpm"},
4: {"cpe:/o:redhat:enterprise_linux:7::server", "rpm"},
}
RealNamespacedFeatures = map[int]database.NamespacedFeature{
1: {RealFeatures[1], RealNamespaces[1]},
2: {RealFeatures[2], RealNamespaces[1]},
3: {RealFeatures[2], RealNamespaces[2]},
4: {RealFeatures[3], RealNamespaces[1]},
}
RealDetectors = map[int]database.Detector{
1: database.NewNamespaceDetector("os-release", "1.0"),
2: database.NewFeatureDetector("dpkg", "1.0"),
3: database.NewFeatureDetector("rpm", "1.0"),
4: database.NewNamespaceDetector("apt-sources", "1.0"),
}
RealLayers = map[int]database.Layer{
2: {
Hash: "layer-1",
By: []database.Detector{RealDetectors[1], RealDetectors[2]},
Features: []database.LayerFeature{
{RealFeatures[1], RealDetectors[2], database.Namespace{}},
{RealFeatures[2], RealDetectors[2], database.Namespace{}},
},
Namespaces: []database.LayerNamespace{
{RealNamespaces[1], RealDetectors[1]},
},
},
6: {
Hash: "layer-4",
By: []database.Detector{RealDetectors[1], RealDetectors[2], RealDetectors[3], RealDetectors[4]},
Features: []database.LayerFeature{
{RealFeatures[4], RealDetectors[3], database.Namespace{}},
{RealFeatures[3], RealDetectors[2], database.Namespace{}},
},
Namespaces: []database.LayerNamespace{
{RealNamespaces[1], RealDetectors[1]},
{RealNamespaces[3], RealDetectors[4]},
},
},
}
RealAncestries = map[int]database.Ancestry{
2: {
Name: "ancestry-2",
By: []database.Detector{RealDetectors[2], RealDetectors[1]},
Layers: []database.AncestryLayer{
{
"layer-0",
[]database.AncestryFeature{},
},
{
"layer-1",
[]database.AncestryFeature{},
},
{
"layer-2",
[]database.AncestryFeature{
{
RealNamespacedFeatures[1],
RealDetectors[2],
RealDetectors[1],
},
},
},
{
"layer-3b",
[]database.AncestryFeature{
{
RealNamespacedFeatures[3],
RealDetectors[2],
RealDetectors[1],
},
},
},
},
},
}
RealVulnerability = map[int]database.Vulnerability{
1: {
Name: "CVE-OPENSSL-1-DEB7",
Namespace: RealNamespaces[1],
Description: "A vulnerability affecting OpenSSL < 2.0 on Debian 7.0",
Link: "http://google.com/#q=CVE-OPENSSL-1-DEB7",
Severity: database.HighSeverity,
},
2: {
Name: "CVE-NOPE",
Namespace: RealNamespaces[1],
Description: "A vulnerability affecting nothing",
Severity: database.UnknownSeverity,
},
}
RealNotification = map[int]database.VulnerabilityNotification{
1: {
NotificationHook: database.NotificationHook{
Name: "test",
},
Old: takeVulnerabilityPointerFromMap(RealVulnerability, 2),
New: takeVulnerabilityPointerFromMap(RealVulnerability, 1),
},
}
FakeFeatures = map[int]database.Feature{
1: {
Name: "ourchat",
Version: "0.6",
VersionFormat: "dpkg",
Type: "source",
},
}
FakeNamespaces = map[int]database.Namespace{
1: {"green hat", "rpm"},
}
FakeNamespacedFeatures = map[int]database.NamespacedFeature{
1: {
Feature: FakeFeatures[0],
Namespace: RealNamespaces[0],
},
}
FakeDetector = map[int]database.Detector{
1: {
Name: "fake",
Version: "1.0",
DType: database.FeatureDetectorType,
},
2: {
Name: "fake2",
Version: "2.0",
DType: database.NamespaceDetectorType,
},
}
)

View File

@ -0,0 +1,207 @@
// Copyright 2019 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package testutil
import (
"database/sql"
"fmt"
"io/ioutil"
"os"
"path/filepath"
"runtime"
"strings"
"testing"
"time"
"github.com/coreos/clair/database/pgsql/migrations"
"github.com/coreos/clair/pkg/pagination"
"github.com/remind101/migrate"
)
var TestPaginationKey = pagination.Must(pagination.NewKey())
var userDBCount = `SELECT count(datname) FROM pg_database WHERE datistemplate = FALSE AND datname != 'postgres';`
func CreateAndConnectTestDB(t *testing.T, testName string) (*sql.DB, func()) {
uri := "postgres@127.0.0.1:5432"
connectionTemplate := "postgresql://%s?sslmode=disable"
if envURI := os.Getenv("CLAIR_TEST_PGSQL"); envURI != "" {
uri = envURI
}
db, err := sql.Open("postgres", fmt.Sprintf(connectionTemplate, uri))
if err != nil {
panic(err)
}
testName = strings.ToLower(testName)
dbName := fmt.Sprintf("test_%s_%s", testName, time.Now().UTC().Format("2006_01_02_15_04_05"))
t.Logf("creating temporary database name = %s", dbName)
_, err = db.Exec("CREATE DATABASE " + dbName)
if err != nil {
panic(err)
}
testDB, err := sql.Open("postgres", fmt.Sprintf(connectionTemplate, uri+"/"+dbName))
if err != nil {
panic(err)
}
return testDB, func() {
cleanupTestDB(t, dbName, db, testDB)
}
}
func cleanupTestDB(t *testing.T, name string, db, testDB *sql.DB) {
t.Logf("cleaning up temporary database %s", name)
if db == nil {
panic("db is none")
}
if testDB == nil {
panic("testDB is none")
}
defer db.Close()
if err := testDB.Close(); err != nil {
panic(err)
}
// Kill any opened connection.
if _, err := db.Exec(`
SELECT pg_terminate_backend(pg_stat_activity.pid)
FROM pg_stat_activity
WHERE pg_stat_activity.datname = $1
AND pid <> pg_backend_pid()`, name); err != nil {
panic(err)
}
if _, err := db.Exec(`DROP DATABASE ` + name); err != nil {
panic(err)
}
// ensure the database is cleaned up
var count int
if err := db.QueryRow(userDBCount).Scan(&count); err != nil {
panic(err)
}
}
func CreateTestDB(t *testing.T, testName string) (*sql.DB, func()) {
connection, cleanup := CreateAndConnectTestDB(t, testName)
err := migrate.NewPostgresMigrator(connection).Exec(migrate.Up, migrations.Migrations...)
if err != nil {
panic(err)
}
return connection, cleanup
}
func CreateTestDBWithFixture(t *testing.T, testName string) (*sql.DB, func()) {
connection, cleanup := CreateTestDB(t, testName)
session, err := connection.Begin()
if err != nil {
panic(err)
}
defer session.Rollback()
loadFixtures(session)
if err = session.Commit(); err != nil {
panic(err)
}
return connection, cleanup
}
func CreateTestTx(t *testing.T, testName string) (*sql.Tx, func()) {
connection, cleanup := CreateTestDB(t, testName)
session, err := connection.Begin()
if session == nil {
panic("session is none")
}
if err != nil {
panic(err)
}
return session, func() {
session.Rollback()
cleanup()
}
}
func CreateTestTxWithFixtures(t *testing.T, testName string) (*sql.Tx, func()) {
tx, cleanup := CreateTestTx(t, testName)
defer func() {
// ensure to cleanup when loadFixtures failed
if r := recover(); r != nil {
cleanup()
}
}()
loadFixtures(tx)
return tx, cleanup
}
func loadFixtures(tx *sql.Tx) {
_, filename, _, _ := runtime.Caller(0)
fixturePath := filepath.Join(filepath.Dir(filename), "data.sql")
d, err := ioutil.ReadFile(fixturePath)
if err != nil {
panic(err)
}
_, err = tx.Exec(string(d))
if err != nil {
panic(err)
}
}
func OpenSessionForTest(t *testing.T, name string, loadFixture bool) (*sql.DB, *sql.Tx) {
var db *sql.DB
if loadFixture {
db, _ = CreateTestDB(t, name)
} else {
db, _ = CreateTestDBWithFixture(t, name)
}
tx, err := db.Begin()
if err != nil {
panic(err)
}
return db, tx
}
func RestartTransaction(db *sql.DB, tx *sql.Tx, commit bool) *sql.Tx {
if !commit {
if err := tx.Rollback(); err != nil {
panic(err)
}
} else {
if err := tx.Commit(); err != nil {
panic(err)
}
}
tx, err := db.Begin()
if err != nil {
panic(err)
}
return tx
}

View File

@ -0,0 +1,94 @@
// Copyright 2018 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package testutil
import (
"database/sql"
"fmt"
"math/rand"
"testing"
"github.com/coreos/clair/database"
"github.com/coreos/clair/database/pgsql/page"
"github.com/coreos/clair/pkg/pagination"
)
func takeVulnerabilityPointerFromMap(m map[int]database.Vulnerability, id int) *database.Vulnerability {
x := m[id]
return &x
}
func TakeAncestryPointerFromMap(m map[int]database.Ancestry, id int) *database.Ancestry {
x := m[id]
return &x
}
func TakeLayerPointerFromMap(m map[int]database.Layer, id int) *database.Layer {
x := m[id]
return &x
}
func ListNamespaces(t *testing.T, tx *sql.Tx) []database.Namespace {
rows, err := tx.Query("SELECT name, version_format FROM namespace")
if err != nil {
t.FailNow()
}
defer rows.Close()
namespaces := []database.Namespace{}
for rows.Next() {
var ns database.Namespace
err := rows.Scan(&ns.Name, &ns.VersionFormat)
if err != nil {
t.FailNow()
}
namespaces = append(namespaces, ns)
}
return namespaces
}
func mustUnmarshalToken(key pagination.Key, token pagination.Token) page.Page {
if token == pagination.FirstPageToken {
return page.Page{}
}
p := page.Page{}
if err := key.UnmarshalToken(token, &p); err != nil {
panic(err)
}
return p
}
func MustMarshalToken(key pagination.Key, v interface{}) pagination.Token {
token, err := key.MarshalToken(v)
if err != nil {
panic(err)
}
return token
}
func GenRandomNamespaces(t *testing.T, count int) []database.Namespace {
r := make([]database.Namespace, count)
for i := 0; i < count; i++ {
r[i] = database.Namespace{
Name: fmt.Sprint(rand.Int()),
VersionFormat: "dpkg",
}
}
return r
}

View File

@ -0,0 +1,63 @@
// Copyright 2019 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package util
import (
"database/sql"
"strings"
"github.com/coreos/clair/database"
"github.com/coreos/clair/database/pgsql/monitoring"
"github.com/coreos/clair/pkg/commonerr"
"github.com/lib/pq"
"github.com/sirupsen/logrus"
)
// IsErrUniqueViolation determines is the given error is a unique contraint violation.
func IsErrUniqueViolation(err error) bool {
pqErr, ok := err.(*pq.Error)
return ok && pqErr.Code == "23505"
}
// HandleError logs an error with an extra description and masks the error if it's an SQL one.
// The function ensures we never return plain SQL errors and leak anything.
// The function should be used for every database query error.
func HandleError(desc string, err error) error {
if err == nil {
return nil
}
if err == sql.ErrNoRows {
return commonerr.ErrNotFound
}
if pqErr, ok := err.(*pq.Error); ok {
if pqErr.Fatal() {
panic(pqErr)
}
if pqErr.Code == "42601" {
panic("invalid query: " + desc + ", info: " + err.Error())
}
}
logrus.WithError(err).WithField("Description", desc).Error("database: handled database error")
monitoring.PromErrorsTotal.WithLabelValues(desc).Inc()
if _, o := err.(*pq.Error); o || err == sql.ErrTxDone || strings.HasPrefix(err.Error(), "sql:") {
return database.ErrBackendException
}
return err
}

View File

@ -0,0 +1,57 @@
// Copyright 2015 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package util
import (
"fmt"
"strings"
"github.com/lib/pq"
)
func QueryInsert(count int, table string, columns ...string) string {
base := `INSERT INTO %s (%s) VALUES %s`
t := pq.QuoteIdentifier(table)
cols := make([]string, len(columns))
for i, c := range columns {
cols[i] = pq.QuoteIdentifier(c)
}
colsQuoted := strings.Join(cols, ",")
return fmt.Sprintf(base, t, colsQuoted, QueryString(len(columns), count))
}
func QueryPersist(count int, table, constraint string, columns ...string) string {
ct := ""
if constraint != "" {
ct = fmt.Sprintf("ON CONSTRAINT %s", constraint)
}
return fmt.Sprintf("%s ON CONFLICT %s DO NOTHING", QueryInsert(count, table, columns...), ct)
}
// size of key and array should be both greater than 0
func QueryString(keySize, arraySize int) string {
if arraySize <= 0 || keySize <= 0 {
panic("Bulk Query requires size of element tuple and number of elements to be greater than 0")
}
keys := make([]string, 0, arraySize)
for i := 0; i < arraySize; i++ {
key := make([]string, keySize)
for j := 0; j < keySize; j++ {
key[j] = fmt.Sprintf("$%d", i*keySize+j+1)
}
keys = append(keys, fmt.Sprintf("(%s)", strings.Join(key, ",")))
}
return strings.Join(keys, ",")
}

View File

@ -12,23 +12,27 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package pgsql package vulnerability
import ( import (
"database/sql" "database/sql"
"errors" "errors"
"fmt"
"time" "time"
"github.com/lib/pq" "github.com/lib/pq"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/coreos/clair/database" "github.com/coreos/clair/database"
"github.com/coreos/clair/database/pgsql/feature"
"github.com/coreos/clair/database/pgsql/monitoring"
"github.com/coreos/clair/database/pgsql/page"
"github.com/coreos/clair/database/pgsql/util"
"github.com/coreos/clair/ext/versionfmt" "github.com/coreos/clair/ext/versionfmt"
"github.com/coreos/clair/pkg/pagination"
) )
const ( const (
lockVulnerabilityAffects = `LOCK vulnerability_affected_namespaced_feature IN SHARE ROW EXCLUSIVE MODE`
searchVulnerability = ` searchVulnerability = `
SELECT v.id, v.description, v.link, v.severity, v.metadata, n.version_format SELECT v.id, v.description, v.link, v.severity, v.metadata, n.version_format
FROM vulnerability AS v, namespace AS n FROM vulnerability AS v, namespace AS n
@ -38,45 +42,12 @@ const (
AND v.deleted_at IS NULL AND v.deleted_at IS NULL
` `
insertVulnerabilityAffected = `
INSERT INTO vulnerability_affected_feature(vulnerability_id, feature_name, affected_version, feature_type, fixedin)
VALUES ($1, $2, $3, $4, $5)
RETURNING ID
`
searchVulnerabilityAffected = `
SELECT vulnerability_id, feature_name, affected_version, t.name, fixedin
FROM vulnerability_affected_feature AS vaf, feature_type AS t
WHERE t.id = vaf.feature_type AND vulnerability_id = ANY($1)
`
searchVulnerabilityByID = ` searchVulnerabilityByID = `
SELECT v.name, v.description, v.link, v.severity, v.metadata, n.name, n.version_format SELECT v.name, v.description, v.link, v.severity, v.metadata, n.name, n.version_format
FROM vulnerability AS v, namespace AS n FROM vulnerability AS v, namespace AS n
WHERE v.namespace_id = n.id WHERE v.namespace_id = n.id
AND v.id = $1` AND v.id = $1`
searchVulnerabilityPotentialAffected = `
WITH req AS (
SELECT vaf.id AS vaf_id, n.id AS n_id, vaf.feature_name AS name, vaf.feature_type AS type, v.id AS vulnerability_id
FROM vulnerability_affected_feature AS vaf,
vulnerability AS v,
namespace AS n
WHERE vaf.vulnerability_id = ANY($1)
AND v.id = vaf.vulnerability_id
AND n.id = v.namespace_id
)
SELECT req.vulnerability_id, nf.id, f.version, req.vaf_id AS added_by
FROM feature AS f, namespaced_feature AS nf, req
WHERE f.name = req.name
AND f.type = req.type
AND nf.namespace_id = req.n_id
AND nf.feature_id = f.id`
insertVulnerabilityAffectedNamespacedFeature = `
INSERT INTO vulnerability_affected_namespaced_feature(vulnerability_id, namespaced_feature_id, added_by)
VALUES ($1, $2, $3)`
insertVulnerability = ` insertVulnerability = `
WITH ns AS ( WITH ns AS (
SELECT id FROM namespace WHERE name = $6 AND version_format = $7 SELECT id FROM namespace WHERE name = $6 AND version_format = $7
@ -92,11 +63,55 @@ const (
AND name = $2 AND name = $2
AND deleted_at IS NULL AND deleted_at IS NULL
RETURNING id` RETURNING id`
searchNotificationVulnerableAncestry = `
SELECT DISTINCT ON (a.id)
a.id, a.name
FROM vulnerability_affected_namespaced_feature AS vanf,
ancestry_layer AS al, ancestry_feature AS af, ancestry AS a
WHERE vanf.vulnerability_id = $1
AND a.id >= $2
AND al.ancestry_id = a.id
AND al.id = af.ancestry_layer_id
AND af.namespaced_feature_id = vanf.namespaced_feature_id
ORDER BY a.id ASC
LIMIT $3;`
) )
var ( func queryInvalidateVulnerabilityCache(count int) string {
errVulnerabilityNotFound = errors.New("vulnerability is not in database") return fmt.Sprintf(`DELETE FROM vulnerability_affected_feature
) WHERE vulnerability_id IN (%s)`,
util.QueryString(1, count))
}
// NOTE(Sida): Every search query can only have count less than postgres set
// stack depth. IN will be resolved to nested OR_s and the parser might exceed
// stack depth. TODO(Sida): Generate different queries for different count: if
// count < 5120, use IN; for count > 5120 and < 65536, use temporary table; for
// count > 65535, use is expected to split data into batches.
func querySearchLastDeletedVulnerabilityID(count int) string {
return fmt.Sprintf(`
SELECT vid, vname, nname FROM (
SELECT v.id AS vid, v.name AS vname, n.name AS nname,
row_number() OVER (
PARTITION by (v.name, n.name)
ORDER BY v.deleted_at DESC
) AS rownum
FROM vulnerability AS v, namespace AS n
WHERE v.namespace_id = n.id
AND (v.name, n.name) IN ( %s )
AND v.deleted_at IS NOT NULL
) tmp WHERE rownum <= 1`,
util.QueryString(2, count))
}
func querySearchNotDeletedVulnerabilityID(count int) string {
return fmt.Sprintf(`
SELECT v.id, v.name, n.name FROM vulnerability AS v, namespace AS n
WHERE v.namespace_id = n.id AND (v.name, n.name) IN (%s)
AND v.deleted_at IS NULL`,
util.QueryString(2, count))
}
type affectedAncestry struct { type affectedAncestry struct {
name string name string
@ -113,8 +128,8 @@ type affectedFeatureRows struct {
rows map[int64]database.AffectedFeature rows map[int64]database.AffectedFeature
} }
func (tx *pgSession) FindVulnerabilities(vulnerabilities []database.VulnerabilityID) ([]database.NullableVulnerability, error) { func FindVulnerabilities(tx *sql.Tx, vulnerabilities []database.VulnerabilityID) ([]database.NullableVulnerability, error) {
defer observeQueryTime("findVulnerabilities", "", time.Now()) defer monitoring.ObserveQueryTime("findVulnerabilities", "", time.Now())
resultVuln := make([]database.NullableVulnerability, len(vulnerabilities)) resultVuln := make([]database.NullableVulnerability, len(vulnerabilities))
vulnIDMap := map[int64][]*database.NullableVulnerability{} vulnIDMap := map[int64][]*database.NullableVulnerability{}
@ -151,7 +166,7 @@ func (tx *pgSession) FindVulnerabilities(vulnerabilities []database.Vulnerabilit
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
stmt.Close() stmt.Close()
return nil, handleError("searchVulnerability", err) return nil, util.HandleError("searchVulnerability", err)
} }
vuln.Valid = id.Valid vuln.Valid = id.Valid
resultVuln[i] = vuln resultVuln[i] = vuln
@ -161,7 +176,7 @@ func (tx *pgSession) FindVulnerabilities(vulnerabilities []database.Vulnerabilit
} }
if err := stmt.Close(); err != nil { if err := stmt.Close(); err != nil {
return nil, handleError("searchVulnerability", err) return nil, util.HandleError("searchVulnerability", err)
} }
toQuery := make([]int64, 0, len(vulnIDMap)) toQuery := make([]int64, 0, len(vulnIDMap))
@ -172,7 +187,7 @@ func (tx *pgSession) FindVulnerabilities(vulnerabilities []database.Vulnerabilit
// load vulnerability affected features // load vulnerability affected features
rows, err := tx.Query(searchVulnerabilityAffected, pq.Array(toQuery)) rows, err := tx.Query(searchVulnerabilityAffected, pq.Array(toQuery))
if err != nil { if err != nil {
return nil, handleError("searchVulnerabilityAffected", err) return nil, util.HandleError("searchVulnerabilityAffected", err)
} }
for rows.Next() { for rows.Next() {
@ -183,7 +198,7 @@ func (tx *pgSession) FindVulnerabilities(vulnerabilities []database.Vulnerabilit
err := rows.Scan(&id, &f.FeatureName, &f.AffectedVersion, &f.FeatureType, &f.FixedInVersion) err := rows.Scan(&id, &f.FeatureName, &f.AffectedVersion, &f.FeatureType, &f.FixedInVersion)
if err != nil { if err != nil {
return nil, handleError("searchVulnerabilityAffected", err) return nil, util.HandleError("searchVulnerabilityAffected", err)
} }
for _, vuln := range vulnIDMap[id] { for _, vuln := range vulnIDMap[id] {
@ -195,41 +210,40 @@ func (tx *pgSession) FindVulnerabilities(vulnerabilities []database.Vulnerabilit
return resultVuln, nil return resultVuln, nil
} }
func (tx *pgSession) InsertVulnerabilities(vulnerabilities []database.VulnerabilityWithAffected) error { func InsertVulnerabilities(tx *sql.Tx, vulnerabilities []database.VulnerabilityWithAffected) error {
defer observeQueryTime("insertVulnerabilities", "all", time.Now()) defer monitoring.ObserveQueryTime("insertVulnerabilities", "all", time.Now())
// bulk insert vulnerabilities // bulk insert vulnerabilities
vulnIDs, err := tx.insertVulnerabilities(vulnerabilities) vulnIDs, err := insertVulnerabilities(tx, vulnerabilities)
if err != nil { if err != nil {
return err return err
} }
// bulk insert vulnerability affected features // bulk insert vulnerability affected features
vulnFeatureMap, err := tx.insertVulnerabilityAffected(vulnIDs, vulnerabilities) vulnFeatureMap, err := InsertVulnerabilityAffected(tx, vulnIDs, vulnerabilities)
if err != nil { if err != nil {
return err return err
} }
return tx.cacheVulnerabiltyAffectedNamespacedFeature(vulnFeatureMap) return CacheVulnerabiltyAffectedNamespacedFeature(tx, vulnFeatureMap)
} }
// insertVulnerabilityAffected inserts a set of vulnerability affected features for each vulnerability provided. // insertVulnerabilityAffected inserts a set of vulnerability affected features for each vulnerability provided.
// //
// i_th vulnerabilityIDs corresponds to i_th vulnerabilities provided. // i_th vulnerabilityIDs corresponds to i_th vulnerabilities provided.
func (tx *pgSession) insertVulnerabilityAffected(vulnerabilityIDs []int64, vulnerabilities []database.VulnerabilityWithAffected) (map[int64]affectedFeatureRows, error) { func InsertVulnerabilityAffected(tx *sql.Tx, vulnerabilityIDs []int64, vulnerabilities []database.VulnerabilityWithAffected) (map[int64]affectedFeatureRows, error) {
var ( var (
vulnFeature = map[int64]affectedFeatureRows{} vulnFeature = map[int64]affectedFeatureRows{}
affectedID int64 affectedID int64
) )
types, err := tx.getFeatureTypeMap() types, err := feature.GetFeatureTypeMap(tx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
//TODO(Sida): Change to bulk insert.
stmt, err := tx.Prepare(insertVulnerabilityAffected) stmt, err := tx.Prepare(insertVulnerabilityAffected)
if err != nil { if err != nil {
return nil, handleError("insertVulnerabilityAffected", err) return nil, util.HandleError("insertVulnerabilityAffected", err)
} }
defer stmt.Close() defer stmt.Close()
@ -237,9 +251,9 @@ func (tx *pgSession) insertVulnerabilityAffected(vulnerabilityIDs []int64, vulne
// affected feature row ID -> affected feature // affected feature row ID -> affected feature
affectedFeatures := map[int64]database.AffectedFeature{} affectedFeatures := map[int64]database.AffectedFeature{}
for _, f := range vuln.Affected { for _, f := range vuln.Affected {
err := stmt.QueryRow(vulnerabilityIDs[i], f.FeatureName, f.AffectedVersion, types.byName[f.FeatureType], f.FixedInVersion).Scan(&affectedID) err := stmt.QueryRow(vulnerabilityIDs[i], f.FeatureName, f.AffectedVersion, types.ByName[f.FeatureType], f.FixedInVersion).Scan(&affectedID)
if err != nil { if err != nil {
return nil, handleError("insertVulnerabilityAffected", err) return nil, util.HandleError("insertVulnerabilityAffected", err)
} }
affectedFeatures[affectedID] = f affectedFeatures[affectedID] = f
} }
@ -251,7 +265,7 @@ func (tx *pgSession) insertVulnerabilityAffected(vulnerabilityIDs []int64, vulne
// insertVulnerabilities inserts a set of unique vulnerabilities into database, // insertVulnerabilities inserts a set of unique vulnerabilities into database,
// under the assumption that all vulnerabilities are valid. // under the assumption that all vulnerabilities are valid.
func (tx *pgSession) insertVulnerabilities(vulnerabilities []database.VulnerabilityWithAffected) ([]int64, error) { func insertVulnerabilities(tx *sql.Tx, vulnerabilities []database.VulnerabilityWithAffected) ([]int64, error) {
var ( var (
vulnID int64 vulnID int64
vulnIDs = make([]int64, 0, len(vulnerabilities)) vulnIDs = make([]int64, 0, len(vulnerabilities))
@ -274,7 +288,7 @@ func (tx *pgSession) insertVulnerabilities(vulnerabilities []database.Vulnerabil
//TODO(Sida): Change to bulk insert. //TODO(Sida): Change to bulk insert.
stmt, err := tx.Prepare(insertVulnerability) stmt, err := tx.Prepare(insertVulnerability)
if err != nil { if err != nil {
return nil, handleError("insertVulnerability", err) return nil, util.HandleError("insertVulnerability", err)
} }
defer stmt.Close() defer stmt.Close()
@ -283,7 +297,7 @@ func (tx *pgSession) insertVulnerabilities(vulnerabilities []database.Vulnerabil
vuln.Link, &vuln.Severity, &vuln.Metadata, vuln.Link, &vuln.Severity, &vuln.Metadata,
vuln.Namespace.Name, vuln.Namespace.VersionFormat).Scan(&vulnID) vuln.Namespace.Name, vuln.Namespace.VersionFormat).Scan(&vulnID)
if err != nil { if err != nil {
return nil, handleError("insertVulnerability", err) return nil, util.HandleError("insertVulnerability", err)
} }
vulnIDs = append(vulnIDs, vulnID) vulnIDs = append(vulnIDs, vulnID)
@ -292,19 +306,19 @@ func (tx *pgSession) insertVulnerabilities(vulnerabilities []database.Vulnerabil
return vulnIDs, nil return vulnIDs, nil
} }
func (tx *pgSession) lockFeatureVulnerabilityCache() error { func LockFeatureVulnerabilityCache(tx *sql.Tx) error {
_, err := tx.Exec(lockVulnerabilityAffects) _, err := tx.Exec(lockVulnerabilityAffects)
if err != nil { if err != nil {
return handleError("lockVulnerabilityAffects", err) return util.HandleError("lockVulnerabilityAffects", err)
} }
return nil return nil
} }
// cacheVulnerabiltyAffectedNamespacedFeature takes in a map of vulnerability ID // cacheVulnerabiltyAffectedNamespacedFeature takes in a map of vulnerability ID
// to affected feature rows and caches them. // to affected feature rows and caches them.
func (tx *pgSession) cacheVulnerabiltyAffectedNamespacedFeature(affected map[int64]affectedFeatureRows) error { func CacheVulnerabiltyAffectedNamespacedFeature(tx *sql.Tx, affected map[int64]affectedFeatureRows) error {
// Prevent InsertNamespacedFeatures to modify it. // Prevent InsertNamespacedFeatures to modify it.
err := tx.lockFeatureVulnerabilityCache() err := LockFeatureVulnerabilityCache(tx)
if err != nil { if err != nil {
return err return err
} }
@ -316,7 +330,7 @@ func (tx *pgSession) cacheVulnerabiltyAffectedNamespacedFeature(affected map[int
rows, err := tx.Query(searchVulnerabilityPotentialAffected, pq.Array(vulnIDs)) rows, err := tx.Query(searchVulnerabilityPotentialAffected, pq.Array(vulnIDs))
if err != nil { if err != nil {
return handleError("searchVulnerabilityPotentialAffected", err) return util.HandleError("searchVulnerabilityPotentialAffected", err)
} }
defer rows.Close() defer rows.Close()
@ -332,7 +346,7 @@ func (tx *pgSession) cacheVulnerabiltyAffectedNamespacedFeature(affected map[int
err := rows.Scan(&vulnID, &nsfID, &fVersion, &addedBy) err := rows.Scan(&vulnID, &nsfID, &fVersion, &addedBy)
if err != nil { if err != nil {
return handleError("searchVulnerabilityPotentialAffected", err) return util.HandleError("searchVulnerabilityPotentialAffected", err)
} }
candidate, ok := affected[vulnID].rows[addedBy] candidate, ok := affected[vulnID].rows[addedBy]
@ -361,7 +375,7 @@ func (tx *pgSession) cacheVulnerabiltyAffectedNamespacedFeature(affected map[int
for _, r := range relation { for _, r := range relation {
result, err := tx.Exec(insertVulnerabilityAffectedNamespacedFeature, r.vulnerabilityID, r.namespacedFeatureID, r.addedBy) result, err := tx.Exec(insertVulnerabilityAffectedNamespacedFeature, r.vulnerabilityID, r.namespacedFeatureID, r.addedBy)
if err != nil { if err != nil {
return handleError("insertVulnerabilityAffectedNamespacedFeature", err) return util.HandleError("insertVulnerabilityAffectedNamespacedFeature", err)
} }
if num, err := result.RowsAffected(); err == nil { if num, err := result.RowsAffected(); err == nil {
@ -377,27 +391,27 @@ func (tx *pgSession) cacheVulnerabiltyAffectedNamespacedFeature(affected map[int
return nil return nil
} }
func (tx *pgSession) DeleteVulnerabilities(vulnerabilities []database.VulnerabilityID) error { func DeleteVulnerabilities(tx *sql.Tx, vulnerabilities []database.VulnerabilityID) error {
defer observeQueryTime("DeleteVulnerability", "all", time.Now()) defer monitoring.ObserveQueryTime("DeleteVulnerability", "all", time.Now())
vulnIDs, err := tx.markVulnerabilitiesAsDeleted(vulnerabilities) vulnIDs, err := MarkVulnerabilitiesAsDeleted(tx, vulnerabilities)
if err != nil { if err != nil {
return err return err
} }
if err := tx.invalidateVulnerabilityCache(vulnIDs); err != nil { if err := InvalidateVulnerabilityCache(tx, vulnIDs); err != nil {
return err return err
} }
return nil return nil
} }
func (tx *pgSession) invalidateVulnerabilityCache(vulnerabilityIDs []int64) error { func InvalidateVulnerabilityCache(tx *sql.Tx, vulnerabilityIDs []int64) error {
if len(vulnerabilityIDs) == 0 { if len(vulnerabilityIDs) == 0 {
return nil return nil
} }
// Prevent InsertNamespacedFeatures to modify it. // Prevent InsertNamespacedFeatures to modify it.
err := tx.lockFeatureVulnerabilityCache() err := LockFeatureVulnerabilityCache(tx)
if err != nil { if err != nil {
return err return err
} }
@ -410,13 +424,13 @@ func (tx *pgSession) invalidateVulnerabilityCache(vulnerabilityIDs []int64) erro
_, err = tx.Exec(queryInvalidateVulnerabilityCache(len(vulnerabilityIDs)), keys...) _, err = tx.Exec(queryInvalidateVulnerabilityCache(len(vulnerabilityIDs)), keys...)
if err != nil { if err != nil {
return handleError("removeVulnerabilityAffectedFeature", err) return util.HandleError("removeVulnerabilityAffectedFeature", err)
} }
return nil return nil
} }
func (tx *pgSession) markVulnerabilitiesAsDeleted(vulnerabilities []database.VulnerabilityID) ([]int64, error) { func MarkVulnerabilitiesAsDeleted(tx *sql.Tx, vulnerabilities []database.VulnerabilityID) ([]int64, error) {
var ( var (
vulnID sql.NullInt64 vulnID sql.NullInt64
vulnIDs []int64 vulnIDs []int64
@ -425,17 +439,17 @@ func (tx *pgSession) markVulnerabilitiesAsDeleted(vulnerabilities []database.Vul
// mark vulnerabilities deleted // mark vulnerabilities deleted
stmt, err := tx.Prepare(removeVulnerability) stmt, err := tx.Prepare(removeVulnerability)
if err != nil { if err != nil {
return nil, handleError("removeVulnerability", err) return nil, util.HandleError("removeVulnerability", err)
} }
defer stmt.Close() defer stmt.Close()
for _, vuln := range vulnerabilities { for _, vuln := range vulnerabilities {
err := stmt.QueryRow(vuln.Namespace, vuln.Name).Scan(&vulnID) err := stmt.QueryRow(vuln.Namespace, vuln.Name).Scan(&vulnID)
if err != nil { if err != nil {
return nil, handleError("removeVulnerability", err) return nil, util.HandleError("removeVulnerability", err)
} }
if !vulnID.Valid { if !vulnID.Valid {
return nil, handleError("removeVulnerability", errors.New("Vulnerability to be removed is not in database")) return nil, util.HandleError("removeVulnerability", errors.New("Vulnerability to be removed is not in database"))
} }
vulnIDs = append(vulnIDs, vulnID.Int64) vulnIDs = append(vulnIDs, vulnID.Int64)
} }
@ -444,15 +458,15 @@ func (tx *pgSession) markVulnerabilitiesAsDeleted(vulnerabilities []database.Vul
// findLatestDeletedVulnerabilityIDs requires all elements in vulnIDs are in // findLatestDeletedVulnerabilityIDs requires all elements in vulnIDs are in
// database and the order of output array is not guaranteed. // database and the order of output array is not guaranteed.
func (tx *pgSession) findLatestDeletedVulnerabilityIDs(vulnIDs []database.VulnerabilityID) ([]sql.NullInt64, error) { func FindLatestDeletedVulnerabilityIDs(tx *sql.Tx, vulnIDs []database.VulnerabilityID) ([]sql.NullInt64, error) {
return tx.findVulnerabilityIDs(vulnIDs, true) return FindVulnerabilityIDs(tx, vulnIDs, true)
} }
func (tx *pgSession) findNotDeletedVulnerabilityIDs(vulnIDs []database.VulnerabilityID) ([]sql.NullInt64, error) { func FindNotDeletedVulnerabilityIDs(tx *sql.Tx, vulnIDs []database.VulnerabilityID) ([]sql.NullInt64, error) {
return tx.findVulnerabilityIDs(vulnIDs, false) return FindVulnerabilityIDs(tx, vulnIDs, false)
} }
func (tx *pgSession) findVulnerabilityIDs(vulnIDs []database.VulnerabilityID, withLatestDeleted bool) ([]sql.NullInt64, error) { func FindVulnerabilityIDs(tx *sql.Tx, vulnIDs []database.VulnerabilityID, withLatestDeleted bool) ([]sql.NullInt64, error) {
if len(vulnIDs) == 0 { if len(vulnIDs) == 0 {
return nil, nil return nil, nil
} }
@ -474,7 +488,7 @@ func (tx *pgSession) findVulnerabilityIDs(vulnIDs []database.VulnerabilityID, wi
rows, err := tx.Query(query, keys...) rows, err := tx.Query(query, keys...)
if err != nil { if err != nil {
return nil, handleError("querySearchVulnerabilityID.LatestDeleted.Query", err) return nil, util.HandleError("querySearchVulnerabilityID.LatestDeleted.Query", err)
} }
defer rows.Close() defer rows.Close()
@ -485,7 +499,7 @@ func (tx *pgSession) findVulnerabilityIDs(vulnIDs []database.VulnerabilityID, wi
for rows.Next() { for rows.Next() {
err := rows.Scan(&id, &vulnID.Name, &vulnID.Namespace) err := rows.Scan(&id, &vulnID.Name, &vulnID.Namespace)
if err != nil { if err != nil {
return nil, handleError("querySearchVulnerabilityID.LatestDeleted.Scan", err) return nil, util.HandleError("querySearchVulnerabilityID.LatestDeleted.Scan", err)
} }
vulnIDMap[vulnID] = id vulnIDMap[vulnID] = id
} }
@ -497,3 +511,67 @@ func (tx *pgSession) findVulnerabilityIDs(vulnIDs []database.VulnerabilityID, wi
return ids, nil return ids, nil
} }
func FindPagedVulnerableAncestries(tx *sql.Tx, vulnID int64, limit int, currentToken pagination.Token, key pagination.Key) (database.PagedVulnerableAncestries, error) {
vulnPage := database.PagedVulnerableAncestries{Limit: limit}
currentPage := page.Page{0}
if currentToken != pagination.FirstPageToken {
if err := key.UnmarshalToken(currentToken, &currentPage); err != nil {
return vulnPage, err
}
}
if err := tx.QueryRow(searchVulnerabilityByID, vulnID).Scan(
&vulnPage.Name,
&vulnPage.Description,
&vulnPage.Link,
&vulnPage.Severity,
&vulnPage.Metadata,
&vulnPage.Namespace.Name,
&vulnPage.Namespace.VersionFormat,
); err != nil {
return vulnPage, util.HandleError("searchVulnerabilityByID", err)
}
// the last result is used for the next page's startID
rows, err := tx.Query(searchNotificationVulnerableAncestry, vulnID, currentPage.StartID, limit+1)
if err != nil {
return vulnPage, util.HandleError("searchNotificationVulnerableAncestry", err)
}
defer rows.Close()
ancestries := []affectedAncestry{}
for rows.Next() {
var ancestry affectedAncestry
err := rows.Scan(&ancestry.id, &ancestry.name)
if err != nil {
return vulnPage, util.HandleError("searchNotificationVulnerableAncestry", err)
}
ancestries = append(ancestries, ancestry)
}
lastIndex := 0
if len(ancestries)-1 < limit {
lastIndex = len(ancestries)
vulnPage.End = true
} else {
// Use the last ancestry's ID as the next page.
lastIndex = len(ancestries) - 1
vulnPage.Next, err = key.MarshalToken(page.Page{ancestries[len(ancestries)-1].id})
if err != nil {
return vulnPage, err
}
}
vulnPage.Affected = map[int]string{}
for _, ancestry := range ancestries[0:lastIndex] {
vulnPage.Affected[int(ancestry.id)] = ancestry.name
}
vulnPage.Current, err = key.MarshalToken(currentPage)
if err != nil {
return vulnPage, err
}
return vulnPage, nil
}

View File

@ -0,0 +1,118 @@
// Copyright 2019 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package vulnerability
import (
"database/sql"
"github.com/coreos/clair/database"
"github.com/coreos/clair/database/pgsql/feature"
"github.com/coreos/clair/database/pgsql/util"
"github.com/coreos/clair/ext/versionfmt"
"github.com/lib/pq"
)
const (
searchPotentialAffectingVulneraibilities = `
SELECT nf.id, v.id, vaf.affected_version, vaf.id
FROM vulnerability_affected_feature AS vaf, vulnerability AS v,
namespaced_feature AS nf, feature AS f
WHERE nf.id = ANY($1)
AND nf.feature_id = f.id
AND nf.namespace_id = v.namespace_id
AND vaf.feature_name = f.name
AND vaf.feature_type = f.type
AND vaf.vulnerability_id = v.id
AND v.deleted_at IS NULL`
insertVulnerabilityAffected = `
INSERT INTO vulnerability_affected_feature(vulnerability_id, feature_name, affected_version, feature_type, fixedin)
VALUES ($1, $2, $3, $4, $5)
RETURNING ID
`
searchVulnerabilityAffected = `
SELECT vulnerability_id, feature_name, affected_version, t.name, fixedin
FROM vulnerability_affected_feature AS vaf, feature_type AS t
WHERE t.id = vaf.feature_type AND vulnerability_id = ANY($1)
`
searchVulnerabilityPotentialAffected = `
WITH req AS (
SELECT vaf.id AS vaf_id, n.id AS n_id, vaf.feature_name AS name, vaf.feature_type AS type, v.id AS vulnerability_id
FROM vulnerability_affected_feature AS vaf,
vulnerability AS v,
namespace AS n
WHERE vaf.vulnerability_id = ANY($1)
AND v.id = vaf.vulnerability_id
AND n.id = v.namespace_id
)
SELECT req.vulnerability_id, nf.id, f.version, req.vaf_id AS added_by
FROM feature AS f, namespaced_feature AS nf, req
WHERE f.name = req.name
AND f.type = req.type
AND nf.namespace_id = req.n_id
AND nf.feature_id = f.id`
)
type vulnerabilityCache struct {
nsFeatureID int64
vulnID int64
vulnAffectingID int64
}
func SearchAffectingVulnerabilities(tx *sql.Tx, features []database.NamespacedFeature) ([]vulnerabilityCache, error) {
if len(features) == 0 {
return nil, nil
}
ids, err := feature.FindNamespacedFeatureIDs(tx, features)
if err != nil {
return nil, err
}
fMap := map[int64]database.NamespacedFeature{}
for i, f := range features {
if !ids[i].Valid {
return nil, database.ErrMissingEntities
}
fMap[ids[i].Int64] = f
}
cacheTable := []vulnerabilityCache{}
rows, err := tx.Query(searchPotentialAffectingVulneraibilities, pq.Array(ids))
if err != nil {
return nil, util.HandleError("searchPotentialAffectingVulneraibilities", err)
}
defer rows.Close()
for rows.Next() {
var (
cache vulnerabilityCache
affected string
)
err := rows.Scan(&cache.nsFeatureID, &cache.vulnID, &affected, &cache.vulnAffectingID)
if err != nil {
return nil, err
}
if ok, err := versionfmt.InRange(fMap[cache.nsFeatureID].VersionFormat, fMap[cache.nsFeatureID].Version, affected); err != nil {
return nil, err
} else if ok {
cacheTable = append(cacheTable, cache)
}
}
return cacheTable, nil
}

View File

@ -0,0 +1,142 @@
// Copyright 2019 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package vulnerability
import (
"database/sql"
"github.com/coreos/clair/database"
"github.com/coreos/clair/database/pgsql/feature"
"github.com/coreos/clair/database/pgsql/util"
"github.com/lib/pq"
log "github.com/sirupsen/logrus"
)
const (
searchNamespacedFeaturesVulnerabilities = `
SELECT vanf.namespaced_feature_id, v.name, v.description, v.link,
v.severity, v.metadata, vaf.fixedin, n.name, n.version_format
FROM vulnerability_affected_namespaced_feature AS vanf,
Vulnerability AS v,
vulnerability_affected_feature AS vaf,
namespace AS n
WHERE vanf.namespaced_feature_id = ANY($1)
AND vaf.id = vanf.added_by
AND v.id = vanf.vulnerability_id
AND n.id = v.namespace_id
AND v.deleted_at IS NULL`
lockVulnerabilityAffects = `LOCK vulnerability_affected_namespaced_feature IN SHARE ROW EXCLUSIVE MODE`
insertVulnerabilityAffectedNamespacedFeature = `
INSERT INTO vulnerability_affected_namespaced_feature(vulnerability_id, namespaced_feature_id, added_by)
VALUES ($1, $2, $3)`
)
func queryPersistVulnerabilityAffectedNamespacedFeature(count int) string {
return util.QueryPersist(count, "vulnerability_affected_namespaced_feature",
"vulnerability_affected_namesp_vulnerability_id_namespaced_f_key",
"vulnerability_id",
"namespaced_feature_id",
"added_by")
}
// FindAffectedNamespacedFeatures retrieves vulnerabilities associated with the
// feature.
func FindAffectedNamespacedFeatures(tx *sql.Tx, features []database.NamespacedFeature) ([]database.NullableAffectedNamespacedFeature, error) {
if len(features) == 0 {
return nil, nil
}
vulnerableFeatures := make([]database.NullableAffectedNamespacedFeature, len(features))
featureIDs, err := feature.FindNamespacedFeatureIDs(tx, features)
if err != nil {
return nil, err
}
for i, id := range featureIDs {
if id.Valid {
vulnerableFeatures[i].Valid = true
vulnerableFeatures[i].NamespacedFeature = features[i]
}
}
rows, err := tx.Query(searchNamespacedFeaturesVulnerabilities, pq.Array(featureIDs))
if err != nil {
return nil, util.HandleError("searchNamespacedFeaturesVulnerabilities", err)
}
defer rows.Close()
for rows.Next() {
var (
featureID int64
vuln database.VulnerabilityWithFixedIn
)
err := rows.Scan(&featureID,
&vuln.Name,
&vuln.Description,
&vuln.Link,
&vuln.Severity,
&vuln.Metadata,
&vuln.FixedInVersion,
&vuln.Namespace.Name,
&vuln.Namespace.VersionFormat,
)
if err != nil {
return nil, util.HandleError("searchNamespacedFeaturesVulnerabilities", err)
}
for i, id := range featureIDs {
if id.Valid && id.Int64 == featureID {
vulnerableFeatures[i].AffectedNamespacedFeature.AffectedBy = append(vulnerableFeatures[i].AffectedNamespacedFeature.AffectedBy, vuln)
}
}
}
return vulnerableFeatures, nil
}
func CacheAffectedNamespacedFeatures(tx *sql.Tx, features []database.NamespacedFeature) error {
if len(features) == 0 {
return nil
}
_, err := tx.Exec(lockVulnerabilityAffects)
if err != nil {
return util.HandleError("lockVulnerabilityAffects", err)
}
cache, err := SearchAffectingVulnerabilities(tx, features)
keys := make([]interface{}, 0, len(cache)*3)
for _, c := range cache {
keys = append(keys, c.vulnID, c.nsFeatureID, c.vulnAffectingID)
}
if len(cache) == 0 {
return nil
}
affected, err := tx.Exec(queryPersistVulnerabilityAffectedNamespacedFeature(len(cache)), keys...)
if err != nil {
return util.HandleError("persistVulnerabilityAffectedNamespacedFeature", err)
}
if count, err := affected.RowsAffected(); err != nil {
log.Debugf("Cached %d features in vulnerability_affected_namespaced_feature", count)
}
return nil
}

View File

@ -12,19 +12,30 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package pgsql package vulnerability
import ( import (
"database/sql"
"math/rand"
"strconv"
"testing" "testing"
"github.com/pborman/uuid"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coreos/clair/database" "github.com/coreos/clair/database"
"github.com/coreos/clair/database/pgsql/feature"
"github.com/coreos/clair/database/pgsql/namespace"
"github.com/coreos/clair/database/pgsql/testutil"
"github.com/coreos/clair/ext/versionfmt"
"github.com/coreos/clair/ext/versionfmt/dpkg" "github.com/coreos/clair/ext/versionfmt/dpkg"
"github.com/coreos/clair/pkg/strutil"
) )
func TestInsertVulnerabilities(t *testing.T) { func TestInsertVulnerabilities(t *testing.T) {
store, tx := openSessionForTest(t, "InsertVulnerabilities", true) store, cleanup := testutil.CreateTestDBWithFixture(t, "InsertVulnerabilities")
defer cleanup()
ns1 := database.Namespace{ ns1 := database.Namespace{
Name: "name", Name: "name",
@ -56,45 +67,48 @@ func TestInsertVulnerabilities(t *testing.T) {
Vulnerability: v2, Vulnerability: v2,
} }
tx, err := store.Begin()
require.Nil(t, err)
// empty // empty
err := tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{}) err = InsertVulnerabilities(tx, []database.VulnerabilityWithAffected{})
assert.Nil(t, err) assert.Nil(t, err)
// invalid content: vwa1 is invalid // invalid content: vwa1 is invalid
err = tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{vwa1, vwa2}) err = InsertVulnerabilities(tx, []database.VulnerabilityWithAffected{vwa1, vwa2})
assert.NotNil(t, err) assert.NotNil(t, err)
tx = restartSession(t, store, tx, false) tx = testutil.RestartTransaction(store, tx, false)
// invalid content: duplicated input // invalid content: duplicated input
err = tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{vwa2, vwa2}) err = InsertVulnerabilities(tx, []database.VulnerabilityWithAffected{vwa2, vwa2})
assert.NotNil(t, err) assert.NotNil(t, err)
tx = restartSession(t, store, tx, false) tx = testutil.RestartTransaction(store, tx, false)
// valid content // valid content
err = tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{vwa2}) err = InsertVulnerabilities(tx, []database.VulnerabilityWithAffected{vwa2})
assert.Nil(t, err) assert.Nil(t, err)
tx = restartSession(t, store, tx, true) tx = testutil.RestartTransaction(store, tx, true)
// ensure the content is in database // ensure the content is in database
vulns, err := tx.FindVulnerabilities([]database.VulnerabilityID{{Name: "valid", Namespace: "debian:7"}}) vulns, err := FindVulnerabilities(tx, []database.VulnerabilityID{{Name: "valid", Namespace: "debian:7"}})
if assert.Nil(t, err) && assert.Len(t, vulns, 1) { if assert.Nil(t, err) && assert.Len(t, vulns, 1) {
assert.True(t, vulns[0].Valid) assert.True(t, vulns[0].Valid)
} }
tx = restartSession(t, store, tx, false) tx = testutil.RestartTransaction(store, tx, false)
// valid content: vwa2 removed and inserted // valid content: vwa2 removed and inserted
err = tx.DeleteVulnerabilities([]database.VulnerabilityID{{Name: vwa2.Name, Namespace: vwa2.Namespace.Name}}) err = DeleteVulnerabilities(tx, []database.VulnerabilityID{{Name: vwa2.Name, Namespace: vwa2.Namespace.Name}})
assert.Nil(t, err) assert.Nil(t, err)
err = tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{vwa2}) err = InsertVulnerabilities(tx, []database.VulnerabilityWithAffected{vwa2})
assert.Nil(t, err) assert.Nil(t, err)
closeTest(t, store, tx) require.Nil(t, tx.Rollback())
} }
func TestCachingVulnerable(t *testing.T) { func TestCachingVulnerable(t *testing.T) {
datastore, tx := openSessionForTest(t, "CachingVulnerable", true) tx, cleanup := testutil.CreateTestTxWithFixtures(t, "CachingVulnerable")
defer closeTest(t, datastore, tx) defer cleanup()
ns := database.Namespace{ ns := database.Namespace{
Name: "debian:8", Name: "debian:8",
@ -163,11 +177,8 @@ func TestCachingVulnerable(t *testing.T) {
FixedInVersion: "2.2", FixedInVersion: "2.2",
} }
if !assert.Nil(t, tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{vuln, vuln2})) { require.Nil(t, InsertVulnerabilities(tx, []database.VulnerabilityWithAffected{vuln, vuln2}))
t.FailNow() r, err := FindAffectedNamespacedFeatures(tx, []database.NamespacedFeature{f})
}
r, err := tx.FindAffectedNamespacedFeatures([]database.NamespacedFeature{f})
assert.Nil(t, err) assert.Nil(t, err)
assert.Len(t, r, 1) assert.Len(t, r, 1)
for _, anf := range r { for _, anf := range r {
@ -186,10 +197,10 @@ func TestCachingVulnerable(t *testing.T) {
} }
func TestFindVulnerabilities(t *testing.T) { func TestFindVulnerabilities(t *testing.T) {
datastore, tx := openSessionForTest(t, "FindVulnerabilities", true) tx, cleanup := testutil.CreateTestTxWithFixtures(t, "FindVulnerabilities")
defer closeTest(t, datastore, tx) defer cleanup()
vuln, err := tx.FindVulnerabilities([]database.VulnerabilityID{ vuln, err := FindVulnerabilities(tx, []database.VulnerabilityID{
{Name: "CVE-OPENSSL-1-DEB7", Namespace: "debian:7"}, {Name: "CVE-OPENSSL-1-DEB7", Namespace: "debian:7"},
{Name: "CVE-NOPE", Namespace: "debian:7"}, {Name: "CVE-NOPE", Namespace: "debian:7"},
{Name: "CVE-NOT HERE"}, {Name: "CVE-NOT HERE"},
@ -255,7 +266,7 @@ func TestFindVulnerabilities(t *testing.T) {
expected, ok := expectedExistingMap[key] expected, ok := expectedExistingMap[key]
if assert.True(t, ok, "vulnerability not found: "+key.Name+":"+key.Namespace) { if assert.True(t, ok, "vulnerability not found: "+key.Name+":"+key.Namespace) {
assertVulnerabilityWithAffectedEqual(t, expected, v.VulnerabilityWithAffected) testutil.AssertVulnerabilityWithAffectedEqual(t, expected, v.VulnerabilityWithAffected)
} }
} else if !assert.Equal(t, nonexisting, v.VulnerabilityWithAffected) { } else if !assert.Equal(t, nonexisting, v.VulnerabilityWithAffected) {
t.FailNow() t.FailNow()
@ -264,7 +275,7 @@ func TestFindVulnerabilities(t *testing.T) {
} }
// same vulnerability // same vulnerability
r, err := tx.FindVulnerabilities([]database.VulnerabilityID{ r, err := FindVulnerabilities(tx, []database.VulnerabilityID{
{Name: "CVE-OPENSSL-1-DEB7", Namespace: "debian:7"}, {Name: "CVE-OPENSSL-1-DEB7", Namespace: "debian:7"},
{Name: "CVE-OPENSSL-1-DEB7", Namespace: "debian:7"}, {Name: "CVE-OPENSSL-1-DEB7", Namespace: "debian:7"},
}) })
@ -273,22 +284,22 @@ func TestFindVulnerabilities(t *testing.T) {
for _, vuln := range r { for _, vuln := range r {
if assert.True(t, vuln.Valid) { if assert.True(t, vuln.Valid) {
expected, _ := expectedExistingMap[database.VulnerabilityID{Name: "CVE-OPENSSL-1-DEB7", Namespace: "debian:7"}] expected, _ := expectedExistingMap[database.VulnerabilityID{Name: "CVE-OPENSSL-1-DEB7", Namespace: "debian:7"}]
assertVulnerabilityWithAffectedEqual(t, expected, vuln.VulnerabilityWithAffected) testutil.AssertVulnerabilityWithAffectedEqual(t, expected, vuln.VulnerabilityWithAffected)
} }
} }
} }
} }
func TestDeleteVulnerabilities(t *testing.T) { func TestDeleteVulnerabilities(t *testing.T) {
datastore, tx := openSessionForTest(t, "DeleteVulnerabilities", true) tx, cleanup := testutil.CreateTestTxWithFixtures(t, "DeleteVulnerabilities")
defer closeTest(t, datastore, tx) defer cleanup()
remove := []database.VulnerabilityID{} remove := []database.VulnerabilityID{}
// empty case // empty case
assert.Nil(t, tx.DeleteVulnerabilities(remove)) assert.Nil(t, DeleteVulnerabilities(tx, remove))
// invalid case // invalid case
remove = append(remove, database.VulnerabilityID{}) remove = append(remove, database.VulnerabilityID{})
assert.NotNil(t, tx.DeleteVulnerabilities(remove)) assert.NotNil(t, DeleteVulnerabilities(tx, remove))
// valid case // valid case
validRemove := []database.VulnerabilityID{ validRemove := []database.VulnerabilityID{
@ -296,8 +307,8 @@ func TestDeleteVulnerabilities(t *testing.T) {
{Name: "CVE-NOPE", Namespace: "debian:7"}, {Name: "CVE-NOPE", Namespace: "debian:7"},
} }
assert.Nil(t, tx.DeleteVulnerabilities(validRemove)) assert.Nil(t, DeleteVulnerabilities(tx, validRemove))
vuln, err := tx.FindVulnerabilities(validRemove) vuln, err := FindVulnerabilities(tx, validRemove)
if assert.Nil(t, err) { if assert.Nil(t, err) {
for _, v := range vuln { for _, v := range vuln {
assert.False(t, v.Valid) assert.False(t, v.Valid)
@ -306,20 +317,158 @@ func TestDeleteVulnerabilities(t *testing.T) {
} }
func TestFindVulnerabilityIDs(t *testing.T) { func TestFindVulnerabilityIDs(t *testing.T) {
store, tx := openSessionForTest(t, "FindVulnerabilityIDs", true) tx, cleanup := testutil.CreateTestTxWithFixtures(t, "FindVulnerabilityIDs")
defer closeTest(t, store, tx) defer cleanup()
ids, err := tx.findLatestDeletedVulnerabilityIDs([]database.VulnerabilityID{{Name: "CVE-DELETED", Namespace: "debian:7"}}) ids, err := FindLatestDeletedVulnerabilityIDs(tx, []database.VulnerabilityID{{Name: "CVE-DELETED", Namespace: "debian:7"}})
if assert.Nil(t, err) { if assert.Nil(t, err) {
if !(assert.Len(t, ids, 1) && assert.True(t, ids[0].Valid) && assert.Equal(t, 3, int(ids[0].Int64))) { if !(assert.Len(t, ids, 1) && assert.True(t, ids[0].Valid) && assert.Equal(t, 3, int(ids[0].Int64))) {
assert.Fail(t, "") assert.Fail(t, "")
} }
} }
ids, err = tx.findNotDeletedVulnerabilityIDs([]database.VulnerabilityID{{Name: "CVE-NOPE", Namespace: "debian:7"}}) ids, err = FindNotDeletedVulnerabilityIDs(tx, []database.VulnerabilityID{{Name: "CVE-NOPE", Namespace: "debian:7"}})
if assert.Nil(t, err) { if assert.Nil(t, err) {
if !(assert.Len(t, ids, 1) && assert.True(t, ids[0].Valid) && assert.Equal(t, 2, int(ids[0].Int64))) { if !(assert.Len(t, ids, 1) && assert.True(t, ids[0].Valid) && assert.Equal(t, 2, int(ids[0].Int64))) {
assert.Fail(t, "") assert.Fail(t, "")
} }
} }
} }
func TestFindAffectedNamespacedFeatures(t *testing.T) {
tx, cleanup := testutil.CreateTestTxWithFixtures(t, "FindAffectedNamespacedFeatures")
defer cleanup()
ns := database.NamespacedFeature{
Feature: database.Feature{
Name: "openssl",
Version: "1.0",
VersionFormat: "dpkg",
Type: database.SourcePackage,
},
Namespace: database.Namespace{
Name: "debian:7",
VersionFormat: "dpkg",
},
}
ans, err := FindAffectedNamespacedFeatures(tx, []database.NamespacedFeature{ns})
if assert.Nil(t, err) &&
assert.Len(t, ans, 1) &&
assert.True(t, ans[0].Valid) &&
assert.Len(t, ans[0].AffectedBy, 1) {
assert.Equal(t, "CVE-OPENSSL-1-DEB7", ans[0].AffectedBy[0].Name)
}
}
func genRandomVulnerabilityAndNamespacedFeature(t *testing.T, store *sql.DB) ([]database.NamespacedFeature, []database.VulnerabilityWithAffected) {
tx, err := store.Begin()
if err != nil {
panic(err)
}
numFeatures := 100
numVulnerabilities := 100
featureName := "TestFeature"
featureVersionFormat := dpkg.ParserName
// Insert the namespace on which we'll work.
ns := database.Namespace{
Name: "TestRaceAffectsFeatureNamespace1",
VersionFormat: dpkg.ParserName,
}
if !assert.Nil(t, namespace.PersistNamespaces(tx, []database.Namespace{ns})) {
t.FailNow()
}
// Generate Distinct random features
features := make([]database.Feature, numFeatures)
nsFeatures := make([]database.NamespacedFeature, numFeatures)
for i := 0; i < numFeatures; i++ {
version := rand.Intn(numFeatures)
features[i] = *database.NewSourcePackage(featureName, strconv.Itoa(version), featureVersionFormat)
nsFeatures[i] = database.NamespacedFeature{
Namespace: ns,
Feature: features[i],
}
}
if !assert.Nil(t, feature.PersistFeatures(tx, features)) {
t.FailNow()
}
// Generate vulnerabilities.
vulnerabilities := []database.VulnerabilityWithAffected{}
for i := 0; i < numVulnerabilities; i++ {
// any version less than this is vulnerable
version := rand.Intn(numFeatures) + 1
vulnerability := database.VulnerabilityWithAffected{
Vulnerability: database.Vulnerability{
Name: uuid.New(),
Namespace: ns,
Severity: database.UnknownSeverity,
},
Affected: []database.AffectedFeature{
{
Namespace: ns,
FeatureName: featureName,
FeatureType: database.SourcePackage,
AffectedVersion: strconv.Itoa(version),
FixedInVersion: strconv.Itoa(version),
},
},
}
vulnerabilities = append(vulnerabilities, vulnerability)
}
tx.Commit()
return nsFeatures, vulnerabilities
}
func TestVulnChangeAffectsVulnerableFeatures(t *testing.T) {
db, cleanup := testutil.CreateTestDB(t, "caching")
defer cleanup()
nsFeatures, vulnerabilities := genRandomVulnerabilityAndNamespacedFeature(t, db)
tx, err := db.Begin()
require.Nil(t, err)
require.Nil(t, feature.PersistNamespacedFeatures(tx, nsFeatures))
require.Nil(t, tx.Commit())
tx, err = db.Begin()
require.Nil(t, InsertVulnerabilities(tx, vulnerabilities))
require.Nil(t, tx.Commit())
tx, err = db.Begin()
require.Nil(t, err)
defer tx.Rollback()
affected, err := FindAffectedNamespacedFeatures(tx, nsFeatures)
require.Nil(t, err)
for _, ansf := range affected {
require.True(t, ansf.Valid)
expectedAffectedNames := []string{}
for _, vuln := range vulnerabilities {
if ok, err := versionfmt.InRange(dpkg.ParserName, ansf.Version, vuln.Affected[0].AffectedVersion); err == nil {
if ok {
expectedAffectedNames = append(expectedAffectedNames, vuln.Name)
}
}
}
actualAffectedNames := []string{}
for _, s := range ansf.AffectedBy {
actualAffectedNames = append(actualAffectedNames, s.Name)
}
require.Len(t, strutil.Difference(expectedAffectedNames, actualAffectedNames), 0, "\nvulns: %#v\nfeature:%#v\nexpected:%#v\nactual:%#v", vulnerabilities, ansf.NamespacedFeature, expectedAffectedNames, actualAffectedNames)
require.Len(t, strutil.Difference(actualAffectedNames, expectedAffectedNames), 0)
}
}

50
database/vulnerability.go Normal file
View File

@ -0,0 +1,50 @@
// Copyright 2019 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package database
// VulnerabilityID is an identifier for every vulnerability. Every vulnerability
// has unique namespace and name.
type VulnerabilityID struct {
Name string
Namespace string
}
// Vulnerability represents CVE or similar vulnerability reports.
type Vulnerability struct {
Name string
Namespace Namespace
Description string
Link string
Severity Severity
Metadata MetadataMap
}
// VulnerabilityWithAffected is a vulnerability with all known affected
// features.
type VulnerabilityWithAffected struct {
Vulnerability
Affected []AffectedFeature
}
// NullableVulnerability is a vulnerability with whether the vulnerability is
// found in datastore.
type NullableVulnerability struct {
VulnerabilityWithAffected
Valid bool
}