Merge pull request #721 from KeyboardNerd/cache
Restructure database folder
This commit is contained in:
commit
2c7838eac7
96
database/ancestry.go
Normal file
96
database/ancestry.go
Normal file
@ -0,0 +1,96 @@
|
||||
// Copyright 2019 clair authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package database
|
||||
|
||||
// Ancestry is a manifest that keeps all layers in an image in order.
|
||||
type Ancestry struct {
|
||||
// Name is a globally unique value for a set of layers. This is often the
|
||||
// sha256 digest of an OCI/Docker manifest.
|
||||
Name string `json:"name"`
|
||||
// By contains the processors that are used when computing the
|
||||
// content of this ancestry.
|
||||
By []Detector `json:"by"`
|
||||
// Layers should be ordered and i_th layer is the parent of i+1_th layer in
|
||||
// the slice.
|
||||
Layers []AncestryLayer `json:"layers"`
|
||||
}
|
||||
|
||||
// Valid checks if the ancestry is compliant to spec.
|
||||
func (a *Ancestry) Valid() bool {
|
||||
if a == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if a.Name == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, d := range a.By {
|
||||
if !d.Valid() {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
for _, l := range a.Layers {
|
||||
if !l.Valid() {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// AncestryLayer is a layer with all detected namespaced features.
|
||||
type AncestryLayer struct {
|
||||
// Hash is the sha-256 tarsum on the layer's blob content.
|
||||
Hash string `json:"hash"`
|
||||
// Features are the features introduced by this layer when it was
|
||||
// processed.
|
||||
Features []AncestryFeature `json:"features"`
|
||||
}
|
||||
|
||||
// Valid checks if the Ancestry Layer is compliant to the spec.
|
||||
func (l *AncestryLayer) Valid() bool {
|
||||
if l == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if l.Hash == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// GetFeatures returns the Ancestry's features.
|
||||
func (l *AncestryLayer) GetFeatures() []NamespacedFeature {
|
||||
nsf := make([]NamespacedFeature, 0, len(l.Features))
|
||||
for _, f := range l.Features {
|
||||
nsf = append(nsf, f.NamespacedFeature)
|
||||
}
|
||||
|
||||
return nsf
|
||||
}
|
||||
|
||||
// AncestryFeature is a namespaced feature with the detectors used to
|
||||
// find this feature.
|
||||
type AncestryFeature struct {
|
||||
NamespacedFeature `json:"namespacedFeature"`
|
||||
|
||||
// FeatureBy is the detector that detected the feature.
|
||||
FeatureBy Detector `json:"featureBy"`
|
||||
// NamespaceBy is the detector that detected the namespace.
|
||||
NamespaceBy Detector `json:"namespaceBy"`
|
||||
}
|
96
database/feature.go
Normal file
96
database/feature.go
Normal file
@ -0,0 +1,96 @@
|
||||
// Copyright 2019 clair authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package database
|
||||
|
||||
// Feature represents a package detected in a layer but the namespace is not
|
||||
// determined.
|
||||
//
|
||||
// e.g. Name: Libssl1.0, Version: 1.0, VersionFormat: dpkg, Type: binary
|
||||
// dpkg is the version format of the installer package manager, which in this
|
||||
// case could be dpkg or apk.
|
||||
type Feature struct {
|
||||
Name string `json:"name"`
|
||||
Version string `json:"version"`
|
||||
VersionFormat string `json:"versionFormat"`
|
||||
Type FeatureType `json:"type"`
|
||||
}
|
||||
|
||||
// NamespacedFeature is a feature with determined namespace and can be affected
|
||||
// by vulnerabilities.
|
||||
//
|
||||
// e.g. OpenSSL 1.0 dpkg Debian:7.
|
||||
type NamespacedFeature struct {
|
||||
Feature `json:"feature"`
|
||||
|
||||
Namespace Namespace `json:"namespace"`
|
||||
}
|
||||
|
||||
// AffectedNamespacedFeature is a namespaced feature affected by the
|
||||
// vulnerabilities with fixed-in versions for this feature.
|
||||
type AffectedNamespacedFeature struct {
|
||||
NamespacedFeature
|
||||
|
||||
AffectedBy []VulnerabilityWithFixedIn
|
||||
}
|
||||
|
||||
// VulnerabilityWithFixedIn is used for AffectedNamespacedFeature to retrieve
|
||||
// the affecting vulnerabilities and the fixed-in versions for the feature.
|
||||
type VulnerabilityWithFixedIn struct {
|
||||
Vulnerability
|
||||
|
||||
FixedInVersion string
|
||||
}
|
||||
|
||||
// AffectedFeature is used to determine whether a namespaced feature is affected
|
||||
// by a Vulnerability. Namespace and Feature Name is unique. Affected Feature is
|
||||
// bound to vulnerability.
|
||||
type AffectedFeature struct {
|
||||
// FeatureType determines which type of package it affects.
|
||||
FeatureType FeatureType
|
||||
Namespace Namespace
|
||||
FeatureName string
|
||||
// FixedInVersion is known next feature version that's not affected by the
|
||||
// vulnerability. Empty FixedInVersion means the unaffected version is
|
||||
// unknown.
|
||||
FixedInVersion string
|
||||
// AffectedVersion contains the version range to determine whether or not a
|
||||
// feature is affected.
|
||||
AffectedVersion string
|
||||
}
|
||||
|
||||
// NullableAffectedNamespacedFeature is an affectednamespacedfeature with
|
||||
// whether it's found in datastore.
|
||||
type NullableAffectedNamespacedFeature struct {
|
||||
AffectedNamespacedFeature
|
||||
|
||||
Valid bool
|
||||
}
|
||||
|
||||
func NewFeature(name string, version string, versionFormat string, featureType FeatureType) *Feature {
|
||||
return &Feature{name, version, versionFormat, featureType}
|
||||
}
|
||||
|
||||
func NewBinaryPackage(name string, version string, versionFormat string) *Feature {
|
||||
return &Feature{name, version, versionFormat, BinaryPackage}
|
||||
}
|
||||
|
||||
func NewSourcePackage(name string, version string, versionFormat string) *Feature {
|
||||
return &Feature{name, version, versionFormat, SourcePackage}
|
||||
}
|
||||
|
||||
func NewNamespacedFeature(namespace *Namespace, feature *Feature) *NamespacedFeature {
|
||||
// TODO: namespaced feature should use pointer values
|
||||
return &NamespacedFeature{*feature, *namespace}
|
||||
}
|
65
database/layer.go
Normal file
65
database/layer.go
Normal file
@ -0,0 +1,65 @@
|
||||
// Copyright 2019 clair authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package database
|
||||
|
||||
// Layer is a layer with all the detected features and namespaces.
|
||||
type Layer struct {
|
||||
// Hash is the sha-256 tarsum on the layer's blob content.
|
||||
Hash string `json:"hash"`
|
||||
// By contains a list of detectors scanned this Layer.
|
||||
By []Detector `json:"by"`
|
||||
Namespaces []LayerNamespace `json:"namespaces"`
|
||||
Features []LayerFeature `json:"features"`
|
||||
}
|
||||
|
||||
func (l *Layer) GetFeatures() []Feature {
|
||||
features := make([]Feature, 0, len(l.Features))
|
||||
for _, f := range l.Features {
|
||||
features = append(features, f.Feature)
|
||||
}
|
||||
|
||||
return features
|
||||
}
|
||||
|
||||
func (l *Layer) GetNamespaces() []Namespace {
|
||||
namespaces := make([]Namespace, 0, len(l.Namespaces)+len(l.Features))
|
||||
for _, ns := range l.Namespaces {
|
||||
namespaces = append(namespaces, ns.Namespace)
|
||||
}
|
||||
for _, f := range l.Features {
|
||||
if f.PotentialNamespace.Valid() {
|
||||
namespaces = append(namespaces, f.PotentialNamespace)
|
||||
}
|
||||
}
|
||||
|
||||
return namespaces
|
||||
}
|
||||
|
||||
// LayerNamespace is a namespace with detection information.
|
||||
type LayerNamespace struct {
|
||||
Namespace `json:"namespace"`
|
||||
|
||||
// By is the detector found the namespace.
|
||||
By Detector `json:"by"`
|
||||
}
|
||||
|
||||
// LayerFeature is a feature with detection information.
|
||||
type LayerFeature struct {
|
||||
Feature `json:"feature"`
|
||||
|
||||
// By is the detector found the feature.
|
||||
By Detector `json:"by"`
|
||||
PotentialNamespace Namespace `json:"potentialNamespace"`
|
||||
}
|
41
database/metadata.go
Normal file
41
database/metadata.go
Normal file
@ -0,0 +1,41 @@
|
||||
// Copyright 2019 clair authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
// MetadataMap is for storing the metadata returned by vulnerability database.
|
||||
type MetadataMap map[string]interface{}
|
||||
|
||||
func (mm *MetadataMap) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// github.com/lib/pq decodes TEXT/VARCHAR fields into strings.
|
||||
val, ok := value.(string)
|
||||
if !ok {
|
||||
panic("got type other than []byte from database")
|
||||
}
|
||||
return json.Unmarshal([]byte(val), mm)
|
||||
}
|
||||
|
||||
func (mm *MetadataMap) Value() (driver.Value, error) {
|
||||
json, err := json.Marshal(*mm)
|
||||
return string(json), err
|
||||
}
|
@ -1,363 +0,0 @@
|
||||
// Copyright 2017 clair authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/clair/pkg/pagination"
|
||||
)
|
||||
|
||||
// Ancestry is a manifest that keeps all layers in an image in order.
|
||||
type Ancestry struct {
|
||||
// Name is a globally unique value for a set of layers. This is often the
|
||||
// sha256 digest of an OCI/Docker manifest.
|
||||
Name string `json:"name"`
|
||||
// By contains the processors that are used when computing the
|
||||
// content of this ancestry.
|
||||
By []Detector `json:"by"`
|
||||
// Layers should be ordered and i_th layer is the parent of i+1_th layer in
|
||||
// the slice.
|
||||
Layers []AncestryLayer `json:"layers"`
|
||||
}
|
||||
|
||||
// Valid checks if the ancestry is compliant to spec.
|
||||
func (a *Ancestry) Valid() bool {
|
||||
if a == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if a.Name == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, d := range a.By {
|
||||
if !d.Valid() {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
for _, l := range a.Layers {
|
||||
if !l.Valid() {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// AncestryLayer is a layer with all detected namespaced features.
|
||||
type AncestryLayer struct {
|
||||
// Hash is the sha-256 tarsum on the layer's blob content.
|
||||
Hash string `json:"hash"`
|
||||
// Features are the features introduced by this layer when it was
|
||||
// processed.
|
||||
Features []AncestryFeature `json:"features"`
|
||||
}
|
||||
|
||||
// Valid checks if the Ancestry Layer is compliant to the spec.
|
||||
func (l *AncestryLayer) Valid() bool {
|
||||
if l == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if l.Hash == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// GetFeatures returns the Ancestry's features.
|
||||
func (l *AncestryLayer) GetFeatures() []NamespacedFeature {
|
||||
nsf := make([]NamespacedFeature, 0, len(l.Features))
|
||||
for _, f := range l.Features {
|
||||
nsf = append(nsf, f.NamespacedFeature)
|
||||
}
|
||||
|
||||
return nsf
|
||||
}
|
||||
|
||||
// AncestryFeature is a namespaced feature with the detectors used to
|
||||
// find this feature.
|
||||
type AncestryFeature struct {
|
||||
NamespacedFeature `json:"namespacedFeature"`
|
||||
|
||||
// FeatureBy is the detector that detected the feature.
|
||||
FeatureBy Detector `json:"featureBy"`
|
||||
// NamespaceBy is the detector that detected the namespace.
|
||||
NamespaceBy Detector `json:"namespaceBy"`
|
||||
}
|
||||
|
||||
// Layer is a layer with all the detected features and namespaces.
|
||||
type Layer struct {
|
||||
// Hash is the sha-256 tarsum on the layer's blob content.
|
||||
Hash string `json:"hash"`
|
||||
// By contains a list of detectors scanned this Layer.
|
||||
By []Detector `json:"by"`
|
||||
Namespaces []LayerNamespace `json:"namespaces"`
|
||||
Features []LayerFeature `json:"features"`
|
||||
}
|
||||
|
||||
func (l *Layer) GetFeatures() []Feature {
|
||||
features := make([]Feature, 0, len(l.Features))
|
||||
for _, f := range l.Features {
|
||||
features = append(features, f.Feature)
|
||||
}
|
||||
|
||||
return features
|
||||
}
|
||||
|
||||
func (l *Layer) GetNamespaces() []Namespace {
|
||||
namespaces := make([]Namespace, 0, len(l.Namespaces)+len(l.Features))
|
||||
for _, ns := range l.Namespaces {
|
||||
namespaces = append(namespaces, ns.Namespace)
|
||||
}
|
||||
for _, f := range l.Features {
|
||||
if f.PotentialNamespace.Valid() {
|
||||
namespaces = append(namespaces, f.PotentialNamespace)
|
||||
}
|
||||
}
|
||||
|
||||
return namespaces
|
||||
}
|
||||
|
||||
// LayerNamespace is a namespace with detection information.
|
||||
type LayerNamespace struct {
|
||||
Namespace `json:"namespace"`
|
||||
|
||||
// By is the detector found the namespace.
|
||||
By Detector `json:"by"`
|
||||
}
|
||||
|
||||
// LayerFeature is a feature with detection information.
|
||||
type LayerFeature struct {
|
||||
Feature `json:"feature"`
|
||||
|
||||
// By is the detector found the feature.
|
||||
By Detector `json:"by"`
|
||||
PotentialNamespace Namespace `json:"potentialNamespace"`
|
||||
}
|
||||
|
||||
// Namespace is the contextual information around features.
|
||||
//
|
||||
// e.g. Debian:7, NodeJS.
|
||||
type Namespace struct {
|
||||
Name string `json:"name"`
|
||||
VersionFormat string `json:"versionFormat"`
|
||||
}
|
||||
|
||||
func NewNamespace(name string, versionFormat string) *Namespace {
|
||||
return &Namespace{name, versionFormat}
|
||||
}
|
||||
|
||||
func (ns *Namespace) Valid() bool {
|
||||
if ns.Name == "" || ns.VersionFormat == "" {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Feature represents a package detected in a layer but the namespace is not
|
||||
// determined.
|
||||
//
|
||||
// e.g. Name: Libssl1.0, Version: 1.0, VersionFormat: dpkg, Type: binary
|
||||
// dpkg is the version format of the installer package manager, which in this
|
||||
// case could be dpkg or apk.
|
||||
type Feature struct {
|
||||
Name string `json:"name"`
|
||||
Version string `json:"version"`
|
||||
VersionFormat string `json:"versionFormat"`
|
||||
Type FeatureType `json:"type"`
|
||||
}
|
||||
|
||||
func NewFeature(name string, version string, versionFormat string, featureType FeatureType) *Feature {
|
||||
return &Feature{name, version, versionFormat, featureType}
|
||||
}
|
||||
|
||||
func NewBinaryPackage(name string, version string, versionFormat string) *Feature {
|
||||
return &Feature{name, version, versionFormat, BinaryPackage}
|
||||
}
|
||||
|
||||
func NewSourcePackage(name string, version string, versionFormat string) *Feature {
|
||||
return &Feature{name, version, versionFormat, SourcePackage}
|
||||
}
|
||||
|
||||
// NamespacedFeature is a feature with determined namespace and can be affected
|
||||
// by vulnerabilities.
|
||||
//
|
||||
// e.g. OpenSSL 1.0 dpkg Debian:7.
|
||||
type NamespacedFeature struct {
|
||||
Feature `json:"feature"`
|
||||
|
||||
Namespace Namespace `json:"namespace"`
|
||||
}
|
||||
|
||||
func (nf *NamespacedFeature) Key() string {
|
||||
return fmt.Sprintf("%s-%s-%s-%s-%s-%s", nf.Name, nf.Version, nf.VersionFormat, nf.Type, nf.Namespace.Name, nf.Namespace.VersionFormat)
|
||||
}
|
||||
|
||||
func NewNamespacedFeature(namespace *Namespace, feature *Feature) *NamespacedFeature {
|
||||
// TODO: namespaced feature should use pointer values
|
||||
return &NamespacedFeature{*feature, *namespace}
|
||||
}
|
||||
|
||||
// AffectedNamespacedFeature is a namespaced feature affected by the
|
||||
// vulnerabilities with fixed-in versions for this feature.
|
||||
type AffectedNamespacedFeature struct {
|
||||
NamespacedFeature
|
||||
|
||||
AffectedBy []VulnerabilityWithFixedIn
|
||||
}
|
||||
|
||||
// VulnerabilityWithFixedIn is used for AffectedNamespacedFeature to retrieve
|
||||
// the affecting vulnerabilities and the fixed-in versions for the feature.
|
||||
type VulnerabilityWithFixedIn struct {
|
||||
Vulnerability
|
||||
|
||||
FixedInVersion string
|
||||
}
|
||||
|
||||
// AffectedFeature is used to determine whether a namespaced feature is affected
|
||||
// by a Vulnerability. Namespace and Feature Name is unique. Affected Feature is
|
||||
// bound to vulnerability.
|
||||
type AffectedFeature struct {
|
||||
// FeatureType determines which type of package it affects.
|
||||
FeatureType FeatureType
|
||||
Namespace Namespace
|
||||
FeatureName string
|
||||
// FixedInVersion is known next feature version that's not affected by the
|
||||
// vulnerability. Empty FixedInVersion means the unaffected version is
|
||||
// unknown.
|
||||
FixedInVersion string
|
||||
// AffectedVersion contains the version range to determine whether or not a
|
||||
// feature is affected.
|
||||
AffectedVersion string
|
||||
}
|
||||
|
||||
// VulnerabilityID is an identifier for every vulnerability. Every vulnerability
|
||||
// has unique namespace and name.
|
||||
type VulnerabilityID struct {
|
||||
Name string
|
||||
Namespace string
|
||||
}
|
||||
|
||||
// Vulnerability represents CVE or similar vulnerability reports.
|
||||
type Vulnerability struct {
|
||||
Name string
|
||||
Namespace Namespace
|
||||
|
||||
Description string
|
||||
Link string
|
||||
Severity Severity
|
||||
|
||||
Metadata MetadataMap
|
||||
}
|
||||
|
||||
// VulnerabilityWithAffected is a vulnerability with all known affected
|
||||
// features.
|
||||
type VulnerabilityWithAffected struct {
|
||||
Vulnerability
|
||||
|
||||
Affected []AffectedFeature
|
||||
}
|
||||
|
||||
// PagedVulnerableAncestries is a vulnerability with a page of affected
|
||||
// ancestries each with a special index attached for streaming purpose. The
|
||||
// current page number and next page number are for navigate.
|
||||
type PagedVulnerableAncestries struct {
|
||||
Vulnerability
|
||||
|
||||
// Affected is a map of special indexes to Ancestries, which the pair
|
||||
// should be unique in a stream. Every indexes in the map should be larger
|
||||
// than previous page.
|
||||
Affected map[int]string
|
||||
|
||||
Limit int
|
||||
Current pagination.Token
|
||||
Next pagination.Token
|
||||
|
||||
// End signals the end of the pages.
|
||||
End bool
|
||||
}
|
||||
|
||||
// NotificationHook is a message sent to another service to inform of a change
|
||||
// to a Vulnerability or the Ancestries affected by a Vulnerability. It contains
|
||||
// the name of a notification that should be read and marked as read via the
|
||||
// API.
|
||||
type NotificationHook struct {
|
||||
Name string
|
||||
|
||||
Created time.Time
|
||||
Notified time.Time
|
||||
Deleted time.Time
|
||||
}
|
||||
|
||||
// VulnerabilityNotification is a notification for vulnerability changes.
|
||||
type VulnerabilityNotification struct {
|
||||
NotificationHook
|
||||
|
||||
Old *Vulnerability
|
||||
New *Vulnerability
|
||||
}
|
||||
|
||||
// VulnerabilityNotificationWithVulnerable is a notification for vulnerability
|
||||
// changes with vulnerable ancestries.
|
||||
type VulnerabilityNotificationWithVulnerable struct {
|
||||
NotificationHook
|
||||
|
||||
Old *PagedVulnerableAncestries
|
||||
New *PagedVulnerableAncestries
|
||||
}
|
||||
|
||||
// MetadataMap is for storing the metadata returned by vulnerability database.
|
||||
type MetadataMap map[string]interface{}
|
||||
|
||||
// NullableAffectedNamespacedFeature is an affectednamespacedfeature with
|
||||
// whether it's found in datastore.
|
||||
type NullableAffectedNamespacedFeature struct {
|
||||
AffectedNamespacedFeature
|
||||
|
||||
Valid bool
|
||||
}
|
||||
|
||||
// NullableVulnerability is a vulnerability with whether the vulnerability is
|
||||
// found in datastore.
|
||||
type NullableVulnerability struct {
|
||||
VulnerabilityWithAffected
|
||||
|
||||
Valid bool
|
||||
}
|
||||
|
||||
func (mm *MetadataMap) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// github.com/lib/pq decodes TEXT/VARCHAR fields into strings.
|
||||
val, ok := value.(string)
|
||||
if !ok {
|
||||
panic("got type other than []byte from database")
|
||||
}
|
||||
return json.Unmarshal([]byte(val), mm)
|
||||
}
|
||||
|
||||
func (mm *MetadataMap) Value() (driver.Value, error) {
|
||||
json, err := json.Marshal(*mm)
|
||||
return string(json), err
|
||||
}
|
34
database/namespace.go
Normal file
34
database/namespace.go
Normal file
@ -0,0 +1,34 @@
|
||||
// Copyright 2019 clair authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package database
|
||||
|
||||
// Namespace is the contextual information around features.
|
||||
//
|
||||
// e.g. Debian:7, NodeJS.
|
||||
type Namespace struct {
|
||||
Name string `json:"name"`
|
||||
VersionFormat string `json:"versionFormat"`
|
||||
}
|
||||
|
||||
func NewNamespace(name string, versionFormat string) *Namespace {
|
||||
return &Namespace{name, versionFormat}
|
||||
}
|
||||
|
||||
func (ns *Namespace) Valid() bool {
|
||||
if ns.Name == "" || ns.VersionFormat == "" {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
69
database/notification.go
Normal file
69
database/notification.go
Normal file
@ -0,0 +1,69 @@
|
||||
// Copyright 2019 clair authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/coreos/clair/pkg/pagination"
|
||||
)
|
||||
|
||||
// NotificationHook is a message sent to another service to inform of a change
|
||||
// to a Vulnerability or the Ancestries affected by a Vulnerability. It contains
|
||||
// the name of a notification that should be read and marked as read via the
|
||||
// API.
|
||||
type NotificationHook struct {
|
||||
Name string
|
||||
|
||||
Created time.Time
|
||||
Notified time.Time
|
||||
Deleted time.Time
|
||||
}
|
||||
|
||||
// VulnerabilityNotification is a notification for vulnerability changes.
|
||||
type VulnerabilityNotification struct {
|
||||
NotificationHook
|
||||
|
||||
Old *Vulnerability
|
||||
New *Vulnerability
|
||||
}
|
||||
|
||||
// VulnerabilityNotificationWithVulnerable is a notification for vulnerability
|
||||
// changes with vulnerable ancestries.
|
||||
type VulnerabilityNotificationWithVulnerable struct {
|
||||
NotificationHook
|
||||
|
||||
Old *PagedVulnerableAncestries
|
||||
New *PagedVulnerableAncestries
|
||||
}
|
||||
|
||||
// PagedVulnerableAncestries is a vulnerability with a page of affected
|
||||
// ancestries each with a special index attached for streaming purpose. The
|
||||
// current page number and next page number are for navigate.
|
||||
type PagedVulnerableAncestries struct {
|
||||
Vulnerability
|
||||
|
||||
// Affected is a map of special indexes to Ancestries, which the pair
|
||||
// should be unique in a stream. Every indexes in the map should be larger
|
||||
// than previous page.
|
||||
Affected map[int]string
|
||||
|
||||
Limit int
|
||||
Current pagination.Token
|
||||
Next pagination.Token
|
||||
|
||||
// End signals the end of the pages.
|
||||
End bool
|
||||
}
|
@ -1,375 +0,0 @@
|
||||
// Copyright 2019 clair authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package pgsql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
"github.com/coreos/clair/pkg/commonerr"
|
||||
)
|
||||
|
||||
const (
|
||||
insertAncestry = `
|
||||
INSERT INTO ancestry (name) VALUES ($1) RETURNING id`
|
||||
|
||||
findAncestryLayerHashes = `
|
||||
SELECT layer.hash, ancestry_layer.ancestry_index
|
||||
FROM layer, ancestry_layer
|
||||
WHERE ancestry_layer.ancestry_id = $1
|
||||
AND ancestry_layer.layer_id = layer.id
|
||||
ORDER BY ancestry_layer.ancestry_index ASC`
|
||||
|
||||
findAncestryFeatures = `
|
||||
SELECT namespace.name, namespace.version_format, feature.name,
|
||||
feature.version, feature.version_format, feature_type.name, ancestry_layer.ancestry_index,
|
||||
ancestry_feature.feature_detector_id, ancestry_feature.namespace_detector_id
|
||||
FROM namespace, feature, feature_type, namespaced_feature, ancestry_layer, ancestry_feature
|
||||
WHERE ancestry_layer.ancestry_id = $1
|
||||
AND feature_type.id = feature.type
|
||||
AND ancestry_feature.ancestry_layer_id = ancestry_layer.id
|
||||
AND ancestry_feature.namespaced_feature_id = namespaced_feature.id
|
||||
AND namespaced_feature.feature_id = feature.id
|
||||
AND namespaced_feature.namespace_id = namespace.id`
|
||||
|
||||
findAncestryID = `SELECT id FROM ancestry WHERE name = $1`
|
||||
removeAncestry = `DELETE FROM ancestry WHERE name = $1`
|
||||
insertAncestryLayers = `
|
||||
INSERT INTO ancestry_layer (ancestry_id, ancestry_index, layer_id) VALUES ($1, $2, $3)
|
||||
RETURNING id`
|
||||
insertAncestryFeatures = `
|
||||
INSERT INTO ancestry_feature
|
||||
(ancestry_layer_id, namespaced_feature_id, feature_detector_id, namespace_detector_id) VALUES
|
||||
($1, $2, $3, $4)`
|
||||
)
|
||||
|
||||
func (tx *pgSession) FindAncestry(name string) (database.Ancestry, bool, error) {
|
||||
var (
|
||||
ancestry = database.Ancestry{Name: name}
|
||||
err error
|
||||
)
|
||||
|
||||
id, ok, err := tx.findAncestryID(name)
|
||||
if !ok || err != nil {
|
||||
return ancestry, ok, err
|
||||
}
|
||||
|
||||
if ancestry.By, err = tx.findAncestryDetectors(id); err != nil {
|
||||
return ancestry, false, err
|
||||
}
|
||||
|
||||
if ancestry.Layers, err = tx.findAncestryLayers(id); err != nil {
|
||||
return ancestry, false, err
|
||||
}
|
||||
|
||||
return ancestry, true, nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) UpsertAncestry(ancestry database.Ancestry) error {
|
||||
if !ancestry.Valid() {
|
||||
return database.ErrInvalidParameters
|
||||
}
|
||||
|
||||
if err := tx.removeAncestry(ancestry.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
id, err := tx.insertAncestry(ancestry.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
detectorIDs, err := tx.findDetectorIDs(ancestry.By)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// insert ancestry metadata
|
||||
if err := tx.insertAncestryDetectors(id, detectorIDs); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
layers := make([]string, 0, len(ancestry.Layers))
|
||||
for _, layer := range ancestry.Layers {
|
||||
layers = append(layers, layer.Hash)
|
||||
}
|
||||
|
||||
layerIDs, ok, err := tx.findLayerIDs(layers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !ok {
|
||||
log.Error("layer cannot be found, this indicates that the internal logic of calling UpsertAncestry is wrong or the database is corrupted.")
|
||||
return database.ErrMissingEntities
|
||||
}
|
||||
|
||||
ancestryLayerIDs, err := tx.insertAncestryLayers(id, layerIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i, id := range ancestryLayerIDs {
|
||||
if err := tx.insertAncestryFeatures(id, ancestry.Layers[i]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) insertAncestry(name string) (int64, error) {
|
||||
var id int64
|
||||
err := tx.QueryRow(insertAncestry, name).Scan(&id)
|
||||
if err != nil {
|
||||
if isErrUniqueViolation(err) {
|
||||
return 0, handleError("insertAncestry", errors.New("other Go-routine is processing this ancestry (skip)"))
|
||||
}
|
||||
|
||||
return 0, handleError("insertAncestry", err)
|
||||
}
|
||||
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) findAncestryID(name string) (int64, bool, error) {
|
||||
var id sql.NullInt64
|
||||
if err := tx.QueryRow(findAncestryID, name).Scan(&id); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return 0, false, nil
|
||||
}
|
||||
|
||||
return 0, false, handleError("findAncestryID", err)
|
||||
}
|
||||
|
||||
return id.Int64, true, nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) removeAncestry(name string) error {
|
||||
result, err := tx.Exec(removeAncestry, name)
|
||||
if err != nil {
|
||||
return handleError("removeAncestry", err)
|
||||
}
|
||||
|
||||
affected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return handleError("removeAncestry", err)
|
||||
}
|
||||
|
||||
if affected != 0 {
|
||||
log.WithField("ancestry", name).Debug("removed ancestry")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) findAncestryLayers(id int64) ([]database.AncestryLayer, error) {
|
||||
detectors, err := tx.findAllDetectors()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
layerMap, err := tx.findAncestryLayerHashes(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
featureMap, err := tx.findAncestryFeatures(id, detectors)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
layers := make([]database.AncestryLayer, len(layerMap))
|
||||
for index, layer := range layerMap {
|
||||
// index MUST match the ancestry layer slice index.
|
||||
if layers[index].Hash == "" && len(layers[index].Features) == 0 {
|
||||
layers[index] = database.AncestryLayer{
|
||||
Hash: layer,
|
||||
Features: featureMap[index],
|
||||
}
|
||||
} else {
|
||||
log.WithFields(log.Fields{
|
||||
"ancestry ID": id,
|
||||
"duplicated ancestry index": index,
|
||||
}).WithError(database.ErrInconsistent).Error("ancestry layers with same ancestry_index is not allowed")
|
||||
return nil, database.ErrInconsistent
|
||||
}
|
||||
}
|
||||
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) findAncestryLayerHashes(ancestryID int64) (map[int64]string, error) {
|
||||
// retrieve layer indexes and hashes
|
||||
rows, err := tx.Query(findAncestryLayerHashes, ancestryID)
|
||||
if err != nil {
|
||||
return nil, handleError("findAncestryLayerHashes", err)
|
||||
}
|
||||
|
||||
layerHashes := map[int64]string{}
|
||||
for rows.Next() {
|
||||
var (
|
||||
hash string
|
||||
index int64
|
||||
)
|
||||
|
||||
if err = rows.Scan(&hash, &index); err != nil {
|
||||
return nil, handleError("findAncestryLayerHashes", err)
|
||||
}
|
||||
|
||||
if _, ok := layerHashes[index]; ok {
|
||||
// one ancestry index should correspond to only one layer
|
||||
return nil, database.ErrInconsistent
|
||||
}
|
||||
|
||||
layerHashes[index] = hash
|
||||
}
|
||||
|
||||
return layerHashes, nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) findAncestryFeatures(ancestryID int64, detectors detectorMap) (map[int64][]database.AncestryFeature, error) {
|
||||
// ancestry_index -> ancestry features
|
||||
featureMap := make(map[int64][]database.AncestryFeature)
|
||||
// retrieve ancestry layer's namespaced features
|
||||
rows, err := tx.Query(findAncestryFeatures, ancestryID)
|
||||
if err != nil {
|
||||
return nil, handleError("findAncestryFeatures", err)
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var (
|
||||
featureDetectorID int64
|
||||
namespaceDetectorID int64
|
||||
feature database.NamespacedFeature
|
||||
// index is used to determine which layer the feature belongs to.
|
||||
index sql.NullInt64
|
||||
)
|
||||
|
||||
if err := rows.Scan(
|
||||
&feature.Namespace.Name,
|
||||
&feature.Namespace.VersionFormat,
|
||||
&feature.Feature.Name,
|
||||
&feature.Feature.Version,
|
||||
&feature.Feature.VersionFormat,
|
||||
&feature.Feature.Type,
|
||||
&index,
|
||||
&featureDetectorID,
|
||||
&namespaceDetectorID,
|
||||
); err != nil {
|
||||
return nil, handleError("findAncestryFeatures", err)
|
||||
}
|
||||
|
||||
if feature.Feature.VersionFormat != feature.Namespace.VersionFormat {
|
||||
// Feature must have the same version format as the associated
|
||||
// namespace version format.
|
||||
return nil, database.ErrInconsistent
|
||||
}
|
||||
|
||||
fDetector, ok := detectors.byID[featureDetectorID]
|
||||
if !ok {
|
||||
return nil, database.ErrInconsistent
|
||||
}
|
||||
|
||||
nsDetector, ok := detectors.byID[namespaceDetectorID]
|
||||
if !ok {
|
||||
return nil, database.ErrInconsistent
|
||||
}
|
||||
|
||||
featureMap[index.Int64] = append(featureMap[index.Int64], database.AncestryFeature{
|
||||
NamespacedFeature: feature,
|
||||
FeatureBy: fDetector,
|
||||
NamespaceBy: nsDetector,
|
||||
})
|
||||
}
|
||||
|
||||
return featureMap, nil
|
||||
}
|
||||
|
||||
// insertAncestryLayers inserts the ancestry layers along with its content into
|
||||
// the database. The layers are 0 based indexed in the original order.
|
||||
func (tx *pgSession) insertAncestryLayers(ancestryID int64, layers []int64) ([]int64, error) {
|
||||
stmt, err := tx.Prepare(insertAncestryLayers)
|
||||
if err != nil {
|
||||
return nil, handleError("insertAncestryLayers", err)
|
||||
}
|
||||
|
||||
ancestryLayerIDs := []int64{}
|
||||
for index, layerID := range layers {
|
||||
var ancestryLayerID sql.NullInt64
|
||||
if err := stmt.QueryRow(ancestryID, index, layerID).Scan(&ancestryLayerID); err != nil {
|
||||
return nil, handleError("insertAncestryLayers", commonerr.CombineErrors(err, stmt.Close()))
|
||||
}
|
||||
|
||||
if !ancestryLayerID.Valid {
|
||||
return nil, database.ErrInconsistent
|
||||
}
|
||||
|
||||
ancestryLayerIDs = append(ancestryLayerIDs, ancestryLayerID.Int64)
|
||||
}
|
||||
|
||||
if err := stmt.Close(); err != nil {
|
||||
return nil, handleError("insertAncestryLayers", err)
|
||||
}
|
||||
|
||||
return ancestryLayerIDs, nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) insertAncestryFeatures(ancestryLayerID int64, layer database.AncestryLayer) error {
|
||||
detectors, err := tx.findAllDetectors()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nsFeatureIDs, err := tx.findNamespacedFeatureIDs(layer.GetFeatures())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// find the detectors for each feature
|
||||
stmt, err := tx.Prepare(insertAncestryFeatures)
|
||||
if err != nil {
|
||||
return handleError("insertAncestryFeatures", err)
|
||||
}
|
||||
|
||||
defer stmt.Close()
|
||||
|
||||
for index, id := range nsFeatureIDs {
|
||||
if !id.Valid {
|
||||
return database.ErrMissingEntities
|
||||
}
|
||||
|
||||
namespaceDetectorID, ok := detectors.byValue[layer.Features[index].NamespaceBy]
|
||||
if !ok {
|
||||
return database.ErrMissingEntities
|
||||
}
|
||||
|
||||
featureDetectorID, ok := detectors.byValue[layer.Features[index].FeatureBy]
|
||||
if !ok {
|
||||
return database.ErrMissingEntities
|
||||
}
|
||||
|
||||
if _, err := stmt.Exec(ancestryLayerID, id, featureDetectorID, namespaceDetectorID); err != nil {
|
||||
return handleError("insertAncestryFeatures", commonerr.CombineErrors(err, stmt.Close()))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
160
database/pgsql/ancestry/ancestry.go
Normal file
160
database/pgsql/ancestry/ancestry.go
Normal file
@ -0,0 +1,160 @@
|
||||
// Copyright 2019 clair authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package ancestry
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
"github.com/coreos/clair/database/pgsql/detector"
|
||||
"github.com/coreos/clair/database/pgsql/layer"
|
||||
"github.com/coreos/clair/database/pgsql/util"
|
||||
)
|
||||
|
||||
const (
|
||||
insertAncestry = `
|
||||
INSERT INTO ancestry (name) VALUES ($1) RETURNING id`
|
||||
|
||||
findAncestryID = `SELECT id FROM ancestry WHERE name = $1`
|
||||
removeAncestry = `DELETE FROM ancestry WHERE name = $1`
|
||||
|
||||
insertAncestryFeatures = `
|
||||
INSERT INTO ancestry_feature
|
||||
(ancestry_layer_id, namespaced_feature_id, feature_detector_id, namespace_detector_id) VALUES
|
||||
($1, $2, $3, $4)`
|
||||
)
|
||||
|
||||
func FindAncestry(tx *sql.Tx, name string) (database.Ancestry, bool, error) {
|
||||
var (
|
||||
ancestry = database.Ancestry{Name: name}
|
||||
err error
|
||||
)
|
||||
|
||||
id, ok, err := FindAncestryID(tx, name)
|
||||
if !ok || err != nil {
|
||||
return ancestry, ok, err
|
||||
}
|
||||
|
||||
if ancestry.By, err = FindAncestryDetectors(tx, id); err != nil {
|
||||
return ancestry, false, err
|
||||
}
|
||||
|
||||
if ancestry.Layers, err = FindAncestryLayers(tx, id); err != nil {
|
||||
return ancestry, false, err
|
||||
}
|
||||
|
||||
return ancestry, true, nil
|
||||
}
|
||||
|
||||
func UpsertAncestry(tx *sql.Tx, ancestry database.Ancestry) error {
|
||||
if !ancestry.Valid() {
|
||||
return database.ErrInvalidParameters
|
||||
}
|
||||
|
||||
if err := RemoveAncestry(tx, ancestry.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
id, err := InsertAncestry(tx, ancestry.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
detectorIDs, err := detector.FindDetectorIDs(tx, ancestry.By)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// insert ancestry metadata
|
||||
if err := InsertAncestryDetectors(tx, id, detectorIDs); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
layers := make([]string, 0, len(ancestry.Layers))
|
||||
for _, l := range ancestry.Layers {
|
||||
layers = append(layers, l.Hash)
|
||||
}
|
||||
|
||||
layerIDs, ok, err := layer.FindLayerIDs(tx, layers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !ok {
|
||||
log.Error("layer cannot be found, this indicates that the internal logic of calling UpsertAncestry is wrong or the database is corrupted.")
|
||||
return database.ErrMissingEntities
|
||||
}
|
||||
|
||||
ancestryLayerIDs, err := InsertAncestryLayers(tx, id, layerIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i, id := range ancestryLayerIDs {
|
||||
if err := InsertAncestryFeatures(tx, id, ancestry.Layers[i]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func InsertAncestry(tx *sql.Tx, name string) (int64, error) {
|
||||
var id int64
|
||||
err := tx.QueryRow(insertAncestry, name).Scan(&id)
|
||||
if err != nil {
|
||||
if util.IsErrUniqueViolation(err) {
|
||||
return 0, util.HandleError("insertAncestry", errors.New("other Go-routine is processing this ancestry (skip)"))
|
||||
}
|
||||
|
||||
return 0, util.HandleError("insertAncestry", err)
|
||||
}
|
||||
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func FindAncestryID(tx *sql.Tx, name string) (int64, bool, error) {
|
||||
var id sql.NullInt64
|
||||
if err := tx.QueryRow(findAncestryID, name).Scan(&id); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return 0, false, nil
|
||||
}
|
||||
|
||||
return 0, false, util.HandleError("findAncestryID", err)
|
||||
}
|
||||
|
||||
return id.Int64, true, nil
|
||||
}
|
||||
|
||||
func RemoveAncestry(tx *sql.Tx, name string) error {
|
||||
result, err := tx.Exec(removeAncestry, name)
|
||||
if err != nil {
|
||||
return util.HandleError("removeAncestry", err)
|
||||
}
|
||||
|
||||
affected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return util.HandleError("removeAncestry", err)
|
||||
}
|
||||
|
||||
if affected != 0 {
|
||||
log.WithField("ancestry", name).Debug("removed ancestry")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
48
database/pgsql/ancestry/ancestry_detector.go
Normal file
48
database/pgsql/ancestry/ancestry_detector.go
Normal file
@ -0,0 +1,48 @@
|
||||
// Copyright 2019 clair authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package ancestry
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
"github.com/coreos/clair/database/pgsql/detector"
|
||||
"github.com/coreos/clair/database/pgsql/util"
|
||||
)
|
||||
|
||||
var selectAncestryDetectors = `
|
||||
SELECT d.name, d.version, d.dtype
|
||||
FROM ancestry_detector, detector AS d
|
||||
WHERE ancestry_detector.detector_id = d.id AND ancestry_detector.ancestry_id = $1;`
|
||||
|
||||
var insertAncestryDetectors = `
|
||||
INSERT INTO ancestry_detector (ancestry_id, detector_id)
|
||||
SELECT $1, $2
|
||||
WHERE NOT EXISTS (SELECT id FROM ancestry_detector WHERE ancestry_id = $1 AND detector_id = $2)`
|
||||
|
||||
func FindAncestryDetectors(tx *sql.Tx, id int64) ([]database.Detector, error) {
|
||||
detectors, err := detector.GetDetectors(tx, selectAncestryDetectors, id)
|
||||
return detectors, err
|
||||
}
|
||||
|
||||
func InsertAncestryDetectors(tx *sql.Tx, ancestryID int64, detectorIDs []int64) error {
|
||||
for _, detectorID := range detectorIDs {
|
||||
if _, err := tx.Exec(insertAncestryDetectors, ancestryID, detectorID); err != nil {
|
||||
return util.HandleError("insertAncestryDetectors", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
139
database/pgsql/ancestry/ancestry_feature.go
Normal file
139
database/pgsql/ancestry/ancestry_feature.go
Normal file
@ -0,0 +1,139 @@
|
||||
// Copyright 2019 clair authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package ancestry
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
"github.com/coreos/clair/database/pgsql/detector"
|
||||
"github.com/coreos/clair/database/pgsql/feature"
|
||||
"github.com/coreos/clair/database/pgsql/util"
|
||||
"github.com/coreos/clair/pkg/commonerr"
|
||||
)
|
||||
|
||||
const findAncestryFeatures = `
|
||||
SELECT namespace.name, namespace.version_format, feature.name,
|
||||
feature.version, feature.version_format, feature_type.name, ancestry_layer.ancestry_index,
|
||||
ancestry_feature.feature_detector_id, ancestry_feature.namespace_detector_id
|
||||
FROM namespace, feature, feature_type, namespaced_feature, ancestry_layer, ancestry_feature
|
||||
WHERE ancestry_layer.ancestry_id = $1
|
||||
AND feature_type.id = feature.type
|
||||
AND ancestry_feature.ancestry_layer_id = ancestry_layer.id
|
||||
AND ancestry_feature.namespaced_feature_id = namespaced_feature.id
|
||||
AND namespaced_feature.feature_id = feature.id
|
||||
AND namespaced_feature.namespace_id = namespace.id`
|
||||
|
||||
func FindAncestryFeatures(tx *sql.Tx, ancestryID int64, detectors detector.DetectorMap) (map[int64][]database.AncestryFeature, error) {
|
||||
// ancestry_index -> ancestry features
|
||||
featureMap := make(map[int64][]database.AncestryFeature)
|
||||
// retrieve ancestry layer's namespaced features
|
||||
rows, err := tx.Query(findAncestryFeatures, ancestryID)
|
||||
if err != nil {
|
||||
return nil, util.HandleError("findAncestryFeatures", err)
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var (
|
||||
featureDetectorID int64
|
||||
namespaceDetectorID int64
|
||||
feature database.NamespacedFeature
|
||||
// index is used to determine which layer the feature belongs to.
|
||||
index sql.NullInt64
|
||||
)
|
||||
|
||||
if err := rows.Scan(
|
||||
&feature.Namespace.Name,
|
||||
&feature.Namespace.VersionFormat,
|
||||
&feature.Feature.Name,
|
||||
&feature.Feature.Version,
|
||||
&feature.Feature.VersionFormat,
|
||||
&feature.Feature.Type,
|
||||
&index,
|
||||
&featureDetectorID,
|
||||
&namespaceDetectorID,
|
||||
); err != nil {
|
||||
return nil, util.HandleError("findAncestryFeatures", err)
|
||||
}
|
||||
|
||||
if feature.Feature.VersionFormat != feature.Namespace.VersionFormat {
|
||||
// Feature must have the same version format as the associated
|
||||
// namespace version format.
|
||||
return nil, database.ErrInconsistent
|
||||
}
|
||||
|
||||
fDetector, ok := detectors.ByID[featureDetectorID]
|
||||
if !ok {
|
||||
return nil, database.ErrInconsistent
|
||||
}
|
||||
|
||||
nsDetector, ok := detectors.ByID[namespaceDetectorID]
|
||||
if !ok {
|
||||
return nil, database.ErrInconsistent
|
||||
}
|
||||
|
||||
featureMap[index.Int64] = append(featureMap[index.Int64], database.AncestryFeature{
|
||||
NamespacedFeature: feature,
|
||||
FeatureBy: fDetector,
|
||||
NamespaceBy: nsDetector,
|
||||
})
|
||||
}
|
||||
|
||||
return featureMap, nil
|
||||
}
|
||||
|
||||
func InsertAncestryFeatures(tx *sql.Tx, ancestryLayerID int64, layer database.AncestryLayer) error {
|
||||
detectors, err := detector.FindAllDetectors(tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nsFeatureIDs, err := feature.FindNamespacedFeatureIDs(tx, layer.GetFeatures())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// find the detectors for each feature
|
||||
stmt, err := tx.Prepare(insertAncestryFeatures)
|
||||
if err != nil {
|
||||
return util.HandleError("insertAncestryFeatures", err)
|
||||
}
|
||||
|
||||
defer stmt.Close()
|
||||
|
||||
for index, id := range nsFeatureIDs {
|
||||
if !id.Valid {
|
||||
return database.ErrMissingEntities
|
||||
}
|
||||
|
||||
namespaceDetectorID, ok := detectors.ByValue[layer.Features[index].NamespaceBy]
|
||||
if !ok {
|
||||
return database.ErrMissingEntities
|
||||
}
|
||||
|
||||
featureDetectorID, ok := detectors.ByValue[layer.Features[index].FeatureBy]
|
||||
if !ok {
|
||||
return database.ErrMissingEntities
|
||||
}
|
||||
|
||||
if _, err := stmt.Exec(ancestryLayerID, id, featureDetectorID, namespaceDetectorID); err != nil {
|
||||
return util.HandleError("insertAncestryFeatures", commonerr.CombineErrors(err, stmt.Close()))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
131
database/pgsql/ancestry/ancestry_layer.go
Normal file
131
database/pgsql/ancestry/ancestry_layer.go
Normal file
@ -0,0 +1,131 @@
|
||||
// Copyright 2019 clair authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package ancestry
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
"github.com/coreos/clair/database/pgsql/detector"
|
||||
"github.com/coreos/clair/database/pgsql/util"
|
||||
"github.com/coreos/clair/pkg/commonerr"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
findAncestryLayerHashes = `
|
||||
SELECT layer.hash, ancestry_layer.ancestry_index
|
||||
FROM layer, ancestry_layer
|
||||
WHERE ancestry_layer.ancestry_id = $1
|
||||
AND ancestry_layer.layer_id = layer.id
|
||||
ORDER BY ancestry_layer.ancestry_index ASC`
|
||||
insertAncestryLayers = `
|
||||
INSERT INTO ancestry_layer (ancestry_id, ancestry_index, layer_id) VALUES ($1, $2, $3)
|
||||
RETURNING id`
|
||||
)
|
||||
|
||||
func FindAncestryLayerHashes(tx *sql.Tx, ancestryID int64) (map[int64]string, error) {
|
||||
// retrieve layer indexes and hashes
|
||||
rows, err := tx.Query(findAncestryLayerHashes, ancestryID)
|
||||
if err != nil {
|
||||
return nil, util.HandleError("findAncestryLayerHashes", err)
|
||||
}
|
||||
|
||||
layerHashes := map[int64]string{}
|
||||
for rows.Next() {
|
||||
var (
|
||||
hash string
|
||||
index int64
|
||||
)
|
||||
|
||||
if err = rows.Scan(&hash, &index); err != nil {
|
||||
return nil, util.HandleError("findAncestryLayerHashes", err)
|
||||
}
|
||||
|
||||
if _, ok := layerHashes[index]; ok {
|
||||
// one ancestry index should correspond to only one layer
|
||||
return nil, database.ErrInconsistent
|
||||
}
|
||||
|
||||
layerHashes[index] = hash
|
||||
}
|
||||
|
||||
return layerHashes, nil
|
||||
}
|
||||
|
||||
// insertAncestryLayers inserts the ancestry layers along with its content into
|
||||
// the database. The layers are 0 based indexed in the original order.
|
||||
func InsertAncestryLayers(tx *sql.Tx, ancestryID int64, layers []int64) ([]int64, error) {
|
||||
stmt, err := tx.Prepare(insertAncestryLayers)
|
||||
if err != nil {
|
||||
return nil, util.HandleError("insertAncestryLayers", err)
|
||||
}
|
||||
|
||||
ancestryLayerIDs := []int64{}
|
||||
for index, layerID := range layers {
|
||||
var ancestryLayerID sql.NullInt64
|
||||
if err := stmt.QueryRow(ancestryID, index, layerID).Scan(&ancestryLayerID); err != nil {
|
||||
return nil, util.HandleError("insertAncestryLayers", commonerr.CombineErrors(err, stmt.Close()))
|
||||
}
|
||||
|
||||
if !ancestryLayerID.Valid {
|
||||
return nil, database.ErrInconsistent
|
||||
}
|
||||
|
||||
ancestryLayerIDs = append(ancestryLayerIDs, ancestryLayerID.Int64)
|
||||
}
|
||||
|
||||
if err := stmt.Close(); err != nil {
|
||||
return nil, util.HandleError("insertAncestryLayers", err)
|
||||
}
|
||||
|
||||
return ancestryLayerIDs, nil
|
||||
}
|
||||
|
||||
func FindAncestryLayers(tx *sql.Tx, id int64) ([]database.AncestryLayer, error) {
|
||||
detectors, err := detector.FindAllDetectors(tx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
layerMap, err := FindAncestryLayerHashes(tx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
featureMap, err := FindAncestryFeatures(tx, id, detectors)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
layers := make([]database.AncestryLayer, len(layerMap))
|
||||
for index, layer := range layerMap {
|
||||
// index MUST match the ancestry layer slice index.
|
||||
if layers[index].Hash == "" && len(layers[index].Features) == 0 {
|
||||
layers[index] = database.AncestryLayer{
|
||||
Hash: layer,
|
||||
Features: featureMap[index],
|
||||
}
|
||||
} else {
|
||||
log.WithFields(log.Fields{
|
||||
"ancestry ID": id,
|
||||
"duplicated ancestry index": index,
|
||||
}).WithError(database.ErrInconsistent).Error("ancestry layers with same ancestry_index is not allowed")
|
||||
return nil, database.ErrInconsistent
|
||||
}
|
||||
}
|
||||
|
||||
return layers, nil
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
// Copyright 2017 clair authors
|
||||
// Copyright 2019 clair authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@ -12,7 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package pgsql
|
||||
package ancestry
|
||||
|
||||
import (
|
||||
"testing"
|
||||
@ -20,6 +20,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
"github.com/coreos/clair/database/pgsql/testutil"
|
||||
)
|
||||
|
||||
var upsertAncestryTests = []struct {
|
||||
@ -55,9 +56,9 @@ var upsertAncestryTests = []struct {
|
||||
title: "ancestry with invalid feature",
|
||||
in: &database.Ancestry{
|
||||
Name: "a",
|
||||
By: []database.Detector{realDetectors[1], realDetectors[2]},
|
||||
By: []database.Detector{testutil.RealDetectors[1], testutil.RealDetectors[2]},
|
||||
Layers: []database.AncestryLayer{{Hash: "layer-1", Features: []database.AncestryFeature{
|
||||
{fakeNamespacedFeatures[1], fakeDetector[1], fakeDetector[2]},
|
||||
{testutil.FakeNamespacedFeatures[1], testutil.FakeDetector[1], testutil.FakeDetector[2]},
|
||||
}}},
|
||||
},
|
||||
err: database.ErrMissingEntities.Error(),
|
||||
@ -66,26 +67,27 @@ var upsertAncestryTests = []struct {
|
||||
title: "replace old ancestry",
|
||||
in: &database.Ancestry{
|
||||
Name: "a",
|
||||
By: []database.Detector{realDetectors[1], realDetectors[2]},
|
||||
By: []database.Detector{testutil.RealDetectors[1], testutil.RealDetectors[2]},
|
||||
Layers: []database.AncestryLayer{
|
||||
{"layer-1", []database.AncestryFeature{{realNamespacedFeatures[1], realDetectors[2], realDetectors[1]}}},
|
||||
{"layer-1", []database.AncestryFeature{{testutil.RealNamespacedFeatures[1], testutil.RealDetectors[2], testutil.RealDetectors[1]}}},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
func TestUpsertAncestry(t *testing.T) {
|
||||
store, tx := openSessionForTest(t, "UpsertAncestry", true)
|
||||
defer closeTest(t, store, tx)
|
||||
tx, cleanup := testutil.CreateTestTxWithFixtures(t, "TestUpsertAncestry")
|
||||
defer cleanup()
|
||||
|
||||
for _, test := range upsertAncestryTests {
|
||||
t.Run(test.title, func(t *testing.T) {
|
||||
err := tx.UpsertAncestry(*test.in)
|
||||
err := UpsertAncestry(tx, *test.in)
|
||||
if test.err != "" {
|
||||
assert.EqualError(t, err, test.err, "unexpected error")
|
||||
return
|
||||
}
|
||||
assert.Nil(t, err)
|
||||
actual, ok, err := tx.FindAncestry(test.in.Name)
|
||||
actual, ok, err := FindAncestry(tx, test.in.Name)
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
database.AssertAncestryEqual(t, test.in, &actual)
|
||||
@ -113,16 +115,17 @@ var findAncestryTests = []struct {
|
||||
in: "ancestry-2",
|
||||
err: "",
|
||||
ok: true,
|
||||
ancestry: takeAncestryPointerFromMap(realAncestries, 2),
|
||||
ancestry: testutil.TakeAncestryPointerFromMap(testutil.RealAncestries, 2),
|
||||
},
|
||||
}
|
||||
|
||||
func TestFindAncestry(t *testing.T) {
|
||||
store, tx := openSessionForTest(t, "FindAncestry", true)
|
||||
defer closeTest(t, store, tx)
|
||||
tx, cleanup := testutil.CreateTestTxWithFixtures(t, "TestFindAncestry")
|
||||
defer cleanup()
|
||||
|
||||
for _, test := range findAncestryTests {
|
||||
t.Run(test.title, func(t *testing.T) {
|
||||
ancestry, ok, err := tx.FindAncestry(test.in)
|
||||
ancestry, ok, err := FindAncestry(tx, test.in)
|
||||
if test.err != "" {
|
||||
assert.EqualError(t, err, test.err, "unexpected error")
|
||||
return
|
@ -1,172 +0,0 @@
|
||||
// Copyright 2017 clair authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package pgsql
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pborman/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
"github.com/coreos/clair/ext/versionfmt"
|
||||
"github.com/coreos/clair/ext/versionfmt/dpkg"
|
||||
"github.com/coreos/clair/pkg/strutil"
|
||||
)
|
||||
|
||||
const (
|
||||
numVulnerabilities = 100
|
||||
numFeatures = 100
|
||||
)
|
||||
|
||||
func testGenRandomVulnerabilityAndNamespacedFeature(t *testing.T, store database.Datastore) ([]database.NamespacedFeature, []database.VulnerabilityWithAffected) {
|
||||
tx, err := store.Begin()
|
||||
if !assert.Nil(t, err) {
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
featureName := "TestFeature"
|
||||
featureVersionFormat := dpkg.ParserName
|
||||
// Insert the namespace on which we'll work.
|
||||
namespace := database.Namespace{
|
||||
Name: "TestRaceAffectsFeatureNamespace1",
|
||||
VersionFormat: dpkg.ParserName,
|
||||
}
|
||||
|
||||
if !assert.Nil(t, tx.PersistNamespaces([]database.Namespace{namespace})) {
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
// Initialize random generator and enforce max procs.
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
runtime.GOMAXPROCS(runtime.NumCPU())
|
||||
|
||||
// Generate Distinct random features
|
||||
features := make([]database.Feature, numFeatures)
|
||||
nsFeatures := make([]database.NamespacedFeature, numFeatures)
|
||||
for i := 0; i < numFeatures; i++ {
|
||||
version := rand.Intn(numFeatures)
|
||||
|
||||
features[i] = *database.NewSourcePackage(featureName, strconv.Itoa(version), featureVersionFormat)
|
||||
nsFeatures[i] = database.NamespacedFeature{
|
||||
Namespace: namespace,
|
||||
Feature: features[i],
|
||||
}
|
||||
}
|
||||
|
||||
if !assert.Nil(t, tx.PersistFeatures(features)) {
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
// Generate vulnerabilities.
|
||||
vulnerabilities := []database.VulnerabilityWithAffected{}
|
||||
for i := 0; i < numVulnerabilities; i++ {
|
||||
// any version less than this is vulnerable
|
||||
version := rand.Intn(numFeatures) + 1
|
||||
|
||||
vulnerability := database.VulnerabilityWithAffected{
|
||||
Vulnerability: database.Vulnerability{
|
||||
Name: uuid.New(),
|
||||
Namespace: namespace,
|
||||
Severity: database.UnknownSeverity,
|
||||
},
|
||||
Affected: []database.AffectedFeature{
|
||||
{
|
||||
Namespace: namespace,
|
||||
FeatureName: featureName,
|
||||
FeatureType: database.SourcePackage,
|
||||
AffectedVersion: strconv.Itoa(version),
|
||||
FixedInVersion: strconv.Itoa(version),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
vulnerabilities = append(vulnerabilities, vulnerability)
|
||||
}
|
||||
tx.Commit()
|
||||
|
||||
return nsFeatures, vulnerabilities
|
||||
}
|
||||
|
||||
func TestConcurrency(t *testing.T) {
|
||||
store, cleanup := createTestPgSQL(t, "concurrency")
|
||||
defer cleanup()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
// there's a limit on the number of concurrent connections in the pool
|
||||
wg.Add(30)
|
||||
for i := 0; i < 30; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
nsNamespaces := genRandomNamespaces(t, 100)
|
||||
tx, err := store.Begin()
|
||||
require.Nil(t, err)
|
||||
require.Nil(t, tx.PersistNamespaces(nsNamespaces))
|
||||
require.Nil(t, tx.Commit())
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestCaching(t *testing.T) {
|
||||
store, cleanup := createTestPgSQL(t, "caching")
|
||||
defer cleanup()
|
||||
|
||||
nsFeatures, vulnerabilities := testGenRandomVulnerabilityAndNamespacedFeature(t, store)
|
||||
tx, err := store.Begin()
|
||||
require.Nil(t, err)
|
||||
|
||||
require.Nil(t, tx.PersistNamespacedFeatures(nsFeatures))
|
||||
require.Nil(t, tx.Commit())
|
||||
|
||||
tx, err = store.Begin()
|
||||
require.Nil(t, tx.InsertVulnerabilities(vulnerabilities))
|
||||
require.Nil(t, tx.Commit())
|
||||
|
||||
tx, err = store.Begin()
|
||||
require.Nil(t, err)
|
||||
defer tx.Rollback()
|
||||
|
||||
affected, err := tx.FindAffectedNamespacedFeatures(nsFeatures)
|
||||
require.Nil(t, err)
|
||||
|
||||
for _, ansf := range affected {
|
||||
require.True(t, ansf.Valid)
|
||||
|
||||
expectedAffectedNames := []string{}
|
||||
for _, vuln := range vulnerabilities {
|
||||
if ok, err := versionfmt.InRange(dpkg.ParserName, ansf.Version, vuln.Affected[0].AffectedVersion); err == nil {
|
||||
if ok {
|
||||
expectedAffectedNames = append(expectedAffectedNames, vuln.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
actualAffectedNames := []string{}
|
||||
for _, s := range ansf.AffectedBy {
|
||||
actualAffectedNames = append(actualAffectedNames, s.Name)
|
||||
}
|
||||
|
||||
require.Len(t, strutil.Difference(expectedAffectedNames, actualAffectedNames), 0, "\nvulns: %#v\nfeature:%#v\nexpected:%#v\nactual:%#v", vulnerabilities, ansf.NamespacedFeature, expectedAffectedNames, actualAffectedNames)
|
||||
require.Len(t, strutil.Difference(actualAffectedNames, expectedAffectedNames), 0)
|
||||
}
|
||||
}
|
@ -1,196 +0,0 @@
|
||||
// Copyright 2018 clair authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package pgsql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"github.com/deckarep/golang-set"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
)
|
||||
|
||||
const (
|
||||
soiDetector = `
|
||||
INSERT INTO detector (name, version, dtype)
|
||||
SELECT CAST ($1 AS TEXT), CAST ($2 AS TEXT), CAST ($3 AS detector_type )
|
||||
WHERE NOT EXISTS (SELECT id FROM detector WHERE name = $1 AND version = $2 AND dtype = $3);`
|
||||
|
||||
selectAncestryDetectors = `
|
||||
SELECT d.name, d.version, d.dtype
|
||||
FROM ancestry_detector, detector AS d
|
||||
WHERE ancestry_detector.detector_id = d.id AND ancestry_detector.ancestry_id = $1;`
|
||||
|
||||
selectLayerDetectors = `
|
||||
SELECT d.name, d.version, d.dtype
|
||||
FROM layer_detector, detector AS d
|
||||
WHERE layer_detector.detector_id = d.id AND layer_detector.layer_id = $1;`
|
||||
|
||||
insertAncestryDetectors = `
|
||||
INSERT INTO ancestry_detector (ancestry_id, detector_id)
|
||||
SELECT $1, $2
|
||||
WHERE NOT EXISTS (SELECT id FROM ancestry_detector WHERE ancestry_id = $1 AND detector_id = $2)`
|
||||
|
||||
persistLayerDetector = `
|
||||
INSERT INTO layer_detector (layer_id, detector_id)
|
||||
SELECT $1, $2
|
||||
WHERE NOT EXISTS (SELECT id FROM layer_detector WHERE layer_id = $1 AND detector_id = $2)`
|
||||
|
||||
findDetectorID = `SELECT id FROM detector WHERE name = $1 AND version = $2 AND dtype = $3`
|
||||
findAllDetectors = `SELECT id, name, version, dtype FROM detector`
|
||||
)
|
||||
|
||||
type detectorMap struct {
|
||||
byID map[int64]database.Detector
|
||||
byValue map[database.Detector]int64
|
||||
}
|
||||
|
||||
func (tx *pgSession) PersistDetectors(detectors []database.Detector) error {
|
||||
for _, d := range detectors {
|
||||
if !d.Valid() {
|
||||
log.WithField("detector", d).Debug("Invalid Detector")
|
||||
return database.ErrInvalidParameters
|
||||
}
|
||||
|
||||
r, err := tx.Exec(soiDetector, d.Name, d.Version, d.DType)
|
||||
if err != nil {
|
||||
return handleError("soiDetector", err)
|
||||
}
|
||||
|
||||
count, err := r.RowsAffected()
|
||||
if err != nil {
|
||||
return handleError("soiDetector", err)
|
||||
}
|
||||
|
||||
if count == 0 {
|
||||
log.Debug("detector already exists: ", d)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) persistLayerDetector(layerID int64, detectorID int64) error {
|
||||
if _, err := tx.Exec(persistLayerDetector, layerID, detectorID); err != nil {
|
||||
return handleError("persistLayerDetector", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) persistLayerDetectors(layerID int64, detectorIDs []int64) error {
|
||||
alreadySaved := mapset.NewSet()
|
||||
for _, id := range detectorIDs {
|
||||
if alreadySaved.Contains(id) {
|
||||
continue
|
||||
}
|
||||
|
||||
alreadySaved.Add(id)
|
||||
if err := tx.persistLayerDetector(layerID, id); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) insertAncestryDetectors(ancestryID int64, detectorIDs []int64) error {
|
||||
for _, detectorID := range detectorIDs {
|
||||
if _, err := tx.Exec(insertAncestryDetectors, ancestryID, detectorID); err != nil {
|
||||
return handleError("insertAncestryDetectors", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) findAncestryDetectors(id int64) ([]database.Detector, error) {
|
||||
detectors, err := tx.getDetectors(selectAncestryDetectors, id)
|
||||
return detectors, err
|
||||
}
|
||||
|
||||
func (tx *pgSession) findLayerDetectors(id int64) ([]database.Detector, error) {
|
||||
detectors, err := tx.getDetectors(selectLayerDetectors, id)
|
||||
return detectors, err
|
||||
}
|
||||
|
||||
// findDetectorIDs retrieve ids of the detectors from the database, if any is not
|
||||
// found, return the error.
|
||||
func (tx *pgSession) findDetectorIDs(detectors []database.Detector) ([]int64, error) {
|
||||
ids := []int64{}
|
||||
for _, d := range detectors {
|
||||
id := sql.NullInt64{}
|
||||
err := tx.QueryRow(findDetectorID, d.Name, d.Version, d.DType).Scan(&id)
|
||||
if err != nil {
|
||||
return nil, handleError("findDetectorID", err)
|
||||
}
|
||||
|
||||
if !id.Valid {
|
||||
return nil, database.ErrInconsistent
|
||||
}
|
||||
|
||||
ids = append(ids, id.Int64)
|
||||
}
|
||||
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) getDetectors(query string, id int64) ([]database.Detector, error) {
|
||||
rows, err := tx.Query(query, id)
|
||||
if err != nil {
|
||||
return nil, handleError("getDetectors", err)
|
||||
}
|
||||
|
||||
detectors := []database.Detector{}
|
||||
for rows.Next() {
|
||||
d := database.Detector{}
|
||||
err := rows.Scan(&d.Name, &d.Version, &d.DType)
|
||||
if err != nil {
|
||||
return nil, handleError("getDetectors", err)
|
||||
}
|
||||
|
||||
if !d.Valid() {
|
||||
return nil, database.ErrInvalidDetector
|
||||
}
|
||||
|
||||
detectors = append(detectors, d)
|
||||
}
|
||||
|
||||
return detectors, nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) findAllDetectors() (detectorMap, error) {
|
||||
rows, err := tx.Query(findAllDetectors)
|
||||
if err != nil {
|
||||
return detectorMap{}, handleError("searchAllDetectors", err)
|
||||
}
|
||||
|
||||
detectors := detectorMap{byID: make(map[int64]database.Detector), byValue: make(map[database.Detector]int64)}
|
||||
for rows.Next() {
|
||||
var (
|
||||
id int64
|
||||
d database.Detector
|
||||
)
|
||||
if err := rows.Scan(&id, &d.Name, &d.Version, &d.DType); err != nil {
|
||||
return detectorMap{}, handleError("searchAllDetectors", err)
|
||||
}
|
||||
|
||||
detectors.byID[id] = d
|
||||
detectors.byValue[d] = id
|
||||
}
|
||||
|
||||
return detectors, nil
|
||||
}
|
132
database/pgsql/detector/detector.go
Normal file
132
database/pgsql/detector/detector.go
Normal file
@ -0,0 +1,132 @@
|
||||
// Copyright 2018 clair authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package detector
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
"github.com/coreos/clair/database/pgsql/util"
|
||||
)
|
||||
|
||||
const (
|
||||
soiDetector = `
|
||||
INSERT INTO detector (name, version, dtype)
|
||||
SELECT CAST ($1 AS TEXT), CAST ($2 AS TEXT), CAST ($3 AS detector_type )
|
||||
WHERE NOT EXISTS (SELECT id FROM detector WHERE name = $1 AND version = $2 AND dtype = $3);`
|
||||
|
||||
findDetectorID = `SELECT id FROM detector WHERE name = $1 AND version = $2 AND dtype = $3`
|
||||
findAllDetectors = `SELECT id, name, version, dtype FROM detector`
|
||||
)
|
||||
|
||||
type DetectorMap struct {
|
||||
ByID map[int64]database.Detector
|
||||
ByValue map[database.Detector]int64
|
||||
}
|
||||
|
||||
func PersistDetectors(tx *sql.Tx, detectors []database.Detector) error {
|
||||
for _, d := range detectors {
|
||||
if !d.Valid() {
|
||||
log.WithField("detector", d).Debug("Invalid Detector")
|
||||
return database.ErrInvalidParameters
|
||||
}
|
||||
|
||||
r, err := tx.Exec(soiDetector, d.Name, d.Version, d.DType)
|
||||
if err != nil {
|
||||
return util.HandleError("soiDetector", err)
|
||||
}
|
||||
|
||||
count, err := r.RowsAffected()
|
||||
if err != nil {
|
||||
return util.HandleError("soiDetector", err)
|
||||
}
|
||||
|
||||
if count == 0 {
|
||||
log.Debug("detector already exists: ", d)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// findDetectorIDs retrieve ids of the detectors from the database, if any is not
|
||||
// found, return the error.
|
||||
func FindDetectorIDs(tx *sql.Tx, detectors []database.Detector) ([]int64, error) {
|
||||
ids := []int64{}
|
||||
for _, d := range detectors {
|
||||
id := sql.NullInt64{}
|
||||
err := tx.QueryRow(findDetectorID, d.Name, d.Version, d.DType).Scan(&id)
|
||||
if err != nil {
|
||||
return nil, util.HandleError("findDetectorID", err)
|
||||
}
|
||||
|
||||
if !id.Valid {
|
||||
return nil, database.ErrInconsistent
|
||||
}
|
||||
|
||||
ids = append(ids, id.Int64)
|
||||
}
|
||||
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func GetDetectors(tx *sql.Tx, query string, id int64) ([]database.Detector, error) {
|
||||
rows, err := tx.Query(query, id)
|
||||
if err != nil {
|
||||
return nil, util.HandleError("getDetectors", err)
|
||||
}
|
||||
|
||||
detectors := []database.Detector{}
|
||||
for rows.Next() {
|
||||
d := database.Detector{}
|
||||
err := rows.Scan(&d.Name, &d.Version, &d.DType)
|
||||
if err != nil {
|
||||
return nil, util.HandleError("getDetectors", err)
|
||||
}
|
||||
|
||||
if !d.Valid() {
|
||||
return nil, database.ErrInvalidDetector
|
||||
}
|
||||
|
||||
detectors = append(detectors, d)
|
||||
}
|
||||
|
||||
return detectors, nil
|
||||
}
|
||||
|
||||
func FindAllDetectors(tx *sql.Tx) (DetectorMap, error) {
|
||||
rows, err := tx.Query(findAllDetectors)
|
||||
if err != nil {
|
||||
return DetectorMap{}, util.HandleError("searchAllDetectors", err)
|
||||
}
|
||||
|
||||
detectors := DetectorMap{ByID: make(map[int64]database.Detector), ByValue: make(map[database.Detector]int64)}
|
||||
for rows.Next() {
|
||||
var (
|
||||
id int64
|
||||
d database.Detector
|
||||
)
|
||||
if err := rows.Scan(&id, &d.Name, &d.Version, &d.DType); err != nil {
|
||||
return DetectorMap{}, util.HandleError("searchAllDetectors", err)
|
||||
}
|
||||
|
||||
detectors.ByID[id] = d
|
||||
detectors.ByValue[d] = id
|
||||
}
|
||||
|
||||
return detectors, nil
|
||||
}
|
@ -12,18 +12,20 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package pgsql
|
||||
package detector
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
"github.com/deckarep/golang-set"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
"github.com/coreos/clair/database/pgsql/testutil"
|
||||
)
|
||||
|
||||
func testGetAllDetectors(tx *pgSession) []database.Detector {
|
||||
func testGetAllDetectors(tx *sql.Tx) []database.Detector {
|
||||
query := `SELECT name, version, dtype FROM detector`
|
||||
rows, err := tx.Query(query)
|
||||
if err != nil {
|
||||
@ -90,12 +92,12 @@ var persistDetectorTests = []struct {
|
||||
}
|
||||
|
||||
func TestPersistDetector(t *testing.T) {
|
||||
datastore, tx := openSessionForTest(t, "PersistDetector", true)
|
||||
defer closeTest(t, datastore, tx)
|
||||
tx, cleanup := testutil.CreateTestTxWithFixtures(t, "PersistDetector")
|
||||
defer cleanup()
|
||||
|
||||
for _, test := range persistDetectorTests {
|
||||
t.Run(test.title, func(t *testing.T) {
|
||||
err := tx.PersistDetectors(test.in)
|
||||
err := PersistDetectors(tx, test.in)
|
||||
if test.err != "" {
|
||||
require.EqualError(t, err, test.err)
|
||||
return
|
@ -1,398 +0,0 @@
|
||||
// Copyright 2017 clair authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package pgsql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"sort"
|
||||
|
||||
"github.com/lib/pq"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
"github.com/coreos/clair/ext/versionfmt"
|
||||
"github.com/coreos/clair/pkg/commonerr"
|
||||
)
|
||||
|
||||
const (
|
||||
soiNamespacedFeature = `
|
||||
WITH new_feature_ns AS (
|
||||
INSERT INTO namespaced_feature(feature_id, namespace_id)
|
||||
SELECT CAST ($1 AS INTEGER), CAST ($2 AS INTEGER)
|
||||
WHERE NOT EXISTS ( SELECT id FROM namespaced_feature WHERE namespaced_feature.feature_id = $1 AND namespaced_feature.namespace_id = $2)
|
||||
RETURNING id
|
||||
)
|
||||
SELECT id FROM namespaced_feature WHERE namespaced_feature.feature_id = $1 AND namespaced_feature.namespace_id = $2
|
||||
UNION
|
||||
SELECT id FROM new_feature_ns`
|
||||
|
||||
searchPotentialAffectingVulneraibilities = `
|
||||
SELECT nf.id, v.id, vaf.affected_version, vaf.id
|
||||
FROM vulnerability_affected_feature AS vaf, vulnerability AS v,
|
||||
namespaced_feature AS nf, feature AS f
|
||||
WHERE nf.id = ANY($1)
|
||||
AND nf.feature_id = f.id
|
||||
AND nf.namespace_id = v.namespace_id
|
||||
AND vaf.feature_name = f.name
|
||||
AND vaf.feature_type = f.type
|
||||
AND vaf.vulnerability_id = v.id
|
||||
AND v.deleted_at IS NULL`
|
||||
|
||||
searchNamespacedFeaturesVulnerabilities = `
|
||||
SELECT vanf.namespaced_feature_id, v.name, v.description, v.link,
|
||||
v.severity, v.metadata, vaf.fixedin, n.name, n.version_format
|
||||
FROM vulnerability_affected_namespaced_feature AS vanf,
|
||||
Vulnerability AS v,
|
||||
vulnerability_affected_feature AS vaf,
|
||||
namespace AS n
|
||||
WHERE vanf.namespaced_feature_id = ANY($1)
|
||||
AND vaf.id = vanf.added_by
|
||||
AND v.id = vanf.vulnerability_id
|
||||
AND n.id = v.namespace_id
|
||||
AND v.deleted_at IS NULL`
|
||||
)
|
||||
|
||||
func (tx *pgSession) PersistFeatures(features []database.Feature) error {
|
||||
if len(features) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
types, err := tx.getFeatureTypeMap()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Sorting is needed before inserting into database to prevent deadlock.
|
||||
sort.Slice(features, func(i, j int) bool {
|
||||
return features[i].Name < features[j].Name ||
|
||||
features[i].Version < features[j].Version ||
|
||||
features[i].VersionFormat < features[j].VersionFormat
|
||||
})
|
||||
|
||||
// TODO(Sida): A better interface for bulk insertion is needed.
|
||||
keys := make([]interface{}, 0, len(features)*3)
|
||||
for _, f := range features {
|
||||
keys = append(keys, f.Name, f.Version, f.VersionFormat, types.byName[f.Type])
|
||||
if f.Name == "" || f.Version == "" || f.VersionFormat == "" {
|
||||
return commonerr.NewBadRequestError("Empty feature name, version or version format is not allowed")
|
||||
}
|
||||
}
|
||||
|
||||
_, err = tx.Exec(queryPersistFeature(len(features)), keys...)
|
||||
return handleError("queryPersistFeature", err)
|
||||
}
|
||||
|
||||
type namespacedFeatureWithID struct {
|
||||
database.NamespacedFeature
|
||||
|
||||
ID int64
|
||||
}
|
||||
|
||||
type vulnerabilityCache struct {
|
||||
nsFeatureID int64
|
||||
vulnID int64
|
||||
vulnAffectingID int64
|
||||
}
|
||||
|
||||
func (tx *pgSession) searchAffectingVulnerabilities(features []database.NamespacedFeature) ([]vulnerabilityCache, error) {
|
||||
if len(features) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
ids, err := tx.findNamespacedFeatureIDs(features)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fMap := map[int64]database.NamespacedFeature{}
|
||||
for i, f := range features {
|
||||
if !ids[i].Valid {
|
||||
return nil, database.ErrMissingEntities
|
||||
}
|
||||
fMap[ids[i].Int64] = f
|
||||
}
|
||||
|
||||
cacheTable := []vulnerabilityCache{}
|
||||
rows, err := tx.Query(searchPotentialAffectingVulneraibilities, pq.Array(ids))
|
||||
if err != nil {
|
||||
return nil, handleError("searchPotentialAffectingVulneraibilities", err)
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var (
|
||||
cache vulnerabilityCache
|
||||
affected string
|
||||
)
|
||||
|
||||
err := rows.Scan(&cache.nsFeatureID, &cache.vulnID, &affected, &cache.vulnAffectingID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if ok, err := versionfmt.InRange(fMap[cache.nsFeatureID].VersionFormat, fMap[cache.nsFeatureID].Version, affected); err != nil {
|
||||
return nil, err
|
||||
} else if ok {
|
||||
cacheTable = append(cacheTable, cache)
|
||||
}
|
||||
}
|
||||
|
||||
return cacheTable, nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) CacheAffectedNamespacedFeatures(features []database.NamespacedFeature) error {
|
||||
if len(features) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := tx.Exec(lockVulnerabilityAffects)
|
||||
if err != nil {
|
||||
return handleError("lockVulnerabilityAffects", err)
|
||||
}
|
||||
|
||||
cache, err := tx.searchAffectingVulnerabilities(features)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
keys := make([]interface{}, 0, len(cache)*3)
|
||||
for _, c := range cache {
|
||||
keys = append(keys, c.vulnID, c.nsFeatureID, c.vulnAffectingID)
|
||||
}
|
||||
|
||||
if len(cache) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
affected, err := tx.Exec(queryPersistVulnerabilityAffectedNamespacedFeature(len(cache)), keys...)
|
||||
if err != nil {
|
||||
return handleError("persistVulnerabilityAffectedNamespacedFeature", err)
|
||||
}
|
||||
if count, err := affected.RowsAffected(); err != nil {
|
||||
log.Debugf("Cached %d features in vulnerability_affected_namespaced_feature", count)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) PersistNamespacedFeatures(features []database.NamespacedFeature) error {
|
||||
if len(features) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
nsIDs := map[database.Namespace]sql.NullInt64{}
|
||||
fIDs := map[database.Feature]sql.NullInt64{}
|
||||
for _, f := range features {
|
||||
nsIDs[f.Namespace] = sql.NullInt64{}
|
||||
fIDs[f.Feature] = sql.NullInt64{}
|
||||
}
|
||||
|
||||
fToFind := []database.Feature{}
|
||||
for f := range fIDs {
|
||||
fToFind = append(fToFind, f)
|
||||
}
|
||||
|
||||
sort.Slice(fToFind, func(i, j int) bool {
|
||||
return fToFind[i].Name < fToFind[j].Name ||
|
||||
fToFind[i].Version < fToFind[j].Version ||
|
||||
fToFind[i].VersionFormat < fToFind[j].VersionFormat
|
||||
})
|
||||
|
||||
if ids, err := tx.findFeatureIDs(fToFind); err == nil {
|
||||
for i, id := range ids {
|
||||
if !id.Valid {
|
||||
return database.ErrMissingEntities
|
||||
}
|
||||
fIDs[fToFind[i]] = id
|
||||
}
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
|
||||
nsToFind := []database.Namespace{}
|
||||
for ns := range nsIDs {
|
||||
nsToFind = append(nsToFind, ns)
|
||||
}
|
||||
|
||||
if ids, err := tx.findNamespaceIDs(nsToFind); err == nil {
|
||||
for i, id := range ids {
|
||||
if !id.Valid {
|
||||
return database.ErrMissingEntities
|
||||
}
|
||||
nsIDs[nsToFind[i]] = id
|
||||
}
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
|
||||
keys := make([]interface{}, 0, len(features)*2)
|
||||
for _, f := range features {
|
||||
keys = append(keys, fIDs[f.Feature], nsIDs[f.Namespace])
|
||||
}
|
||||
|
||||
_, err := tx.Exec(queryPersistNamespacedFeature(len(features)), keys...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// FindAffectedNamespacedFeatures retrieves vulnerabilities associated with the
|
||||
// feature.
|
||||
func (tx *pgSession) FindAffectedNamespacedFeatures(features []database.NamespacedFeature) ([]database.NullableAffectedNamespacedFeature, error) {
|
||||
if len(features) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
vulnerableFeatures := make([]database.NullableAffectedNamespacedFeature, len(features))
|
||||
featureIDs, err := tx.findNamespacedFeatureIDs(features)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i, id := range featureIDs {
|
||||
if id.Valid {
|
||||
vulnerableFeatures[i].Valid = true
|
||||
vulnerableFeatures[i].NamespacedFeature = features[i]
|
||||
}
|
||||
}
|
||||
|
||||
rows, err := tx.Query(searchNamespacedFeaturesVulnerabilities, pq.Array(featureIDs))
|
||||
if err != nil {
|
||||
return nil, handleError("searchNamespacedFeaturesVulnerabilities", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var (
|
||||
featureID int64
|
||||
vuln database.VulnerabilityWithFixedIn
|
||||
)
|
||||
|
||||
err := rows.Scan(&featureID,
|
||||
&vuln.Name,
|
||||
&vuln.Description,
|
||||
&vuln.Link,
|
||||
&vuln.Severity,
|
||||
&vuln.Metadata,
|
||||
&vuln.FixedInVersion,
|
||||
&vuln.Namespace.Name,
|
||||
&vuln.Namespace.VersionFormat,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, handleError("searchNamespacedFeaturesVulnerabilities", err)
|
||||
}
|
||||
|
||||
for i, id := range featureIDs {
|
||||
if id.Valid && id.Int64 == featureID {
|
||||
vulnerableFeatures[i].AffectedNamespacedFeature.AffectedBy = append(vulnerableFeatures[i].AffectedNamespacedFeature.AffectedBy, vuln)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return vulnerableFeatures, nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) findNamespacedFeatureIDs(nfs []database.NamespacedFeature) ([]sql.NullInt64, error) {
|
||||
if len(nfs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
nfsMap := map[database.NamespacedFeature]int64{}
|
||||
keys := make([]interface{}, 0, len(nfs)*5)
|
||||
for _, nf := range nfs {
|
||||
keys = append(keys, nf.Name, nf.Version, nf.VersionFormat, nf.Type, nf.Namespace.Name)
|
||||
}
|
||||
|
||||
rows, err := tx.Query(querySearchNamespacedFeature(len(nfs)), keys...)
|
||||
if err != nil {
|
||||
return nil, handleError("searchNamespacedFeature", err)
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
var (
|
||||
id int64
|
||||
nf database.NamespacedFeature
|
||||
)
|
||||
|
||||
for rows.Next() {
|
||||
err := rows.Scan(&id, &nf.Name, &nf.Version, &nf.VersionFormat, &nf.Type, &nf.Namespace.Name)
|
||||
nf.Namespace.VersionFormat = nf.VersionFormat
|
||||
if err != nil {
|
||||
return nil, handleError("searchNamespacedFeature", err)
|
||||
}
|
||||
nfsMap[nf] = id
|
||||
}
|
||||
|
||||
ids := make([]sql.NullInt64, len(nfs))
|
||||
for i, nf := range nfs {
|
||||
if id, ok := nfsMap[nf]; ok {
|
||||
ids[i] = sql.NullInt64{id, true}
|
||||
} else {
|
||||
ids[i] = sql.NullInt64{}
|
||||
}
|
||||
}
|
||||
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) findFeatureIDs(fs []database.Feature) ([]sql.NullInt64, error) {
|
||||
if len(fs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
types, err := tx.getFeatureTypeMap()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fMap := map[database.Feature]sql.NullInt64{}
|
||||
|
||||
keys := make([]interface{}, 0, len(fs)*4)
|
||||
for _, f := range fs {
|
||||
typeID := types.byName[f.Type]
|
||||
keys = append(keys, f.Name, f.Version, f.VersionFormat, typeID)
|
||||
fMap[f] = sql.NullInt64{}
|
||||
}
|
||||
|
||||
rows, err := tx.Query(querySearchFeatureID(len(fs)), keys...)
|
||||
if err != nil {
|
||||
return nil, handleError("querySearchFeatureID", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var (
|
||||
id sql.NullInt64
|
||||
f database.Feature
|
||||
)
|
||||
for rows.Next() {
|
||||
var typeID int
|
||||
err := rows.Scan(&id, &f.Name, &f.Version, &f.VersionFormat, &typeID)
|
||||
if err != nil {
|
||||
return nil, handleError("querySearchFeatureID", err)
|
||||
}
|
||||
|
||||
f.Type = types.byID[typeID]
|
||||
fMap[f] = id
|
||||
}
|
||||
|
||||
ids := make([]sql.NullInt64, len(fs))
|
||||
for i, f := range fs {
|
||||
ids[i] = fMap[f]
|
||||
}
|
||||
|
||||
return ids, nil
|
||||
}
|
121
database/pgsql/feature/feature.go
Normal file
121
database/pgsql/feature/feature.go
Normal file
@ -0,0 +1,121 @@
|
||||
// Copyright 2017 clair authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package feature
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
"github.com/coreos/clair/database/pgsql/util"
|
||||
"github.com/coreos/clair/pkg/commonerr"
|
||||
)
|
||||
|
||||
func queryPersistFeature(count int) string {
|
||||
return util.QueryPersist(count,
|
||||
"feature",
|
||||
"feature_name_version_version_format_type_key",
|
||||
"name",
|
||||
"version",
|
||||
"version_format",
|
||||
"type")
|
||||
}
|
||||
|
||||
func querySearchFeatureID(featureCount int) string {
|
||||
return fmt.Sprintf(`
|
||||
SELECT id, name, version, version_format, type
|
||||
FROM Feature WHERE (name, version, version_format, type) IN (%s)`,
|
||||
util.QueryString(4, featureCount),
|
||||
)
|
||||
}
|
||||
|
||||
func PersistFeatures(tx *sql.Tx, features []database.Feature) error {
|
||||
if len(features) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
types, err := GetFeatureTypeMap(tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Sorting is needed before inserting into database to prevent deadlock.
|
||||
sort.Slice(features, func(i, j int) bool {
|
||||
return features[i].Name < features[j].Name ||
|
||||
features[i].Version < features[j].Version ||
|
||||
features[i].VersionFormat < features[j].VersionFormat
|
||||
})
|
||||
|
||||
// TODO(Sida): A better interface for bulk insertion is needed.
|
||||
keys := make([]interface{}, 0, len(features)*3)
|
||||
for _, f := range features {
|
||||
keys = append(keys, f.Name, f.Version, f.VersionFormat, types.ByName[f.Type])
|
||||
if f.Name == "" || f.Version == "" || f.VersionFormat == "" {
|
||||
return commonerr.NewBadRequestError("Empty feature name, version or version format is not allowed")
|
||||
}
|
||||
}
|
||||
|
||||
_, err = tx.Exec(queryPersistFeature(len(features)), keys...)
|
||||
return util.HandleError("queryPersistFeature", err)
|
||||
}
|
||||
|
||||
func FindFeatureIDs(tx *sql.Tx, fs []database.Feature) ([]sql.NullInt64, error) {
|
||||
if len(fs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
types, err := GetFeatureTypeMap(tx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fMap := map[database.Feature]sql.NullInt64{}
|
||||
|
||||
keys := make([]interface{}, 0, len(fs)*4)
|
||||
for _, f := range fs {
|
||||
typeID := types.ByName[f.Type]
|
||||
keys = append(keys, f.Name, f.Version, f.VersionFormat, typeID)
|
||||
fMap[f] = sql.NullInt64{}
|
||||
}
|
||||
|
||||
rows, err := tx.Query(querySearchFeatureID(len(fs)), keys...)
|
||||
if err != nil {
|
||||
return nil, util.HandleError("querySearchFeatureID", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var (
|
||||
id sql.NullInt64
|
||||
f database.Feature
|
||||
)
|
||||
for rows.Next() {
|
||||
var typeID int
|
||||
err := rows.Scan(&id, &f.Name, &f.Version, &f.VersionFormat, &typeID)
|
||||
if err != nil {
|
||||
return nil, util.HandleError("querySearchFeatureID", err)
|
||||
}
|
||||
|
||||
f.Type = types.ByID[typeID]
|
||||
fMap[f] = id
|
||||
}
|
||||
|
||||
ids := make([]sql.NullInt64, len(fs))
|
||||
for i, f := range fs {
|
||||
ids[i] = fMap[f]
|
||||
}
|
||||
|
||||
return ids, nil
|
||||
}
|
@ -12,36 +12,38 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package pgsql
|
||||
package feature
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
"github.com/coreos/clair/database/pgsql/testutil"
|
||||
)
|
||||
|
||||
func TestPersistFeatures(t *testing.T) {
|
||||
tx, cleanup := createTestPgSession(t, "TestPersistFeatures")
|
||||
tx, cleanup := testutil.CreateTestTx(t, "TestPersistFeatures")
|
||||
defer cleanup()
|
||||
|
||||
invalid := database.Feature{}
|
||||
valid := *database.NewBinaryPackage("mount", "2.31.1-0.4ubuntu3.1", "dpkg")
|
||||
|
||||
// invalid
|
||||
require.NotNil(t, tx.PersistFeatures([]database.Feature{invalid}))
|
||||
require.NotNil(t, PersistFeatures(tx, []database.Feature{invalid}))
|
||||
// existing
|
||||
require.Nil(t, tx.PersistFeatures([]database.Feature{valid}))
|
||||
require.Nil(t, tx.PersistFeatures([]database.Feature{valid}))
|
||||
require.Nil(t, PersistFeatures(tx, []database.Feature{valid}))
|
||||
require.Nil(t, PersistFeatures(tx, []database.Feature{valid}))
|
||||
|
||||
features := selectAllFeatures(t, tx)
|
||||
assert.Equal(t, []database.Feature{valid}, features)
|
||||
}
|
||||
|
||||
func TestPersistNamespacedFeatures(t *testing.T) {
|
||||
tx, cleanup := createTestPgSessionWithFixtures(t, "TestPersistNamespacedFeatures")
|
||||
tx, cleanup := testutil.CreateTestTxWithFixtures(t, "TestPersistNamespacedFeatures")
|
||||
defer cleanup()
|
||||
|
||||
// existing features
|
||||
@ -58,42 +60,17 @@ func TestPersistNamespacedFeatures(t *testing.T) {
|
||||
nf2 := database.NewNamespacedFeature(n2, f2)
|
||||
// namespaced features with namespaces or features not in the database will
|
||||
// generate error.
|
||||
assert.Nil(t, tx.PersistNamespacedFeatures([]database.NamespacedFeature{}))
|
||||
assert.NotNil(t, tx.PersistNamespacedFeatures([]database.NamespacedFeature{*nf1, *nf2}))
|
||||
assert.Nil(t, PersistNamespacedFeatures(tx, []database.NamespacedFeature{}))
|
||||
assert.NotNil(t, PersistNamespacedFeatures(tx, []database.NamespacedFeature{*nf1, *nf2}))
|
||||
// valid case: insert nf3
|
||||
assert.Nil(t, tx.PersistNamespacedFeatures([]database.NamespacedFeature{*nf1}))
|
||||
assert.Nil(t, PersistNamespacedFeatures(tx, []database.NamespacedFeature{*nf1}))
|
||||
|
||||
all := listNamespacedFeatures(t, tx)
|
||||
assert.Contains(t, all, *nf1)
|
||||
}
|
||||
|
||||
func TestFindAffectedNamespacedFeatures(t *testing.T) {
|
||||
datastore, tx := openSessionForTest(t, "FindAffectedNamespacedFeatures", true)
|
||||
defer closeTest(t, datastore, tx)
|
||||
ns := database.NamespacedFeature{
|
||||
Feature: database.Feature{
|
||||
Name: "openssl",
|
||||
Version: "1.0",
|
||||
VersionFormat: "dpkg",
|
||||
Type: database.SourcePackage,
|
||||
},
|
||||
Namespace: database.Namespace{
|
||||
Name: "debian:7",
|
||||
VersionFormat: "dpkg",
|
||||
},
|
||||
}
|
||||
|
||||
ans, err := tx.FindAffectedNamespacedFeatures([]database.NamespacedFeature{ns})
|
||||
if assert.Nil(t, err) &&
|
||||
assert.Len(t, ans, 1) &&
|
||||
assert.True(t, ans[0].Valid) &&
|
||||
assert.Len(t, ans[0].AffectedBy, 1) {
|
||||
assert.Equal(t, "CVE-OPENSSL-1-DEB7", ans[0].AffectedBy[0].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func listNamespacedFeatures(t *testing.T, tx *pgSession) []database.NamespacedFeature {
|
||||
types, err := tx.getFeatureTypeMap()
|
||||
func listNamespacedFeatures(t *testing.T, tx *sql.Tx) []database.NamespacedFeature {
|
||||
types, err := GetFeatureTypeMap(tx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
@ -114,15 +91,15 @@ func listNamespacedFeatures(t *testing.T, tx *pgSession) []database.NamespacedFe
|
||||
panic(err)
|
||||
}
|
||||
|
||||
f.Type = types.byID[typeID]
|
||||
f.Type = types.ByID[typeID]
|
||||
nf = append(nf, f)
|
||||
}
|
||||
|
||||
return nf
|
||||
}
|
||||
|
||||
func selectAllFeatures(t *testing.T, tx *pgSession) []database.Feature {
|
||||
types, err := tx.getFeatureTypeMap()
|
||||
func selectAllFeatures(t *testing.T, tx *sql.Tx) []database.Feature {
|
||||
types, err := GetFeatureTypeMap(tx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
@ -137,7 +114,7 @@ func selectAllFeatures(t *testing.T, tx *pgSession) []database.Feature {
|
||||
f := database.Feature{}
|
||||
var typeID int
|
||||
err := rows.Scan(&f.Name, &f.Version, &f.VersionFormat, &typeID)
|
||||
f.Type = types.byID[typeID]
|
||||
f.Type = types.ByID[typeID]
|
||||
if err != nil {
|
||||
t.FailNow()
|
||||
}
|
||||
@ -146,45 +123,24 @@ func selectAllFeatures(t *testing.T, tx *pgSession) []database.Feature {
|
||||
return fs
|
||||
}
|
||||
|
||||
func assertNamespacedFeatureEqual(t *testing.T, expected []database.NamespacedFeature, actual []database.NamespacedFeature) bool {
|
||||
if assert.Len(t, actual, len(expected)) {
|
||||
has := map[database.NamespacedFeature]bool{}
|
||||
for _, nf := range expected {
|
||||
has[nf] = false
|
||||
}
|
||||
|
||||
for _, nf := range actual {
|
||||
has[nf] = true
|
||||
}
|
||||
|
||||
for nf, visited := range has {
|
||||
if !assert.True(t, visited, nf.Namespace.Name+":"+nf.Name+" is expected") {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func TestFindNamespacedFeatureIDs(t *testing.T) {
|
||||
tx, cleanup := createTestPgSessionWithFixtures(t, "TestFindNamespacedFeatureIDs")
|
||||
tx, cleanup := testutil.CreateTestTxWithFixtures(t, "TestFindNamespacedFeatureIDs")
|
||||
defer cleanup()
|
||||
|
||||
features := []database.NamespacedFeature{}
|
||||
expectedIDs := []int{}
|
||||
for id, feature := range realNamespacedFeatures {
|
||||
for id, feature := range testutil.RealNamespacedFeatures {
|
||||
features = append(features, feature)
|
||||
expectedIDs = append(expectedIDs, id)
|
||||
}
|
||||
|
||||
features = append(features, realNamespacedFeatures[1]) // test duplicated
|
||||
features = append(features, testutil.RealNamespacedFeatures[1]) // test duplicated
|
||||
expectedIDs = append(expectedIDs, 1)
|
||||
|
||||
namespace := realNamespaces[1]
|
||||
namespace := testutil.RealNamespaces[1]
|
||||
features = append(features, *database.NewNamespacedFeature(&namespace, database.NewBinaryPackage("not-found", "1.0", "dpkg"))) // test not found feature
|
||||
|
||||
ids, err := tx.findNamespacedFeatureIDs(features)
|
||||
ids, err := FindNamespacedFeatureIDs(tx, features)
|
||||
require.Nil(t, err)
|
||||
require.Len(t, ids, len(expectedIDs)+1)
|
||||
for i, id := range ids {
|
@ -12,24 +12,28 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package pgsql
|
||||
package feature
|
||||
|
||||
import "github.com/coreos/clair/database"
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
)
|
||||
|
||||
const (
|
||||
selectAllFeatureTypes = `SELECT id, name FROM feature_type`
|
||||
)
|
||||
|
||||
type featureTypes struct {
|
||||
byID map[int]database.FeatureType
|
||||
byName map[database.FeatureType]int
|
||||
type FeatureTypes struct {
|
||||
ByID map[int]database.FeatureType
|
||||
ByName map[database.FeatureType]int
|
||||
}
|
||||
|
||||
func newFeatureTypes() *featureTypes {
|
||||
return &featureTypes{make(map[int]database.FeatureType), make(map[database.FeatureType]int)}
|
||||
func newFeatureTypes() *FeatureTypes {
|
||||
return &FeatureTypes{make(map[int]database.FeatureType), make(map[database.FeatureType]int)}
|
||||
}
|
||||
|
||||
func (tx *pgSession) getFeatureTypeMap() (*featureTypes, error) {
|
||||
func GetFeatureTypeMap(tx *sql.Tx) (*FeatureTypes, error) {
|
||||
rows, err := tx.Query(selectAllFeatureTypes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -45,8 +49,8 @@ func (tx *pgSession) getFeatureTypeMap() (*featureTypes, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
types.byID[id] = name
|
||||
types.byName[name] = id
|
||||
types.ByID[id] = name
|
||||
types.ByName[name] = id
|
||||
}
|
||||
|
||||
return types, nil
|
@ -12,7 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package pgsql
|
||||
package feature
|
||||
|
||||
import (
|
||||
"testing"
|
||||
@ -20,19 +20,20 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
"github.com/coreos/clair/database/pgsql/testutil"
|
||||
)
|
||||
|
||||
func TestGetFeatureTypeMap(t *testing.T) {
|
||||
tx, cleanup := createTestPgSession(t, "TestGetFeatureTypeMap")
|
||||
tx, cleanup := testutil.CreateTestTx(t, "TestGetFeatureTypeMap")
|
||||
defer cleanup()
|
||||
|
||||
types, err := tx.getFeatureTypeMap()
|
||||
types, err := GetFeatureTypeMap(tx)
|
||||
if err != nil {
|
||||
require.Nil(t, err, err.Error())
|
||||
}
|
||||
|
||||
require.Equal(t, database.SourcePackage, types.byID[1])
|
||||
require.Equal(t, database.BinaryPackage, types.byID[2])
|
||||
require.Equal(t, 1, types.byName[database.SourcePackage])
|
||||
require.Equal(t, 2, types.byName[database.BinaryPackage])
|
||||
require.Equal(t, database.SourcePackage, types.ByID[1])
|
||||
require.Equal(t, database.BinaryPackage, types.ByID[2])
|
||||
require.Equal(t, 1, types.ByName[database.SourcePackage])
|
||||
require.Equal(t, 2, types.ByName[database.BinaryPackage])
|
||||
}
|
168
database/pgsql/feature/namespaced_feature.go
Normal file
168
database/pgsql/feature/namespaced_feature.go
Normal file
@ -0,0 +1,168 @@
|
||||
// Copyright 2019 clair authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package feature
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
"github.com/coreos/clair/database/pgsql/namespace"
|
||||
"github.com/coreos/clair/database/pgsql/util"
|
||||
)
|
||||
|
||||
var soiNamespacedFeature = `
|
||||
WITH new_feature_ns AS (
|
||||
INSERT INTO namespaced_feature(feature_id, namespace_id)
|
||||
SELECT CAST ($1 AS INTEGER), CAST ($2 AS INTEGER)
|
||||
WHERE NOT EXISTS ( SELECT id FROM namespaced_feature WHERE namespaced_feature.feature_id = $1 AND namespaced_feature.namespace_id = $2)
|
||||
RETURNING id
|
||||
)
|
||||
SELECT id FROM namespaced_feature WHERE namespaced_feature.feature_id = $1 AND namespaced_feature.namespace_id = $2
|
||||
UNION
|
||||
SELECT id FROM new_feature_ns`
|
||||
|
||||
func queryPersistNamespacedFeature(count int) string {
|
||||
return util.QueryPersist(count, "namespaced_feature",
|
||||
"namespaced_feature_namespace_id_feature_id_key",
|
||||
"feature_id",
|
||||
"namespace_id")
|
||||
}
|
||||
|
||||
func querySearchNamespacedFeature(nsfCount int) string {
|
||||
return fmt.Sprintf(`
|
||||
SELECT nf.id, f.name, f.version, f.version_format, t.name, n.name
|
||||
FROM namespaced_feature AS nf, feature AS f, namespace AS n, feature_type AS t
|
||||
WHERE nf.feature_id = f.id
|
||||
AND nf.namespace_id = n.id
|
||||
AND n.version_format = f.version_format
|
||||
AND f.type = t.id
|
||||
AND (f.name, f.version, f.version_format, t.name, n.name) IN (%s)`,
|
||||
util.QueryString(5, nsfCount),
|
||||
)
|
||||
}
|
||||
|
||||
type namespacedFeatureWithID struct {
|
||||
database.NamespacedFeature
|
||||
|
||||
ID int64
|
||||
}
|
||||
|
||||
func PersistNamespacedFeatures(tx *sql.Tx, features []database.NamespacedFeature) error {
|
||||
if len(features) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
nsIDs := map[database.Namespace]sql.NullInt64{}
|
||||
fIDs := map[database.Feature]sql.NullInt64{}
|
||||
for _, f := range features {
|
||||
nsIDs[f.Namespace] = sql.NullInt64{}
|
||||
fIDs[f.Feature] = sql.NullInt64{}
|
||||
}
|
||||
|
||||
fToFind := []database.Feature{}
|
||||
for f := range fIDs {
|
||||
fToFind = append(fToFind, f)
|
||||
}
|
||||
|
||||
sort.Slice(fToFind, func(i, j int) bool {
|
||||
return fToFind[i].Name < fToFind[j].Name ||
|
||||
fToFind[i].Version < fToFind[j].Version ||
|
||||
fToFind[i].VersionFormat < fToFind[j].VersionFormat
|
||||
})
|
||||
|
||||
if ids, err := FindFeatureIDs(tx, fToFind); err == nil {
|
||||
for i, id := range ids {
|
||||
if !id.Valid {
|
||||
return database.ErrMissingEntities
|
||||
}
|
||||
fIDs[fToFind[i]] = id
|
||||
}
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
|
||||
nsToFind := []database.Namespace{}
|
||||
for ns := range nsIDs {
|
||||
nsToFind = append(nsToFind, ns)
|
||||
}
|
||||
|
||||
if ids, err := namespace.FindNamespaceIDs(tx, nsToFind); err == nil {
|
||||
for i, id := range ids {
|
||||
if !id.Valid {
|
||||
return database.ErrMissingEntities
|
||||
}
|
||||
nsIDs[nsToFind[i]] = id
|
||||
}
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
|
||||
keys := make([]interface{}, 0, len(features)*2)
|
||||
for _, f := range features {
|
||||
keys = append(keys, fIDs[f.Feature], nsIDs[f.Namespace])
|
||||
}
|
||||
|
||||
_, err := tx.Exec(queryPersistNamespacedFeature(len(features)), keys...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func FindNamespacedFeatureIDs(tx *sql.Tx, nfs []database.NamespacedFeature) ([]sql.NullInt64, error) {
|
||||
if len(nfs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
nfsMap := map[database.NamespacedFeature]int64{}
|
||||
keys := make([]interface{}, 0, len(nfs)*5)
|
||||
for _, nf := range nfs {
|
||||
keys = append(keys, nf.Name, nf.Version, nf.VersionFormat, nf.Type, nf.Namespace.Name)
|
||||
}
|
||||
|
||||
rows, err := tx.Query(querySearchNamespacedFeature(len(nfs)), keys...)
|
||||
if err != nil {
|
||||
return nil, util.HandleError("searchNamespacedFeature", err)
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
var (
|
||||
id int64
|
||||
nf database.NamespacedFeature
|
||||
)
|
||||
|
||||
for rows.Next() {
|
||||
err := rows.Scan(&id, &nf.Name, &nf.Version, &nf.VersionFormat, &nf.Type, &nf.Namespace.Name)
|
||||
nf.Namespace.VersionFormat = nf.VersionFormat
|
||||
if err != nil {
|
||||
return nil, util.HandleError("searchNamespacedFeature", err)
|
||||
}
|
||||
nfsMap[nf] = id
|
||||
}
|
||||
|
||||
ids := make([]sql.NullInt64, len(nfs))
|
||||
for i, nf := range nfs {
|
||||
if id, ok := nfsMap[nf]; ok {
|
||||
ids[i] = sql.NullInt64{id, true}
|
||||
} else {
|
||||
ids[i] = sql.NullInt64{}
|
||||
}
|
||||
}
|
||||
|
||||
return ids, nil
|
||||
}
|
@ -12,7 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package pgsql
|
||||
package keyvalue
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
@ -20,6 +20,8 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/coreos/clair/database/pgsql/monitoring"
|
||||
"github.com/coreos/clair/database/pgsql/util"
|
||||
"github.com/coreos/clair/pkg/commonerr"
|
||||
)
|
||||
|
||||
@ -32,24 +34,24 @@ const (
|
||||
DO UPDATE SET key=$1, value=$2`
|
||||
)
|
||||
|
||||
func (tx *pgSession) UpdateKeyValue(key, value string) (err error) {
|
||||
func UpdateKeyValue(tx *sql.Tx, key, value string) (err error) {
|
||||
if key == "" || value == "" {
|
||||
log.Warning("could not insert a flag which has an empty name or value")
|
||||
return commonerr.NewBadRequestError("could not insert a flag which has an empty name or value")
|
||||
}
|
||||
|
||||
defer observeQueryTime("PersistKeyValue", "all", time.Now())
|
||||
defer monitoring.ObserveQueryTime("PersistKeyValue", "all", time.Now())
|
||||
|
||||
_, err = tx.Exec(upsertKeyValue, key, value)
|
||||
if err != nil {
|
||||
return handleError("insertKeyValue", err)
|
||||
return util.HandleError("insertKeyValue", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) FindKeyValue(key string) (string, bool, error) {
|
||||
defer observeQueryTime("FindKeyValue", "all", time.Now())
|
||||
func FindKeyValue(tx *sql.Tx, key string) (string, bool, error) {
|
||||
defer monitoring.ObserveQueryTime("FindKeyValue", "all", time.Now())
|
||||
|
||||
var value string
|
||||
err := tx.QueryRow(searchKeyValue, key).Scan(&value)
|
||||
@ -59,7 +61,7 @@ func (tx *pgSession) FindKeyValue(key string) (string, bool, error) {
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return "", false, handleError("searchKeyValue", err)
|
||||
return "", false, util.HandleError("searchKeyValue", err)
|
||||
}
|
||||
|
||||
return value, true, nil
|
@ -12,38 +12,39 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package pgsql
|
||||
package keyvalue
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/coreos/clair/database/pgsql/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestKeyValue(t *testing.T) {
|
||||
datastore, tx := openSessionForTest(t, "KeyValue", true)
|
||||
defer closeTest(t, datastore, tx)
|
||||
tx, cleanup := testutil.CreateTestTxWithFixtures(t, "KeyValue")
|
||||
defer cleanup()
|
||||
|
||||
// Get non-existing key/value
|
||||
f, ok, err := tx.FindKeyValue("test")
|
||||
f, ok, err := FindKeyValue(tx, "test")
|
||||
assert.Nil(t, err)
|
||||
assert.False(t, ok)
|
||||
|
||||
// Try to insert invalid key/value.
|
||||
assert.Error(t, tx.UpdateKeyValue("test", ""))
|
||||
assert.Error(t, tx.UpdateKeyValue("", "test"))
|
||||
assert.Error(t, tx.UpdateKeyValue("", ""))
|
||||
assert.Error(t, UpdateKeyValue(tx, "test", ""))
|
||||
assert.Error(t, UpdateKeyValue(tx, "", "test"))
|
||||
assert.Error(t, UpdateKeyValue(tx, "", ""))
|
||||
|
||||
// Insert and verify.
|
||||
assert.Nil(t, tx.UpdateKeyValue("test", "test1"))
|
||||
f, ok, err = tx.FindKeyValue("test")
|
||||
assert.Nil(t, UpdateKeyValue(tx, "test", "test1"))
|
||||
f, ok, err = FindKeyValue(tx, "test")
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "test1", f)
|
||||
|
||||
// Update and verify.
|
||||
assert.Nil(t, tx.UpdateKeyValue("test", "test2"))
|
||||
f, ok, err = tx.FindKeyValue("test")
|
||||
assert.Nil(t, UpdateKeyValue(tx, "test", "test2"))
|
||||
f, ok, err = FindKeyValue(tx, "test")
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "test2", f)
|
@ -1,379 +0,0 @@
|
||||
// Copyright 2017 clair authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package pgsql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"sort"
|
||||
|
||||
"github.com/deckarep/golang-set"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
"github.com/coreos/clair/pkg/commonerr"
|
||||
)
|
||||
|
||||
const (
|
||||
soiLayer = `
|
||||
WITH new_layer AS (
|
||||
INSERT INTO layer (hash)
|
||||
SELECT CAST ($1 AS VARCHAR)
|
||||
WHERE NOT EXISTS (SELECT id FROM layer WHERE hash = $1)
|
||||
RETURNING id
|
||||
)
|
||||
SELECT id FROM new_Layer
|
||||
UNION
|
||||
SELECT id FROM layer WHERE hash = $1`
|
||||
|
||||
findLayerFeatures = `
|
||||
SELECT
|
||||
f.name, f.version, f.version_format, ft.name, lf.detector_id, ns.name, ns.version_format
|
||||
FROM
|
||||
layer_feature AS lf
|
||||
LEFT JOIN feature f on f.id = lf.feature_id
|
||||
LEFT JOIN feature_type ft on ft.id = f.type
|
||||
LEFT JOIN namespace ns ON ns.id = lf.namespace_id
|
||||
|
||||
WHERE lf.layer_id = $1`
|
||||
|
||||
findLayerNamespaces = `
|
||||
SELECT ns.name, ns.version_format, ln.detector_id
|
||||
FROM layer_namespace AS ln, namespace AS ns
|
||||
WHERE ln.namespace_id = ns.id
|
||||
AND ln.layer_id = $1`
|
||||
|
||||
findLayerID = `SELECT id FROM layer WHERE hash = $1`
|
||||
)
|
||||
|
||||
// dbLayerNamespace represents the layer_namespace table.
|
||||
type dbLayerNamespace struct {
|
||||
layerID int64
|
||||
namespaceID int64
|
||||
detectorID int64
|
||||
}
|
||||
|
||||
// dbLayerFeature represents the layer_feature table
|
||||
type dbLayerFeature struct {
|
||||
layerID int64
|
||||
featureID int64
|
||||
detectorID int64
|
||||
namespaceID sql.NullInt64
|
||||
}
|
||||
|
||||
func (tx *pgSession) FindLayer(hash string) (database.Layer, bool, error) {
|
||||
layer := database.Layer{Hash: hash}
|
||||
if hash == "" {
|
||||
return layer, false, commonerr.NewBadRequestError("non empty layer hash is expected.")
|
||||
}
|
||||
|
||||
layerID, ok, err := tx.findLayerID(hash)
|
||||
if err != nil || !ok {
|
||||
return layer, ok, err
|
||||
}
|
||||
|
||||
detectorMap, err := tx.findAllDetectors()
|
||||
if err != nil {
|
||||
return layer, false, err
|
||||
}
|
||||
|
||||
if layer.By, err = tx.findLayerDetectors(layerID); err != nil {
|
||||
return layer, false, err
|
||||
}
|
||||
|
||||
if layer.Features, err = tx.findLayerFeatures(layerID, detectorMap); err != nil {
|
||||
return layer, false, err
|
||||
}
|
||||
|
||||
if layer.Namespaces, err = tx.findLayerNamespaces(layerID, detectorMap); err != nil {
|
||||
return layer, false, err
|
||||
}
|
||||
|
||||
return layer, true, nil
|
||||
}
|
||||
|
||||
func sanitizePersistLayerInput(hash string, features []database.LayerFeature, namespaces []database.LayerNamespace, detectedBy []database.Detector) error {
|
||||
if hash == "" {
|
||||
return commonerr.NewBadRequestError("expected non-empty layer hash")
|
||||
}
|
||||
|
||||
detectedBySet := mapset.NewSet()
|
||||
for _, d := range detectedBy {
|
||||
detectedBySet.Add(d)
|
||||
}
|
||||
|
||||
for _, f := range features {
|
||||
if !detectedBySet.Contains(f.By) {
|
||||
return database.ErrInvalidParameters
|
||||
}
|
||||
}
|
||||
|
||||
for _, n := range namespaces {
|
||||
if !detectedBySet.Contains(n.By) {
|
||||
return database.ErrInvalidParameters
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PersistLayer saves the content of a layer to the database.
|
||||
func (tx *pgSession) PersistLayer(hash string, features []database.LayerFeature, namespaces []database.LayerNamespace, detectedBy []database.Detector) error {
|
||||
var (
|
||||
err error
|
||||
id int64
|
||||
detectorIDs []int64
|
||||
)
|
||||
|
||||
if err = sanitizePersistLayerInput(hash, features, namespaces, detectedBy); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if id, err = tx.soiLayer(hash); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if detectorIDs, err = tx.findDetectorIDs(detectedBy); err != nil {
|
||||
if err == commonerr.ErrNotFound {
|
||||
return database.ErrMissingEntities
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
if err = tx.persistLayerDetectors(id, detectorIDs); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = tx.persistAllLayerFeatures(id, features); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = tx.persistAllLayerNamespaces(id, namespaces); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) persistAllLayerNamespaces(layerID int64, namespaces []database.LayerNamespace) error {
|
||||
detectorMap, err := tx.findAllDetectors()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO(sidac): This kind of type conversion is very useless and wasteful,
|
||||
// we need interfaces around the database models to reduce these kind of
|
||||
// operations.
|
||||
rawNamespaces := make([]database.Namespace, 0, len(namespaces))
|
||||
for _, ns := range namespaces {
|
||||
rawNamespaces = append(rawNamespaces, ns.Namespace)
|
||||
}
|
||||
|
||||
rawNamespaceIDs, err := tx.findNamespaceIDs(rawNamespaces)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dbLayerNamespaces := make([]dbLayerNamespace, 0, len(namespaces))
|
||||
for i, ns := range namespaces {
|
||||
detectorID := detectorMap.byValue[ns.By]
|
||||
namespaceID := rawNamespaceIDs[i].Int64
|
||||
if !rawNamespaceIDs[i].Valid {
|
||||
return database.ErrMissingEntities
|
||||
}
|
||||
|
||||
dbLayerNamespaces = append(dbLayerNamespaces, dbLayerNamespace{layerID, namespaceID, detectorID})
|
||||
}
|
||||
|
||||
return tx.persistLayerNamespaces(dbLayerNamespaces)
|
||||
}
|
||||
|
||||
func (tx *pgSession) persistAllLayerFeatures(layerID int64, features []database.LayerFeature) error {
|
||||
detectorMap, err := tx.findAllDetectors()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var namespaces []database.Namespace
|
||||
for _, feature := range features {
|
||||
namespaces = append(namespaces, feature.PotentialNamespace)
|
||||
}
|
||||
nameSpaceIDs, _ := tx.findNamespaceIDs(namespaces)
|
||||
featureNamespaceMap := map[database.Namespace]sql.NullInt64{}
|
||||
rawFeatures := make([]database.Feature, 0, len(features))
|
||||
for i, f := range features {
|
||||
rawFeatures = append(rawFeatures, f.Feature)
|
||||
if f.PotentialNamespace.Valid() {
|
||||
featureNamespaceMap[f.PotentialNamespace] = nameSpaceIDs[i]
|
||||
}
|
||||
}
|
||||
|
||||
featureIDs, err := tx.findFeatureIDs(rawFeatures)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var namespaceID sql.NullInt64
|
||||
dbFeatures := make([]dbLayerFeature, 0, len(features))
|
||||
for i, f := range features {
|
||||
detectorID := detectorMap.byValue[f.By]
|
||||
if !featureIDs[i].Valid {
|
||||
return database.ErrMissingEntities
|
||||
}
|
||||
featureID := featureIDs[i].Int64
|
||||
namespaceID = featureNamespaceMap[f.PotentialNamespace]
|
||||
|
||||
dbFeatures = append(dbFeatures, dbLayerFeature{layerID, featureID, detectorID, namespaceID})
|
||||
}
|
||||
|
||||
if err := tx.persistLayerFeatures(dbFeatures); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) persistLayerFeatures(features []dbLayerFeature) error {
|
||||
if len(features) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
sort.Slice(features, func(i, j int) bool {
|
||||
return features[i].featureID < features[j].featureID
|
||||
})
|
||||
keys := make([]interface{}, 0, len(features)*4)
|
||||
|
||||
for _, f := range features {
|
||||
keys = append(keys, f.layerID, f.featureID, f.detectorID, f.namespaceID)
|
||||
}
|
||||
|
||||
_, err := tx.Exec(queryPersistLayerFeature(len(features)), keys...)
|
||||
if err != nil {
|
||||
return handleError("queryPersistLayerFeature", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) persistLayerNamespaces(namespaces []dbLayerNamespace) error {
|
||||
if len(namespaces) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// for every bulk persist operation, the input data should be sorted.
|
||||
sort.Slice(namespaces, func(i, j int) bool {
|
||||
return namespaces[i].namespaceID < namespaces[j].namespaceID
|
||||
})
|
||||
|
||||
keys := make([]interface{}, 0, len(namespaces)*3)
|
||||
for _, row := range namespaces {
|
||||
keys = append(keys, row.layerID, row.namespaceID, row.detectorID)
|
||||
}
|
||||
|
||||
_, err := tx.Exec(queryPersistLayerNamespace(len(namespaces)), keys...)
|
||||
if err != nil {
|
||||
return handleError("queryPersistLayerNamespace", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) findLayerNamespaces(layerID int64, detectors detectorMap) ([]database.LayerNamespace, error) {
|
||||
rows, err := tx.Query(findLayerNamespaces, layerID)
|
||||
if err != nil {
|
||||
return nil, handleError("findLayerNamespaces", err)
|
||||
}
|
||||
|
||||
namespaces := []database.LayerNamespace{}
|
||||
for rows.Next() {
|
||||
var (
|
||||
namespace database.LayerNamespace
|
||||
detectorID int64
|
||||
)
|
||||
|
||||
if err := rows.Scan(&namespace.Name, &namespace.VersionFormat, &detectorID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
namespace.By = detectors.byID[detectorID]
|
||||
namespaces = append(namespaces, namespace)
|
||||
}
|
||||
|
||||
return namespaces, nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) findLayerFeatures(layerID int64, detectors detectorMap) ([]database.LayerFeature, error) {
|
||||
rows, err := tx.Query(findLayerFeatures, layerID)
|
||||
if err != nil {
|
||||
return nil, handleError("findLayerFeatures", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
features := []database.LayerFeature{}
|
||||
for rows.Next() {
|
||||
var (
|
||||
detectorID int64
|
||||
feature database.LayerFeature
|
||||
)
|
||||
var namespaceName, namespaceVersion sql.NullString
|
||||
if err := rows.Scan(&feature.Name, &feature.Version, &feature.VersionFormat, &feature.Type, &detectorID, &namespaceName, &namespaceVersion); err != nil {
|
||||
return nil, handleError("findLayerFeatures", err)
|
||||
}
|
||||
feature.PotentialNamespace.Name = namespaceName.String
|
||||
feature.PotentialNamespace.VersionFormat = namespaceVersion.String
|
||||
|
||||
feature.By = detectors.byID[detectorID]
|
||||
features = append(features, feature)
|
||||
}
|
||||
|
||||
return features, nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) findLayerID(hash string) (int64, bool, error) {
|
||||
var layerID int64
|
||||
err := tx.QueryRow(findLayerID, hash).Scan(&layerID)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return layerID, false, nil
|
||||
}
|
||||
|
||||
return layerID, false, handleError("findLayerID", err)
|
||||
}
|
||||
|
||||
return layerID, true, nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) findLayerIDs(hashes []string) ([]int64, bool, error) {
|
||||
layerIDs := make([]int64, 0, len(hashes))
|
||||
for _, hash := range hashes {
|
||||
id, ok, err := tx.findLayerID(hash)
|
||||
if !ok {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
layerIDs = append(layerIDs, id)
|
||||
}
|
||||
|
||||
return layerIDs, true, nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) soiLayer(hash string) (int64, error) {
|
||||
var id int64
|
||||
if err := tx.QueryRow(soiLayer, hash).Scan(&id); err != nil {
|
||||
return 0, handleError("soiLayer", err)
|
||||
}
|
||||
|
||||
return id, nil
|
||||
}
|
177
database/pgsql/layer/layer.go
Normal file
177
database/pgsql/layer/layer.go
Normal file
@ -0,0 +1,177 @@
|
||||
// Copyright 2017 clair authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package layer
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"github.com/deckarep/golang-set"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
"github.com/coreos/clair/database/pgsql/detector"
|
||||
"github.com/coreos/clair/database/pgsql/util"
|
||||
"github.com/coreos/clair/pkg/commonerr"
|
||||
)
|
||||
|
||||
const (
|
||||
soiLayer = `
|
||||
WITH new_layer AS (
|
||||
INSERT INTO layer (hash)
|
||||
SELECT CAST ($1 AS VARCHAR)
|
||||
WHERE NOT EXISTS (SELECT id FROM layer WHERE hash = $1)
|
||||
RETURNING id
|
||||
)
|
||||
SELECT id FROM new_Layer
|
||||
UNION
|
||||
SELECT id FROM layer WHERE hash = $1`
|
||||
|
||||
findLayerID = `SELECT id FROM layer WHERE hash = $1`
|
||||
)
|
||||
|
||||
func FindLayer(tx *sql.Tx, hash string) (database.Layer, bool, error) {
|
||||
layer := database.Layer{Hash: hash}
|
||||
if hash == "" {
|
||||
return layer, false, commonerr.NewBadRequestError("non empty layer hash is expected.")
|
||||
}
|
||||
|
||||
layerID, ok, err := FindLayerID(tx, hash)
|
||||
if err != nil || !ok {
|
||||
return layer, ok, err
|
||||
}
|
||||
|
||||
detectorMap, err := detector.FindAllDetectors(tx)
|
||||
if err != nil {
|
||||
return layer, false, err
|
||||
}
|
||||
|
||||
if layer.By, err = FindLayerDetectors(tx, layerID); err != nil {
|
||||
return layer, false, err
|
||||
}
|
||||
|
||||
if layer.Features, err = FindLayerFeatures(tx, layerID, detectorMap); err != nil {
|
||||
return layer, false, err
|
||||
}
|
||||
|
||||
if layer.Namespaces, err = FindLayerNamespaces(tx, layerID, detectorMap); err != nil {
|
||||
return layer, false, err
|
||||
}
|
||||
|
||||
return layer, true, nil
|
||||
}
|
||||
|
||||
func sanitizePersistLayerInput(hash string, features []database.LayerFeature, namespaces []database.LayerNamespace, detectedBy []database.Detector) error {
|
||||
if hash == "" {
|
||||
return commonerr.NewBadRequestError("expected non-empty layer hash")
|
||||
}
|
||||
|
||||
detectedBySet := mapset.NewSet()
|
||||
for _, d := range detectedBy {
|
||||
detectedBySet.Add(d)
|
||||
}
|
||||
|
||||
for _, f := range features {
|
||||
if !detectedBySet.Contains(f.By) {
|
||||
return database.ErrInvalidParameters
|
||||
}
|
||||
}
|
||||
|
||||
for _, n := range namespaces {
|
||||
if !detectedBySet.Contains(n.By) {
|
||||
return database.ErrInvalidParameters
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PersistLayer saves the content of a layer to the database.
|
||||
func PersistLayer(tx *sql.Tx, hash string, features []database.LayerFeature, namespaces []database.LayerNamespace, detectedBy []database.Detector) error {
|
||||
var (
|
||||
err error
|
||||
id int64
|
||||
detectorIDs []int64
|
||||
)
|
||||
|
||||
if err = sanitizePersistLayerInput(hash, features, namespaces, detectedBy); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if id, err = SoiLayer(tx, hash); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if detectorIDs, err = detector.FindDetectorIDs(tx, detectedBy); err != nil {
|
||||
if err == commonerr.ErrNotFound {
|
||||
return database.ErrMissingEntities
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
if err = PersistLayerDetectors(tx, id, detectorIDs); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = PersistAllLayerFeatures(tx, id, features); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = PersistAllLayerNamespaces(tx, id, namespaces); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func FindLayerID(tx *sql.Tx, hash string) (int64, bool, error) {
|
||||
var layerID int64
|
||||
err := tx.QueryRow(findLayerID, hash).Scan(&layerID)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return layerID, false, nil
|
||||
}
|
||||
|
||||
return layerID, false, util.HandleError("findLayerID", err)
|
||||
}
|
||||
|
||||
return layerID, true, nil
|
||||
}
|
||||
|
||||
func FindLayerIDs(tx *sql.Tx, hashes []string) ([]int64, bool, error) {
|
||||
layerIDs := make([]int64, 0, len(hashes))
|
||||
for _, hash := range hashes {
|
||||
id, ok, err := FindLayerID(tx, hash)
|
||||
if !ok {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
layerIDs = append(layerIDs, id)
|
||||
}
|
||||
|
||||
return layerIDs, true, nil
|
||||
}
|
||||
|
||||
func SoiLayer(tx *sql.Tx, hash string) (int64, error) {
|
||||
var id int64
|
||||
if err := tx.QueryRow(soiLayer, hash).Scan(&id); err != nil {
|
||||
return 0, util.HandleError("soiLayer", err)
|
||||
}
|
||||
|
||||
return id, nil
|
||||
}
|
66
database/pgsql/layer/layer_detector.go
Normal file
66
database/pgsql/layer/layer_detector.go
Normal file
@ -0,0 +1,66 @@
|
||||
// Copyright 2019 clair authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package layer
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"github.com/deckarep/golang-set"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
"github.com/coreos/clair/database/pgsql/detector"
|
||||
"github.com/coreos/clair/database/pgsql/util"
|
||||
)
|
||||
|
||||
const (
|
||||
selectLayerDetectors = `
|
||||
SELECT d.name, d.version, d.dtype
|
||||
FROM layer_detector, detector AS d
|
||||
WHERE layer_detector.detector_id = d.id AND layer_detector.layer_id = $1;`
|
||||
|
||||
persistLayerDetector = `
|
||||
INSERT INTO layer_detector (layer_id, detector_id)
|
||||
SELECT $1, $2
|
||||
WHERE NOT EXISTS (SELECT id FROM layer_detector WHERE layer_id = $1 AND detector_id = $2)`
|
||||
)
|
||||
|
||||
func PersistLayerDetector(tx *sql.Tx, layerID int64, detectorID int64) error {
|
||||
if _, err := tx.Exec(persistLayerDetector, layerID, detectorID); err != nil {
|
||||
return util.HandleError("persistLayerDetector", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func PersistLayerDetectors(tx *sql.Tx, layerID int64, detectorIDs []int64) error {
|
||||
alreadySaved := mapset.NewSet()
|
||||
for _, id := range detectorIDs {
|
||||
if alreadySaved.Contains(id) {
|
||||
continue
|
||||
}
|
||||
|
||||
alreadySaved.Add(id)
|
||||
if err := PersistLayerDetector(tx, layerID, id); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func FindLayerDetectors(tx *sql.Tx, id int64) ([]database.Detector, error) {
|
||||
detectors, err := detector.GetDetectors(tx, selectLayerDetectors, id)
|
||||
return detectors, err
|
||||
}
|
147
database/pgsql/layer/layer_feature.go
Normal file
147
database/pgsql/layer/layer_feature.go
Normal file
@ -0,0 +1,147 @@
|
||||
// Copyright 2019 clair authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package layer
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"sort"
|
||||
|
||||
"github.com/coreos/clair/database/pgsql/namespace"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
"github.com/coreos/clair/database/pgsql/detector"
|
||||
"github.com/coreos/clair/database/pgsql/feature"
|
||||
"github.com/coreos/clair/database/pgsql/util"
|
||||
)
|
||||
|
||||
const findLayerFeatures = `
|
||||
SELECT
|
||||
f.name, f.version, f.version_format, ft.name, lf.detector_id, ns.name, ns.version_format
|
||||
FROM
|
||||
layer_feature AS lf
|
||||
LEFT JOIN feature f on f.id = lf.feature_id
|
||||
LEFT JOIN feature_type ft on ft.id = f.type
|
||||
LEFT JOIN namespace ns ON ns.id = lf.namespace_id
|
||||
|
||||
WHERE lf.layer_id = $1`
|
||||
|
||||
func queryPersistLayerFeature(count int) string {
|
||||
return util.QueryPersist(count,
|
||||
"layer_feature",
|
||||
"layer_feature_layer_id_feature_id_namespace_id_key",
|
||||
"layer_id",
|
||||
"feature_id",
|
||||
"detector_id",
|
||||
"namespace_id")
|
||||
}
|
||||
|
||||
// dbLayerFeature represents the layer_feature table
|
||||
type dbLayerFeature struct {
|
||||
layerID int64
|
||||
featureID int64
|
||||
detectorID int64
|
||||
namespaceID sql.NullInt64
|
||||
}
|
||||
|
||||
func FindLayerFeatures(tx *sql.Tx, layerID int64, detectors detector.DetectorMap) ([]database.LayerFeature, error) {
|
||||
rows, err := tx.Query(findLayerFeatures, layerID)
|
||||
if err != nil {
|
||||
return nil, util.HandleError("findLayerFeatures", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
features := []database.LayerFeature{}
|
||||
for rows.Next() {
|
||||
var (
|
||||
detectorID int64
|
||||
feature database.LayerFeature
|
||||
)
|
||||
var namespaceName, namespaceVersion sql.NullString
|
||||
if err := rows.Scan(&feature.Name, &feature.Version, &feature.VersionFormat, &feature.Type, &detectorID, &namespaceName, &namespaceVersion); err != nil {
|
||||
return nil, util.HandleError("findLayerFeatures", err)
|
||||
}
|
||||
feature.PotentialNamespace.Name = namespaceName.String
|
||||
feature.PotentialNamespace.VersionFormat = namespaceVersion.String
|
||||
|
||||
feature.By = detectors.ByID[detectorID]
|
||||
features = append(features, feature)
|
||||
}
|
||||
|
||||
return features, nil
|
||||
}
|
||||
|
||||
func PersistAllLayerFeatures(tx *sql.Tx, layerID int64, features []database.LayerFeature) error {
|
||||
detectorMap, err := detector.FindAllDetectors(tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var namespaces []database.Namespace
|
||||
for _, feature := range features {
|
||||
namespaces = append(namespaces, feature.PotentialNamespace)
|
||||
}
|
||||
nameSpaceIDs, _ := namespace.FindNamespaceIDs(tx, namespaces)
|
||||
featureNamespaceMap := map[database.Namespace]sql.NullInt64{}
|
||||
rawFeatures := make([]database.Feature, 0, len(features))
|
||||
for i, f := range features {
|
||||
rawFeatures = append(rawFeatures, f.Feature)
|
||||
if f.PotentialNamespace.Valid() {
|
||||
featureNamespaceMap[f.PotentialNamespace] = nameSpaceIDs[i]
|
||||
}
|
||||
}
|
||||
|
||||
featureIDs, err := feature.FindFeatureIDs(tx, rawFeatures)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var namespaceID sql.NullInt64
|
||||
dbFeatures := make([]dbLayerFeature, 0, len(features))
|
||||
for i, f := range features {
|
||||
detectorID := detectorMap.ByValue[f.By]
|
||||
featureID := featureIDs[i].Int64
|
||||
if !featureIDs[i].Valid {
|
||||
return database.ErrMissingEntities
|
||||
}
|
||||
namespaceID = featureNamespaceMap[f.PotentialNamespace]
|
||||
|
||||
dbFeatures = append(dbFeatures, dbLayerFeature{layerID, featureID, detectorID, namespaceID})
|
||||
}
|
||||
|
||||
if err := PersistLayerFeatures(tx, dbFeatures); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func PersistLayerFeatures(tx *sql.Tx, features []dbLayerFeature) error {
|
||||
if len(features) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
sort.Slice(features, func(i, j int) bool {
|
||||
return features[i].featureID < features[j].featureID
|
||||
})
|
||||
keys := make([]interface{}, 0, len(features)*4)
|
||||
|
||||
for _, f := range features {
|
||||
keys = append(keys, f.layerID, f.featureID, f.detectorID, f.namespaceID)
|
||||
}
|
||||
|
||||
_, err := tx.Exec(queryPersistLayerFeature(len(features)), keys...)
|
||||
if err != nil {
|
||||
return util.HandleError("queryPersistLayerFeature", err)
|
||||
}
|
||||
return nil
|
||||
}
|
127
database/pgsql/layer/layer_namespace.go
Normal file
127
database/pgsql/layer/layer_namespace.go
Normal file
@ -0,0 +1,127 @@
|
||||
// Copyright 2019 clair authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package layer
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"sort"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
"github.com/coreos/clair/database/pgsql/detector"
|
||||
"github.com/coreos/clair/database/pgsql/namespace"
|
||||
"github.com/coreos/clair/database/pgsql/util"
|
||||
)
|
||||
|
||||
const findLayerNamespaces = `
|
||||
SELECT ns.name, ns.version_format, ln.detector_id
|
||||
FROM layer_namespace AS ln, namespace AS ns
|
||||
WHERE ln.namespace_id = ns.id
|
||||
AND ln.layer_id = $1`
|
||||
|
||||
func queryPersistLayerNamespace(count int) string {
|
||||
return util.QueryPersist(count,
|
||||
"layer_namespace",
|
||||
"layer_namespace_layer_id_namespace_id_key",
|
||||
"layer_id",
|
||||
"namespace_id",
|
||||
"detector_id")
|
||||
}
|
||||
|
||||
// dbLayerNamespace represents the layer_namespace table.
|
||||
type dbLayerNamespace struct {
|
||||
layerID int64
|
||||
namespaceID int64
|
||||
detectorID int64
|
||||
}
|
||||
|
||||
func FindLayerNamespaces(tx *sql.Tx, layerID int64, detectors detector.DetectorMap) ([]database.LayerNamespace, error) {
|
||||
rows, err := tx.Query(findLayerNamespaces, layerID)
|
||||
if err != nil {
|
||||
return nil, util.HandleError("findLayerNamespaces", err)
|
||||
}
|
||||
|
||||
namespaces := []database.LayerNamespace{}
|
||||
for rows.Next() {
|
||||
var (
|
||||
namespace database.LayerNamespace
|
||||
detectorID int64
|
||||
)
|
||||
|
||||
if err := rows.Scan(&namespace.Name, &namespace.VersionFormat, &detectorID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
namespace.By = detectors.ByID[detectorID]
|
||||
namespaces = append(namespaces, namespace)
|
||||
}
|
||||
|
||||
return namespaces, nil
|
||||
}
|
||||
|
||||
func PersistAllLayerNamespaces(tx *sql.Tx, layerID int64, namespaces []database.LayerNamespace) error {
|
||||
detectorMap, err := detector.FindAllDetectors(tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO(sidac): This kind of type conversion is very useless and wasteful,
|
||||
// we need interfaces around the database models to reduce these kind of
|
||||
// operations.
|
||||
rawNamespaces := make([]database.Namespace, 0, len(namespaces))
|
||||
for _, ns := range namespaces {
|
||||
rawNamespaces = append(rawNamespaces, ns.Namespace)
|
||||
}
|
||||
|
||||
rawNamespaceIDs, err := namespace.FindNamespaceIDs(tx, rawNamespaces)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dbLayerNamespaces := make([]dbLayerNamespace, 0, len(namespaces))
|
||||
for i, ns := range namespaces {
|
||||
detectorID := detectorMap.ByValue[ns.By]
|
||||
namespaceID := rawNamespaceIDs[i].Int64
|
||||
if !rawNamespaceIDs[i].Valid {
|
||||
return database.ErrMissingEntities
|
||||
}
|
||||
|
||||
dbLayerNamespaces = append(dbLayerNamespaces, dbLayerNamespace{layerID, namespaceID, detectorID})
|
||||
}
|
||||
|
||||
return PersistLayerNamespaces(tx, dbLayerNamespaces)
|
||||
}
|
||||
|
||||
func PersistLayerNamespaces(tx *sql.Tx, namespaces []dbLayerNamespace) error {
|
||||
if len(namespaces) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// for every bulk persist operation, the input data should be sorted.
|
||||
sort.Slice(namespaces, func(i, j int) bool {
|
||||
return namespaces[i].namespaceID < namespaces[j].namespaceID
|
||||
})
|
||||
|
||||
keys := make([]interface{}, 0, len(namespaces)*3)
|
||||
for _, row := range namespaces {
|
||||
keys = append(keys, row.layerID, row.namespaceID, row.detectorID)
|
||||
}
|
||||
|
||||
_, err := tx.Exec(queryPersistLayerNamespace(len(namespaces)), keys...)
|
||||
if err != nil {
|
||||
return util.HandleError("queryPersistLayerNamespace", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
@ -12,7 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package pgsql
|
||||
package layer
|
||||
|
||||
import (
|
||||
"testing"
|
||||
@ -20,6 +20,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
"github.com/coreos/clair/database/pgsql/testutil"
|
||||
)
|
||||
|
||||
var persistLayerTests = []struct {
|
||||
@ -39,9 +40,9 @@ var persistLayerTests = []struct {
|
||||
{
|
||||
title: "layer with inconsistent feature and detectors",
|
||||
name: "random-forest",
|
||||
by: []database.Detector{realDetectors[2]},
|
||||
by: []database.Detector{testutil.RealDetectors[2]},
|
||||
features: []database.LayerFeature{
|
||||
{realFeatures[1], realDetectors[1], database.Namespace{}},
|
||||
{testutil.RealFeatures[1], testutil.RealDetectors[1], database.Namespace{}},
|
||||
},
|
||||
err: "parameters are not valid",
|
||||
},
|
||||
@ -49,70 +50,71 @@ var persistLayerTests = []struct {
|
||||
title: "layer with non-existing feature",
|
||||
name: "random-forest",
|
||||
err: "associated immutable entities are missing in the database",
|
||||
by: []database.Detector{realDetectors[2]},
|
||||
by: []database.Detector{testutil.RealDetectors[2]},
|
||||
features: []database.LayerFeature{
|
||||
{fakeFeatures[1], realDetectors[2], database.Namespace{}},
|
||||
{testutil.FakeFeatures[1], testutil.RealDetectors[2], database.Namespace{}},
|
||||
},
|
||||
},
|
||||
{
|
||||
title: "layer with non-existing namespace",
|
||||
name: "random-forest2",
|
||||
err: "associated immutable entities are missing in the database",
|
||||
by: []database.Detector{realDetectors[1]},
|
||||
by: []database.Detector{testutil.RealDetectors[1]},
|
||||
namespaces: []database.LayerNamespace{
|
||||
{fakeNamespaces[1], realDetectors[1]},
|
||||
{testutil.FakeNamespaces[1], testutil.RealDetectors[1]},
|
||||
},
|
||||
},
|
||||
{
|
||||
title: "layer with non-existing detector",
|
||||
name: "random-forest3",
|
||||
err: "associated immutable entities are missing in the database",
|
||||
by: []database.Detector{fakeDetector[1]},
|
||||
by: []database.Detector{testutil.FakeDetector[1]},
|
||||
},
|
||||
{
|
||||
|
||||
title: "valid layer",
|
||||
name: "hamsterhouse",
|
||||
by: []database.Detector{realDetectors[1], realDetectors[2]},
|
||||
by: []database.Detector{testutil.RealDetectors[1], testutil.RealDetectors[2]},
|
||||
features: []database.LayerFeature{
|
||||
{realFeatures[1], realDetectors[2], database.Namespace{}},
|
||||
{realFeatures[2], realDetectors[2], database.Namespace{}},
|
||||
{testutil.RealFeatures[1], testutil.RealDetectors[2], database.Namespace{}},
|
||||
{testutil.RealFeatures[2], testutil.RealDetectors[2], database.Namespace{}},
|
||||
},
|
||||
namespaces: []database.LayerNamespace{
|
||||
{realNamespaces[1], realDetectors[1]},
|
||||
{testutil.RealNamespaces[1], testutil.RealDetectors[1]},
|
||||
},
|
||||
layer: &database.Layer{
|
||||
Hash: "hamsterhouse",
|
||||
By: []database.Detector{realDetectors[1], realDetectors[2]},
|
||||
By: []database.Detector{testutil.RealDetectors[1], testutil.RealDetectors[2]},
|
||||
Features: []database.LayerFeature{
|
||||
{realFeatures[1], realDetectors[2], database.Namespace{}},
|
||||
{realFeatures[2], realDetectors[2], database.Namespace{}},
|
||||
{testutil.RealFeatures[1], testutil.RealDetectors[2], database.Namespace{}},
|
||||
{testutil.RealFeatures[2], testutil.RealDetectors[2], database.Namespace{}},
|
||||
},
|
||||
Namespaces: []database.LayerNamespace{
|
||||
{realNamespaces[1], realDetectors[1]},
|
||||
{testutil.RealNamespaces[1], testutil.RealDetectors[1]},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
title: "update existing layer",
|
||||
name: "layer-1",
|
||||
by: []database.Detector{realDetectors[3], realDetectors[4]},
|
||||
by: []database.Detector{testutil.RealDetectors[3], testutil.RealDetectors[4]},
|
||||
features: []database.LayerFeature{
|
||||
{realFeatures[4], realDetectors[3], database.Namespace{}},
|
||||
{testutil.RealFeatures[4], testutil.RealDetectors[3], database.Namespace{}},
|
||||
},
|
||||
namespaces: []database.LayerNamespace{
|
||||
{realNamespaces[3], realDetectors[4]},
|
||||
{testutil.RealNamespaces[3], testutil.RealDetectors[4]},
|
||||
},
|
||||
layer: &database.Layer{
|
||||
Hash: "layer-1",
|
||||
By: []database.Detector{realDetectors[1], realDetectors[2], realDetectors[3], realDetectors[4]},
|
||||
By: []database.Detector{testutil.RealDetectors[1], testutil.RealDetectors[2], testutil.RealDetectors[3], testutil.RealDetectors[4]},
|
||||
Features: []database.LayerFeature{
|
||||
{realFeatures[1], realDetectors[2], database.Namespace{}},
|
||||
{realFeatures[2], realDetectors[2], database.Namespace{}},
|
||||
{realFeatures[4], realDetectors[3], database.Namespace{}},
|
||||
{testutil.RealFeatures[1], testutil.RealDetectors[2], database.Namespace{}},
|
||||
{testutil.RealFeatures[2], testutil.RealDetectors[2], database.Namespace{}},
|
||||
{testutil.RealFeatures[4], testutil.RealDetectors[3], database.Namespace{}},
|
||||
},
|
||||
Namespaces: []database.LayerNamespace{
|
||||
{realNamespaces[1], realDetectors[1]},
|
||||
{realNamespaces[3], realDetectors[4]},
|
||||
{testutil.RealNamespaces[1], testutil.RealDetectors[1]},
|
||||
{testutil.RealNamespaces[3], testutil.RealDetectors[4]},
|
||||
},
|
||||
},
|
||||
},
|
||||
@ -120,33 +122,33 @@ var persistLayerTests = []struct {
|
||||
{
|
||||
title: "layer with potential namespace",
|
||||
name: "layer-potential-namespace",
|
||||
by: []database.Detector{realDetectors[3]},
|
||||
by: []database.Detector{testutil.RealDetectors[3]},
|
||||
features: []database.LayerFeature{
|
||||
{realFeatures[4], realDetectors[3], realNamespaces[4]},
|
||||
{testutil.RealFeatures[4], testutil.RealDetectors[3], testutil.RealNamespaces[4]},
|
||||
},
|
||||
namespaces: []database.LayerNamespace{
|
||||
{realNamespaces[3], realDetectors[3]},
|
||||
{testutil.RealNamespaces[3], testutil.RealDetectors[3]},
|
||||
},
|
||||
layer: &database.Layer{
|
||||
Hash: "layer-potential-namespace",
|
||||
By: []database.Detector{realDetectors[3]},
|
||||
By: []database.Detector{testutil.RealDetectors[3]},
|
||||
Features: []database.LayerFeature{
|
||||
{realFeatures[4], realDetectors[3], realNamespaces[4]},
|
||||
{testutil.RealFeatures[4], testutil.RealDetectors[3], testutil.RealNamespaces[4]},
|
||||
},
|
||||
Namespaces: []database.LayerNamespace{
|
||||
{realNamespaces[3], realDetectors[3]},
|
||||
{testutil.RealNamespaces[3], testutil.RealDetectors[3]},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
func TestPersistLayer(t *testing.T) {
|
||||
datastore, tx := openSessionForTest(t, "PersistLayer", true)
|
||||
defer closeTest(t, datastore, tx)
|
||||
tx, cleanup := testutil.CreateTestTxWithFixtures(t, "PersistLayer")
|
||||
defer cleanup()
|
||||
|
||||
for _, test := range persistLayerTests {
|
||||
t.Run(test.title, func(t *testing.T) {
|
||||
err := tx.PersistLayer(test.name, test.features, test.namespaces, test.by)
|
||||
err := PersistLayer(tx, test.name, test.features, test.namespaces, test.by)
|
||||
if test.err != "" {
|
||||
assert.EqualError(t, err, test.err, "unexpected error")
|
||||
return
|
||||
@ -154,7 +156,7 @@ func TestPersistLayer(t *testing.T) {
|
||||
|
||||
assert.Nil(t, err)
|
||||
if test.layer != nil {
|
||||
layer, ok, err := tx.FindLayer(test.name)
|
||||
layer, ok, err := FindLayer(tx, test.name)
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
database.AssertLayerEqual(t, test.layer, &layer)
|
||||
@ -186,17 +188,17 @@ var findLayerTests = []struct {
|
||||
title: "existing layer",
|
||||
in: "layer-4",
|
||||
ok: true,
|
||||
out: takeLayerPointerFromMap(realLayers, 6),
|
||||
out: testutil.TakeLayerPointerFromMap(testutil.RealLayers, 6),
|
||||
},
|
||||
}
|
||||
|
||||
func TestFindLayer(t *testing.T) {
|
||||
datastore, tx := openSessionForTest(t, "FindLayer", true)
|
||||
defer closeTest(t, datastore, tx)
|
||||
tx, cleanup := testutil.CreateTestTxWithFixtures(t, "FindLayer")
|
||||
defer cleanup()
|
||||
|
||||
for _, test := range findLayerTests {
|
||||
t.Run(test.title, func(t *testing.T) {
|
||||
layer, ok, err := tx.FindLayer(test.in)
|
||||
layer, ok, err := FindLayer(tx, test.in)
|
||||
if test.err != "" {
|
||||
assert.EqualError(t, err, test.err, "unexpected error")
|
||||
return
|
@ -12,11 +12,14 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package pgsql
|
||||
package lock
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/clair/database/pgsql/monitoring"
|
||||
"github.com/coreos/clair/database/pgsql/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
@ -38,12 +41,12 @@ const (
|
||||
SELECT owner, until FROM lock WHERE name = $1`
|
||||
)
|
||||
|
||||
func (tx *pgSession) AcquireLock(lockName, whoami string, desiredDuration time.Duration) (bool, time.Time, error) {
|
||||
func AcquireLock(tx *sql.Tx, lockName, whoami string, desiredDuration time.Duration) (bool, time.Time, error) {
|
||||
if lockName == "" || whoami == "" || desiredDuration == 0 {
|
||||
panic("invalid lock parameters")
|
||||
}
|
||||
|
||||
if err := tx.pruneLocks(); err != nil {
|
||||
if err := PruneLocks(tx); err != nil {
|
||||
return false, time.Time{}, err
|
||||
}
|
||||
|
||||
@ -54,22 +57,22 @@ func (tx *pgSession) AcquireLock(lockName, whoami string, desiredDuration time.D
|
||||
lockOwner string
|
||||
)
|
||||
|
||||
defer observeQueryTime("Lock", "soiLock", time.Now())
|
||||
defer monitoring.ObserveQueryTime("Lock", "soiLock", time.Now())
|
||||
err := tx.QueryRow(soiLock, lockName, whoami, desiredLockedUntil).Scan(&lockOwner, &lockedUntil)
|
||||
return lockOwner == whoami, lockedUntil, err
|
||||
return lockOwner == whoami, lockedUntil, util.HandleError("AcquireLock", err)
|
||||
}
|
||||
|
||||
func (tx *pgSession) ExtendLock(lockName, whoami string, desiredDuration time.Duration) (bool, time.Time, error) {
|
||||
func ExtendLock(tx *sql.Tx, lockName, whoami string, desiredDuration time.Duration) (bool, time.Time, error) {
|
||||
if lockName == "" || whoami == "" || desiredDuration == 0 {
|
||||
panic("invalid lock parameters")
|
||||
}
|
||||
|
||||
desiredLockedUntil := time.Now().Add(desiredDuration)
|
||||
|
||||
defer observeQueryTime("Lock", "update", time.Now())
|
||||
defer monitoring.ObserveQueryTime("Lock", "update", time.Now())
|
||||
result, err := tx.Exec(updateLock, lockName, whoami, desiredLockedUntil)
|
||||
if err != nil {
|
||||
return false, time.Time{}, handleError("updateLock", err)
|
||||
return false, time.Time{}, util.HandleError("updateLock", err)
|
||||
}
|
||||
|
||||
if numRows, err := result.RowsAffected(); err == nil {
|
||||
@ -77,27 +80,27 @@ func (tx *pgSession) ExtendLock(lockName, whoami string, desiredDuration time.Du
|
||||
return numRows > 0, desiredLockedUntil, nil
|
||||
}
|
||||
|
||||
return false, time.Time{}, handleError("updateLock", err)
|
||||
return false, time.Time{}, util.HandleError("updateLock", err)
|
||||
}
|
||||
|
||||
func (tx *pgSession) ReleaseLock(name, owner string) error {
|
||||
func ReleaseLock(tx *sql.Tx, name, owner string) error {
|
||||
if name == "" || owner == "" {
|
||||
panic("invalid lock parameters")
|
||||
}
|
||||
|
||||
defer observeQueryTime("Unlock", "all", time.Now())
|
||||
defer monitoring.ObserveQueryTime("Unlock", "all", time.Now())
|
||||
_, err := tx.Exec(removeLock, name, owner)
|
||||
return err
|
||||
}
|
||||
|
||||
// pruneLocks removes every expired locks from the database
|
||||
func (tx *pgSession) pruneLocks() error {
|
||||
defer observeQueryTime("pruneLocks", "all", time.Now())
|
||||
func PruneLocks(tx *sql.Tx) error {
|
||||
defer monitoring.ObserveQueryTime("pruneLocks", "all", time.Now())
|
||||
|
||||
if r, err := tx.Exec(removeLockExpired, time.Now().UTC()); err != nil {
|
||||
return handleError("removeLockExpired", err)
|
||||
return util.HandleError("removeLockExpired", err)
|
||||
} else if affected, err := r.RowsAffected(); err != nil {
|
||||
return handleError("removeLockExpired", err)
|
||||
return util.HandleError("removeLockExpired", err)
|
||||
} else {
|
||||
log.Debugf("Pruned %d Locks", affected)
|
||||
}
|
100
database/pgsql/lock/lock_test.go
Normal file
100
database/pgsql/lock/lock_test.go
Normal file
@ -0,0 +1,100 @@
|
||||
// Copyright 2019 clair authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package lock
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/clair/database/pgsql/testutil"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAcquireLockReturnsExistingLockDuration(t *testing.T) {
|
||||
tx, cleanup := testutil.CreateTestTxWithFixtures(t, "Lock")
|
||||
defer cleanup()
|
||||
|
||||
acquired, originalExpiration, err := AcquireLock(tx, "test1", "owner1", time.Minute)
|
||||
require.Nil(t, err)
|
||||
require.True(t, acquired)
|
||||
|
||||
acquired2, expiration, err := AcquireLock(tx, "test1", "owner2", time.Hour)
|
||||
require.Nil(t, err)
|
||||
require.False(t, acquired2)
|
||||
require.Equal(t, expiration, originalExpiration)
|
||||
}
|
||||
|
||||
func TestLock(t *testing.T) {
|
||||
db, cleanup := testutil.CreateTestDBWithFixture(t, "Lock")
|
||||
defer cleanup()
|
||||
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Create a first lock.
|
||||
l, _, err := AcquireLock(tx, "test1", "owner1", time.Minute)
|
||||
require.Nil(t, err)
|
||||
require.True(t, l)
|
||||
tx = testutil.RestartTransaction(db, tx, true)
|
||||
|
||||
// lock again by itself, the previous lock is not expired yet.
|
||||
l, _, err = AcquireLock(tx, "test1", "owner1", time.Minute)
|
||||
require.Nil(t, err)
|
||||
require.True(t, l)
|
||||
tx = testutil.RestartTransaction(db, tx, false)
|
||||
|
||||
// Try to renew the same lock with another owner.
|
||||
l, _, err = ExtendLock(tx, "test1", "owner2", time.Minute)
|
||||
require.Nil(t, err)
|
||||
require.False(t, l)
|
||||
tx = testutil.RestartTransaction(db, tx, false)
|
||||
|
||||
l, _, err = AcquireLock(tx, "test1", "owner2", time.Minute)
|
||||
require.Nil(t, err)
|
||||
require.False(t, l)
|
||||
tx = testutil.RestartTransaction(db, tx, false)
|
||||
|
||||
// Renew the lock.
|
||||
l, _, err = ExtendLock(tx, "test1", "owner1", 2*time.Minute)
|
||||
require.Nil(t, err)
|
||||
require.True(t, l)
|
||||
tx = testutil.RestartTransaction(db, tx, true)
|
||||
|
||||
// Unlock and then relock by someone else.
|
||||
err = ReleaseLock(tx, "test1", "owner1")
|
||||
require.Nil(t, err)
|
||||
tx = testutil.RestartTransaction(db, tx, true)
|
||||
|
||||
l, _, err = AcquireLock(tx, "test1", "owner2", time.Minute)
|
||||
require.Nil(t, err)
|
||||
require.True(t, l)
|
||||
tx = testutil.RestartTransaction(db, tx, true)
|
||||
|
||||
// Create a second lock which is actually already expired ...
|
||||
l, _, err = AcquireLock(tx, "test2", "owner1", -time.Minute)
|
||||
require.Nil(t, err)
|
||||
require.True(t, l)
|
||||
tx = testutil.RestartTransaction(db, tx, true)
|
||||
|
||||
// Take over the lock
|
||||
l, _, err = AcquireLock(tx, "test2", "owner2", time.Minute)
|
||||
require.Nil(t, err)
|
||||
require.True(t, l)
|
||||
tx = testutil.RestartTransaction(db, tx, true)
|
||||
|
||||
require.Nil(t, tx.Rollback())
|
||||
}
|
@ -1,99 +0,0 @@
|
||||
// Copyright 2019 clair authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package pgsql
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAcquireLockReturnsExistingLockDuration(t *testing.T) {
|
||||
tx, cleanup := createTestPgSessionWithFixtures(t, "Lock")
|
||||
defer cleanup()
|
||||
|
||||
acquired, originalExpiration, err := tx.AcquireLock("test1", "owner1", time.Minute)
|
||||
require.Nil(t, err)
|
||||
require.True(t, acquired)
|
||||
|
||||
acquired2, expiration, err := tx.AcquireLock("test1", "owner2", time.Hour)
|
||||
require.Nil(t, err)
|
||||
require.False(t, acquired2)
|
||||
require.Equal(t, expiration, originalExpiration)
|
||||
}
|
||||
|
||||
func TestLock(t *testing.T) {
|
||||
datastore, tx := openSessionForTest(t, "Lock", true)
|
||||
defer datastore.Close()
|
||||
|
||||
var l bool
|
||||
|
||||
// Create a first lock.
|
||||
l, _, err := tx.AcquireLock("test1", "owner1", time.Minute)
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, l)
|
||||
tx = restartSession(t, datastore, tx, true)
|
||||
|
||||
// lock again by itself, the previous lock is not expired yet.
|
||||
l, _, err = tx.AcquireLock("test1", "owner1", time.Minute)
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, l) // acquire lock no-op when owner already has the lock.
|
||||
tx = restartSession(t, datastore, tx, false)
|
||||
|
||||
// Try to renew the same lock with another owner.
|
||||
l, _, err = tx.ExtendLock("test1", "owner2", time.Minute)
|
||||
assert.Nil(t, err)
|
||||
assert.False(t, l)
|
||||
tx = restartSession(t, datastore, tx, false)
|
||||
|
||||
l, _, err = tx.AcquireLock("test1", "owner2", time.Minute)
|
||||
assert.Nil(t, err)
|
||||
assert.False(t, l)
|
||||
tx = restartSession(t, datastore, tx, false)
|
||||
|
||||
// Renew the lock.
|
||||
l, _, err = tx.ExtendLock("test1", "owner1", 2*time.Minute)
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, l)
|
||||
tx = restartSession(t, datastore, tx, true)
|
||||
|
||||
// Unlock and then relock by someone else.
|
||||
err = tx.ReleaseLock("test1", "owner1")
|
||||
assert.Nil(t, err)
|
||||
tx = restartSession(t, datastore, tx, true)
|
||||
|
||||
l, _, err = tx.AcquireLock("test1", "owner2", time.Minute)
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, l)
|
||||
tx = restartSession(t, datastore, tx, true)
|
||||
|
||||
// Create a second lock which is actually already expired ...
|
||||
l, _, err = tx.AcquireLock("test2", "owner1", -time.Minute)
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, l)
|
||||
tx = restartSession(t, datastore, tx, true)
|
||||
|
||||
// Take over the lock
|
||||
l, _, err = tx.AcquireLock("test2", "owner2", time.Minute)
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, l)
|
||||
tx = restartSession(t, datastore, tx, true)
|
||||
|
||||
if !assert.Nil(t, tx.Rollback()) {
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
@ -12,22 +12,22 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package pgsql
|
||||
package migrations_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/coreos/clair/database/pgsql/migrations"
|
||||
"github.com/coreos/clair/database/pgsql/testutil"
|
||||
_ "github.com/lib/pq"
|
||||
"github.com/remind101/migrate"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coreos/clair/database/pgsql/migrations"
|
||||
)
|
||||
|
||||
var userTableCount = `SELECT tablename FROM pg_catalog.pg_tables WHERE schemaname='public'`
|
||||
|
||||
func TestMigration(t *testing.T) {
|
||||
db, cleanup := createAndConnectTestDB(t, "TestMigration")
|
||||
db, cleanup := testutil.CreateAndConnectTestDB(t, "TestMigration")
|
||||
defer cleanup()
|
||||
|
||||
err := migrate.NewPostgresMigrator(db).Exec(migrate.Up, migrations.Migrations...)
|
67
database/pgsql/monitoring/prometheus.go
Normal file
67
database/pgsql/monitoring/prometheus.go
Normal file
@ -0,0 +1,67 @@
|
||||
// Copyright 2019 clair authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package monitoring
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
)
|
||||
|
||||
var (
|
||||
PromErrorsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "clair_pgsql_errors_total",
|
||||
Help: "Number of errors that PostgreSQL requests generated.",
|
||||
}, []string{"request"})
|
||||
|
||||
PromCacheHitsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "clair_pgsql_cache_hits_total",
|
||||
Help: "Number of cache hits that the PostgreSQL backend did.",
|
||||
}, []string{"object"})
|
||||
|
||||
PromCacheQueriesTotal = prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "clair_pgsql_cache_queries_total",
|
||||
Help: "Number of cache queries that the PostgreSQL backend did.",
|
||||
}, []string{"object"})
|
||||
|
||||
PromQueryDurationMilliseconds = prometheus.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Name: "clair_pgsql_query_duration_milliseconds",
|
||||
Help: "Time it takes to execute the database query.",
|
||||
}, []string{"query", "subquery"})
|
||||
|
||||
PromConcurrentLockVAFV = prometheus.NewGauge(prometheus.GaugeOpts{
|
||||
Name: "clair_pgsql_concurrent_lock_vafv_total",
|
||||
Help: "Number of transactions trying to hold the exclusive Vulnerability_Affects_Feature lock.",
|
||||
})
|
||||
)
|
||||
|
||||
func init() {
|
||||
prometheus.MustRegister(PromErrorsTotal)
|
||||
prometheus.MustRegister(PromCacheHitsTotal)
|
||||
prometheus.MustRegister(PromCacheQueriesTotal)
|
||||
prometheus.MustRegister(PromQueryDurationMilliseconds)
|
||||
prometheus.MustRegister(PromConcurrentLockVAFV)
|
||||
}
|
||||
|
||||
// monitoring.ObserveQueryTime computes the time elapsed since `start` to represent the
|
||||
// query time.
|
||||
// 1. `query` is a pgSession function name.
|
||||
// 2. `subquery` is a specific query or a batched query.
|
||||
// 3. `start` is the time right before query is executed.
|
||||
func ObserveQueryTime(query, subquery string, start time.Time) {
|
||||
PromQueryDurationMilliseconds.
|
||||
WithLabelValues(query, subquery).
|
||||
Observe(float64(time.Since(start).Nanoseconds()) / float64(time.Millisecond))
|
||||
}
|
@ -12,13 +12,15 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package pgsql
|
||||
package namespace
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
"github.com/coreos/clair/database/pgsql/util"
|
||||
"github.com/coreos/clair/pkg/commonerr"
|
||||
)
|
||||
|
||||
@ -26,8 +28,24 @@ const (
|
||||
searchNamespaceID = `SELECT id FROM Namespace WHERE name = $1 AND version_format = $2`
|
||||
)
|
||||
|
||||
func queryPersistNamespace(count int) string {
|
||||
return util.QueryPersist(count,
|
||||
"namespace",
|
||||
"namespace_name_version_format_key",
|
||||
"name",
|
||||
"version_format")
|
||||
}
|
||||
|
||||
func querySearchNamespace(nsCount int) string {
|
||||
return fmt.Sprintf(
|
||||
`SELECT id, name, version_format
|
||||
FROM namespace WHERE (name, version_format) IN (%s)`,
|
||||
util.QueryString(2, nsCount),
|
||||
)
|
||||
}
|
||||
|
||||
// PersistNamespaces soi namespaces into database.
|
||||
func (tx *pgSession) PersistNamespaces(namespaces []database.Namespace) error {
|
||||
func PersistNamespaces(tx *sql.Tx, namespaces []database.Namespace) error {
|
||||
if len(namespaces) == 0 {
|
||||
return nil
|
||||
}
|
||||
@ -49,12 +67,12 @@ func (tx *pgSession) PersistNamespaces(namespaces []database.Namespace) error {
|
||||
|
||||
_, err := tx.Exec(queryPersistNamespace(len(namespaces)), keys...)
|
||||
if err != nil {
|
||||
return handleError("queryPersistNamespace", err)
|
||||
return util.HandleError("queryPersistNamespace", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) findNamespaceIDs(namespaces []database.Namespace) ([]sql.NullInt64, error) {
|
||||
func FindNamespaceIDs(tx *sql.Tx, namespaces []database.Namespace) ([]sql.NullInt64, error) {
|
||||
if len(namespaces) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
@ -69,7 +87,7 @@ func (tx *pgSession) findNamespaceIDs(namespaces []database.Namespace) ([]sql.Nu
|
||||
|
||||
rows, err := tx.Query(querySearchNamespace(len(namespaces)), keys...)
|
||||
if err != nil {
|
||||
return nil, handleError("searchNamespace", err)
|
||||
return nil, util.HandleError("searchNamespace", err)
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
@ -81,7 +99,7 @@ func (tx *pgSession) findNamespaceIDs(namespaces []database.Namespace) ([]sql.Nu
|
||||
for rows.Next() {
|
||||
err := rows.Scan(&id, &ns.Name, &ns.VersionFormat)
|
||||
if err != nil {
|
||||
return nil, handleError("searchNamespace", err)
|
||||
return nil, util.HandleError("searchNamespace", err)
|
||||
}
|
||||
nsMap[ns] = id
|
||||
}
|
@ -12,7 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package pgsql
|
||||
package namespace
|
||||
|
||||
import (
|
||||
"testing"
|
||||
@ -20,25 +20,26 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
"github.com/coreos/clair/database/pgsql/testutil"
|
||||
)
|
||||
|
||||
func TestPersistNamespaces(t *testing.T) {
|
||||
datastore, tx := openSessionForTest(t, "PersistNamespaces", false)
|
||||
defer closeTest(t, datastore, tx)
|
||||
tx, cleanup := testutil.CreateTestTx(t, "PersistNamespaces")
|
||||
defer cleanup()
|
||||
|
||||
ns1 := database.Namespace{}
|
||||
ns2 := database.Namespace{Name: "t", VersionFormat: "b"}
|
||||
|
||||
// Empty Case
|
||||
assert.Nil(t, tx.PersistNamespaces([]database.Namespace{}))
|
||||
assert.Nil(t, PersistNamespaces(tx, []database.Namespace{}))
|
||||
// Invalid Case
|
||||
assert.NotNil(t, tx.PersistNamespaces([]database.Namespace{ns1}))
|
||||
assert.NotNil(t, PersistNamespaces(tx, []database.Namespace{ns1}))
|
||||
// Duplicated Case
|
||||
assert.Nil(t, tx.PersistNamespaces([]database.Namespace{ns2, ns2}))
|
||||
assert.Nil(t, PersistNamespaces(tx, []database.Namespace{ns2, ns2}))
|
||||
// Existing Case
|
||||
assert.Nil(t, tx.PersistNamespaces([]database.Namespace{ns2}))
|
||||
assert.Nil(t, PersistNamespaces(tx, []database.Namespace{ns2}))
|
||||
|
||||
nsList := listNamespaces(t, tx)
|
||||
nsList := testutil.ListNamespaces(t, tx)
|
||||
assert.Len(t, nsList, 1)
|
||||
assert.Equal(t, ns2, nsList[0])
|
||||
}
|
@ -12,7 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package pgsql
|
||||
package notification
|
||||
|
||||
import (
|
||||
"testing"
|
||||
@ -22,6 +22,8 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
"github.com/coreos/clair/database/pgsql/page"
|
||||
"github.com/coreos/clair/database/pgsql/testutil"
|
||||
"github.com/coreos/clair/pkg/pagination"
|
||||
)
|
||||
|
||||
@ -38,6 +40,8 @@ type findVulnerabilityNotificationOut struct {
|
||||
err string
|
||||
}
|
||||
|
||||
var testPaginationKey = pagination.Must(pagination.NewKey())
|
||||
|
||||
var findVulnerabilityNotificationTests = []struct {
|
||||
title string
|
||||
in findVulnerabilityNotificationIn
|
||||
@ -77,21 +81,21 @@ var findVulnerabilityNotificationTests = []struct {
|
||||
},
|
||||
out: findVulnerabilityNotificationOut{
|
||||
&database.VulnerabilityNotificationWithVulnerable{
|
||||
NotificationHook: realNotification[1].NotificationHook,
|
||||
NotificationHook: testutil.RealNotification[1].NotificationHook,
|
||||
Old: &database.PagedVulnerableAncestries{
|
||||
Vulnerability: realVulnerability[2],
|
||||
Vulnerability: testutil.RealVulnerability[2],
|
||||
Limit: 1,
|
||||
Affected: make(map[int]string),
|
||||
Current: mustMarshalToken(testPaginationKey, Page{0}),
|
||||
Next: mustMarshalToken(testPaginationKey, Page{0}),
|
||||
Current: testutil.MustMarshalToken(testutil.TestPaginationKey, page.Page{0}),
|
||||
Next: testutil.MustMarshalToken(testutil.TestPaginationKey, page.Page{0}),
|
||||
End: true,
|
||||
},
|
||||
New: &database.PagedVulnerableAncestries{
|
||||
Vulnerability: realVulnerability[1],
|
||||
Vulnerability: testutil.RealVulnerability[1],
|
||||
Limit: 1,
|
||||
Affected: map[int]string{3: "ancestry-3"},
|
||||
Current: mustMarshalToken(testPaginationKey, Page{0}),
|
||||
Next: mustMarshalToken(testPaginationKey, Page{4}),
|
||||
Current: testutil.MustMarshalToken(testutil.TestPaginationKey, page.Page{0}),
|
||||
Next: testutil.MustMarshalToken(testutil.TestPaginationKey, page.Page{4}),
|
||||
End: false,
|
||||
},
|
||||
},
|
||||
@ -100,32 +104,31 @@ var findVulnerabilityNotificationTests = []struct {
|
||||
"",
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
title: "find existing notification of second page of new affected ancestry",
|
||||
in: findVulnerabilityNotificationIn{
|
||||
notificationName: "test",
|
||||
pageSize: 1,
|
||||
oldAffectedAncestryPage: pagination.FirstPageToken,
|
||||
newAffectedAncestryPage: mustMarshalToken(testPaginationKey, Page{4}),
|
||||
newAffectedAncestryPage: testutil.MustMarshalToken(testutil.TestPaginationKey, page.Page{4}),
|
||||
},
|
||||
out: findVulnerabilityNotificationOut{
|
||||
&database.VulnerabilityNotificationWithVulnerable{
|
||||
NotificationHook: realNotification[1].NotificationHook,
|
||||
NotificationHook: testutil.RealNotification[1].NotificationHook,
|
||||
Old: &database.PagedVulnerableAncestries{
|
||||
Vulnerability: realVulnerability[2],
|
||||
Vulnerability: testutil.RealVulnerability[2],
|
||||
Limit: 1,
|
||||
Affected: make(map[int]string),
|
||||
Current: mustMarshalToken(testPaginationKey, Page{0}),
|
||||
Next: mustMarshalToken(testPaginationKey, Page{0}),
|
||||
Current: testutil.MustMarshalToken(testutil.TestPaginationKey, page.Page{0}),
|
||||
Next: testutil.MustMarshalToken(testutil.TestPaginationKey, page.Page{0}),
|
||||
End: true,
|
||||
},
|
||||
New: &database.PagedVulnerableAncestries{
|
||||
Vulnerability: realVulnerability[1],
|
||||
Vulnerability: testutil.RealVulnerability[1],
|
||||
Limit: 1,
|
||||
Affected: map[int]string{4: "ancestry-4"},
|
||||
Current: mustMarshalToken(testPaginationKey, Page{4}),
|
||||
Next: mustMarshalToken(testPaginationKey, Page{0}),
|
||||
Current: testutil.MustMarshalToken(testutil.TestPaginationKey, page.Page{4}),
|
||||
Next: testutil.MustMarshalToken(testutil.TestPaginationKey, page.Page{0}),
|
||||
End: true,
|
||||
},
|
||||
},
|
||||
@ -137,12 +140,12 @@ var findVulnerabilityNotificationTests = []struct {
|
||||
}
|
||||
|
||||
func TestFindVulnerabilityNotification(t *testing.T) {
|
||||
datastore, tx := openSessionForTest(t, "pagination", true)
|
||||
defer closeTest(t, datastore, tx)
|
||||
tx, cleanup := testutil.CreateTestTxWithFixtures(t, "pagination")
|
||||
defer cleanup()
|
||||
|
||||
for _, test := range findVulnerabilityNotificationTests {
|
||||
t.Run(test.title, func(t *testing.T) {
|
||||
notification, ok, err := tx.FindVulnerabilityNotification(test.in.notificationName, test.in.pageSize, test.in.oldAffectedAncestryPage, test.in.newAffectedAncestryPage)
|
||||
notification, ok, err := FindVulnerabilityNotification(tx, test.in.notificationName, test.in.pageSize, test.in.oldAffectedAncestryPage, test.in.newAffectedAncestryPage, testutil.TestPaginationKey)
|
||||
if test.out.err != "" {
|
||||
require.EqualError(t, err, test.out.err)
|
||||
return
|
||||
@ -155,13 +158,14 @@ func TestFindVulnerabilityNotification(t *testing.T) {
|
||||
}
|
||||
|
||||
require.True(t, ok)
|
||||
assertVulnerabilityNotificationWithVulnerableEqual(t, testPaginationKey, test.out.notification, ¬ification)
|
||||
testutil.AssertVulnerabilityNotificationWithVulnerableEqual(t, testutil.TestPaginationKey, test.out.notification, ¬ification)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInsertVulnerabilityNotifications(t *testing.T) {
|
||||
datastore, tx := openSessionForTest(t, "InsertVulnerabilityNotifications", true)
|
||||
datastore, cleanup := testutil.CreateTestDBWithFixture(t, "InsertVulnerabilityNotifications")
|
||||
defer cleanup()
|
||||
|
||||
n1 := database.VulnerabilityNotification{}
|
||||
n3 := database.VulnerabilityNotification{
|
||||
@ -187,34 +191,37 @@ func TestInsertVulnerabilityNotifications(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
tx, err := datastore.Begin()
|
||||
require.Nil(t, err)
|
||||
|
||||
// invalid case
|
||||
err := tx.InsertVulnerabilityNotifications([]database.VulnerabilityNotification{n1})
|
||||
assert.NotNil(t, err)
|
||||
err = InsertVulnerabilityNotifications(tx, []database.VulnerabilityNotification{n1})
|
||||
require.NotNil(t, err)
|
||||
|
||||
// invalid case: unknown vulnerability
|
||||
err = tx.InsertVulnerabilityNotifications([]database.VulnerabilityNotification{n3})
|
||||
assert.NotNil(t, err)
|
||||
err = InsertVulnerabilityNotifications(tx, []database.VulnerabilityNotification{n3})
|
||||
require.NotNil(t, err)
|
||||
|
||||
// invalid case: duplicated input notification
|
||||
err = tx.InsertVulnerabilityNotifications([]database.VulnerabilityNotification{n4, n4})
|
||||
assert.NotNil(t, err)
|
||||
tx = restartSession(t, datastore, tx, false)
|
||||
err = InsertVulnerabilityNotifications(tx, []database.VulnerabilityNotification{n4, n4})
|
||||
require.NotNil(t, err)
|
||||
tx = testutil.RestartTransaction(datastore, tx, false)
|
||||
|
||||
// valid case
|
||||
err = tx.InsertVulnerabilityNotifications([]database.VulnerabilityNotification{n4})
|
||||
assert.Nil(t, err)
|
||||
err = InsertVulnerabilityNotifications(tx, []database.VulnerabilityNotification{n4})
|
||||
require.Nil(t, err)
|
||||
// invalid case: notification is already in database
|
||||
err = tx.InsertVulnerabilityNotifications([]database.VulnerabilityNotification{n4})
|
||||
assert.NotNil(t, err)
|
||||
err = InsertVulnerabilityNotifications(tx, []database.VulnerabilityNotification{n4})
|
||||
require.NotNil(t, err)
|
||||
|
||||
closeTest(t, datastore, tx)
|
||||
require.Nil(t, tx.Rollback())
|
||||
}
|
||||
|
||||
func TestFindNewNotification(t *testing.T) {
|
||||
tx, cleanup := createTestPgSessionWithFixtures(t, "TestFindNewNotification")
|
||||
tx, cleanup := testutil.CreateTestTxWithFixtures(t, "TestFindNewNotification")
|
||||
defer cleanup()
|
||||
|
||||
noti, ok, err := tx.FindNewNotification(time.Now())
|
||||
noti, ok, err := FindNewNotification(tx, time.Now())
|
||||
if assert.Nil(t, err) && assert.True(t, ok) {
|
||||
assert.Equal(t, "test", noti.Name)
|
||||
assert.Equal(t, time.Time{}, noti.Notified)
|
||||
@ -223,13 +230,13 @@ func TestFindNewNotification(t *testing.T) {
|
||||
}
|
||||
|
||||
// can't find the notified
|
||||
assert.Nil(t, tx.MarkNotificationAsRead("test"))
|
||||
assert.Nil(t, MarkNotificationAsRead(tx, "test"))
|
||||
// if the notified time is before
|
||||
noti, ok, err = tx.FindNewNotification(time.Now().Add(-time.Duration(10 * time.Second)))
|
||||
noti, ok, err = FindNewNotification(tx, time.Now().Add(-time.Duration(10*time.Second)))
|
||||
assert.Nil(t, err)
|
||||
assert.False(t, ok)
|
||||
// can find the notified after a period of time
|
||||
noti, ok, err = tx.FindNewNotification(time.Now().Add(time.Duration(10 * time.Second)))
|
||||
noti, ok, err = FindNewNotification(tx, time.Now().Add(time.Duration(10*time.Second)))
|
||||
if assert.Nil(t, err) && assert.True(t, ok) {
|
||||
assert.Equal(t, "test", noti.Name)
|
||||
assert.NotEqual(t, time.Time{}, noti.Notified)
|
||||
@ -237,37 +244,37 @@ func TestFindNewNotification(t *testing.T) {
|
||||
assert.Equal(t, time.Time{}, noti.Deleted)
|
||||
}
|
||||
|
||||
assert.Nil(t, tx.DeleteNotification("test"))
|
||||
assert.Nil(t, DeleteNotification(tx, "test"))
|
||||
// can't find in any time
|
||||
noti, ok, err = tx.FindNewNotification(time.Now().Add(-time.Duration(1000)))
|
||||
noti, ok, err = FindNewNotification(tx, time.Now().Add(-time.Duration(1000)))
|
||||
assert.Nil(t, err)
|
||||
assert.False(t, ok)
|
||||
|
||||
noti, ok, err = tx.FindNewNotification(time.Now().Add(time.Duration(1000)))
|
||||
noti, ok, err = FindNewNotification(tx, time.Now().Add(time.Duration(1000)))
|
||||
assert.Nil(t, err)
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
func TestMarkNotificationAsRead(t *testing.T) {
|
||||
datastore, tx := openSessionForTest(t, "MarkNotificationAsRead", true)
|
||||
defer closeTest(t, datastore, tx)
|
||||
tx, cleanup := testutil.CreateTestTxWithFixtures(t, "MarkNotificationAsRead")
|
||||
defer cleanup()
|
||||
|
||||
// invalid case: notification doesn't exist
|
||||
assert.NotNil(t, tx.MarkNotificationAsRead("non-existing"))
|
||||
assert.NotNil(t, MarkNotificationAsRead(tx, "non-existing"))
|
||||
// valid case
|
||||
assert.Nil(t, tx.MarkNotificationAsRead("test"))
|
||||
assert.Nil(t, MarkNotificationAsRead(tx, "test"))
|
||||
// valid case
|
||||
assert.Nil(t, tx.MarkNotificationAsRead("test"))
|
||||
assert.Nil(t, MarkNotificationAsRead(tx, "test"))
|
||||
}
|
||||
|
||||
func TestDeleteNotification(t *testing.T) {
|
||||
datastore, tx := openSessionForTest(t, "DeleteNotification", true)
|
||||
defer closeTest(t, datastore, tx)
|
||||
tx, cleanup := testutil.CreateTestTxWithFixtures(t, "DeleteNotification")
|
||||
defer cleanup()
|
||||
|
||||
// invalid case: notification doesn't exist
|
||||
assert.NotNil(t, tx.DeleteNotification("non-existing"))
|
||||
assert.NotNil(t, DeleteNotification(tx, "non-existing"))
|
||||
// valid case
|
||||
assert.Nil(t, tx.DeleteNotification("test"))
|
||||
assert.Nil(t, DeleteNotification(tx, "test"))
|
||||
// invalid case: notification is already deleted
|
||||
assert.NotNil(t, tx.DeleteNotification("test"))
|
||||
assert.NotNil(t, DeleteNotification(tx, "test"))
|
||||
}
|
@ -12,7 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package pgsql
|
||||
package notification
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
@ -22,6 +22,8 @@ import (
|
||||
"github.com/guregu/null/zero"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
"github.com/coreos/clair/database/pgsql/util"
|
||||
"github.com/coreos/clair/database/pgsql/vulnerability"
|
||||
"github.com/coreos/clair/pkg/commonerr"
|
||||
"github.com/coreos/clair/pkg/pagination"
|
||||
)
|
||||
@ -54,26 +56,24 @@ const (
|
||||
SELECT created_at, notified_at, deleted_at, old_vulnerability_id, new_vulnerability_id
|
||||
FROM Vulnerability_Notification
|
||||
WHERE name = $1`
|
||||
|
||||
searchNotificationVulnerableAncestry = `
|
||||
SELECT DISTINCT ON (a.id)
|
||||
a.id, a.name
|
||||
FROM vulnerability_affected_namespaced_feature AS vanf,
|
||||
ancestry_layer AS al, ancestry_feature AS af, ancestry AS a
|
||||
WHERE vanf.vulnerability_id = $1
|
||||
AND a.id >= $2
|
||||
AND al.ancestry_id = a.id
|
||||
AND al.id = af.ancestry_layer_id
|
||||
AND af.namespaced_feature_id = vanf.namespaced_feature_id
|
||||
ORDER BY a.id ASC
|
||||
LIMIT $3;`
|
||||
)
|
||||
|
||||
func queryInsertNotifications(count int) string {
|
||||
return util.QueryInsert(count,
|
||||
"vulnerability_notification",
|
||||
"name",
|
||||
"created_at",
|
||||
"old_vulnerability_id",
|
||||
"new_vulnerability_id",
|
||||
)
|
||||
}
|
||||
|
||||
var (
|
||||
errNotificationNotFound = errors.New("requested notification is not found")
|
||||
errVulnerabilityNotFound = errors.New("vulnerability is not in database")
|
||||
)
|
||||
|
||||
func (tx *pgSession) InsertVulnerabilityNotifications(notifications []database.VulnerabilityNotification) error {
|
||||
func InsertVulnerabilityNotifications(tx *sql.Tx, notifications []database.VulnerabilityNotification) error {
|
||||
if len(notifications) == 0 {
|
||||
return nil
|
||||
}
|
||||
@ -122,26 +122,26 @@ func (tx *pgSession) InsertVulnerabilityNotifications(notifications []database.V
|
||||
oldVulnIDs = append(oldVulnIDs, vulnID)
|
||||
}
|
||||
|
||||
ids, err := tx.findNotDeletedVulnerabilityIDs(newVulnIDs)
|
||||
ids, err := vulnerability.FindNotDeletedVulnerabilityIDs(tx, newVulnIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i, id := range ids {
|
||||
if !id.Valid {
|
||||
return handleError("findNotDeletedVulnerabilityIDs", errVulnerabilityNotFound)
|
||||
return util.HandleError("findNotDeletedVulnerabilityIDs", errVulnerabilityNotFound)
|
||||
}
|
||||
newVulnIDMap[newVulnIDs[i]] = id
|
||||
}
|
||||
|
||||
ids, err = tx.findLatestDeletedVulnerabilityIDs(oldVulnIDs)
|
||||
ids, err = vulnerability.FindLatestDeletedVulnerabilityIDs(tx, oldVulnIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i, id := range ids {
|
||||
if !id.Valid {
|
||||
return handleError("findLatestDeletedVulnerabilityIDs", errVulnerabilityNotFound)
|
||||
return util.HandleError("findLatestDeletedVulnerabilityIDs", errVulnerabilityNotFound)
|
||||
}
|
||||
oldVulnIDMap[oldVulnIDs[i]] = id
|
||||
}
|
||||
@ -178,13 +178,13 @@ func (tx *pgSession) InsertVulnerabilityNotifications(notifications []database.V
|
||||
// multiple updaters, deadlock may happen.
|
||||
_, err = tx.Exec(queryInsertNotifications(len(notifications)), keys...)
|
||||
if err != nil {
|
||||
return handleError("queryInsertNotifications", err)
|
||||
return util.HandleError("queryInsertNotifications", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) FindNewNotification(notifiedBefore time.Time) (database.NotificationHook, bool, error) {
|
||||
func FindNewNotification(tx *sql.Tx, notifiedBefore time.Time) (database.NotificationHook, bool, error) {
|
||||
var (
|
||||
notification database.NotificationHook
|
||||
created zero.Time
|
||||
@ -197,7 +197,7 @@ func (tx *pgSession) FindNewNotification(notifiedBefore time.Time) (database.Not
|
||||
if err == sql.ErrNoRows {
|
||||
return notification, false, nil
|
||||
}
|
||||
return notification, false, handleError("searchNotificationAvailable", err)
|
||||
return notification, false, util.HandleError("searchNotificationAvailable", err)
|
||||
}
|
||||
|
||||
notification.Created = created.Time
|
||||
@ -207,71 +207,7 @@ func (tx *pgSession) FindNewNotification(notifiedBefore time.Time) (database.Not
|
||||
return notification, true, nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) findPagedVulnerableAncestries(vulnID int64, limit int, currentToken pagination.Token) (database.PagedVulnerableAncestries, error) {
|
||||
vulnPage := database.PagedVulnerableAncestries{Limit: limit}
|
||||
currentPage := Page{0}
|
||||
if currentToken != pagination.FirstPageToken {
|
||||
if err := tx.key.UnmarshalToken(currentToken, ¤tPage); err != nil {
|
||||
return vulnPage, err
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.QueryRow(searchVulnerabilityByID, vulnID).Scan(
|
||||
&vulnPage.Name,
|
||||
&vulnPage.Description,
|
||||
&vulnPage.Link,
|
||||
&vulnPage.Severity,
|
||||
&vulnPage.Metadata,
|
||||
&vulnPage.Namespace.Name,
|
||||
&vulnPage.Namespace.VersionFormat,
|
||||
); err != nil {
|
||||
return vulnPage, handleError("searchVulnerabilityByID", err)
|
||||
}
|
||||
|
||||
// the last result is used for the next page's startID
|
||||
rows, err := tx.Query(searchNotificationVulnerableAncestry, vulnID, currentPage.StartID, limit+1)
|
||||
if err != nil {
|
||||
return vulnPage, handleError("searchNotificationVulnerableAncestry", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
ancestries := []affectedAncestry{}
|
||||
for rows.Next() {
|
||||
var ancestry affectedAncestry
|
||||
err := rows.Scan(&ancestry.id, &ancestry.name)
|
||||
if err != nil {
|
||||
return vulnPage, handleError("searchNotificationVulnerableAncestry", err)
|
||||
}
|
||||
ancestries = append(ancestries, ancestry)
|
||||
}
|
||||
|
||||
lastIndex := 0
|
||||
if len(ancestries)-1 < limit {
|
||||
lastIndex = len(ancestries)
|
||||
vulnPage.End = true
|
||||
} else {
|
||||
// Use the last ancestry's ID as the next page.
|
||||
lastIndex = len(ancestries) - 1
|
||||
vulnPage.Next, err = tx.key.MarshalToken(Page{ancestries[len(ancestries)-1].id})
|
||||
if err != nil {
|
||||
return vulnPage, err
|
||||
}
|
||||
}
|
||||
|
||||
vulnPage.Affected = map[int]string{}
|
||||
for _, ancestry := range ancestries[0:lastIndex] {
|
||||
vulnPage.Affected[int(ancestry.id)] = ancestry.name
|
||||
}
|
||||
|
||||
vulnPage.Current, err = tx.key.MarshalToken(currentPage)
|
||||
if err != nil {
|
||||
return vulnPage, err
|
||||
}
|
||||
|
||||
return vulnPage, nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) FindVulnerabilityNotification(name string, limit int, oldPageToken pagination.Token, newPageToken pagination.Token) (
|
||||
func FindVulnerabilityNotification(tx *sql.Tx, name string, limit int, oldPageToken pagination.Token, newPageToken pagination.Token, key pagination.Key) (
|
||||
database.VulnerabilityNotificationWithVulnerable, bool, error) {
|
||||
var (
|
||||
noti database.VulnerabilityNotificationWithVulnerable
|
||||
@ -294,7 +230,7 @@ func (tx *pgSession) FindVulnerabilityNotification(name string, limit int, oldPa
|
||||
if err == sql.ErrNoRows {
|
||||
return noti, false, nil
|
||||
}
|
||||
return noti, false, handleError("searchNotification", err)
|
||||
return noti, false, util.HandleError("searchNotification", err)
|
||||
}
|
||||
|
||||
if created.Valid {
|
||||
@ -310,7 +246,7 @@ func (tx *pgSession) FindVulnerabilityNotification(name string, limit int, oldPa
|
||||
}
|
||||
|
||||
if oldVulnID.Valid {
|
||||
page, err := tx.findPagedVulnerableAncestries(oldVulnID.Int64, limit, oldPageToken)
|
||||
page, err := vulnerability.FindPagedVulnerableAncestries(tx, oldVulnID.Int64, limit, oldPageToken, key)
|
||||
if err != nil {
|
||||
return noti, false, err
|
||||
}
|
||||
@ -318,7 +254,7 @@ func (tx *pgSession) FindVulnerabilityNotification(name string, limit int, oldPa
|
||||
}
|
||||
|
||||
if newVulnID.Valid {
|
||||
page, err := tx.findPagedVulnerableAncestries(newVulnID.Int64, limit, newPageToken)
|
||||
page, err := vulnerability.FindPagedVulnerableAncestries(tx, newVulnID.Int64, limit, newPageToken, key)
|
||||
if err != nil {
|
||||
return noti, false, err
|
||||
}
|
||||
@ -328,44 +264,44 @@ func (tx *pgSession) FindVulnerabilityNotification(name string, limit int, oldPa
|
||||
return noti, true, nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) MarkNotificationAsRead(name string) error {
|
||||
func MarkNotificationAsRead(tx *sql.Tx, name string) error {
|
||||
if name == "" {
|
||||
return commonerr.NewBadRequestError("Empty notification name is not allowed")
|
||||
}
|
||||
|
||||
r, err := tx.Exec(updatedNotificationAsRead, name)
|
||||
if err != nil {
|
||||
return handleError("updatedNotificationAsRead", err)
|
||||
return util.HandleError("updatedNotificationAsRead", err)
|
||||
}
|
||||
|
||||
affected, err := r.RowsAffected()
|
||||
if err != nil {
|
||||
return handleError("updatedNotificationAsRead", err)
|
||||
return util.HandleError("updatedNotificationAsRead", err)
|
||||
}
|
||||
|
||||
if affected <= 0 {
|
||||
return handleError("updatedNotificationAsRead", errNotificationNotFound)
|
||||
return util.HandleError("updatedNotificationAsRead", errNotificationNotFound)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) DeleteNotification(name string) error {
|
||||
func DeleteNotification(tx *sql.Tx, name string) error {
|
||||
if name == "" {
|
||||
return commonerr.NewBadRequestError("Empty notification name is not allowed")
|
||||
}
|
||||
|
||||
result, err := tx.Exec(removeNotification, name)
|
||||
if err != nil {
|
||||
return handleError("removeNotification", err)
|
||||
return util.HandleError("removeNotification", err)
|
||||
}
|
||||
|
||||
affected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return handleError("removeNotification", err)
|
||||
return util.HandleError("removeNotification", err)
|
||||
}
|
||||
|
||||
if affected <= 0 {
|
||||
return handleError("removeNotification", commonerr.ErrNotFound)
|
||||
return util.HandleError("removeNotification", commonerr.ErrNotFound)
|
||||
}
|
||||
|
||||
return nil
|
24
database/pgsql/page/page.go
Normal file
24
database/pgsql/page/page.go
Normal file
@ -0,0 +1,24 @@
|
||||
// Copyright 2019 clair authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package page
|
||||
|
||||
// Page is the representation of a page for the Postgres schema.
|
||||
type Page struct {
|
||||
// StartID is the ID being used as the basis for pagination across database
|
||||
// results. It is used to search for an ancestry with ID >= StartID.
|
||||
//
|
||||
// StartID is required to be unique to every ancestry and always increasing.
|
||||
StartID int64
|
||||
}
|
134
database/pgsql/pgsession.go
Normal file
134
database/pgsql/pgsession.go
Normal file
@ -0,0 +1,134 @@
|
||||
// Copyright 2019 clair authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package pgsql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/clair/database/pgsql/keyvalue"
|
||||
"github.com/coreos/clair/database/pgsql/vulnerability"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
"github.com/coreos/clair/database/pgsql/ancestry"
|
||||
"github.com/coreos/clair/database/pgsql/detector"
|
||||
"github.com/coreos/clair/database/pgsql/feature"
|
||||
"github.com/coreos/clair/database/pgsql/layer"
|
||||
"github.com/coreos/clair/database/pgsql/lock"
|
||||
"github.com/coreos/clair/database/pgsql/namespace"
|
||||
"github.com/coreos/clair/database/pgsql/notification"
|
||||
"github.com/coreos/clair/pkg/pagination"
|
||||
)
|
||||
|
||||
// Enforce the interface at compile time.
|
||||
var _ database.Session = &pgSession{}
|
||||
|
||||
type pgSession struct {
|
||||
*sql.Tx
|
||||
|
||||
key pagination.Key
|
||||
}
|
||||
|
||||
func (tx *pgSession) UpsertAncestry(a database.Ancestry) error {
|
||||
return ancestry.UpsertAncestry(tx.Tx, a)
|
||||
}
|
||||
|
||||
func (tx *pgSession) FindAncestry(name string) (database.Ancestry, bool, error) {
|
||||
return ancestry.FindAncestry(tx.Tx, name)
|
||||
}
|
||||
|
||||
func (tx *pgSession) PersistDetectors(detectors []database.Detector) error {
|
||||
return detector.PersistDetectors(tx.Tx, detectors)
|
||||
}
|
||||
|
||||
func (tx *pgSession) PersistFeatures(features []database.Feature) error {
|
||||
return feature.PersistFeatures(tx.Tx, features)
|
||||
}
|
||||
|
||||
func (tx *pgSession) PersistNamespacedFeatures(features []database.NamespacedFeature) error {
|
||||
return feature.PersistNamespacedFeatures(tx.Tx, features)
|
||||
}
|
||||
|
||||
func (tx *pgSession) CacheAffectedNamespacedFeatures(features []database.NamespacedFeature) error {
|
||||
return vulnerability.CacheAffectedNamespacedFeatures(tx.Tx, features)
|
||||
}
|
||||
|
||||
func (tx *pgSession) FindAffectedNamespacedFeatures(features []database.NamespacedFeature) ([]database.NullableAffectedNamespacedFeature, error) {
|
||||
return vulnerability.FindAffectedNamespacedFeatures(tx.Tx, features)
|
||||
}
|
||||
|
||||
func (tx *pgSession) PersistNamespaces(namespaces []database.Namespace) error {
|
||||
return namespace.PersistNamespaces(tx.Tx, namespaces)
|
||||
}
|
||||
|
||||
func (tx *pgSession) PersistLayer(hash string, features []database.LayerFeature, namespaces []database.LayerNamespace, detectedBy []database.Detector) error {
|
||||
return layer.PersistLayer(tx.Tx, hash, features, namespaces, detectedBy)
|
||||
}
|
||||
|
||||
func (tx *pgSession) FindLayer(hash string) (database.Layer, bool, error) {
|
||||
return layer.FindLayer(tx.Tx, hash)
|
||||
}
|
||||
|
||||
func (tx *pgSession) InsertVulnerabilities(vulns []database.VulnerabilityWithAffected) error {
|
||||
return vulnerability.InsertVulnerabilities(tx.Tx, vulns)
|
||||
}
|
||||
|
||||
func (tx *pgSession) FindVulnerabilities(ids []database.VulnerabilityID) ([]database.NullableVulnerability, error) {
|
||||
return vulnerability.FindVulnerabilities(tx.Tx, ids)
|
||||
}
|
||||
|
||||
func (tx *pgSession) DeleteVulnerabilities(ids []database.VulnerabilityID) error {
|
||||
return vulnerability.DeleteVulnerabilities(tx.Tx, ids)
|
||||
}
|
||||
|
||||
func (tx *pgSession) InsertVulnerabilityNotifications(notifications []database.VulnerabilityNotification) error {
|
||||
return notification.InsertVulnerabilityNotifications(tx.Tx, notifications)
|
||||
}
|
||||
|
||||
func (tx *pgSession) FindNewNotification(notifiedBefore time.Time) (hook database.NotificationHook, found bool, err error) {
|
||||
return notification.FindNewNotification(tx.Tx, notifiedBefore)
|
||||
}
|
||||
|
||||
func (tx *pgSession) FindVulnerabilityNotification(name string, limit int, oldVulnerabilityPage pagination.Token, newVulnerabilityPage pagination.Token) (noti database.VulnerabilityNotificationWithVulnerable, found bool, err error) {
|
||||
return notification.FindVulnerabilityNotification(tx.Tx, name, limit, oldVulnerabilityPage, newVulnerabilityPage, tx.key)
|
||||
}
|
||||
|
||||
func (tx *pgSession) MarkNotificationAsRead(name string) error {
|
||||
return notification.MarkNotificationAsRead(tx.Tx, name)
|
||||
}
|
||||
|
||||
func (tx *pgSession) DeleteNotification(name string) error {
|
||||
return notification.DeleteNotification(tx.Tx, name)
|
||||
}
|
||||
|
||||
func (tx *pgSession) UpdateKeyValue(key, value string) error {
|
||||
return keyvalue.UpdateKeyValue(tx.Tx, key, value)
|
||||
}
|
||||
|
||||
func (tx *pgSession) FindKeyValue(key string) (value string, found bool, err error) {
|
||||
return keyvalue.FindKeyValue(tx.Tx, key)
|
||||
}
|
||||
|
||||
func (tx *pgSession) AcquireLock(name, owner string, duration time.Duration) (acquired bool, expiration time.Time, err error) {
|
||||
return lock.AcquireLock(tx.Tx, name, owner, duration)
|
||||
}
|
||||
|
||||
func (tx *pgSession) ExtendLock(name, owner string, duration time.Duration) (extended bool, expiration time.Time, err error) {
|
||||
return lock.ExtendLock(tx.Tx, name, owner, duration)
|
||||
}
|
||||
|
||||
func (tx *pgSession) ReleaseLock(name, owner string) error {
|
||||
return lock.ReleaseLock(tx.Tx, name, owner)
|
||||
}
|
56
database/pgsql/pgsession_test.go
Normal file
56
database/pgsql/pgsession_test.go
Normal file
@ -0,0 +1,56 @@
|
||||
// Copyright 2019 clair authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package pgsql
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/coreos/clair/database/pgsql/namespace"
|
||||
"github.com/coreos/clair/database/pgsql/testutil"
|
||||
)
|
||||
|
||||
const (
|
||||
numVulnerabilities = 100
|
||||
numFeatures = 100
|
||||
)
|
||||
|
||||
func TestConcurrency(t *testing.T) {
|
||||
db, cleanup := testutil.CreateTestDB(t, "concurrency")
|
||||
defer cleanup()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
// there's a limit on the number of concurrent connections in the pool
|
||||
wg.Add(30)
|
||||
for i := 0; i < 30; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
nsNamespaces := testutil.GenRandomNamespaces(t, 100)
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
assert.Nil(t, namespace.PersistNamespaces(tx, nsNamespaces))
|
||||
if err := tx.Commit(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
@ -21,13 +21,10 @@ import (
|
||||
"io/ioutil"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gopkg.in/yaml.v2"
|
||||
|
||||
"github.com/hashicorp/golang-lru"
|
||||
"github.com/lib/pq"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/remind101/migrate"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
@ -37,50 +34,10 @@ import (
|
||||
"github.com/coreos/clair/pkg/pagination"
|
||||
)
|
||||
|
||||
var (
|
||||
promErrorsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "clair_pgsql_errors_total",
|
||||
Help: "Number of errors that PostgreSQL requests generated.",
|
||||
}, []string{"request"})
|
||||
|
||||
promCacheHitsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "clair_pgsql_cache_hits_total",
|
||||
Help: "Number of cache hits that the PostgreSQL backend did.",
|
||||
}, []string{"object"})
|
||||
|
||||
promCacheQueriesTotal = prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "clair_pgsql_cache_queries_total",
|
||||
Help: "Number of cache queries that the PostgreSQL backend did.",
|
||||
}, []string{"object"})
|
||||
|
||||
promQueryDurationMilliseconds = prometheus.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Name: "clair_pgsql_query_duration_milliseconds",
|
||||
Help: "Time it takes to execute the database query.",
|
||||
}, []string{"query", "subquery"})
|
||||
|
||||
promConcurrentLockVAFV = prometheus.NewGauge(prometheus.GaugeOpts{
|
||||
Name: "clair_pgsql_concurrent_lock_vafv_total",
|
||||
Help: "Number of transactions trying to hold the exclusive Vulnerability_Affects_Feature lock.",
|
||||
})
|
||||
)
|
||||
|
||||
func init() {
|
||||
prometheus.MustRegister(promErrorsTotal)
|
||||
prometheus.MustRegister(promCacheHitsTotal)
|
||||
prometheus.MustRegister(promCacheQueriesTotal)
|
||||
prometheus.MustRegister(promQueryDurationMilliseconds)
|
||||
prometheus.MustRegister(promConcurrentLockVAFV)
|
||||
|
||||
database.Register("pgsql", openDatabase)
|
||||
}
|
||||
|
||||
// pgSessionCache is the session's cache, which holds the pgSQL's cache and the
|
||||
// individual session's cache. Only when session.Commit is called, all the
|
||||
// changes to pgSQL cache will be applied.
|
||||
type pgSessionCache struct {
|
||||
c *lru.ARCCache
|
||||
}
|
||||
|
||||
type pgSQL struct {
|
||||
*sql.DB
|
||||
|
||||
@ -88,12 +45,6 @@ type pgSQL struct {
|
||||
config Config
|
||||
}
|
||||
|
||||
type pgSession struct {
|
||||
*sql.Tx
|
||||
|
||||
key pagination.Key
|
||||
}
|
||||
|
||||
// Begin initiates a transaction to database.
|
||||
//
|
||||
// The expected transaction isolation level in this implementation is "Read
|
||||
@ -109,10 +60,6 @@ func (pgSQL *pgSQL) Begin() (database.Session, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) Commit() error {
|
||||
return tx.Tx.Commit()
|
||||
}
|
||||
|
||||
// Close closes the database and destroys if ManageDatabaseLifecycle has been specified in
|
||||
// the configuration.
|
||||
func (pgSQL *pgSQL) Close() {
|
||||
@ -131,15 +78,6 @@ func (pgSQL *pgSQL) Ping() bool {
|
||||
return pgSQL.DB.Ping() == nil
|
||||
}
|
||||
|
||||
// Page is the representation of a page for the Postgres schema.
|
||||
type Page struct {
|
||||
// StartID is the ID being used as the basis for pagination across database
|
||||
// results. It is used to search for an ancestry with ID >= StartID.
|
||||
//
|
||||
// StartID is required to be unique to every ancestry and always increasing.
|
||||
StartID int64
|
||||
}
|
||||
|
||||
// Config is the configuration that is used by openDatabase.
|
||||
type Config struct {
|
||||
Source string
|
||||
@ -313,42 +251,3 @@ func dropDatabase(source, dbName string) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleError logs an error with an extra description and masks the error if it's an SQL one.
|
||||
// The function ensures we never return plain SQL errors and leak anything.
|
||||
// The function should be used for every database query error.
|
||||
func handleError(desc string, err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
return commonerr.ErrNotFound
|
||||
}
|
||||
|
||||
log.WithError(err).WithField("Description", desc).Error("database: handled database error")
|
||||
promErrorsTotal.WithLabelValues(desc).Inc()
|
||||
|
||||
if _, o := err.(*pq.Error); o || err == sql.ErrTxDone || strings.HasPrefix(err.Error(), "sql:") {
|
||||
return database.ErrBackendException
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// isErrUniqueViolation determines is the given error is a unique contraint violation.
|
||||
func isErrUniqueViolation(err error) bool {
|
||||
pqErr, ok := err.(*pq.Error)
|
||||
return ok && pqErr.Code == "23505"
|
||||
}
|
||||
|
||||
// observeQueryTime computes the time elapsed since `start` to represent the
|
||||
// query time.
|
||||
// 1. `query` is a pgSession function name.
|
||||
// 2. `subquery` is a specific query or a batched query.
|
||||
// 3. `start` is the time right before query is executed.
|
||||
func observeQueryTime(query, subquery string, start time.Time) {
|
||||
promQueryDurationMilliseconds.
|
||||
WithLabelValues(query, subquery).
|
||||
Observe(float64(time.Since(start).Nanoseconds()) / float64(time.Millisecond))
|
||||
}
|
||||
|
@ -1,272 +0,0 @@
|
||||
// Copyright 2016 clair authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package pgsql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pborman/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
yaml "gopkg.in/yaml.v2"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
"github.com/coreos/clair/pkg/pagination"
|
||||
)
|
||||
|
||||
var (
|
||||
withFixtureName, withoutFixtureName string
|
||||
)
|
||||
|
||||
var testPaginationKey = pagination.Must(pagination.NewKey())
|
||||
|
||||
func genTemplateDatabase(name string, loadFixture bool) (sourceURL string, dbName string) {
|
||||
config := generateTestConfig(name, loadFixture, false)
|
||||
source := config.Options["source"].(string)
|
||||
name, url, err := parseConnectionString(source)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
fixturePath := config.Options["fixturepath"].(string)
|
||||
|
||||
if err := createDatabase(url, name); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// migration and fixture
|
||||
db, err := sql.Open("postgres", source)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Verify database state.
|
||||
if err := db.Ping(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Run migrations.
|
||||
if err := migrateDatabase(db); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if loadFixture {
|
||||
d, err := ioutil.ReadFile(fixturePath)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
_, err = db.Exec(string(d))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
db.Exec("UPDATE pg_database SET datistemplate=True WHERE datname=$1", name)
|
||||
db.Close()
|
||||
|
||||
log.Info("Generated Template database ", name)
|
||||
return url, name
|
||||
}
|
||||
|
||||
func dropTemplateDatabase(url string, name string) {
|
||||
db, err := sql.Open("postgres", url)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if _, err := db.Exec("UPDATE pg_database SET datistemplate=False WHERE datname=$1", name); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if err := db.Close(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if err := dropDatabase(url, name); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
fURL, fName := genTemplateDatabase("fixture", true)
|
||||
nfURL, nfName := genTemplateDatabase("nonfixture", false)
|
||||
|
||||
withFixtureName = fName
|
||||
withoutFixtureName = nfName
|
||||
|
||||
rc := m.Run()
|
||||
|
||||
dropTemplateDatabase(fURL, fName)
|
||||
dropTemplateDatabase(nfURL, nfName)
|
||||
os.Exit(rc)
|
||||
}
|
||||
|
||||
func openCopiedDatabase(testConfig database.RegistrableComponentConfig, fixture bool) (database.Datastore, error) {
|
||||
var fixtureName string
|
||||
if fixture {
|
||||
fixtureName = withFixtureName
|
||||
} else {
|
||||
fixtureName = withoutFixtureName
|
||||
}
|
||||
|
||||
// copy the database into new database
|
||||
var pg pgSQL
|
||||
// Parse configuration.
|
||||
pg.config = Config{
|
||||
CacheSize: 16384,
|
||||
}
|
||||
|
||||
bytes, err := yaml.Marshal(testConfig.Options)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("pgsql: could not load configuration: %v", err)
|
||||
}
|
||||
err = yaml.Unmarshal(bytes, &pg.config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("pgsql: could not load configuration: %v", err)
|
||||
}
|
||||
|
||||
dbName, pgSourceURL, err := parseConnectionString(pg.config.Source)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create database.
|
||||
if pg.config.ManageDatabaseLifecycle {
|
||||
if err = copyDatabase(pgSourceURL, dbName, fixtureName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Open database.
|
||||
pg.DB, err = sql.Open("postgres", pg.config.Source)
|
||||
fmt.Println("database", pg.config.Source)
|
||||
if err != nil {
|
||||
pg.Close()
|
||||
return nil, fmt.Errorf("pgsql: could not open database: %v", err)
|
||||
}
|
||||
|
||||
return &pg, nil
|
||||
}
|
||||
|
||||
// copyDatabase creates a new database with
|
||||
func copyDatabase(url, name string, templateName string) error {
|
||||
// Open database.
|
||||
db, err := sql.Open("postgres", url)
|
||||
if err != nil {
|
||||
return fmt.Errorf("pgsql: could not open 'postgres' database for creation: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Create database with copy
|
||||
_, err = db.Exec("CREATE DATABASE " + name + " WITH TEMPLATE " + templateName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("pgsql: could not create database: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func openDatabaseForTest(testName string, loadFixture bool) (*pgSQL, error) {
|
||||
var (
|
||||
db database.Datastore
|
||||
err error
|
||||
testConfig = generateTestConfig(testName, loadFixture, true)
|
||||
)
|
||||
|
||||
db, err = openCopiedDatabase(testConfig, loadFixture)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
datastore := db.(*pgSQL)
|
||||
return datastore, nil
|
||||
}
|
||||
|
||||
func generateTestConfig(testName string, loadFixture bool, manageLife bool) database.RegistrableComponentConfig {
|
||||
dbName := "test_" + strings.ToLower(testName) + "_" + strings.Replace(uuid.New(), "-", "_", -1)
|
||||
|
||||
var fixturePath string
|
||||
if loadFixture {
|
||||
_, filename, _, _ := runtime.Caller(0)
|
||||
fixturePath = filepath.Join(filepath.Dir(filename)) + "/testdata/data.sql"
|
||||
}
|
||||
|
||||
source := fmt.Sprintf("postgresql://postgres@127.0.0.1:5432/%s?sslmode=disable", dbName)
|
||||
if sourceEnv := os.Getenv("CLAIR_TEST_PGSQL"); sourceEnv != "" {
|
||||
source = fmt.Sprintf(sourceEnv, dbName)
|
||||
}
|
||||
|
||||
return database.RegistrableComponentConfig{
|
||||
Options: map[string]interface{}{
|
||||
"source": source,
|
||||
"cachesize": 0,
|
||||
"managedatabaselifecycle": manageLife,
|
||||
"fixturepath": fixturePath,
|
||||
"paginationkey": testPaginationKey.String(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func closeTest(t *testing.T, store database.Datastore, session database.Session) {
|
||||
err := session.Rollback()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
store.Close()
|
||||
}
|
||||
|
||||
func openSessionForTest(t *testing.T, name string, loadFixture bool) (*pgSQL, *pgSession) {
|
||||
store, err := openDatabaseForTest(name, loadFixture)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
t.FailNow()
|
||||
}
|
||||
tx, err := store.Begin()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
return store, tx.(*pgSession)
|
||||
}
|
||||
|
||||
func restartSession(t *testing.T, datastore *pgSQL, tx *pgSession, commit bool) *pgSession {
|
||||
var err error
|
||||
if !commit {
|
||||
err = tx.Rollback()
|
||||
} else {
|
||||
err = tx.Commit()
|
||||
}
|
||||
|
||||
if assert.Nil(t, err) {
|
||||
session, err := datastore.Begin()
|
||||
if assert.Nil(t, err) {
|
||||
return session.(*pgSession)
|
||||
}
|
||||
}
|
||||
t.FailNow()
|
||||
return nil
|
||||
}
|
@ -1,187 +0,0 @@
|
||||
// Copyright 2015 clair authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package pgsql
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
// NOTE(Sida): Every search query can only have count less than postgres set
|
||||
// stack depth. IN will be resolved to nested OR_s and the parser might exceed
|
||||
// stack depth. TODO(Sida): Generate different queries for different count: if
|
||||
// count < 5120, use IN; for count > 5120 and < 65536, use temporary table; for
|
||||
// count > 65535, use is expected to split data into batches.
|
||||
func querySearchLastDeletedVulnerabilityID(count int) string {
|
||||
return fmt.Sprintf(`
|
||||
SELECT vid, vname, nname FROM (
|
||||
SELECT v.id AS vid, v.name AS vname, n.name AS nname,
|
||||
row_number() OVER (
|
||||
PARTITION by (v.name, n.name)
|
||||
ORDER BY v.deleted_at DESC
|
||||
) AS rownum
|
||||
FROM vulnerability AS v, namespace AS n
|
||||
WHERE v.namespace_id = n.id
|
||||
AND (v.name, n.name) IN ( %s )
|
||||
AND v.deleted_at IS NOT NULL
|
||||
) tmp WHERE rownum <= 1`,
|
||||
queryString(2, count))
|
||||
}
|
||||
|
||||
func querySearchNotDeletedVulnerabilityID(count int) string {
|
||||
return fmt.Sprintf(`
|
||||
SELECT v.id, v.name, n.name FROM vulnerability AS v, namespace AS n
|
||||
WHERE v.namespace_id = n.id AND (v.name, n.name) IN (%s)
|
||||
AND v.deleted_at IS NULL`,
|
||||
queryString(2, count))
|
||||
}
|
||||
|
||||
func querySearchFeatureID(featureCount int) string {
|
||||
return fmt.Sprintf(`
|
||||
SELECT id, name, version, version_format, type
|
||||
FROM Feature WHERE (name, version, version_format, type) IN (%s)`,
|
||||
queryString(4, featureCount),
|
||||
)
|
||||
}
|
||||
|
||||
func querySearchNamespacedFeature(nsfCount int) string {
|
||||
return fmt.Sprintf(`
|
||||
SELECT nf.id, f.name, f.version, f.version_format, t.name, n.name
|
||||
FROM namespaced_feature AS nf, feature AS f, namespace AS n, feature_type AS t
|
||||
WHERE nf.feature_id = f.id
|
||||
AND nf.namespace_id = n.id
|
||||
AND n.version_format = f.version_format
|
||||
AND f.type = t.id
|
||||
AND (f.name, f.version, f.version_format, t.name, n.name) IN (%s)`,
|
||||
queryString(5, nsfCount),
|
||||
)
|
||||
}
|
||||
|
||||
func querySearchNamespace(nsCount int) string {
|
||||
return fmt.Sprintf(
|
||||
`SELECT id, name, version_format
|
||||
FROM namespace WHERE (name, version_format) IN (%s)`,
|
||||
queryString(2, nsCount),
|
||||
)
|
||||
}
|
||||
|
||||
func queryInsert(count int, table string, columns ...string) string {
|
||||
base := `INSERT INTO %s (%s) VALUES %s`
|
||||
t := pq.QuoteIdentifier(table)
|
||||
cols := make([]string, len(columns))
|
||||
for i, c := range columns {
|
||||
cols[i] = pq.QuoteIdentifier(c)
|
||||
}
|
||||
colsQuoted := strings.Join(cols, ",")
|
||||
return fmt.Sprintf(base, t, colsQuoted, queryString(len(columns), count))
|
||||
}
|
||||
|
||||
func queryPersist(count int, table, constraint string, columns ...string) string {
|
||||
ct := ""
|
||||
if constraint != "" {
|
||||
ct = fmt.Sprintf("ON CONSTRAINT %s", constraint)
|
||||
}
|
||||
return fmt.Sprintf("%s ON CONFLICT %s DO NOTHING", queryInsert(count, table, columns...), ct)
|
||||
}
|
||||
|
||||
func queryInsertNotifications(count int) string {
|
||||
return queryInsert(count,
|
||||
"vulnerability_notification",
|
||||
"name",
|
||||
"created_at",
|
||||
"old_vulnerability_id",
|
||||
"new_vulnerability_id",
|
||||
)
|
||||
}
|
||||
|
||||
func queryPersistFeature(count int) string {
|
||||
return queryPersist(count,
|
||||
"feature",
|
||||
"feature_name_version_version_format_type_key",
|
||||
"name",
|
||||
"version",
|
||||
"version_format",
|
||||
"type")
|
||||
}
|
||||
|
||||
func queryPersistLayerFeature(count int) string {
|
||||
return queryPersist(count,
|
||||
"layer_feature",
|
||||
"layer_feature_layer_id_feature_id_namespace_id_key",
|
||||
"layer_id",
|
||||
"feature_id",
|
||||
"detector_id",
|
||||
"namespace_id")
|
||||
}
|
||||
|
||||
func queryPersistNamespace(count int) string {
|
||||
return queryPersist(count,
|
||||
"namespace",
|
||||
"namespace_name_version_format_key",
|
||||
"name",
|
||||
"version_format")
|
||||
}
|
||||
|
||||
func queryPersistLayerNamespace(count int) string {
|
||||
return queryPersist(count,
|
||||
"layer_namespace",
|
||||
"layer_namespace_layer_id_namespace_id_key",
|
||||
"layer_id",
|
||||
"namespace_id",
|
||||
"detector_id")
|
||||
}
|
||||
|
||||
// size of key and array should be both greater than 0
|
||||
func queryString(keySize, arraySize int) string {
|
||||
if arraySize <= 0 || keySize <= 0 {
|
||||
panic("Bulk Query requires size of element tuple and number of elements to be greater than 0")
|
||||
}
|
||||
keys := make([]string, 0, arraySize)
|
||||
for i := 0; i < arraySize; i++ {
|
||||
key := make([]string, keySize)
|
||||
for j := 0; j < keySize; j++ {
|
||||
key[j] = fmt.Sprintf("$%d", i*keySize+j+1)
|
||||
}
|
||||
keys = append(keys, fmt.Sprintf("(%s)", strings.Join(key, ",")))
|
||||
}
|
||||
return strings.Join(keys, ",")
|
||||
}
|
||||
|
||||
func queryPersistNamespacedFeature(count int) string {
|
||||
return queryPersist(count, "namespaced_feature",
|
||||
"namespaced_feature_namespace_id_feature_id_key",
|
||||
"feature_id",
|
||||
"namespace_id")
|
||||
}
|
||||
|
||||
func queryPersistVulnerabilityAffectedNamespacedFeature(count int) string {
|
||||
return queryPersist(count, "vulnerability_affected_namespaced_feature",
|
||||
"vulnerability_affected_namesp_vulnerability_id_namespaced_f_key",
|
||||
"vulnerability_id",
|
||||
"namespaced_feature_id",
|
||||
"added_by")
|
||||
}
|
||||
|
||||
func queryPersistLayer(count int) string {
|
||||
return queryPersist(count, "layer", "", "hash")
|
||||
}
|
||||
|
||||
func queryInvalidateVulnerabilityCache(count int) string {
|
||||
return fmt.Sprintf(`DELETE FROM vulnerability_affected_feature
|
||||
WHERE vulnerability_id IN (%s)`,
|
||||
queryString(1, count))
|
||||
}
|
@ -1,424 +0,0 @@
|
||||
// Copyright 2018 clair authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package pgsql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"math/rand"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/remind101/migrate"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
"github.com/coreos/clair/database/pgsql/migrations"
|
||||
"github.com/coreos/clair/pkg/pagination"
|
||||
)
|
||||
|
||||
// int keys must be the consistent with the database ID.
|
||||
var (
|
||||
realFeatures = map[int]database.Feature{
|
||||
1: {"ourchat", "0.5", "dpkg", "source"},
|
||||
2: {"openssl", "1.0", "dpkg", "source"},
|
||||
3: {"openssl", "2.0", "dpkg", "source"},
|
||||
4: {"fake", "2.0", "rpm", "source"},
|
||||
5: {"mount", "2.31.1-0.4ubuntu3.1", "dpkg", "binary"},
|
||||
}
|
||||
|
||||
realNamespaces = map[int]database.Namespace{
|
||||
1: {"debian:7", "dpkg"},
|
||||
2: {"debian:8", "dpkg"},
|
||||
3: {"fake:1.0", "rpm"},
|
||||
4: {"cpe:/o:redhat:enterprise_linux:7::server", "rpm"},
|
||||
}
|
||||
|
||||
realNamespacedFeatures = map[int]database.NamespacedFeature{
|
||||
1: {realFeatures[1], realNamespaces[1]},
|
||||
2: {realFeatures[2], realNamespaces[1]},
|
||||
3: {realFeatures[2], realNamespaces[2]},
|
||||
4: {realFeatures[3], realNamespaces[1]},
|
||||
}
|
||||
|
||||
realDetectors = map[int]database.Detector{
|
||||
1: database.NewNamespaceDetector("os-release", "1.0"),
|
||||
2: database.NewFeatureDetector("dpkg", "1.0"),
|
||||
3: database.NewFeatureDetector("rpm", "1.0"),
|
||||
4: database.NewNamespaceDetector("apt-sources", "1.0"),
|
||||
}
|
||||
|
||||
realLayers = map[int]database.Layer{
|
||||
2: {
|
||||
Hash: "layer-1",
|
||||
By: []database.Detector{realDetectors[1], realDetectors[2]},
|
||||
Features: []database.LayerFeature{
|
||||
{realFeatures[1], realDetectors[2], database.Namespace{}},
|
||||
{realFeatures[2], realDetectors[2], database.Namespace{}},
|
||||
},
|
||||
Namespaces: []database.LayerNamespace{
|
||||
{realNamespaces[1], realDetectors[1]},
|
||||
},
|
||||
},
|
||||
6: {
|
||||
Hash: "layer-4",
|
||||
By: []database.Detector{realDetectors[1], realDetectors[2], realDetectors[3], realDetectors[4]},
|
||||
Features: []database.LayerFeature{
|
||||
{realFeatures[4], realDetectors[3], database.Namespace{}},
|
||||
{realFeatures[3], realDetectors[2], database.Namespace{}},
|
||||
},
|
||||
Namespaces: []database.LayerNamespace{
|
||||
{realNamespaces[1], realDetectors[1]},
|
||||
{realNamespaces[3], realDetectors[4]},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
realAncestries = map[int]database.Ancestry{
|
||||
2: {
|
||||
Name: "ancestry-2",
|
||||
By: []database.Detector{realDetectors[2], realDetectors[1]},
|
||||
Layers: []database.AncestryLayer{
|
||||
{
|
||||
"layer-0",
|
||||
[]database.AncestryFeature{},
|
||||
},
|
||||
{
|
||||
"layer-1",
|
||||
[]database.AncestryFeature{},
|
||||
},
|
||||
{
|
||||
"layer-2",
|
||||
[]database.AncestryFeature{
|
||||
{
|
||||
realNamespacedFeatures[1],
|
||||
realDetectors[2],
|
||||
realDetectors[1],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"layer-3b",
|
||||
[]database.AncestryFeature{
|
||||
{
|
||||
realNamespacedFeatures[3],
|
||||
realDetectors[2],
|
||||
realDetectors[1],
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
realVulnerability = map[int]database.Vulnerability{
|
||||
1: {
|
||||
Name: "CVE-OPENSSL-1-DEB7",
|
||||
Namespace: realNamespaces[1],
|
||||
Description: "A vulnerability affecting OpenSSL < 2.0 on Debian 7.0",
|
||||
Link: "http://google.com/#q=CVE-OPENSSL-1-DEB7",
|
||||
Severity: database.HighSeverity,
|
||||
},
|
||||
2: {
|
||||
Name: "CVE-NOPE",
|
||||
Namespace: realNamespaces[1],
|
||||
Description: "A vulnerability affecting nothing",
|
||||
Severity: database.UnknownSeverity,
|
||||
},
|
||||
}
|
||||
|
||||
realNotification = map[int]database.VulnerabilityNotification{
|
||||
1: {
|
||||
NotificationHook: database.NotificationHook{
|
||||
Name: "test",
|
||||
},
|
||||
Old: takeVulnerabilityPointerFromMap(realVulnerability, 2),
|
||||
New: takeVulnerabilityPointerFromMap(realVulnerability, 1),
|
||||
},
|
||||
}
|
||||
|
||||
fakeFeatures = map[int]database.Feature{
|
||||
1: {
|
||||
Name: "ourchat",
|
||||
Version: "0.6",
|
||||
VersionFormat: "dpkg",
|
||||
Type: "source",
|
||||
},
|
||||
}
|
||||
|
||||
fakeNamespaces = map[int]database.Namespace{
|
||||
1: {"green hat", "rpm"},
|
||||
}
|
||||
fakeNamespacedFeatures = map[int]database.NamespacedFeature{
|
||||
1: {
|
||||
Feature: fakeFeatures[0],
|
||||
Namespace: realNamespaces[0],
|
||||
},
|
||||
}
|
||||
|
||||
fakeDetector = map[int]database.Detector{
|
||||
1: {
|
||||
Name: "fake",
|
||||
Version: "1.0",
|
||||
DType: database.FeatureDetectorType,
|
||||
},
|
||||
2: {
|
||||
Name: "fake2",
|
||||
Version: "2.0",
|
||||
DType: database.NamespaceDetectorType,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
func takeVulnerabilityPointerFromMap(m map[int]database.Vulnerability, id int) *database.Vulnerability {
|
||||
x := m[id]
|
||||
return &x
|
||||
}
|
||||
|
||||
func takeAncestryPointerFromMap(m map[int]database.Ancestry, id int) *database.Ancestry {
|
||||
x := m[id]
|
||||
return &x
|
||||
}
|
||||
|
||||
func takeLayerPointerFromMap(m map[int]database.Layer, id int) *database.Layer {
|
||||
x := m[id]
|
||||
return &x
|
||||
}
|
||||
|
||||
func listNamespaces(t *testing.T, tx *pgSession) []database.Namespace {
|
||||
rows, err := tx.Query("SELECT name, version_format FROM namespace")
|
||||
if err != nil {
|
||||
t.FailNow()
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
namespaces := []database.Namespace{}
|
||||
for rows.Next() {
|
||||
var ns database.Namespace
|
||||
err := rows.Scan(&ns.Name, &ns.VersionFormat)
|
||||
if err != nil {
|
||||
t.FailNow()
|
||||
}
|
||||
namespaces = append(namespaces, ns)
|
||||
}
|
||||
|
||||
return namespaces
|
||||
}
|
||||
|
||||
func assertVulnerabilityNotificationWithVulnerableEqual(t *testing.T, key pagination.Key, expected, actual *database.VulnerabilityNotificationWithVulnerable) bool {
|
||||
if expected == actual {
|
||||
return true
|
||||
}
|
||||
|
||||
if expected == nil || actual == nil {
|
||||
return assert.Equal(t, expected, actual)
|
||||
}
|
||||
|
||||
return assert.Equal(t, expected.NotificationHook, actual.NotificationHook) &&
|
||||
AssertPagedVulnerableAncestriesEqual(t, key, expected.Old, actual.Old) &&
|
||||
AssertPagedVulnerableAncestriesEqual(t, key, expected.New, actual.New)
|
||||
}
|
||||
|
||||
func AssertPagedVulnerableAncestriesEqual(t *testing.T, key pagination.Key, expected, actual *database.PagedVulnerableAncestries) bool {
|
||||
if expected == actual {
|
||||
return true
|
||||
}
|
||||
|
||||
if expected == nil || actual == nil {
|
||||
return assert.Equal(t, expected, actual)
|
||||
}
|
||||
|
||||
return database.AssertVulnerabilityEqual(t, &expected.Vulnerability, &actual.Vulnerability) &&
|
||||
assert.Equal(t, expected.Limit, actual.Limit) &&
|
||||
assert.Equal(t, mustUnmarshalToken(key, expected.Current), mustUnmarshalToken(key, actual.Current)) &&
|
||||
assert.Equal(t, mustUnmarshalToken(key, expected.Next), mustUnmarshalToken(key, actual.Next)) &&
|
||||
assert.Equal(t, expected.End, actual.End) &&
|
||||
database.AssertIntStringMapEqual(t, expected.Affected, actual.Affected)
|
||||
}
|
||||
|
||||
func mustUnmarshalToken(key pagination.Key, token pagination.Token) Page {
|
||||
if token == pagination.FirstPageToken {
|
||||
return Page{}
|
||||
}
|
||||
|
||||
p := Page{}
|
||||
if err := key.UnmarshalToken(token, &p); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
func mustMarshalToken(key pagination.Key, v interface{}) pagination.Token {
|
||||
token, err := key.MarshalToken(v)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return token
|
||||
}
|
||||
|
||||
var userDBCount = `SELECT count(datname) FROM pg_database WHERE datistemplate = FALSE AND datname != 'postgres';`
|
||||
|
||||
func createAndConnectTestDB(t *testing.T, testName string) (*sql.DB, func()) {
|
||||
uri := "postgres@127.0.0.1:5432"
|
||||
connectionTemplate := "postgresql://%s?sslmode=disable"
|
||||
if envURI := os.Getenv("CLAIR_TEST_PGSQL"); envURI != "" {
|
||||
uri = envURI
|
||||
}
|
||||
|
||||
db, err := sql.Open("postgres", fmt.Sprintf(connectionTemplate, uri))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
testName = strings.ToLower(testName)
|
||||
dbName := fmt.Sprintf("test_%s_%s", testName, time.Now().UTC().Format("2006_01_02_15_04_05"))
|
||||
t.Logf("creating temporary database name = %s", dbName)
|
||||
_, err = db.Exec("CREATE DATABASE " + dbName)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
testDB, err := sql.Open("postgres", fmt.Sprintf(connectionTemplate, uri+"/"+dbName))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return testDB, func() {
|
||||
t.Logf("cleaning up temporary database %s", dbName)
|
||||
defer db.Close()
|
||||
if err := testDB.Close(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if _, err := db.Exec(`DROP DATABASE ` + dbName); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// ensure the database is cleaned up
|
||||
var count int
|
||||
if err := db.QueryRow(userDBCount).Scan(&count); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func createTestPgSQL(t *testing.T, testName string) (*pgSQL, func()) {
|
||||
connection, cleanup := createAndConnectTestDB(t, testName)
|
||||
err := migrate.NewPostgresMigrator(connection).Exec(migrate.Up, migrations.Migrations...)
|
||||
if err != nil {
|
||||
require.Nil(t, err, err.Error())
|
||||
}
|
||||
|
||||
return &pgSQL{connection, nil, Config{PaginationKey: pagination.Must(pagination.NewKey()).String()}}, cleanup
|
||||
}
|
||||
|
||||
func createTestPgSQLWithFixtures(t *testing.T, testName string) (*pgSQL, func()) {
|
||||
connection, cleanup := createTestPgSQL(t, testName)
|
||||
session, err := connection.Begin()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
defer session.Rollback()
|
||||
|
||||
loadFixtures(session.(*pgSession))
|
||||
if err = session.Commit(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return connection, cleanup
|
||||
}
|
||||
|
||||
func createTestPgSession(t *testing.T, testName string) (*pgSession, func()) {
|
||||
connection, cleanup := createTestPgSQL(t, testName)
|
||||
session, err := connection.Begin()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return session.(*pgSession), func() {
|
||||
session.Rollback()
|
||||
cleanup()
|
||||
}
|
||||
}
|
||||
|
||||
func createTestPgSessionWithFixtures(t *testing.T, testName string) (*pgSession, func()) {
|
||||
tx, cleanup := createTestPgSession(t, testName)
|
||||
defer func() {
|
||||
// ensure to cleanup when loadFixtures failed
|
||||
if r := recover(); r != nil {
|
||||
cleanup()
|
||||
}
|
||||
}()
|
||||
|
||||
loadFixtures(tx)
|
||||
return tx, cleanup
|
||||
}
|
||||
|
||||
func loadFixtures(tx *pgSession) {
|
||||
_, filename, _, _ := runtime.Caller(0)
|
||||
fixturePath := filepath.Join(filepath.Dir(filename)) + "/testdata/data.sql"
|
||||
d, err := ioutil.ReadFile(fixturePath)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
_, err = tx.Exec(string(d))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func assertVulnerabilityWithAffectedEqual(t *testing.T, expected database.VulnerabilityWithAffected, actual database.VulnerabilityWithAffected) bool {
|
||||
return assert.Equal(t, expected.Vulnerability, actual.Vulnerability) && assertAffectedFeaturesEqual(t, expected.Affected, actual.Affected)
|
||||
}
|
||||
|
||||
func assertAffectedFeaturesEqual(t *testing.T, expected []database.AffectedFeature, actual []database.AffectedFeature) bool {
|
||||
if assert.Len(t, actual, len(expected)) {
|
||||
has := map[database.AffectedFeature]bool{}
|
||||
for _, i := range expected {
|
||||
has[i] = false
|
||||
}
|
||||
for _, i := range actual {
|
||||
if visited, ok := has[i]; !ok {
|
||||
return false
|
||||
} else if visited {
|
||||
return false
|
||||
}
|
||||
has[i] = true
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func genRandomNamespaces(t *testing.T, count int) []database.Namespace {
|
||||
r := make([]database.Namespace, count)
|
||||
for i := 0; i < count; i++ {
|
||||
r[i] = database.Namespace{
|
||||
Name: fmt.Sprint(rand.Int()),
|
||||
VersionFormat: "dpkg",
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
98
database/pgsql/testutil/assertion.go
Normal file
98
database/pgsql/testutil/assertion.go
Normal file
@ -0,0 +1,98 @@
|
||||
// Copyright 2019 clair authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package testutil
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
"github.com/coreos/clair/pkg/pagination"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func AssertVulnerabilityNotificationWithVulnerableEqual(t *testing.T, key pagination.Key, expected, actual *database.VulnerabilityNotificationWithVulnerable) bool {
|
||||
if expected == actual {
|
||||
return true
|
||||
}
|
||||
|
||||
if expected == nil || actual == nil {
|
||||
return assert.Equal(t, expected, actual)
|
||||
}
|
||||
|
||||
return assert.Equal(t, expected.NotificationHook, actual.NotificationHook) &&
|
||||
AssertPagedVulnerableAncestriesEqual(t, key, expected.Old, actual.Old) &&
|
||||
AssertPagedVulnerableAncestriesEqual(t, key, expected.New, actual.New)
|
||||
}
|
||||
|
||||
func AssertPagedVulnerableAncestriesEqual(t *testing.T, key pagination.Key, expected, actual *database.PagedVulnerableAncestries) bool {
|
||||
if expected == actual {
|
||||
return true
|
||||
}
|
||||
|
||||
if expected == nil || actual == nil {
|
||||
return assert.Equal(t, expected, actual)
|
||||
}
|
||||
|
||||
return database.AssertVulnerabilityEqual(t, &expected.Vulnerability, &actual.Vulnerability) &&
|
||||
assert.Equal(t, expected.Limit, actual.Limit) &&
|
||||
assert.Equal(t, mustUnmarshalToken(key, expected.Current), mustUnmarshalToken(key, actual.Current)) &&
|
||||
assert.Equal(t, mustUnmarshalToken(key, expected.Next), mustUnmarshalToken(key, actual.Next)) &&
|
||||
assert.Equal(t, expected.End, actual.End) &&
|
||||
database.AssertIntStringMapEqual(t, expected.Affected, actual.Affected)
|
||||
}
|
||||
|
||||
func AssertVulnerabilityWithAffectedEqual(t *testing.T, expected database.VulnerabilityWithAffected, actual database.VulnerabilityWithAffected) bool {
|
||||
return assert.Equal(t, expected.Vulnerability, actual.Vulnerability) && AssertAffectedFeaturesEqual(t, expected.Affected, actual.Affected)
|
||||
}
|
||||
|
||||
func AssertAffectedFeaturesEqual(t *testing.T, expected []database.AffectedFeature, actual []database.AffectedFeature) bool {
|
||||
if assert.Len(t, actual, len(expected)) {
|
||||
has := map[database.AffectedFeature]bool{}
|
||||
for _, i := range expected {
|
||||
has[i] = false
|
||||
}
|
||||
for _, i := range actual {
|
||||
if visited, ok := has[i]; !ok {
|
||||
return false
|
||||
} else if visited {
|
||||
return false
|
||||
}
|
||||
has[i] = true
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func AssertNamespacedFeatureEqual(t *testing.T, expected []database.NamespacedFeature, actual []database.NamespacedFeature) bool {
|
||||
if assert.Len(t, actual, len(expected)) {
|
||||
has := map[database.NamespacedFeature]bool{}
|
||||
for _, nf := range expected {
|
||||
has[nf] = false
|
||||
}
|
||||
|
||||
for _, nf := range actual {
|
||||
has[nf] = true
|
||||
}
|
||||
|
||||
for nf, visited := range has {
|
||||
if !assert.True(t, visited, nf.Namespace.Name+":"+nf.Name+" is expected") {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
171
database/pgsql/testutil/data.go
Normal file
171
database/pgsql/testutil/data.go
Normal file
@ -0,0 +1,171 @@
|
||||
// Copyright 2019 clair authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package testutil
|
||||
|
||||
import "github.com/coreos/clair/database"
|
||||
|
||||
// int keys must be the consistent with the database ID.
|
||||
var (
|
||||
RealFeatures = map[int]database.Feature{
|
||||
1: {"ourchat", "0.5", "dpkg", "source"},
|
||||
2: {"openssl", "1.0", "dpkg", "source"},
|
||||
3: {"openssl", "2.0", "dpkg", "source"},
|
||||
4: {"fake", "2.0", "rpm", "source"},
|
||||
5: {"mount", "2.31.1-0.4ubuntu3.1", "dpkg", "binary"},
|
||||
}
|
||||
|
||||
RealNamespaces = map[int]database.Namespace{
|
||||
1: {"debian:7", "dpkg"},
|
||||
2: {"debian:8", "dpkg"},
|
||||
3: {"fake:1.0", "rpm"},
|
||||
4: {"cpe:/o:redhat:enterprise_linux:7::server", "rpm"},
|
||||
}
|
||||
|
||||
RealNamespacedFeatures = map[int]database.NamespacedFeature{
|
||||
1: {RealFeatures[1], RealNamespaces[1]},
|
||||
2: {RealFeatures[2], RealNamespaces[1]},
|
||||
3: {RealFeatures[2], RealNamespaces[2]},
|
||||
4: {RealFeatures[3], RealNamespaces[1]},
|
||||
}
|
||||
|
||||
RealDetectors = map[int]database.Detector{
|
||||
1: database.NewNamespaceDetector("os-release", "1.0"),
|
||||
2: database.NewFeatureDetector("dpkg", "1.0"),
|
||||
3: database.NewFeatureDetector("rpm", "1.0"),
|
||||
4: database.NewNamespaceDetector("apt-sources", "1.0"),
|
||||
}
|
||||
|
||||
RealLayers = map[int]database.Layer{
|
||||
2: {
|
||||
Hash: "layer-1",
|
||||
By: []database.Detector{RealDetectors[1], RealDetectors[2]},
|
||||
Features: []database.LayerFeature{
|
||||
{RealFeatures[1], RealDetectors[2], database.Namespace{}},
|
||||
{RealFeatures[2], RealDetectors[2], database.Namespace{}},
|
||||
},
|
||||
Namespaces: []database.LayerNamespace{
|
||||
{RealNamespaces[1], RealDetectors[1]},
|
||||
},
|
||||
},
|
||||
6: {
|
||||
Hash: "layer-4",
|
||||
By: []database.Detector{RealDetectors[1], RealDetectors[2], RealDetectors[3], RealDetectors[4]},
|
||||
Features: []database.LayerFeature{
|
||||
{RealFeatures[4], RealDetectors[3], database.Namespace{}},
|
||||
{RealFeatures[3], RealDetectors[2], database.Namespace{}},
|
||||
},
|
||||
Namespaces: []database.LayerNamespace{
|
||||
{RealNamespaces[1], RealDetectors[1]},
|
||||
{RealNamespaces[3], RealDetectors[4]},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
RealAncestries = map[int]database.Ancestry{
|
||||
2: {
|
||||
Name: "ancestry-2",
|
||||
By: []database.Detector{RealDetectors[2], RealDetectors[1]},
|
||||
Layers: []database.AncestryLayer{
|
||||
{
|
||||
"layer-0",
|
||||
[]database.AncestryFeature{},
|
||||
},
|
||||
{
|
||||
"layer-1",
|
||||
[]database.AncestryFeature{},
|
||||
},
|
||||
{
|
||||
"layer-2",
|
||||
[]database.AncestryFeature{
|
||||
{
|
||||
RealNamespacedFeatures[1],
|
||||
RealDetectors[2],
|
||||
RealDetectors[1],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"layer-3b",
|
||||
[]database.AncestryFeature{
|
||||
{
|
||||
RealNamespacedFeatures[3],
|
||||
RealDetectors[2],
|
||||
RealDetectors[1],
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
RealVulnerability = map[int]database.Vulnerability{
|
||||
1: {
|
||||
Name: "CVE-OPENSSL-1-DEB7",
|
||||
Namespace: RealNamespaces[1],
|
||||
Description: "A vulnerability affecting OpenSSL < 2.0 on Debian 7.0",
|
||||
Link: "http://google.com/#q=CVE-OPENSSL-1-DEB7",
|
||||
Severity: database.HighSeverity,
|
||||
},
|
||||
2: {
|
||||
Name: "CVE-NOPE",
|
||||
Namespace: RealNamespaces[1],
|
||||
Description: "A vulnerability affecting nothing",
|
||||
Severity: database.UnknownSeverity,
|
||||
},
|
||||
}
|
||||
|
||||
RealNotification = map[int]database.VulnerabilityNotification{
|
||||
1: {
|
||||
NotificationHook: database.NotificationHook{
|
||||
Name: "test",
|
||||
},
|
||||
Old: takeVulnerabilityPointerFromMap(RealVulnerability, 2),
|
||||
New: takeVulnerabilityPointerFromMap(RealVulnerability, 1),
|
||||
},
|
||||
}
|
||||
|
||||
FakeFeatures = map[int]database.Feature{
|
||||
1: {
|
||||
Name: "ourchat",
|
||||
Version: "0.6",
|
||||
VersionFormat: "dpkg",
|
||||
Type: "source",
|
||||
},
|
||||
}
|
||||
|
||||
FakeNamespaces = map[int]database.Namespace{
|
||||
1: {"green hat", "rpm"},
|
||||
}
|
||||
|
||||
FakeNamespacedFeatures = map[int]database.NamespacedFeature{
|
||||
1: {
|
||||
Feature: FakeFeatures[0],
|
||||
Namespace: RealNamespaces[0],
|
||||
},
|
||||
}
|
||||
|
||||
FakeDetector = map[int]database.Detector{
|
||||
1: {
|
||||
Name: "fake",
|
||||
Version: "1.0",
|
||||
DType: database.FeatureDetectorType,
|
||||
},
|
||||
2: {
|
||||
Name: "fake2",
|
||||
Version: "2.0",
|
||||
DType: database.NamespaceDetectorType,
|
||||
},
|
||||
}
|
||||
)
|
207
database/pgsql/testutil/testdb.go
Normal file
207
database/pgsql/testutil/testdb.go
Normal file
@ -0,0 +1,207 @@
|
||||
// Copyright 2019 clair authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package testutil
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/clair/database/pgsql/migrations"
|
||||
"github.com/coreos/clair/pkg/pagination"
|
||||
"github.com/remind101/migrate"
|
||||
)
|
||||
|
||||
var TestPaginationKey = pagination.Must(pagination.NewKey())
|
||||
|
||||
var userDBCount = `SELECT count(datname) FROM pg_database WHERE datistemplate = FALSE AND datname != 'postgres';`
|
||||
|
||||
func CreateAndConnectTestDB(t *testing.T, testName string) (*sql.DB, func()) {
|
||||
uri := "postgres@127.0.0.1:5432"
|
||||
connectionTemplate := "postgresql://%s?sslmode=disable"
|
||||
if envURI := os.Getenv("CLAIR_TEST_PGSQL"); envURI != "" {
|
||||
uri = envURI
|
||||
}
|
||||
|
||||
db, err := sql.Open("postgres", fmt.Sprintf(connectionTemplate, uri))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
testName = strings.ToLower(testName)
|
||||
dbName := fmt.Sprintf("test_%s_%s", testName, time.Now().UTC().Format("2006_01_02_15_04_05"))
|
||||
t.Logf("creating temporary database name = %s", dbName)
|
||||
_, err = db.Exec("CREATE DATABASE " + dbName)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
testDB, err := sql.Open("postgres", fmt.Sprintf(connectionTemplate, uri+"/"+dbName))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return testDB, func() {
|
||||
cleanupTestDB(t, dbName, db, testDB)
|
||||
}
|
||||
}
|
||||
|
||||
func cleanupTestDB(t *testing.T, name string, db, testDB *sql.DB) {
|
||||
t.Logf("cleaning up temporary database %s", name)
|
||||
if db == nil {
|
||||
panic("db is none")
|
||||
}
|
||||
|
||||
if testDB == nil {
|
||||
panic("testDB is none")
|
||||
}
|
||||
|
||||
defer db.Close()
|
||||
|
||||
if err := testDB.Close(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Kill any opened connection.
|
||||
if _, err := db.Exec(`
|
||||
SELECT pg_terminate_backend(pg_stat_activity.pid)
|
||||
FROM pg_stat_activity
|
||||
WHERE pg_stat_activity.datname = $1
|
||||
AND pid <> pg_backend_pid()`, name); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if _, err := db.Exec(`DROP DATABASE ` + name); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// ensure the database is cleaned up
|
||||
var count int
|
||||
if err := db.QueryRow(userDBCount).Scan(&count); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func CreateTestDB(t *testing.T, testName string) (*sql.DB, func()) {
|
||||
connection, cleanup := CreateAndConnectTestDB(t, testName)
|
||||
err := migrate.NewPostgresMigrator(connection).Exec(migrate.Up, migrations.Migrations...)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return connection, cleanup
|
||||
}
|
||||
|
||||
func CreateTestDBWithFixture(t *testing.T, testName string) (*sql.DB, func()) {
|
||||
connection, cleanup := CreateTestDB(t, testName)
|
||||
session, err := connection.Begin()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
defer session.Rollback()
|
||||
|
||||
loadFixtures(session)
|
||||
if err = session.Commit(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return connection, cleanup
|
||||
}
|
||||
|
||||
func CreateTestTx(t *testing.T, testName string) (*sql.Tx, func()) {
|
||||
connection, cleanup := CreateTestDB(t, testName)
|
||||
session, err := connection.Begin()
|
||||
if session == nil {
|
||||
panic("session is none")
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return session, func() {
|
||||
session.Rollback()
|
||||
cleanup()
|
||||
}
|
||||
}
|
||||
|
||||
func CreateTestTxWithFixtures(t *testing.T, testName string) (*sql.Tx, func()) {
|
||||
tx, cleanup := CreateTestTx(t, testName)
|
||||
defer func() {
|
||||
// ensure to cleanup when loadFixtures failed
|
||||
if r := recover(); r != nil {
|
||||
cleanup()
|
||||
}
|
||||
}()
|
||||
|
||||
loadFixtures(tx)
|
||||
return tx, cleanup
|
||||
}
|
||||
|
||||
func loadFixtures(tx *sql.Tx) {
|
||||
_, filename, _, _ := runtime.Caller(0)
|
||||
fixturePath := filepath.Join(filepath.Dir(filename), "data.sql")
|
||||
d, err := ioutil.ReadFile(fixturePath)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
_, err = tx.Exec(string(d))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func OpenSessionForTest(t *testing.T, name string, loadFixture bool) (*sql.DB, *sql.Tx) {
|
||||
var db *sql.DB
|
||||
if loadFixture {
|
||||
db, _ = CreateTestDB(t, name)
|
||||
} else {
|
||||
db, _ = CreateTestDBWithFixture(t, name)
|
||||
}
|
||||
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return db, tx
|
||||
}
|
||||
|
||||
func RestartTransaction(db *sql.DB, tx *sql.Tx, commit bool) *sql.Tx {
|
||||
if !commit {
|
||||
if err := tx.Rollback(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
} else {
|
||||
if err := tx.Commit(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return tx
|
||||
}
|
94
database/pgsql/testutil/testutil.go
Normal file
94
database/pgsql/testutil/testutil.go
Normal file
@ -0,0 +1,94 @@
|
||||
// Copyright 2018 clair authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package testutil
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
"github.com/coreos/clair/database/pgsql/page"
|
||||
"github.com/coreos/clair/pkg/pagination"
|
||||
)
|
||||
|
||||
func takeVulnerabilityPointerFromMap(m map[int]database.Vulnerability, id int) *database.Vulnerability {
|
||||
x := m[id]
|
||||
return &x
|
||||
}
|
||||
|
||||
func TakeAncestryPointerFromMap(m map[int]database.Ancestry, id int) *database.Ancestry {
|
||||
x := m[id]
|
||||
return &x
|
||||
}
|
||||
|
||||
func TakeLayerPointerFromMap(m map[int]database.Layer, id int) *database.Layer {
|
||||
x := m[id]
|
||||
return &x
|
||||
}
|
||||
|
||||
func ListNamespaces(t *testing.T, tx *sql.Tx) []database.Namespace {
|
||||
rows, err := tx.Query("SELECT name, version_format FROM namespace")
|
||||
if err != nil {
|
||||
t.FailNow()
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
namespaces := []database.Namespace{}
|
||||
for rows.Next() {
|
||||
var ns database.Namespace
|
||||
err := rows.Scan(&ns.Name, &ns.VersionFormat)
|
||||
if err != nil {
|
||||
t.FailNow()
|
||||
}
|
||||
namespaces = append(namespaces, ns)
|
||||
}
|
||||
|
||||
return namespaces
|
||||
}
|
||||
|
||||
func mustUnmarshalToken(key pagination.Key, token pagination.Token) page.Page {
|
||||
if token == pagination.FirstPageToken {
|
||||
return page.Page{}
|
||||
}
|
||||
|
||||
p := page.Page{}
|
||||
if err := key.UnmarshalToken(token, &p); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
func MustMarshalToken(key pagination.Key, v interface{}) pagination.Token {
|
||||
token, err := key.MarshalToken(v)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return token
|
||||
}
|
||||
|
||||
func GenRandomNamespaces(t *testing.T, count int) []database.Namespace {
|
||||
r := make([]database.Namespace, count)
|
||||
for i := 0; i < count; i++ {
|
||||
r[i] = database.Namespace{
|
||||
Name: fmt.Sprint(rand.Int()),
|
||||
VersionFormat: "dpkg",
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
63
database/pgsql/util/error.go
Normal file
63
database/pgsql/util/error.go
Normal file
@ -0,0 +1,63 @@
|
||||
// Copyright 2019 clair authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package util
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"strings"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
"github.com/coreos/clair/database/pgsql/monitoring"
|
||||
"github.com/coreos/clair/pkg/commonerr"
|
||||
"github.com/lib/pq"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// IsErrUniqueViolation determines is the given error is a unique contraint violation.
|
||||
func IsErrUniqueViolation(err error) bool {
|
||||
pqErr, ok := err.(*pq.Error)
|
||||
return ok && pqErr.Code == "23505"
|
||||
}
|
||||
|
||||
// HandleError logs an error with an extra description and masks the error if it's an SQL one.
|
||||
// The function ensures we never return plain SQL errors and leak anything.
|
||||
// The function should be used for every database query error.
|
||||
func HandleError(desc string, err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
return commonerr.ErrNotFound
|
||||
}
|
||||
|
||||
if pqErr, ok := err.(*pq.Error); ok {
|
||||
if pqErr.Fatal() {
|
||||
panic(pqErr)
|
||||
}
|
||||
|
||||
if pqErr.Code == "42601" {
|
||||
panic("invalid query: " + desc + ", info: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
logrus.WithError(err).WithField("Description", desc).Error("database: handled database error")
|
||||
monitoring.PromErrorsTotal.WithLabelValues(desc).Inc()
|
||||
if _, o := err.(*pq.Error); o || err == sql.ErrTxDone || strings.HasPrefix(err.Error(), "sql:") {
|
||||
return database.ErrBackendException
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
57
database/pgsql/util/queries.go
Normal file
57
database/pgsql/util/queries.go
Normal file
@ -0,0 +1,57 @@
|
||||
// Copyright 2015 clair authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package util
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
func QueryInsert(count int, table string, columns ...string) string {
|
||||
base := `INSERT INTO %s (%s) VALUES %s`
|
||||
t := pq.QuoteIdentifier(table)
|
||||
cols := make([]string, len(columns))
|
||||
for i, c := range columns {
|
||||
cols[i] = pq.QuoteIdentifier(c)
|
||||
}
|
||||
colsQuoted := strings.Join(cols, ",")
|
||||
return fmt.Sprintf(base, t, colsQuoted, QueryString(len(columns), count))
|
||||
}
|
||||
|
||||
func QueryPersist(count int, table, constraint string, columns ...string) string {
|
||||
ct := ""
|
||||
if constraint != "" {
|
||||
ct = fmt.Sprintf("ON CONSTRAINT %s", constraint)
|
||||
}
|
||||
return fmt.Sprintf("%s ON CONFLICT %s DO NOTHING", QueryInsert(count, table, columns...), ct)
|
||||
}
|
||||
|
||||
// size of key and array should be both greater than 0
|
||||
func QueryString(keySize, arraySize int) string {
|
||||
if arraySize <= 0 || keySize <= 0 {
|
||||
panic("Bulk Query requires size of element tuple and number of elements to be greater than 0")
|
||||
}
|
||||
keys := make([]string, 0, arraySize)
|
||||
for i := 0; i < arraySize; i++ {
|
||||
key := make([]string, keySize)
|
||||
for j := 0; j < keySize; j++ {
|
||||
key[j] = fmt.Sprintf("$%d", i*keySize+j+1)
|
||||
}
|
||||
keys = append(keys, fmt.Sprintf("(%s)", strings.Join(key, ",")))
|
||||
}
|
||||
return strings.Join(keys, ",")
|
||||
}
|
@ -12,23 +12,27 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package pgsql
|
||||
package vulnerability
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
"github.com/coreos/clair/database/pgsql/feature"
|
||||
"github.com/coreos/clair/database/pgsql/monitoring"
|
||||
"github.com/coreos/clair/database/pgsql/page"
|
||||
"github.com/coreos/clair/database/pgsql/util"
|
||||
"github.com/coreos/clair/ext/versionfmt"
|
||||
"github.com/coreos/clair/pkg/pagination"
|
||||
)
|
||||
|
||||
const (
|
||||
lockVulnerabilityAffects = `LOCK vulnerability_affected_namespaced_feature IN SHARE ROW EXCLUSIVE MODE`
|
||||
|
||||
searchVulnerability = `
|
||||
SELECT v.id, v.description, v.link, v.severity, v.metadata, n.version_format
|
||||
FROM vulnerability AS v, namespace AS n
|
||||
@ -38,45 +42,12 @@ const (
|
||||
AND v.deleted_at IS NULL
|
||||
`
|
||||
|
||||
insertVulnerabilityAffected = `
|
||||
INSERT INTO vulnerability_affected_feature(vulnerability_id, feature_name, affected_version, feature_type, fixedin)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
RETURNING ID
|
||||
`
|
||||
|
||||
searchVulnerabilityAffected = `
|
||||
SELECT vulnerability_id, feature_name, affected_version, t.name, fixedin
|
||||
FROM vulnerability_affected_feature AS vaf, feature_type AS t
|
||||
WHERE t.id = vaf.feature_type AND vulnerability_id = ANY($1)
|
||||
`
|
||||
|
||||
searchVulnerabilityByID = `
|
||||
SELECT v.name, v.description, v.link, v.severity, v.metadata, n.name, n.version_format
|
||||
FROM vulnerability AS v, namespace AS n
|
||||
WHERE v.namespace_id = n.id
|
||||
AND v.id = $1`
|
||||
|
||||
searchVulnerabilityPotentialAffected = `
|
||||
WITH req AS (
|
||||
SELECT vaf.id AS vaf_id, n.id AS n_id, vaf.feature_name AS name, vaf.feature_type AS type, v.id AS vulnerability_id
|
||||
FROM vulnerability_affected_feature AS vaf,
|
||||
vulnerability AS v,
|
||||
namespace AS n
|
||||
WHERE vaf.vulnerability_id = ANY($1)
|
||||
AND v.id = vaf.vulnerability_id
|
||||
AND n.id = v.namespace_id
|
||||
)
|
||||
SELECT req.vulnerability_id, nf.id, f.version, req.vaf_id AS added_by
|
||||
FROM feature AS f, namespaced_feature AS nf, req
|
||||
WHERE f.name = req.name
|
||||
AND f.type = req.type
|
||||
AND nf.namespace_id = req.n_id
|
||||
AND nf.feature_id = f.id`
|
||||
|
||||
insertVulnerabilityAffectedNamespacedFeature = `
|
||||
INSERT INTO vulnerability_affected_namespaced_feature(vulnerability_id, namespaced_feature_id, added_by)
|
||||
VALUES ($1, $2, $3)`
|
||||
|
||||
insertVulnerability = `
|
||||
WITH ns AS (
|
||||
SELECT id FROM namespace WHERE name = $6 AND version_format = $7
|
||||
@ -92,11 +63,55 @@ const (
|
||||
AND name = $2
|
||||
AND deleted_at IS NULL
|
||||
RETURNING id`
|
||||
|
||||
searchNotificationVulnerableAncestry = `
|
||||
SELECT DISTINCT ON (a.id)
|
||||
a.id, a.name
|
||||
FROM vulnerability_affected_namespaced_feature AS vanf,
|
||||
ancestry_layer AS al, ancestry_feature AS af, ancestry AS a
|
||||
WHERE vanf.vulnerability_id = $1
|
||||
AND a.id >= $2
|
||||
AND al.ancestry_id = a.id
|
||||
AND al.id = af.ancestry_layer_id
|
||||
AND af.namespaced_feature_id = vanf.namespaced_feature_id
|
||||
ORDER BY a.id ASC
|
||||
LIMIT $3;`
|
||||
)
|
||||
|
||||
var (
|
||||
errVulnerabilityNotFound = errors.New("vulnerability is not in database")
|
||||
)
|
||||
func queryInvalidateVulnerabilityCache(count int) string {
|
||||
return fmt.Sprintf(`DELETE FROM vulnerability_affected_feature
|
||||
WHERE vulnerability_id IN (%s)`,
|
||||
util.QueryString(1, count))
|
||||
}
|
||||
|
||||
// NOTE(Sida): Every search query can only have count less than postgres set
|
||||
// stack depth. IN will be resolved to nested OR_s and the parser might exceed
|
||||
// stack depth. TODO(Sida): Generate different queries for different count: if
|
||||
// count < 5120, use IN; for count > 5120 and < 65536, use temporary table; for
|
||||
// count > 65535, use is expected to split data into batches.
|
||||
func querySearchLastDeletedVulnerabilityID(count int) string {
|
||||
return fmt.Sprintf(`
|
||||
SELECT vid, vname, nname FROM (
|
||||
SELECT v.id AS vid, v.name AS vname, n.name AS nname,
|
||||
row_number() OVER (
|
||||
PARTITION by (v.name, n.name)
|
||||
ORDER BY v.deleted_at DESC
|
||||
) AS rownum
|
||||
FROM vulnerability AS v, namespace AS n
|
||||
WHERE v.namespace_id = n.id
|
||||
AND (v.name, n.name) IN ( %s )
|
||||
AND v.deleted_at IS NOT NULL
|
||||
) tmp WHERE rownum <= 1`,
|
||||
util.QueryString(2, count))
|
||||
}
|
||||
|
||||
func querySearchNotDeletedVulnerabilityID(count int) string {
|
||||
return fmt.Sprintf(`
|
||||
SELECT v.id, v.name, n.name FROM vulnerability AS v, namespace AS n
|
||||
WHERE v.namespace_id = n.id AND (v.name, n.name) IN (%s)
|
||||
AND v.deleted_at IS NULL`,
|
||||
util.QueryString(2, count))
|
||||
}
|
||||
|
||||
type affectedAncestry struct {
|
||||
name string
|
||||
@ -113,8 +128,8 @@ type affectedFeatureRows struct {
|
||||
rows map[int64]database.AffectedFeature
|
||||
}
|
||||
|
||||
func (tx *pgSession) FindVulnerabilities(vulnerabilities []database.VulnerabilityID) ([]database.NullableVulnerability, error) {
|
||||
defer observeQueryTime("findVulnerabilities", "", time.Now())
|
||||
func FindVulnerabilities(tx *sql.Tx, vulnerabilities []database.VulnerabilityID) ([]database.NullableVulnerability, error) {
|
||||
defer monitoring.ObserveQueryTime("findVulnerabilities", "", time.Now())
|
||||
resultVuln := make([]database.NullableVulnerability, len(vulnerabilities))
|
||||
vulnIDMap := map[int64][]*database.NullableVulnerability{}
|
||||
|
||||
@ -151,7 +166,7 @@ func (tx *pgSession) FindVulnerabilities(vulnerabilities []database.Vulnerabilit
|
||||
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
stmt.Close()
|
||||
return nil, handleError("searchVulnerability", err)
|
||||
return nil, util.HandleError("searchVulnerability", err)
|
||||
}
|
||||
vuln.Valid = id.Valid
|
||||
resultVuln[i] = vuln
|
||||
@ -161,7 +176,7 @@ func (tx *pgSession) FindVulnerabilities(vulnerabilities []database.Vulnerabilit
|
||||
}
|
||||
|
||||
if err := stmt.Close(); err != nil {
|
||||
return nil, handleError("searchVulnerability", err)
|
||||
return nil, util.HandleError("searchVulnerability", err)
|
||||
}
|
||||
|
||||
toQuery := make([]int64, 0, len(vulnIDMap))
|
||||
@ -172,7 +187,7 @@ func (tx *pgSession) FindVulnerabilities(vulnerabilities []database.Vulnerabilit
|
||||
// load vulnerability affected features
|
||||
rows, err := tx.Query(searchVulnerabilityAffected, pq.Array(toQuery))
|
||||
if err != nil {
|
||||
return nil, handleError("searchVulnerabilityAffected", err)
|
||||
return nil, util.HandleError("searchVulnerabilityAffected", err)
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
@ -183,7 +198,7 @@ func (tx *pgSession) FindVulnerabilities(vulnerabilities []database.Vulnerabilit
|
||||
|
||||
err := rows.Scan(&id, &f.FeatureName, &f.AffectedVersion, &f.FeatureType, &f.FixedInVersion)
|
||||
if err != nil {
|
||||
return nil, handleError("searchVulnerabilityAffected", err)
|
||||
return nil, util.HandleError("searchVulnerabilityAffected", err)
|
||||
}
|
||||
|
||||
for _, vuln := range vulnIDMap[id] {
|
||||
@ -195,41 +210,40 @@ func (tx *pgSession) FindVulnerabilities(vulnerabilities []database.Vulnerabilit
|
||||
return resultVuln, nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) InsertVulnerabilities(vulnerabilities []database.VulnerabilityWithAffected) error {
|
||||
defer observeQueryTime("insertVulnerabilities", "all", time.Now())
|
||||
func InsertVulnerabilities(tx *sql.Tx, vulnerabilities []database.VulnerabilityWithAffected) error {
|
||||
defer monitoring.ObserveQueryTime("insertVulnerabilities", "all", time.Now())
|
||||
// bulk insert vulnerabilities
|
||||
vulnIDs, err := tx.insertVulnerabilities(vulnerabilities)
|
||||
vulnIDs, err := insertVulnerabilities(tx, vulnerabilities)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// bulk insert vulnerability affected features
|
||||
vulnFeatureMap, err := tx.insertVulnerabilityAffected(vulnIDs, vulnerabilities)
|
||||
vulnFeatureMap, err := InsertVulnerabilityAffected(tx, vulnIDs, vulnerabilities)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.cacheVulnerabiltyAffectedNamespacedFeature(vulnFeatureMap)
|
||||
return CacheVulnerabiltyAffectedNamespacedFeature(tx, vulnFeatureMap)
|
||||
}
|
||||
|
||||
// insertVulnerabilityAffected inserts a set of vulnerability affected features for each vulnerability provided.
|
||||
//
|
||||
// i_th vulnerabilityIDs corresponds to i_th vulnerabilities provided.
|
||||
func (tx *pgSession) insertVulnerabilityAffected(vulnerabilityIDs []int64, vulnerabilities []database.VulnerabilityWithAffected) (map[int64]affectedFeatureRows, error) {
|
||||
func InsertVulnerabilityAffected(tx *sql.Tx, vulnerabilityIDs []int64, vulnerabilities []database.VulnerabilityWithAffected) (map[int64]affectedFeatureRows, error) {
|
||||
var (
|
||||
vulnFeature = map[int64]affectedFeatureRows{}
|
||||
affectedID int64
|
||||
)
|
||||
|
||||
types, err := tx.getFeatureTypeMap()
|
||||
types, err := feature.GetFeatureTypeMap(tx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
//TODO(Sida): Change to bulk insert.
|
||||
stmt, err := tx.Prepare(insertVulnerabilityAffected)
|
||||
if err != nil {
|
||||
return nil, handleError("insertVulnerabilityAffected", err)
|
||||
return nil, util.HandleError("insertVulnerabilityAffected", err)
|
||||
}
|
||||
|
||||
defer stmt.Close()
|
||||
@ -237,9 +251,9 @@ func (tx *pgSession) insertVulnerabilityAffected(vulnerabilityIDs []int64, vulne
|
||||
// affected feature row ID -> affected feature
|
||||
affectedFeatures := map[int64]database.AffectedFeature{}
|
||||
for _, f := range vuln.Affected {
|
||||
err := stmt.QueryRow(vulnerabilityIDs[i], f.FeatureName, f.AffectedVersion, types.byName[f.FeatureType], f.FixedInVersion).Scan(&affectedID)
|
||||
err := stmt.QueryRow(vulnerabilityIDs[i], f.FeatureName, f.AffectedVersion, types.ByName[f.FeatureType], f.FixedInVersion).Scan(&affectedID)
|
||||
if err != nil {
|
||||
return nil, handleError("insertVulnerabilityAffected", err)
|
||||
return nil, util.HandleError("insertVulnerabilityAffected", err)
|
||||
}
|
||||
affectedFeatures[affectedID] = f
|
||||
}
|
||||
@ -251,7 +265,7 @@ func (tx *pgSession) insertVulnerabilityAffected(vulnerabilityIDs []int64, vulne
|
||||
|
||||
// insertVulnerabilities inserts a set of unique vulnerabilities into database,
|
||||
// under the assumption that all vulnerabilities are valid.
|
||||
func (tx *pgSession) insertVulnerabilities(vulnerabilities []database.VulnerabilityWithAffected) ([]int64, error) {
|
||||
func insertVulnerabilities(tx *sql.Tx, vulnerabilities []database.VulnerabilityWithAffected) ([]int64, error) {
|
||||
var (
|
||||
vulnID int64
|
||||
vulnIDs = make([]int64, 0, len(vulnerabilities))
|
||||
@ -274,7 +288,7 @@ func (tx *pgSession) insertVulnerabilities(vulnerabilities []database.Vulnerabil
|
||||
//TODO(Sida): Change to bulk insert.
|
||||
stmt, err := tx.Prepare(insertVulnerability)
|
||||
if err != nil {
|
||||
return nil, handleError("insertVulnerability", err)
|
||||
return nil, util.HandleError("insertVulnerability", err)
|
||||
}
|
||||
|
||||
defer stmt.Close()
|
||||
@ -283,7 +297,7 @@ func (tx *pgSession) insertVulnerabilities(vulnerabilities []database.Vulnerabil
|
||||
vuln.Link, &vuln.Severity, &vuln.Metadata,
|
||||
vuln.Namespace.Name, vuln.Namespace.VersionFormat).Scan(&vulnID)
|
||||
if err != nil {
|
||||
return nil, handleError("insertVulnerability", err)
|
||||
return nil, util.HandleError("insertVulnerability", err)
|
||||
}
|
||||
|
||||
vulnIDs = append(vulnIDs, vulnID)
|
||||
@ -292,19 +306,19 @@ func (tx *pgSession) insertVulnerabilities(vulnerabilities []database.Vulnerabil
|
||||
return vulnIDs, nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) lockFeatureVulnerabilityCache() error {
|
||||
func LockFeatureVulnerabilityCache(tx *sql.Tx) error {
|
||||
_, err := tx.Exec(lockVulnerabilityAffects)
|
||||
if err != nil {
|
||||
return handleError("lockVulnerabilityAffects", err)
|
||||
return util.HandleError("lockVulnerabilityAffects", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// cacheVulnerabiltyAffectedNamespacedFeature takes in a map of vulnerability ID
|
||||
// to affected feature rows and caches them.
|
||||
func (tx *pgSession) cacheVulnerabiltyAffectedNamespacedFeature(affected map[int64]affectedFeatureRows) error {
|
||||
func CacheVulnerabiltyAffectedNamespacedFeature(tx *sql.Tx, affected map[int64]affectedFeatureRows) error {
|
||||
// Prevent InsertNamespacedFeatures to modify it.
|
||||
err := tx.lockFeatureVulnerabilityCache()
|
||||
err := LockFeatureVulnerabilityCache(tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -316,7 +330,7 @@ func (tx *pgSession) cacheVulnerabiltyAffectedNamespacedFeature(affected map[int
|
||||
|
||||
rows, err := tx.Query(searchVulnerabilityPotentialAffected, pq.Array(vulnIDs))
|
||||
if err != nil {
|
||||
return handleError("searchVulnerabilityPotentialAffected", err)
|
||||
return util.HandleError("searchVulnerabilityPotentialAffected", err)
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
@ -332,7 +346,7 @@ func (tx *pgSession) cacheVulnerabiltyAffectedNamespacedFeature(affected map[int
|
||||
|
||||
err := rows.Scan(&vulnID, &nsfID, &fVersion, &addedBy)
|
||||
if err != nil {
|
||||
return handleError("searchVulnerabilityPotentialAffected", err)
|
||||
return util.HandleError("searchVulnerabilityPotentialAffected", err)
|
||||
}
|
||||
|
||||
candidate, ok := affected[vulnID].rows[addedBy]
|
||||
@ -361,7 +375,7 @@ func (tx *pgSession) cacheVulnerabiltyAffectedNamespacedFeature(affected map[int
|
||||
for _, r := range relation {
|
||||
result, err := tx.Exec(insertVulnerabilityAffectedNamespacedFeature, r.vulnerabilityID, r.namespacedFeatureID, r.addedBy)
|
||||
if err != nil {
|
||||
return handleError("insertVulnerabilityAffectedNamespacedFeature", err)
|
||||
return util.HandleError("insertVulnerabilityAffectedNamespacedFeature", err)
|
||||
}
|
||||
|
||||
if num, err := result.RowsAffected(); err == nil {
|
||||
@ -377,27 +391,27 @@ func (tx *pgSession) cacheVulnerabiltyAffectedNamespacedFeature(affected map[int
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) DeleteVulnerabilities(vulnerabilities []database.VulnerabilityID) error {
|
||||
defer observeQueryTime("DeleteVulnerability", "all", time.Now())
|
||||
func DeleteVulnerabilities(tx *sql.Tx, vulnerabilities []database.VulnerabilityID) error {
|
||||
defer monitoring.ObserveQueryTime("DeleteVulnerability", "all", time.Now())
|
||||
|
||||
vulnIDs, err := tx.markVulnerabilitiesAsDeleted(vulnerabilities)
|
||||
vulnIDs, err := MarkVulnerabilitiesAsDeleted(tx, vulnerabilities)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := tx.invalidateVulnerabilityCache(vulnIDs); err != nil {
|
||||
if err := InvalidateVulnerabilityCache(tx, vulnIDs); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) invalidateVulnerabilityCache(vulnerabilityIDs []int64) error {
|
||||
func InvalidateVulnerabilityCache(tx *sql.Tx, vulnerabilityIDs []int64) error {
|
||||
if len(vulnerabilityIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Prevent InsertNamespacedFeatures to modify it.
|
||||
err := tx.lockFeatureVulnerabilityCache()
|
||||
err := LockFeatureVulnerabilityCache(tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -410,13 +424,13 @@ func (tx *pgSession) invalidateVulnerabilityCache(vulnerabilityIDs []int64) erro
|
||||
|
||||
_, err = tx.Exec(queryInvalidateVulnerabilityCache(len(vulnerabilityIDs)), keys...)
|
||||
if err != nil {
|
||||
return handleError("removeVulnerabilityAffectedFeature", err)
|
||||
return util.HandleError("removeVulnerabilityAffectedFeature", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) markVulnerabilitiesAsDeleted(vulnerabilities []database.VulnerabilityID) ([]int64, error) {
|
||||
func MarkVulnerabilitiesAsDeleted(tx *sql.Tx, vulnerabilities []database.VulnerabilityID) ([]int64, error) {
|
||||
var (
|
||||
vulnID sql.NullInt64
|
||||
vulnIDs []int64
|
||||
@ -425,17 +439,17 @@ func (tx *pgSession) markVulnerabilitiesAsDeleted(vulnerabilities []database.Vul
|
||||
// mark vulnerabilities deleted
|
||||
stmt, err := tx.Prepare(removeVulnerability)
|
||||
if err != nil {
|
||||
return nil, handleError("removeVulnerability", err)
|
||||
return nil, util.HandleError("removeVulnerability", err)
|
||||
}
|
||||
|
||||
defer stmt.Close()
|
||||
for _, vuln := range vulnerabilities {
|
||||
err := stmt.QueryRow(vuln.Namespace, vuln.Name).Scan(&vulnID)
|
||||
if err != nil {
|
||||
return nil, handleError("removeVulnerability", err)
|
||||
return nil, util.HandleError("removeVulnerability", err)
|
||||
}
|
||||
if !vulnID.Valid {
|
||||
return nil, handleError("removeVulnerability", errors.New("Vulnerability to be removed is not in database"))
|
||||
return nil, util.HandleError("removeVulnerability", errors.New("Vulnerability to be removed is not in database"))
|
||||
}
|
||||
vulnIDs = append(vulnIDs, vulnID.Int64)
|
||||
}
|
||||
@ -444,15 +458,15 @@ func (tx *pgSession) markVulnerabilitiesAsDeleted(vulnerabilities []database.Vul
|
||||
|
||||
// findLatestDeletedVulnerabilityIDs requires all elements in vulnIDs are in
|
||||
// database and the order of output array is not guaranteed.
|
||||
func (tx *pgSession) findLatestDeletedVulnerabilityIDs(vulnIDs []database.VulnerabilityID) ([]sql.NullInt64, error) {
|
||||
return tx.findVulnerabilityIDs(vulnIDs, true)
|
||||
func FindLatestDeletedVulnerabilityIDs(tx *sql.Tx, vulnIDs []database.VulnerabilityID) ([]sql.NullInt64, error) {
|
||||
return FindVulnerabilityIDs(tx, vulnIDs, true)
|
||||
}
|
||||
|
||||
func (tx *pgSession) findNotDeletedVulnerabilityIDs(vulnIDs []database.VulnerabilityID) ([]sql.NullInt64, error) {
|
||||
return tx.findVulnerabilityIDs(vulnIDs, false)
|
||||
func FindNotDeletedVulnerabilityIDs(tx *sql.Tx, vulnIDs []database.VulnerabilityID) ([]sql.NullInt64, error) {
|
||||
return FindVulnerabilityIDs(tx, vulnIDs, false)
|
||||
}
|
||||
|
||||
func (tx *pgSession) findVulnerabilityIDs(vulnIDs []database.VulnerabilityID, withLatestDeleted bool) ([]sql.NullInt64, error) {
|
||||
func FindVulnerabilityIDs(tx *sql.Tx, vulnIDs []database.VulnerabilityID, withLatestDeleted bool) ([]sql.NullInt64, error) {
|
||||
if len(vulnIDs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
@ -474,7 +488,7 @@ func (tx *pgSession) findVulnerabilityIDs(vulnIDs []database.VulnerabilityID, wi
|
||||
|
||||
rows, err := tx.Query(query, keys...)
|
||||
if err != nil {
|
||||
return nil, handleError("querySearchVulnerabilityID.LatestDeleted.Query", err)
|
||||
return nil, util.HandleError("querySearchVulnerabilityID.LatestDeleted.Query", err)
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
@ -485,7 +499,7 @@ func (tx *pgSession) findVulnerabilityIDs(vulnIDs []database.VulnerabilityID, wi
|
||||
for rows.Next() {
|
||||
err := rows.Scan(&id, &vulnID.Name, &vulnID.Namespace)
|
||||
if err != nil {
|
||||
return nil, handleError("querySearchVulnerabilityID.LatestDeleted.Scan", err)
|
||||
return nil, util.HandleError("querySearchVulnerabilityID.LatestDeleted.Scan", err)
|
||||
}
|
||||
vulnIDMap[vulnID] = id
|
||||
}
|
||||
@ -497,3 +511,67 @@ func (tx *pgSession) findVulnerabilityIDs(vulnIDs []database.VulnerabilityID, wi
|
||||
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func FindPagedVulnerableAncestries(tx *sql.Tx, vulnID int64, limit int, currentToken pagination.Token, key pagination.Key) (database.PagedVulnerableAncestries, error) {
|
||||
vulnPage := database.PagedVulnerableAncestries{Limit: limit}
|
||||
currentPage := page.Page{0}
|
||||
if currentToken != pagination.FirstPageToken {
|
||||
if err := key.UnmarshalToken(currentToken, ¤tPage); err != nil {
|
||||
return vulnPage, err
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.QueryRow(searchVulnerabilityByID, vulnID).Scan(
|
||||
&vulnPage.Name,
|
||||
&vulnPage.Description,
|
||||
&vulnPage.Link,
|
||||
&vulnPage.Severity,
|
||||
&vulnPage.Metadata,
|
||||
&vulnPage.Namespace.Name,
|
||||
&vulnPage.Namespace.VersionFormat,
|
||||
); err != nil {
|
||||
return vulnPage, util.HandleError("searchVulnerabilityByID", err)
|
||||
}
|
||||
|
||||
// the last result is used for the next page's startID
|
||||
rows, err := tx.Query(searchNotificationVulnerableAncestry, vulnID, currentPage.StartID, limit+1)
|
||||
if err != nil {
|
||||
return vulnPage, util.HandleError("searchNotificationVulnerableAncestry", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
ancestries := []affectedAncestry{}
|
||||
for rows.Next() {
|
||||
var ancestry affectedAncestry
|
||||
err := rows.Scan(&ancestry.id, &ancestry.name)
|
||||
if err != nil {
|
||||
return vulnPage, util.HandleError("searchNotificationVulnerableAncestry", err)
|
||||
}
|
||||
ancestries = append(ancestries, ancestry)
|
||||
}
|
||||
|
||||
lastIndex := 0
|
||||
if len(ancestries)-1 < limit {
|
||||
lastIndex = len(ancestries)
|
||||
vulnPage.End = true
|
||||
} else {
|
||||
// Use the last ancestry's ID as the next page.
|
||||
lastIndex = len(ancestries) - 1
|
||||
vulnPage.Next, err = key.MarshalToken(page.Page{ancestries[len(ancestries)-1].id})
|
||||
if err != nil {
|
||||
return vulnPage, err
|
||||
}
|
||||
}
|
||||
|
||||
vulnPage.Affected = map[int]string{}
|
||||
for _, ancestry := range ancestries[0:lastIndex] {
|
||||
vulnPage.Affected[int(ancestry.id)] = ancestry.name
|
||||
}
|
||||
|
||||
vulnPage.Current, err = key.MarshalToken(currentPage)
|
||||
if err != nil {
|
||||
return vulnPage, err
|
||||
}
|
||||
|
||||
return vulnPage, nil
|
||||
}
|
118
database/pgsql/vulnerability/vulnerability_affected_feature.go
Normal file
118
database/pgsql/vulnerability/vulnerability_affected_feature.go
Normal file
@ -0,0 +1,118 @@
|
||||
// Copyright 2019 clair authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package vulnerability
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
"github.com/coreos/clair/database/pgsql/feature"
|
||||
"github.com/coreos/clair/database/pgsql/util"
|
||||
"github.com/coreos/clair/ext/versionfmt"
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
const (
|
||||
searchPotentialAffectingVulneraibilities = `
|
||||
SELECT nf.id, v.id, vaf.affected_version, vaf.id
|
||||
FROM vulnerability_affected_feature AS vaf, vulnerability AS v,
|
||||
namespaced_feature AS nf, feature AS f
|
||||
WHERE nf.id = ANY($1)
|
||||
AND nf.feature_id = f.id
|
||||
AND nf.namespace_id = v.namespace_id
|
||||
AND vaf.feature_name = f.name
|
||||
AND vaf.feature_type = f.type
|
||||
AND vaf.vulnerability_id = v.id
|
||||
AND v.deleted_at IS NULL`
|
||||
insertVulnerabilityAffected = `
|
||||
INSERT INTO vulnerability_affected_feature(vulnerability_id, feature_name, affected_version, feature_type, fixedin)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
RETURNING ID
|
||||
`
|
||||
searchVulnerabilityAffected = `
|
||||
SELECT vulnerability_id, feature_name, affected_version, t.name, fixedin
|
||||
FROM vulnerability_affected_feature AS vaf, feature_type AS t
|
||||
WHERE t.id = vaf.feature_type AND vulnerability_id = ANY($1)
|
||||
`
|
||||
|
||||
searchVulnerabilityPotentialAffected = `
|
||||
WITH req AS (
|
||||
SELECT vaf.id AS vaf_id, n.id AS n_id, vaf.feature_name AS name, vaf.feature_type AS type, v.id AS vulnerability_id
|
||||
FROM vulnerability_affected_feature AS vaf,
|
||||
vulnerability AS v,
|
||||
namespace AS n
|
||||
WHERE vaf.vulnerability_id = ANY($1)
|
||||
AND v.id = vaf.vulnerability_id
|
||||
AND n.id = v.namespace_id
|
||||
)
|
||||
SELECT req.vulnerability_id, nf.id, f.version, req.vaf_id AS added_by
|
||||
FROM feature AS f, namespaced_feature AS nf, req
|
||||
WHERE f.name = req.name
|
||||
AND f.type = req.type
|
||||
AND nf.namespace_id = req.n_id
|
||||
AND nf.feature_id = f.id`
|
||||
)
|
||||
|
||||
type vulnerabilityCache struct {
|
||||
nsFeatureID int64
|
||||
vulnID int64
|
||||
vulnAffectingID int64
|
||||
}
|
||||
|
||||
func SearchAffectingVulnerabilities(tx *sql.Tx, features []database.NamespacedFeature) ([]vulnerabilityCache, error) {
|
||||
if len(features) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
ids, err := feature.FindNamespacedFeatureIDs(tx, features)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fMap := map[int64]database.NamespacedFeature{}
|
||||
for i, f := range features {
|
||||
if !ids[i].Valid {
|
||||
return nil, database.ErrMissingEntities
|
||||
}
|
||||
fMap[ids[i].Int64] = f
|
||||
}
|
||||
|
||||
cacheTable := []vulnerabilityCache{}
|
||||
rows, err := tx.Query(searchPotentialAffectingVulneraibilities, pq.Array(ids))
|
||||
if err != nil {
|
||||
return nil, util.HandleError("searchPotentialAffectingVulneraibilities", err)
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var (
|
||||
cache vulnerabilityCache
|
||||
affected string
|
||||
)
|
||||
|
||||
err := rows.Scan(&cache.nsFeatureID, &cache.vulnID, &affected, &cache.vulnAffectingID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if ok, err := versionfmt.InRange(fMap[cache.nsFeatureID].VersionFormat, fMap[cache.nsFeatureID].Version, affected); err != nil {
|
||||
return nil, err
|
||||
} else if ok {
|
||||
cacheTable = append(cacheTable, cache)
|
||||
}
|
||||
}
|
||||
|
||||
return cacheTable, nil
|
||||
}
|
@ -0,0 +1,142 @@
|
||||
// Copyright 2019 clair authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package vulnerability
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
"github.com/coreos/clair/database/pgsql/feature"
|
||||
"github.com/coreos/clair/database/pgsql/util"
|
||||
"github.com/lib/pq"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
searchNamespacedFeaturesVulnerabilities = `
|
||||
SELECT vanf.namespaced_feature_id, v.name, v.description, v.link,
|
||||
v.severity, v.metadata, vaf.fixedin, n.name, n.version_format
|
||||
FROM vulnerability_affected_namespaced_feature AS vanf,
|
||||
Vulnerability AS v,
|
||||
vulnerability_affected_feature AS vaf,
|
||||
namespace AS n
|
||||
WHERE vanf.namespaced_feature_id = ANY($1)
|
||||
AND vaf.id = vanf.added_by
|
||||
AND v.id = vanf.vulnerability_id
|
||||
AND n.id = v.namespace_id
|
||||
AND v.deleted_at IS NULL`
|
||||
|
||||
lockVulnerabilityAffects = `LOCK vulnerability_affected_namespaced_feature IN SHARE ROW EXCLUSIVE MODE`
|
||||
|
||||
insertVulnerabilityAffectedNamespacedFeature = `
|
||||
INSERT INTO vulnerability_affected_namespaced_feature(vulnerability_id, namespaced_feature_id, added_by)
|
||||
VALUES ($1, $2, $3)`
|
||||
)
|
||||
|
||||
func queryPersistVulnerabilityAffectedNamespacedFeature(count int) string {
|
||||
return util.QueryPersist(count, "vulnerability_affected_namespaced_feature",
|
||||
"vulnerability_affected_namesp_vulnerability_id_namespaced_f_key",
|
||||
"vulnerability_id",
|
||||
"namespaced_feature_id",
|
||||
"added_by")
|
||||
}
|
||||
|
||||
// FindAffectedNamespacedFeatures retrieves vulnerabilities associated with the
|
||||
// feature.
|
||||
func FindAffectedNamespacedFeatures(tx *sql.Tx, features []database.NamespacedFeature) ([]database.NullableAffectedNamespacedFeature, error) {
|
||||
if len(features) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
vulnerableFeatures := make([]database.NullableAffectedNamespacedFeature, len(features))
|
||||
featureIDs, err := feature.FindNamespacedFeatureIDs(tx, features)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i, id := range featureIDs {
|
||||
if id.Valid {
|
||||
vulnerableFeatures[i].Valid = true
|
||||
vulnerableFeatures[i].NamespacedFeature = features[i]
|
||||
}
|
||||
}
|
||||
|
||||
rows, err := tx.Query(searchNamespacedFeaturesVulnerabilities, pq.Array(featureIDs))
|
||||
if err != nil {
|
||||
return nil, util.HandleError("searchNamespacedFeaturesVulnerabilities", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var (
|
||||
featureID int64
|
||||
vuln database.VulnerabilityWithFixedIn
|
||||
)
|
||||
|
||||
err := rows.Scan(&featureID,
|
||||
&vuln.Name,
|
||||
&vuln.Description,
|
||||
&vuln.Link,
|
||||
&vuln.Severity,
|
||||
&vuln.Metadata,
|
||||
&vuln.FixedInVersion,
|
||||
&vuln.Namespace.Name,
|
||||
&vuln.Namespace.VersionFormat,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, util.HandleError("searchNamespacedFeaturesVulnerabilities", err)
|
||||
}
|
||||
|
||||
for i, id := range featureIDs {
|
||||
if id.Valid && id.Int64 == featureID {
|
||||
vulnerableFeatures[i].AffectedNamespacedFeature.AffectedBy = append(vulnerableFeatures[i].AffectedNamespacedFeature.AffectedBy, vuln)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return vulnerableFeatures, nil
|
||||
}
|
||||
|
||||
func CacheAffectedNamespacedFeatures(tx *sql.Tx, features []database.NamespacedFeature) error {
|
||||
if len(features) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := tx.Exec(lockVulnerabilityAffects)
|
||||
if err != nil {
|
||||
return util.HandleError("lockVulnerabilityAffects", err)
|
||||
}
|
||||
|
||||
cache, err := SearchAffectingVulnerabilities(tx, features)
|
||||
|
||||
keys := make([]interface{}, 0, len(cache)*3)
|
||||
for _, c := range cache {
|
||||
keys = append(keys, c.vulnID, c.nsFeatureID, c.vulnAffectingID)
|
||||
}
|
||||
|
||||
if len(cache) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
affected, err := tx.Exec(queryPersistVulnerabilityAffectedNamespacedFeature(len(cache)), keys...)
|
||||
if err != nil {
|
||||
return util.HandleError("persistVulnerabilityAffectedNamespacedFeature", err)
|
||||
}
|
||||
if count, err := affected.RowsAffected(); err != nil {
|
||||
log.Debugf("Cached %d features in vulnerability_affected_namespaced_feature", count)
|
||||
}
|
||||
return nil
|
||||
}
|
@ -12,19 +12,30 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package pgsql
|
||||
package vulnerability
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"math/rand"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/pborman/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
"github.com/coreos/clair/database/pgsql/feature"
|
||||
"github.com/coreos/clair/database/pgsql/namespace"
|
||||
"github.com/coreos/clair/database/pgsql/testutil"
|
||||
"github.com/coreos/clair/ext/versionfmt"
|
||||
"github.com/coreos/clair/ext/versionfmt/dpkg"
|
||||
"github.com/coreos/clair/pkg/strutil"
|
||||
)
|
||||
|
||||
func TestInsertVulnerabilities(t *testing.T) {
|
||||
store, tx := openSessionForTest(t, "InsertVulnerabilities", true)
|
||||
store, cleanup := testutil.CreateTestDBWithFixture(t, "InsertVulnerabilities")
|
||||
defer cleanup()
|
||||
|
||||
ns1 := database.Namespace{
|
||||
Name: "name",
|
||||
@ -56,45 +67,48 @@ func TestInsertVulnerabilities(t *testing.T) {
|
||||
Vulnerability: v2,
|
||||
}
|
||||
|
||||
tx, err := store.Begin()
|
||||
require.Nil(t, err)
|
||||
|
||||
// empty
|
||||
err := tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{})
|
||||
err = InsertVulnerabilities(tx, []database.VulnerabilityWithAffected{})
|
||||
assert.Nil(t, err)
|
||||
|
||||
// invalid content: vwa1 is invalid
|
||||
err = tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{vwa1, vwa2})
|
||||
err = InsertVulnerabilities(tx, []database.VulnerabilityWithAffected{vwa1, vwa2})
|
||||
assert.NotNil(t, err)
|
||||
|
||||
tx = restartSession(t, store, tx, false)
|
||||
tx = testutil.RestartTransaction(store, tx, false)
|
||||
// invalid content: duplicated input
|
||||
err = tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{vwa2, vwa2})
|
||||
err = InsertVulnerabilities(tx, []database.VulnerabilityWithAffected{vwa2, vwa2})
|
||||
assert.NotNil(t, err)
|
||||
|
||||
tx = restartSession(t, store, tx, false)
|
||||
tx = testutil.RestartTransaction(store, tx, false)
|
||||
// valid content
|
||||
err = tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{vwa2})
|
||||
err = InsertVulnerabilities(tx, []database.VulnerabilityWithAffected{vwa2})
|
||||
assert.Nil(t, err)
|
||||
|
||||
tx = restartSession(t, store, tx, true)
|
||||
tx = testutil.RestartTransaction(store, tx, true)
|
||||
// ensure the content is in database
|
||||
vulns, err := tx.FindVulnerabilities([]database.VulnerabilityID{{Name: "valid", Namespace: "debian:7"}})
|
||||
vulns, err := FindVulnerabilities(tx, []database.VulnerabilityID{{Name: "valid", Namespace: "debian:7"}})
|
||||
if assert.Nil(t, err) && assert.Len(t, vulns, 1) {
|
||||
assert.True(t, vulns[0].Valid)
|
||||
}
|
||||
|
||||
tx = restartSession(t, store, tx, false)
|
||||
tx = testutil.RestartTransaction(store, tx, false)
|
||||
// valid content: vwa2 removed and inserted
|
||||
err = tx.DeleteVulnerabilities([]database.VulnerabilityID{{Name: vwa2.Name, Namespace: vwa2.Namespace.Name}})
|
||||
err = DeleteVulnerabilities(tx, []database.VulnerabilityID{{Name: vwa2.Name, Namespace: vwa2.Namespace.Name}})
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{vwa2})
|
||||
err = InsertVulnerabilities(tx, []database.VulnerabilityWithAffected{vwa2})
|
||||
assert.Nil(t, err)
|
||||
|
||||
closeTest(t, store, tx)
|
||||
require.Nil(t, tx.Rollback())
|
||||
}
|
||||
|
||||
func TestCachingVulnerable(t *testing.T) {
|
||||
datastore, tx := openSessionForTest(t, "CachingVulnerable", true)
|
||||
defer closeTest(t, datastore, tx)
|
||||
tx, cleanup := testutil.CreateTestTxWithFixtures(t, "CachingVulnerable")
|
||||
defer cleanup()
|
||||
|
||||
ns := database.Namespace{
|
||||
Name: "debian:8",
|
||||
@ -163,11 +177,8 @@ func TestCachingVulnerable(t *testing.T) {
|
||||
FixedInVersion: "2.2",
|
||||
}
|
||||
|
||||
if !assert.Nil(t, tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{vuln, vuln2})) {
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
r, err := tx.FindAffectedNamespacedFeatures([]database.NamespacedFeature{f})
|
||||
require.Nil(t, InsertVulnerabilities(tx, []database.VulnerabilityWithAffected{vuln, vuln2}))
|
||||
r, err := FindAffectedNamespacedFeatures(tx, []database.NamespacedFeature{f})
|
||||
assert.Nil(t, err)
|
||||
assert.Len(t, r, 1)
|
||||
for _, anf := range r {
|
||||
@ -186,10 +197,10 @@ func TestCachingVulnerable(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestFindVulnerabilities(t *testing.T) {
|
||||
datastore, tx := openSessionForTest(t, "FindVulnerabilities", true)
|
||||
defer closeTest(t, datastore, tx)
|
||||
tx, cleanup := testutil.CreateTestTxWithFixtures(t, "FindVulnerabilities")
|
||||
defer cleanup()
|
||||
|
||||
vuln, err := tx.FindVulnerabilities([]database.VulnerabilityID{
|
||||
vuln, err := FindVulnerabilities(tx, []database.VulnerabilityID{
|
||||
{Name: "CVE-OPENSSL-1-DEB7", Namespace: "debian:7"},
|
||||
{Name: "CVE-NOPE", Namespace: "debian:7"},
|
||||
{Name: "CVE-NOT HERE"},
|
||||
@ -255,7 +266,7 @@ func TestFindVulnerabilities(t *testing.T) {
|
||||
|
||||
expected, ok := expectedExistingMap[key]
|
||||
if assert.True(t, ok, "vulnerability not found: "+key.Name+":"+key.Namespace) {
|
||||
assertVulnerabilityWithAffectedEqual(t, expected, v.VulnerabilityWithAffected)
|
||||
testutil.AssertVulnerabilityWithAffectedEqual(t, expected, v.VulnerabilityWithAffected)
|
||||
}
|
||||
} else if !assert.Equal(t, nonexisting, v.VulnerabilityWithAffected) {
|
||||
t.FailNow()
|
||||
@ -264,7 +275,7 @@ func TestFindVulnerabilities(t *testing.T) {
|
||||
}
|
||||
|
||||
// same vulnerability
|
||||
r, err := tx.FindVulnerabilities([]database.VulnerabilityID{
|
||||
r, err := FindVulnerabilities(tx, []database.VulnerabilityID{
|
||||
{Name: "CVE-OPENSSL-1-DEB7", Namespace: "debian:7"},
|
||||
{Name: "CVE-OPENSSL-1-DEB7", Namespace: "debian:7"},
|
||||
})
|
||||
@ -273,22 +284,22 @@ func TestFindVulnerabilities(t *testing.T) {
|
||||
for _, vuln := range r {
|
||||
if assert.True(t, vuln.Valid) {
|
||||
expected, _ := expectedExistingMap[database.VulnerabilityID{Name: "CVE-OPENSSL-1-DEB7", Namespace: "debian:7"}]
|
||||
assertVulnerabilityWithAffectedEqual(t, expected, vuln.VulnerabilityWithAffected)
|
||||
testutil.AssertVulnerabilityWithAffectedEqual(t, expected, vuln.VulnerabilityWithAffected)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteVulnerabilities(t *testing.T) {
|
||||
datastore, tx := openSessionForTest(t, "DeleteVulnerabilities", true)
|
||||
defer closeTest(t, datastore, tx)
|
||||
tx, cleanup := testutil.CreateTestTxWithFixtures(t, "DeleteVulnerabilities")
|
||||
defer cleanup()
|
||||
|
||||
remove := []database.VulnerabilityID{}
|
||||
// empty case
|
||||
assert.Nil(t, tx.DeleteVulnerabilities(remove))
|
||||
assert.Nil(t, DeleteVulnerabilities(tx, remove))
|
||||
// invalid case
|
||||
remove = append(remove, database.VulnerabilityID{})
|
||||
assert.NotNil(t, tx.DeleteVulnerabilities(remove))
|
||||
assert.NotNil(t, DeleteVulnerabilities(tx, remove))
|
||||
|
||||
// valid case
|
||||
validRemove := []database.VulnerabilityID{
|
||||
@ -296,8 +307,8 @@ func TestDeleteVulnerabilities(t *testing.T) {
|
||||
{Name: "CVE-NOPE", Namespace: "debian:7"},
|
||||
}
|
||||
|
||||
assert.Nil(t, tx.DeleteVulnerabilities(validRemove))
|
||||
vuln, err := tx.FindVulnerabilities(validRemove)
|
||||
assert.Nil(t, DeleteVulnerabilities(tx, validRemove))
|
||||
vuln, err := FindVulnerabilities(tx, validRemove)
|
||||
if assert.Nil(t, err) {
|
||||
for _, v := range vuln {
|
||||
assert.False(t, v.Valid)
|
||||
@ -306,20 +317,158 @@ func TestDeleteVulnerabilities(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestFindVulnerabilityIDs(t *testing.T) {
|
||||
store, tx := openSessionForTest(t, "FindVulnerabilityIDs", true)
|
||||
defer closeTest(t, store, tx)
|
||||
tx, cleanup := testutil.CreateTestTxWithFixtures(t, "FindVulnerabilityIDs")
|
||||
defer cleanup()
|
||||
|
||||
ids, err := tx.findLatestDeletedVulnerabilityIDs([]database.VulnerabilityID{{Name: "CVE-DELETED", Namespace: "debian:7"}})
|
||||
ids, err := FindLatestDeletedVulnerabilityIDs(tx, []database.VulnerabilityID{{Name: "CVE-DELETED", Namespace: "debian:7"}})
|
||||
if assert.Nil(t, err) {
|
||||
if !(assert.Len(t, ids, 1) && assert.True(t, ids[0].Valid) && assert.Equal(t, 3, int(ids[0].Int64))) {
|
||||
assert.Fail(t, "")
|
||||
}
|
||||
}
|
||||
|
||||
ids, err = tx.findNotDeletedVulnerabilityIDs([]database.VulnerabilityID{{Name: "CVE-NOPE", Namespace: "debian:7"}})
|
||||
ids, err = FindNotDeletedVulnerabilityIDs(tx, []database.VulnerabilityID{{Name: "CVE-NOPE", Namespace: "debian:7"}})
|
||||
if assert.Nil(t, err) {
|
||||
if !(assert.Len(t, ids, 1) && assert.True(t, ids[0].Valid) && assert.Equal(t, 2, int(ids[0].Int64))) {
|
||||
assert.Fail(t, "")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindAffectedNamespacedFeatures(t *testing.T) {
|
||||
tx, cleanup := testutil.CreateTestTxWithFixtures(t, "FindAffectedNamespacedFeatures")
|
||||
defer cleanup()
|
||||
|
||||
ns := database.NamespacedFeature{
|
||||
Feature: database.Feature{
|
||||
Name: "openssl",
|
||||
Version: "1.0",
|
||||
VersionFormat: "dpkg",
|
||||
Type: database.SourcePackage,
|
||||
},
|
||||
Namespace: database.Namespace{
|
||||
Name: "debian:7",
|
||||
VersionFormat: "dpkg",
|
||||
},
|
||||
}
|
||||
|
||||
ans, err := FindAffectedNamespacedFeatures(tx, []database.NamespacedFeature{ns})
|
||||
if assert.Nil(t, err) &&
|
||||
assert.Len(t, ans, 1) &&
|
||||
assert.True(t, ans[0].Valid) &&
|
||||
assert.Len(t, ans[0].AffectedBy, 1) {
|
||||
assert.Equal(t, "CVE-OPENSSL-1-DEB7", ans[0].AffectedBy[0].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func genRandomVulnerabilityAndNamespacedFeature(t *testing.T, store *sql.DB) ([]database.NamespacedFeature, []database.VulnerabilityWithAffected) {
|
||||
tx, err := store.Begin()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
numFeatures := 100
|
||||
numVulnerabilities := 100
|
||||
|
||||
featureName := "TestFeature"
|
||||
featureVersionFormat := dpkg.ParserName
|
||||
// Insert the namespace on which we'll work.
|
||||
ns := database.Namespace{
|
||||
Name: "TestRaceAffectsFeatureNamespace1",
|
||||
VersionFormat: dpkg.ParserName,
|
||||
}
|
||||
|
||||
if !assert.Nil(t, namespace.PersistNamespaces(tx, []database.Namespace{ns})) {
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
// Generate Distinct random features
|
||||
features := make([]database.Feature, numFeatures)
|
||||
nsFeatures := make([]database.NamespacedFeature, numFeatures)
|
||||
for i := 0; i < numFeatures; i++ {
|
||||
version := rand.Intn(numFeatures)
|
||||
|
||||
features[i] = *database.NewSourcePackage(featureName, strconv.Itoa(version), featureVersionFormat)
|
||||
nsFeatures[i] = database.NamespacedFeature{
|
||||
Namespace: ns,
|
||||
Feature: features[i],
|
||||
}
|
||||
}
|
||||
|
||||
if !assert.Nil(t, feature.PersistFeatures(tx, features)) {
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
// Generate vulnerabilities.
|
||||
vulnerabilities := []database.VulnerabilityWithAffected{}
|
||||
for i := 0; i < numVulnerabilities; i++ {
|
||||
// any version less than this is vulnerable
|
||||
version := rand.Intn(numFeatures) + 1
|
||||
|
||||
vulnerability := database.VulnerabilityWithAffected{
|
||||
Vulnerability: database.Vulnerability{
|
||||
Name: uuid.New(),
|
||||
Namespace: ns,
|
||||
Severity: database.UnknownSeverity,
|
||||
},
|
||||
Affected: []database.AffectedFeature{
|
||||
{
|
||||
Namespace: ns,
|
||||
FeatureName: featureName,
|
||||
FeatureType: database.SourcePackage,
|
||||
AffectedVersion: strconv.Itoa(version),
|
||||
FixedInVersion: strconv.Itoa(version),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
vulnerabilities = append(vulnerabilities, vulnerability)
|
||||
}
|
||||
tx.Commit()
|
||||
|
||||
return nsFeatures, vulnerabilities
|
||||
}
|
||||
|
||||
func TestVulnChangeAffectsVulnerableFeatures(t *testing.T) {
|
||||
db, cleanup := testutil.CreateTestDB(t, "caching")
|
||||
defer cleanup()
|
||||
|
||||
nsFeatures, vulnerabilities := genRandomVulnerabilityAndNamespacedFeature(t, db)
|
||||
tx, err := db.Begin()
|
||||
require.Nil(t, err)
|
||||
|
||||
require.Nil(t, feature.PersistNamespacedFeatures(tx, nsFeatures))
|
||||
require.Nil(t, tx.Commit())
|
||||
|
||||
tx, err = db.Begin()
|
||||
require.Nil(t, InsertVulnerabilities(tx, vulnerabilities))
|
||||
require.Nil(t, tx.Commit())
|
||||
|
||||
tx, err = db.Begin()
|
||||
require.Nil(t, err)
|
||||
defer tx.Rollback()
|
||||
|
||||
affected, err := FindAffectedNamespacedFeatures(tx, nsFeatures)
|
||||
require.Nil(t, err)
|
||||
|
||||
for _, ansf := range affected {
|
||||
require.True(t, ansf.Valid)
|
||||
|
||||
expectedAffectedNames := []string{}
|
||||
for _, vuln := range vulnerabilities {
|
||||
if ok, err := versionfmt.InRange(dpkg.ParserName, ansf.Version, vuln.Affected[0].AffectedVersion); err == nil {
|
||||
if ok {
|
||||
expectedAffectedNames = append(expectedAffectedNames, vuln.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
actualAffectedNames := []string{}
|
||||
for _, s := range ansf.AffectedBy {
|
||||
actualAffectedNames = append(actualAffectedNames, s.Name)
|
||||
}
|
||||
|
||||
require.Len(t, strutil.Difference(expectedAffectedNames, actualAffectedNames), 0, "\nvulns: %#v\nfeature:%#v\nexpected:%#v\nactual:%#v", vulnerabilities, ansf.NamespacedFeature, expectedAffectedNames, actualAffectedNames)
|
||||
require.Len(t, strutil.Difference(actualAffectedNames, expectedAffectedNames), 0)
|
||||
}
|
||||
}
|
50
database/vulnerability.go
Normal file
50
database/vulnerability.go
Normal file
@ -0,0 +1,50 @@
|
||||
// Copyright 2019 clair authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package database
|
||||
|
||||
// VulnerabilityID is an identifier for every vulnerability. Every vulnerability
|
||||
// has unique namespace and name.
|
||||
type VulnerabilityID struct {
|
||||
Name string
|
||||
Namespace string
|
||||
}
|
||||
|
||||
// Vulnerability represents CVE or similar vulnerability reports.
|
||||
type Vulnerability struct {
|
||||
Name string
|
||||
Namespace Namespace
|
||||
|
||||
Description string
|
||||
Link string
|
||||
Severity Severity
|
||||
|
||||
Metadata MetadataMap
|
||||
}
|
||||
|
||||
// VulnerabilityWithAffected is a vulnerability with all known affected
|
||||
// features.
|
||||
type VulnerabilityWithAffected struct {
|
||||
Vulnerability
|
||||
|
||||
Affected []AffectedFeature
|
||||
}
|
||||
|
||||
// NullableVulnerability is a vulnerability with whether the vulnerability is
|
||||
// found in datastore.
|
||||
type NullableVulnerability struct {
|
||||
VulnerabilityWithAffected
|
||||
|
||||
Valid bool
|
||||
}
|
Loading…
Reference in New Issue
Block a user