diff --git a/api/api.go b/api/api.go index cf1a6a6d..ff73c13e 100644 --- a/api/api.go +++ b/api/api.go @@ -26,7 +26,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/tylerb/graceful" - "github.com/coreos/clair/api/v2" + "github.com/coreos/clair/api/v3" "github.com/coreos/clair/database" "github.com/coreos/clair/pkg/stopper" ) @@ -35,11 +35,9 @@ const timeoutResponse = `{"Error":{"Message":"Clair failed to respond within the // Config is the configuration for the API service. type Config struct { - Port int GrpcPort int HealthPort int Timeout time.Duration - PaginationKey string CertFile, KeyFile, CAFile string } @@ -51,40 +49,7 @@ func RunV2(cfg *Config, store database.Datastore) { if tlsConfig != nil { log.Info("main API configured with client certificate authentication") } - v2.Run(cfg.GrpcPort, tlsConfig, cfg.PaginationKey, cfg.CertFile, cfg.KeyFile, store) -} - -func Run(cfg *Config, store database.Datastore, st *stopper.Stopper) { - defer st.End() - - // Do not run the API service if there is no config. - if cfg == nil { - log.Info("main API service is disabled.") - return - } - log.WithField("port", cfg.Port).Info("starting main API") - - tlsConfig, err := tlsClientConfig(cfg.CAFile) - if err != nil { - log.WithError(err).Fatal("could not initialize client cert authentication") - } - if tlsConfig != nil { - log.Info("main API configured with client certificate authentication") - } - - srv := &graceful.Server{ - Timeout: 0, // Already handled by our TimeOut middleware - NoSignalHandling: true, // We want to use our own Stopper - Server: &http.Server{ - Addr: ":" + strconv.Itoa(cfg.Port), - TLSConfig: tlsConfig, - Handler: http.TimeoutHandler(newAPIHandler(cfg, store), cfg.Timeout, timeoutResponse), - }, - } - - listenAndServeWithStopper(srv, st, cfg.CertFile, cfg.KeyFile) - - log.Info("main API stopped") + v3.Run(cfg.GrpcPort, tlsConfig, cfg.CertFile, cfg.KeyFile, store) } func RunHealth(cfg *Config, store database.Datastore, st *stopper.Stopper) { diff --git a/api/router.go b/api/router.go index 59ebf96b..c3bd3b41 100644 --- a/api/router.go +++ b/api/router.go @@ -16,13 +16,9 @@ package api import ( "net/http" - "strings" "github.com/julienschmidt/httprouter" - log "github.com/sirupsen/logrus" - "github.com/coreos/clair/api/httputil" - "github.com/coreos/clair/api/v1" "github.com/coreos/clair/database" ) @@ -30,34 +26,6 @@ import ( // depending on the API version specified in the request URI. type router map[string]*httprouter.Router -// Let's hope we never have more than 99 API versions. -const apiVersionLength = len("v99") - -func newAPIHandler(cfg *Config, store database.Datastore) http.Handler { - router := make(router) - router["/v1"] = v1.NewRouter(store, cfg.PaginationKey) - return router -} - -func (rtr router) ServeHTTP(w http.ResponseWriter, r *http.Request) { - urlStr := r.URL.String() - var version string - if len(urlStr) >= apiVersionLength { - version = urlStr[:apiVersionLength] - } - - if router, _ := rtr[version]; router != nil { - // Remove the version number from the request path to let the router do its - // job but do not update the RequestURI - r.URL.Path = strings.Replace(r.URL.Path, version, "", 1) - router.ServeHTTP(w, r) - return - } - - log.WithFields(log.Fields{"status": http.StatusNotFound, "method": r.Method, "request uri": r.RequestURI, "remote addr": httputil.GetClientAddr(r)}).Info("Served HTTP request") - http.NotFound(w, r) -} - func newHealthHandler(store database.Datastore) http.Handler { router := httprouter.New() router.GET("/health", healthHandler(store)) diff --git a/api/v1/models.go b/api/v1/models.go deleted file mode 100644 index 2a1b8065..00000000 --- a/api/v1/models.go +++ /dev/null @@ -1,317 +0,0 @@ -// Copyright 2017 clair authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package v1 - -import ( - "fmt" - - "github.com/coreos/clair/api/token" - "github.com/coreos/clair/database" - "github.com/coreos/clair/ext/versionfmt" -) - -type Error struct { - Message string `json:"Message,omitempty"` -} - -type Layer struct { - Name string `json:"Name,omitempty"` - NamespaceNames []string `json:"NamespaceNames,omitempty"` - Path string `json:"Path,omitempty"` - Headers map[string]string `json:"Headers,omitempty"` - ParentName string `json:"ParentName,omitempty"` - Format string `json:"Format,omitempty"` - IndexedByVersion int `json:"IndexedByVersion,omitempty"` - Features []Feature `json:"Features,omitempty"` -} - -func LayerFromDatabaseModel(dbLayer database.Layer, withFeatures, withVulnerabilities bool) Layer { - layer := Layer{ - Name: dbLayer.Name, - IndexedByVersion: dbLayer.EngineVersion, - } - - if dbLayer.Parent != nil { - layer.ParentName = dbLayer.Parent.Name - } - - for _, ns := range dbLayer.Namespaces { - layer.NamespaceNames = append(layer.NamespaceNames, ns.Name) - } - - if withFeatures || withVulnerabilities && dbLayer.Features != nil { - for _, dbFeatureVersion := range dbLayer.Features { - feature := Feature{ - Name: dbFeatureVersion.Feature.Name, - NamespaceName: dbFeatureVersion.Feature.Namespace.Name, - VersionFormat: dbFeatureVersion.Feature.Namespace.VersionFormat, - Version: dbFeatureVersion.Version, - AddedBy: dbFeatureVersion.AddedBy.Name, - } - - for _, dbVuln := range dbFeatureVersion.AffectedBy { - vuln := Vulnerability{ - Name: dbVuln.Name, - NamespaceName: dbVuln.Namespace.Name, - Description: dbVuln.Description, - Link: dbVuln.Link, - Severity: string(dbVuln.Severity), - Metadata: dbVuln.Metadata, - } - - if dbVuln.FixedBy != versionfmt.MaxVersion { - vuln.FixedBy = dbVuln.FixedBy - } - feature.Vulnerabilities = append(feature.Vulnerabilities, vuln) - } - layer.Features = append(layer.Features, feature) - } - } - - return layer -} - -type Namespace struct { - Name string `json:"Name,omitempty"` - VersionFormat string `json:"VersionFormat,omitempty"` -} - -type Vulnerability struct { - Name string `json:"Name,omitempty"` - NamespaceName string `json:"NamespaceName,omitempty"` - Description string `json:"Description,omitempty"` - Link string `json:"Link,omitempty"` - Severity string `json:"Severity,omitempty"` - Metadata map[string]interface{} `json:"Metadata,omitempty"` - FixedBy string `json:"FixedBy,omitempty"` - FixedIn []Feature `json:"FixedIn,omitempty"` -} - -func (v Vulnerability) DatabaseModel() (database.Vulnerability, error) { - severity, err := database.NewSeverity(v.Severity) - if err != nil { - return database.Vulnerability{}, err - } - - var dbFeatures []database.FeatureVersion - for _, feature := range v.FixedIn { - dbFeature, err := feature.DatabaseModel() - if err != nil { - return database.Vulnerability{}, err - } - - dbFeatures = append(dbFeatures, dbFeature) - } - - return database.Vulnerability{ - Name: v.Name, - Namespace: database.Namespace{Name: v.NamespaceName}, - Description: v.Description, - Link: v.Link, - Severity: severity, - Metadata: v.Metadata, - FixedIn: dbFeatures, - }, nil -} - -func VulnerabilityFromDatabaseModel(dbVuln database.Vulnerability, withFixedIn bool) Vulnerability { - vuln := Vulnerability{ - Name: dbVuln.Name, - NamespaceName: dbVuln.Namespace.Name, - Description: dbVuln.Description, - Link: dbVuln.Link, - Severity: string(dbVuln.Severity), - Metadata: dbVuln.Metadata, - } - - if withFixedIn { - for _, dbFeatureVersion := range dbVuln.FixedIn { - vuln.FixedIn = append(vuln.FixedIn, FeatureFromDatabaseModel(dbFeatureVersion)) - } - } - - return vuln -} - -type Feature struct { - Name string `json:"Name,omitempty"` - NamespaceName string `json:"NamespaceName,omitempty"` - VersionFormat string `json:"VersionFormat,omitempty"` - Version string `json:"Version,omitempty"` - Vulnerabilities []Vulnerability `json:"Vulnerabilities,omitempty"` - AddedBy string `json:"AddedBy,omitempty"` -} - -func FeatureFromDatabaseModel(dbFeatureVersion database.FeatureVersion) Feature { - version := dbFeatureVersion.Version - if version == versionfmt.MaxVersion { - version = "None" - } - - return Feature{ - Name: dbFeatureVersion.Feature.Name, - NamespaceName: dbFeatureVersion.Feature.Namespace.Name, - VersionFormat: dbFeatureVersion.Feature.Namespace.VersionFormat, - Version: version, - AddedBy: dbFeatureVersion.AddedBy.Name, - } -} - -func (f Feature) DatabaseModel() (fv database.FeatureVersion, err error) { - var version string - if f.Version == "None" { - version = versionfmt.MaxVersion - } else { - err = versionfmt.Valid(f.VersionFormat, f.Version) - if err != nil { - return - } - version = f.Version - } - - fv = database.FeatureVersion{ - Feature: database.Feature{ - Name: f.Name, - Namespace: database.Namespace{ - Name: f.NamespaceName, - VersionFormat: f.VersionFormat, - }, - }, - Version: version, - } - - return -} - -type Notification struct { - Name string `json:"Name,omitempty"` - Created string `json:"Created,omitempty"` - Notified string `json:"Notified,omitempty"` - Deleted string `json:"Deleted,omitempty"` - Limit int `json:"Limit,omitempty"` - Page string `json:"Page,omitempty"` - NextPage string `json:"NextPage,omitempty"` - Old *VulnerabilityWithLayers `json:"Old,omitempty"` - New *VulnerabilityWithLayers `json:"New,omitempty"` -} - -func NotificationFromDatabaseModel(dbNotification database.VulnerabilityNotification, limit int, pageToken string, nextPage database.VulnerabilityNotificationPageNumber, key string) Notification { - var oldVuln *VulnerabilityWithLayers - if dbNotification.OldVulnerability != nil { - v := VulnerabilityWithLayersFromDatabaseModel(*dbNotification.OldVulnerability) - oldVuln = &v - } - - var newVuln *VulnerabilityWithLayers - if dbNotification.NewVulnerability != nil { - v := VulnerabilityWithLayersFromDatabaseModel(*dbNotification.NewVulnerability) - newVuln = &v - } - - var nextPageStr string - if nextPage != database.NoVulnerabilityNotificationPage { - nextPageBytes, _ := token.Marshal(nextPage, key) - nextPageStr = string(nextPageBytes) - } - - var created, notified, deleted string - if !dbNotification.Created.IsZero() { - created = fmt.Sprintf("%d", dbNotification.Created.Unix()) - } - if !dbNotification.Notified.IsZero() { - notified = fmt.Sprintf("%d", dbNotification.Notified.Unix()) - } - if !dbNotification.Deleted.IsZero() { - deleted = fmt.Sprintf("%d", dbNotification.Deleted.Unix()) - } - - // TODO(jzelinskie): implement "changed" key - fmt.Println(dbNotification.Deleted.IsZero()) - return Notification{ - Name: dbNotification.Name, - Created: created, - Notified: notified, - Deleted: deleted, - Limit: limit, - Page: pageToken, - NextPage: nextPageStr, - Old: oldVuln, - New: newVuln, - } -} - -type VulnerabilityWithLayers struct { - Vulnerability *Vulnerability `json:"Vulnerability,omitempty"` - - // This field is guaranteed to be in order only for pagination. - // Indices from different notifications may not be comparable. - OrderedLayersIntroducingVulnerability []OrderedLayerName `json:"OrderedLayersIntroducingVulnerability,omitempty"` - - // This field is deprecated. - LayersIntroducingVulnerability []string `json:"LayersIntroducingVulnerability,omitempty"` -} - -type OrderedLayerName struct { - Index int `json:"Index"` - LayerName string `json:"LayerName"` -} - -func VulnerabilityWithLayersFromDatabaseModel(dbVuln database.Vulnerability) VulnerabilityWithLayers { - vuln := VulnerabilityFromDatabaseModel(dbVuln, true) - - var layers []string - var orderedLayers []OrderedLayerName - for _, layer := range dbVuln.LayersIntroducingVulnerability { - layers = append(layers, layer.Name) - orderedLayers = append(orderedLayers, OrderedLayerName{ - Index: layer.ID, - LayerName: layer.Name, - }) - } - - return VulnerabilityWithLayers{ - Vulnerability: &vuln, - OrderedLayersIntroducingVulnerability: orderedLayers, - LayersIntroducingVulnerability: layers, - } -} - -type LayerEnvelope struct { - Layer *Layer `json:"Layer,omitempty"` - Error *Error `json:"Error,omitempty"` -} - -type NamespaceEnvelope struct { - Namespaces *[]Namespace `json:"Namespaces,omitempty"` - Error *Error `json:"Error,omitempty"` -} - -type VulnerabilityEnvelope struct { - Vulnerability *Vulnerability `json:"Vulnerability,omitempty"` - Vulnerabilities *[]Vulnerability `json:"Vulnerabilities,omitempty"` - NextPage string `json:"NextPage,omitempty"` - Error *Error `json:"Error,omitempty"` -} - -type NotificationEnvelope struct { - Notification *Notification `json:"Notification,omitempty"` - Error *Error `json:"Error,omitempty"` -} - -type FeatureEnvelope struct { - Feature *Feature `json:"Feature,omitempty"` - Features *[]Feature `json:"Features,omitempty"` - Error *Error `json:"Error,omitempty"` -} diff --git a/api/v1/router.go b/api/v1/router.go deleted file mode 100644 index d5e93eeb..00000000 --- a/api/v1/router.go +++ /dev/null @@ -1,100 +0,0 @@ -// Copyright 2015 clair authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package v1 implements the first version of the Clair API. -package v1 - -import ( - "net/http" - "strconv" - "time" - - "github.com/julienschmidt/httprouter" - "github.com/prometheus/client_golang/prometheus" - log "github.com/sirupsen/logrus" - - "github.com/coreos/clair/api/httputil" - "github.com/coreos/clair/database" -) - -var ( - promResponseDurationMilliseconds = prometheus.NewHistogramVec(prometheus.HistogramOpts{ - Name: "clair_api_response_duration_milliseconds", - Help: "The duration of time it takes to receieve and write a response to an API request", - Buckets: prometheus.ExponentialBuckets(9.375, 2, 10), - }, []string{"route", "code"}) -) - -func init() { - prometheus.MustRegister(promResponseDurationMilliseconds) -} - -type handler func(http.ResponseWriter, *http.Request, httprouter.Params, *context) (route string, status int) - -func httpHandler(h handler, ctx *context) httprouter.Handle { - return func(w http.ResponseWriter, r *http.Request, p httprouter.Params) { - start := time.Now() - route, status := h(w, r, p, ctx) - statusStr := strconv.Itoa(status) - if status == 0 { - statusStr = "???" - } - - promResponseDurationMilliseconds. - WithLabelValues(route, statusStr). - Observe(float64(time.Since(start).Nanoseconds()) / float64(time.Millisecond)) - - log.WithFields(log.Fields{"remote addr": httputil.GetClientAddr(r), "method": r.Method, "request uri": r.RequestURI, "status": statusStr, "elapsed time": time.Since(start)}).Info("Handled HTTP request") - } -} - -type context struct { - Store database.Datastore - PaginationKey string -} - -// NewRouter creates an HTTP router for version 1 of the Clair API. -func NewRouter(store database.Datastore, paginationKey string) *httprouter.Router { - router := httprouter.New() - ctx := &context{store, paginationKey} - - // Layers - router.POST("/layers", httpHandler(postLayer, ctx)) - router.GET("/layers/:layerName", httpHandler(getLayer, ctx)) - router.DELETE("/layers/:layerName", httpHandler(deleteLayer, ctx)) - - // Namespaces - router.GET("/namespaces", httpHandler(getNamespaces, ctx)) - - // Vulnerabilities - router.GET("/namespaces/:namespaceName/vulnerabilities", httpHandler(getVulnerabilities, ctx)) - router.POST("/namespaces/:namespaceName/vulnerabilities", httpHandler(postVulnerability, ctx)) - router.GET("/namespaces/:namespaceName/vulnerabilities/:vulnerabilityName", httpHandler(getVulnerability, ctx)) - router.PUT("/namespaces/:namespaceName/vulnerabilities/:vulnerabilityName", httpHandler(putVulnerability, ctx)) - router.DELETE("/namespaces/:namespaceName/vulnerabilities/:vulnerabilityName", httpHandler(deleteVulnerability, ctx)) - - // Fixes - router.GET("/namespaces/:namespaceName/vulnerabilities/:vulnerabilityName/fixes", httpHandler(getFixes, ctx)) - router.PUT("/namespaces/:namespaceName/vulnerabilities/:vulnerabilityName/fixes/:fixName", httpHandler(putFix, ctx)) - router.DELETE("/namespaces/:namespaceName/vulnerabilities/:vulnerabilityName/fixes/:fixName", httpHandler(deleteFix, ctx)) - - // Notifications - router.GET("/notifications/:notificationName", httpHandler(getNotification, ctx)) - router.DELETE("/notifications/:notificationName", httpHandler(deleteNotification, ctx)) - - // Metrics - router.GET("/metrics", httpHandler(getMetrics, ctx)) - - return router -} diff --git a/api/v1/routes.go b/api/v1/routes.go deleted file mode 100644 index 9a5f6bb3..00000000 --- a/api/v1/routes.go +++ /dev/null @@ -1,503 +0,0 @@ -// Copyright 2015 clair authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package v1 - -import ( - "compress/gzip" - "encoding/json" - "io" - "net/http" - "strconv" - "strings" - - "github.com/julienschmidt/httprouter" - "github.com/prometheus/client_golang/prometheus" - log "github.com/sirupsen/logrus" - - "github.com/coreos/clair" - "github.com/coreos/clair/api/token" - "github.com/coreos/clair/database" - "github.com/coreos/clair/pkg/commonerr" - "github.com/coreos/clair/pkg/tarutil" -) - -const ( - // These are the route identifiers for prometheus. - postLayerRoute = "v1/postLayer" - getLayerRoute = "v1/getLayer" - deleteLayerRoute = "v1/deleteLayer" - getNamespacesRoute = "v1/getNamespaces" - getVulnerabilitiesRoute = "v1/getVulnerabilities" - postVulnerabilityRoute = "v1/postVulnerability" - getVulnerabilityRoute = "v1/getVulnerability" - putVulnerabilityRoute = "v1/putVulnerability" - deleteVulnerabilityRoute = "v1/deleteVulnerability" - getFixesRoute = "v1/getFixes" - putFixRoute = "v1/putFix" - deleteFixRoute = "v1/deleteFix" - getNotificationRoute = "v1/getNotification" - deleteNotificationRoute = "v1/deleteNotification" - getMetricsRoute = "v1/getMetrics" - - // maxBodySize restricts client request bodies to 1MiB. - maxBodySize int64 = 1048576 - - // statusUnprocessableEntity represents the 422 (Unprocessable Entity) status code, which means - // the server understands the content type of the request entity - // (hence a 415(Unsupported Media Type) status code is inappropriate), and the syntax of the - // request entity is correct (thus a 400 (Bad Request) status code is inappropriate) but was - // unable to process the contained instructions. - statusUnprocessableEntity = 422 -) - -func decodeJSON(r *http.Request, v interface{}) error { - defer r.Body.Close() - return json.NewDecoder(io.LimitReader(r.Body, maxBodySize)).Decode(v) -} - -func writeResponse(w http.ResponseWriter, r *http.Request, status int, resp interface{}) { - // Headers must be written before the response. - header := w.Header() - header.Set("Content-Type", "application/json;charset=utf-8") - header.Set("Server", "clair") - - // Gzip the response if the client supports it. - var writer io.Writer = w - if strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") { - gzipWriter := gzip.NewWriter(w) - defer gzipWriter.Close() - writer = gzipWriter - - header.Set("Content-Encoding", "gzip") - } - - // Write the response. - w.WriteHeader(status) - err := json.NewEncoder(writer).Encode(resp) - - if err != nil { - switch err.(type) { - case *json.MarshalerError, *json.UnsupportedTypeError, *json.UnsupportedValueError: - panic("v1: failed to marshal response: " + err.Error()) - default: - log.WithError(err).Warning("failed to write response") - } - } -} - -func postLayer(w http.ResponseWriter, r *http.Request, p httprouter.Params, ctx *context) (string, int) { - request := LayerEnvelope{} - err := decodeJSON(r, &request) - if err != nil { - writeResponse(w, r, http.StatusBadRequest, LayerEnvelope{Error: &Error{err.Error()}}) - return postLayerRoute, http.StatusBadRequest - } - - if request.Layer == nil { - writeResponse(w, r, http.StatusBadRequest, LayerEnvelope{Error: &Error{"failed to provide layer"}}) - return postLayerRoute, http.StatusBadRequest - } - - err = clair.ProcessLayer(ctx.Store, request.Layer.Format, request.Layer.Name, request.Layer.ParentName, request.Layer.Path, request.Layer.Headers) - if err != nil { - if err == tarutil.ErrCouldNotExtract || - err == tarutil.ErrExtractedFileTooBig || - err == clair.ErrUnsupported { - writeResponse(w, r, statusUnprocessableEntity, LayerEnvelope{Error: &Error{err.Error()}}) - return postLayerRoute, statusUnprocessableEntity - } - - if _, badreq := err.(*commonerr.ErrBadRequest); badreq { - writeResponse(w, r, http.StatusBadRequest, LayerEnvelope{Error: &Error{err.Error()}}) - return postLayerRoute, http.StatusBadRequest - } - - writeResponse(w, r, http.StatusInternalServerError, LayerEnvelope{Error: &Error{err.Error()}}) - return postLayerRoute, http.StatusInternalServerError - } - - writeResponse(w, r, http.StatusCreated, LayerEnvelope{Layer: &Layer{ - Name: request.Layer.Name, - ParentName: request.Layer.ParentName, - Path: request.Layer.Path, - Headers: request.Layer.Headers, - Format: request.Layer.Format, - IndexedByVersion: clair.Version, - }}) - return postLayerRoute, http.StatusCreated -} - -func getLayer(w http.ResponseWriter, r *http.Request, p httprouter.Params, ctx *context) (string, int) { - _, withFeatures := r.URL.Query()["features"] - _, withVulnerabilities := r.URL.Query()["vulnerabilities"] - - dbLayer, err := ctx.Store.FindLayer(p.ByName("layerName"), withFeatures, withVulnerabilities) - if err == commonerr.ErrNotFound { - writeResponse(w, r, http.StatusNotFound, LayerEnvelope{Error: &Error{err.Error()}}) - return getLayerRoute, http.StatusNotFound - } else if err != nil { - writeResponse(w, r, http.StatusInternalServerError, LayerEnvelope{Error: &Error{err.Error()}}) - return getLayerRoute, http.StatusInternalServerError - } - - layer := LayerFromDatabaseModel(dbLayer, withFeatures, withVulnerabilities) - - writeResponse(w, r, http.StatusOK, LayerEnvelope{Layer: &layer}) - return getLayerRoute, http.StatusOK -} - -func deleteLayer(w http.ResponseWriter, r *http.Request, p httprouter.Params, ctx *context) (string, int) { - err := ctx.Store.DeleteLayer(p.ByName("layerName")) - if err == commonerr.ErrNotFound { - writeResponse(w, r, http.StatusNotFound, LayerEnvelope{Error: &Error{err.Error()}}) - return deleteLayerRoute, http.StatusNotFound - } else if err != nil { - writeResponse(w, r, http.StatusInternalServerError, LayerEnvelope{Error: &Error{err.Error()}}) - return deleteLayerRoute, http.StatusInternalServerError - } - - w.WriteHeader(http.StatusOK) - return deleteLayerRoute, http.StatusOK -} - -func getNamespaces(w http.ResponseWriter, r *http.Request, p httprouter.Params, ctx *context) (string, int) { - dbNamespaces, err := ctx.Store.ListNamespaces() - if err != nil { - writeResponse(w, r, http.StatusInternalServerError, NamespaceEnvelope{Error: &Error{err.Error()}}) - return getNamespacesRoute, http.StatusInternalServerError - } - var namespaces []Namespace - for _, dbNamespace := range dbNamespaces { - namespaces = append(namespaces, Namespace{ - Name: dbNamespace.Name, - VersionFormat: dbNamespace.VersionFormat, - }) - } - - writeResponse(w, r, http.StatusOK, NamespaceEnvelope{Namespaces: &namespaces}) - return getNamespacesRoute, http.StatusOK -} - -func getVulnerabilities(w http.ResponseWriter, r *http.Request, p httprouter.Params, ctx *context) (string, int) { - query := r.URL.Query() - - limitStrs, limitExists := query["limit"] - if !limitExists { - writeResponse(w, r, http.StatusBadRequest, VulnerabilityEnvelope{Error: &Error{"must provide limit query parameter"}}) - return getVulnerabilitiesRoute, http.StatusBadRequest - } - limit, err := strconv.Atoi(limitStrs[0]) - if err != nil { - writeResponse(w, r, http.StatusBadRequest, VulnerabilityEnvelope{Error: &Error{"invalid limit format: " + err.Error()}}) - return getVulnerabilitiesRoute, http.StatusBadRequest - } else if limit < 0 { - writeResponse(w, r, http.StatusBadRequest, VulnerabilityEnvelope{Error: &Error{"limit value should not be less than zero"}}) - return getVulnerabilitiesRoute, http.StatusBadRequest - } - - page := 0 - pageStrs, pageExists := query["page"] - if pageExists { - err = token.Unmarshal(pageStrs[0], ctx.PaginationKey, &page) - if err != nil { - writeResponse(w, r, http.StatusBadRequest, VulnerabilityEnvelope{Error: &Error{"invalid page format: " + err.Error()}}) - return getNotificationRoute, http.StatusBadRequest - } - } - - namespace := p.ByName("namespaceName") - if namespace == "" { - writeResponse(w, r, http.StatusBadRequest, VulnerabilityEnvelope{Error: &Error{"namespace should not be empty"}}) - return getNotificationRoute, http.StatusBadRequest - } - - dbVulns, nextPage, err := ctx.Store.ListVulnerabilities(namespace, limit, page) - if err == commonerr.ErrNotFound { - writeResponse(w, r, http.StatusNotFound, VulnerabilityEnvelope{Error: &Error{err.Error()}}) - return getVulnerabilityRoute, http.StatusNotFound - } else if err != nil { - writeResponse(w, r, http.StatusInternalServerError, VulnerabilityEnvelope{Error: &Error{err.Error()}}) - return getVulnerabilitiesRoute, http.StatusInternalServerError - } - - var vulns []Vulnerability - for _, dbVuln := range dbVulns { - vuln := VulnerabilityFromDatabaseModel(dbVuln, false) - vulns = append(vulns, vuln) - } - - var nextPageStr string - if nextPage != -1 { - nextPageBytes, err := token.Marshal(nextPage, ctx.PaginationKey) - if err != nil { - writeResponse(w, r, http.StatusBadRequest, VulnerabilityEnvelope{Error: &Error{"failed to marshal token: " + err.Error()}}) - return getNotificationRoute, http.StatusBadRequest - } - nextPageStr = string(nextPageBytes) - } - - writeResponse(w, r, http.StatusOK, VulnerabilityEnvelope{Vulnerabilities: &vulns, NextPage: nextPageStr}) - return getVulnerabilitiesRoute, http.StatusOK -} - -func postVulnerability(w http.ResponseWriter, r *http.Request, p httprouter.Params, ctx *context) (string, int) { - request := VulnerabilityEnvelope{} - err := decodeJSON(r, &request) - if err != nil { - writeResponse(w, r, http.StatusBadRequest, VulnerabilityEnvelope{Error: &Error{err.Error()}}) - return postVulnerabilityRoute, http.StatusBadRequest - } - - if request.Vulnerability == nil { - writeResponse(w, r, http.StatusBadRequest, VulnerabilityEnvelope{Error: &Error{"failed to provide vulnerability"}}) - return postVulnerabilityRoute, http.StatusBadRequest - } - - vuln, err := request.Vulnerability.DatabaseModel() - if err != nil { - writeResponse(w, r, http.StatusBadRequest, VulnerabilityEnvelope{Error: &Error{err.Error()}}) - return postVulnerabilityRoute, http.StatusBadRequest - } - - err = ctx.Store.InsertVulnerabilities([]database.Vulnerability{vuln}, true) - if err != nil { - switch err.(type) { - case *commonerr.ErrBadRequest: - writeResponse(w, r, http.StatusBadRequest, VulnerabilityEnvelope{Error: &Error{err.Error()}}) - return postVulnerabilityRoute, http.StatusBadRequest - default: - writeResponse(w, r, http.StatusInternalServerError, VulnerabilityEnvelope{Error: &Error{err.Error()}}) - return postVulnerabilityRoute, http.StatusInternalServerError - } - } - - writeResponse(w, r, http.StatusCreated, VulnerabilityEnvelope{Vulnerability: request.Vulnerability}) - return postVulnerabilityRoute, http.StatusCreated -} - -func getVulnerability(w http.ResponseWriter, r *http.Request, p httprouter.Params, ctx *context) (string, int) { - _, withFixedIn := r.URL.Query()["fixedIn"] - - dbVuln, err := ctx.Store.FindVulnerability(p.ByName("namespaceName"), p.ByName("vulnerabilityName")) - if err == commonerr.ErrNotFound { - writeResponse(w, r, http.StatusNotFound, VulnerabilityEnvelope{Error: &Error{err.Error()}}) - return getVulnerabilityRoute, http.StatusNotFound - } else if err != nil { - writeResponse(w, r, http.StatusInternalServerError, VulnerabilityEnvelope{Error: &Error{err.Error()}}) - return getVulnerabilityRoute, http.StatusInternalServerError - } - - vuln := VulnerabilityFromDatabaseModel(dbVuln, withFixedIn) - - writeResponse(w, r, http.StatusOK, VulnerabilityEnvelope{Vulnerability: &vuln}) - return getVulnerabilityRoute, http.StatusOK -} - -func putVulnerability(w http.ResponseWriter, r *http.Request, p httprouter.Params, ctx *context) (string, int) { - request := VulnerabilityEnvelope{} - err := decodeJSON(r, &request) - if err != nil { - writeResponse(w, r, http.StatusBadRequest, VulnerabilityEnvelope{Error: &Error{err.Error()}}) - return putVulnerabilityRoute, http.StatusBadRequest - } - - if request.Vulnerability == nil { - writeResponse(w, r, http.StatusBadRequest, VulnerabilityEnvelope{Error: &Error{"failed to provide vulnerability"}}) - return putVulnerabilityRoute, http.StatusBadRequest - } - - if len(request.Vulnerability.FixedIn) != 0 { - writeResponse(w, r, http.StatusBadRequest, VulnerabilityEnvelope{Error: &Error{"Vulnerability.FixedIn must be empty"}}) - return putVulnerabilityRoute, http.StatusBadRequest - } - - vuln, err := request.Vulnerability.DatabaseModel() - if err != nil { - writeResponse(w, r, http.StatusBadRequest, VulnerabilityEnvelope{Error: &Error{err.Error()}}) - return putVulnerabilityRoute, http.StatusBadRequest - } - - vuln.Namespace.Name = p.ByName("namespaceName") - vuln.Name = p.ByName("vulnerabilityName") - - err = ctx.Store.InsertVulnerabilities([]database.Vulnerability{vuln}, true) - if err != nil { - switch err.(type) { - case *commonerr.ErrBadRequest: - writeResponse(w, r, http.StatusBadRequest, VulnerabilityEnvelope{Error: &Error{err.Error()}}) - return putVulnerabilityRoute, http.StatusBadRequest - default: - writeResponse(w, r, http.StatusInternalServerError, VulnerabilityEnvelope{Error: &Error{err.Error()}}) - return putVulnerabilityRoute, http.StatusInternalServerError - } - } - - writeResponse(w, r, http.StatusOK, VulnerabilityEnvelope{Vulnerability: request.Vulnerability}) - return putVulnerabilityRoute, http.StatusOK -} - -func deleteVulnerability(w http.ResponseWriter, r *http.Request, p httprouter.Params, ctx *context) (string, int) { - err := ctx.Store.DeleteVulnerability(p.ByName("namespaceName"), p.ByName("vulnerabilityName")) - if err == commonerr.ErrNotFound { - writeResponse(w, r, http.StatusNotFound, VulnerabilityEnvelope{Error: &Error{err.Error()}}) - return deleteVulnerabilityRoute, http.StatusNotFound - } else if err != nil { - writeResponse(w, r, http.StatusInternalServerError, VulnerabilityEnvelope{Error: &Error{err.Error()}}) - return deleteVulnerabilityRoute, http.StatusInternalServerError - } - - w.WriteHeader(http.StatusOK) - return deleteVulnerabilityRoute, http.StatusOK -} - -func getFixes(w http.ResponseWriter, r *http.Request, p httprouter.Params, ctx *context) (string, int) { - dbVuln, err := ctx.Store.FindVulnerability(p.ByName("namespaceName"), p.ByName("vulnerabilityName")) - if err == commonerr.ErrNotFound { - writeResponse(w, r, http.StatusNotFound, FeatureEnvelope{Error: &Error{err.Error()}}) - return getFixesRoute, http.StatusNotFound - } else if err != nil { - writeResponse(w, r, http.StatusInternalServerError, FeatureEnvelope{Error: &Error{err.Error()}}) - return getFixesRoute, http.StatusInternalServerError - } - - vuln := VulnerabilityFromDatabaseModel(dbVuln, true) - writeResponse(w, r, http.StatusOK, FeatureEnvelope{Features: &vuln.FixedIn}) - return getFixesRoute, http.StatusOK -} - -func putFix(w http.ResponseWriter, r *http.Request, p httprouter.Params, ctx *context) (string, int) { - request := FeatureEnvelope{} - err := decodeJSON(r, &request) - if err != nil { - writeResponse(w, r, http.StatusBadRequest, FeatureEnvelope{Error: &Error{err.Error()}}) - return putFixRoute, http.StatusBadRequest - } - - if request.Feature == nil { - writeResponse(w, r, http.StatusBadRequest, FeatureEnvelope{Error: &Error{"failed to provide feature"}}) - return putFixRoute, http.StatusBadRequest - } - - if request.Feature.Name != p.ByName("fixName") { - writeResponse(w, r, http.StatusBadRequest, FeatureEnvelope{Error: &Error{"feature name in URL and JSON do not match"}}) - return putFixRoute, http.StatusBadRequest - } - - dbFix, err := request.Feature.DatabaseModel() - if err != nil { - writeResponse(w, r, http.StatusBadRequest, FeatureEnvelope{Error: &Error{err.Error()}}) - return putFixRoute, http.StatusBadRequest - } - - err = ctx.Store.InsertVulnerabilityFixes(p.ByName("vulnerabilityNamespace"), p.ByName("vulnerabilityName"), []database.FeatureVersion{dbFix}) - if err != nil { - switch err.(type) { - case *commonerr.ErrBadRequest: - writeResponse(w, r, http.StatusBadRequest, FeatureEnvelope{Error: &Error{err.Error()}}) - return putFixRoute, http.StatusBadRequest - default: - if err == commonerr.ErrNotFound { - writeResponse(w, r, http.StatusNotFound, FeatureEnvelope{Error: &Error{err.Error()}}) - return putFixRoute, http.StatusNotFound - } - writeResponse(w, r, http.StatusInternalServerError, FeatureEnvelope{Error: &Error{err.Error()}}) - return putFixRoute, http.StatusInternalServerError - } - } - - writeResponse(w, r, http.StatusOK, FeatureEnvelope{Feature: request.Feature}) - return putFixRoute, http.StatusOK -} - -func deleteFix(w http.ResponseWriter, r *http.Request, p httprouter.Params, ctx *context) (string, int) { - err := ctx.Store.DeleteVulnerabilityFix(p.ByName("vulnerabilityNamespace"), p.ByName("vulnerabilityName"), p.ByName("fixName")) - if err == commonerr.ErrNotFound { - writeResponse(w, r, http.StatusNotFound, FeatureEnvelope{Error: &Error{err.Error()}}) - return deleteFixRoute, http.StatusNotFound - } else if err != nil { - writeResponse(w, r, http.StatusInternalServerError, FeatureEnvelope{Error: &Error{err.Error()}}) - return deleteFixRoute, http.StatusInternalServerError - } - - w.WriteHeader(http.StatusOK) - return deleteFixRoute, http.StatusOK -} - -func getNotification(w http.ResponseWriter, r *http.Request, p httprouter.Params, ctx *context) (string, int) { - query := r.URL.Query() - - limitStrs, limitExists := query["limit"] - if !limitExists { - writeResponse(w, r, http.StatusBadRequest, NotificationEnvelope{Error: &Error{"must provide limit query parameter"}}) - return getNotificationRoute, http.StatusBadRequest - } - limit, err := strconv.Atoi(limitStrs[0]) - if err != nil { - writeResponse(w, r, http.StatusBadRequest, NotificationEnvelope{Error: &Error{"invalid limit format: " + err.Error()}}) - return getNotificationRoute, http.StatusBadRequest - } - - var pageToken string - page := database.VulnerabilityNotificationFirstPage - pageStrs, pageExists := query["page"] - if pageExists { - err := token.Unmarshal(pageStrs[0], ctx.PaginationKey, &page) - if err != nil { - writeResponse(w, r, http.StatusBadRequest, NotificationEnvelope{Error: &Error{"invalid page format: " + err.Error()}}) - return getNotificationRoute, http.StatusBadRequest - } - pageToken = pageStrs[0] - } else { - pageTokenBytes, err := token.Marshal(page, ctx.PaginationKey) - if err != nil { - writeResponse(w, r, http.StatusBadRequest, NotificationEnvelope{Error: &Error{"failed to marshal token: " + err.Error()}}) - return getNotificationRoute, http.StatusBadRequest - } - pageToken = string(pageTokenBytes) - } - - dbNotification, nextPage, err := ctx.Store.GetNotification(p.ByName("notificationName"), limit, page) - if err == commonerr.ErrNotFound { - writeResponse(w, r, http.StatusNotFound, NotificationEnvelope{Error: &Error{err.Error()}}) - return deleteNotificationRoute, http.StatusNotFound - } else if err != nil { - writeResponse(w, r, http.StatusInternalServerError, NotificationEnvelope{Error: &Error{err.Error()}}) - return getNotificationRoute, http.StatusInternalServerError - } - - notification := NotificationFromDatabaseModel(dbNotification, limit, pageToken, nextPage, ctx.PaginationKey) - - writeResponse(w, r, http.StatusOK, NotificationEnvelope{Notification: ¬ification}) - return getNotificationRoute, http.StatusOK -} - -func deleteNotification(w http.ResponseWriter, r *http.Request, p httprouter.Params, ctx *context) (string, int) { - err := ctx.Store.DeleteNotification(p.ByName("notificationName")) - if err == commonerr.ErrNotFound { - writeResponse(w, r, http.StatusNotFound, NotificationEnvelope{Error: &Error{err.Error()}}) - return deleteNotificationRoute, http.StatusNotFound - } else if err != nil { - writeResponse(w, r, http.StatusInternalServerError, NotificationEnvelope{Error: &Error{err.Error()}}) - return deleteNotificationRoute, http.StatusInternalServerError - } - - w.WriteHeader(http.StatusOK) - return deleteNotificationRoute, http.StatusOK -} - -func getMetrics(w http.ResponseWriter, r *http.Request, p httprouter.Params, ctx *context) (string, int) { - prometheus.Handler().ServeHTTP(w, r) - return getMetricsRoute, 0 -} diff --git a/api/v2/clairpb/convert.go b/api/v2/clairpb/convert.go deleted file mode 100644 index 34e7eb6e..00000000 --- a/api/v2/clairpb/convert.go +++ /dev/null @@ -1,165 +0,0 @@ -// Copyright 2017 clair authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package clairpb - -import ( - "encoding/json" - "fmt" - - "github.com/coreos/clair/api/token" - "github.com/coreos/clair/database" - "github.com/coreos/clair/ext/versionfmt" -) - -func NotificationFromDatabaseModel(dbNotification database.VulnerabilityNotification, limit int, pageToken string, nextPage database.VulnerabilityNotificationPageNumber, key string) (*Notification, error) { - var oldVuln *LayersIntroducingVulnerabilty - if dbNotification.OldVulnerability != nil { - v, err := LayersIntroducingVulnerabiltyFromDatabaseModel(*dbNotification.OldVulnerability) - if err != nil { - return nil, err - } - oldVuln = v - } - - var newVuln *LayersIntroducingVulnerabilty - if dbNotification.NewVulnerability != nil { - v, err := LayersIntroducingVulnerabiltyFromDatabaseModel(*dbNotification.NewVulnerability) - if err != nil { - return nil, err - } - newVuln = v - } - - var nextPageStr string - if nextPage != database.NoVulnerabilityNotificationPage { - nextPageBytes, _ := token.Marshal(nextPage, key) - nextPageStr = string(nextPageBytes) - } - - var created, notified, deleted string - if !dbNotification.Created.IsZero() { - created = fmt.Sprintf("%d", dbNotification.Created.Unix()) - } - if !dbNotification.Notified.IsZero() { - notified = fmt.Sprintf("%d", dbNotification.Notified.Unix()) - } - if !dbNotification.Deleted.IsZero() { - deleted = fmt.Sprintf("%d", dbNotification.Deleted.Unix()) - } - - return &Notification{ - Name: dbNotification.Name, - Created: created, - Notified: notified, - Deleted: deleted, - Limit: int32(limit), - Page: &Page{ - ThisToken: pageToken, - NextToken: nextPageStr, - Old: oldVuln, - New: newVuln, - }, - }, nil -} - -func LayersIntroducingVulnerabiltyFromDatabaseModel(dbVuln database.Vulnerability) (*LayersIntroducingVulnerabilty, error) { - vuln, err := VulnerabilityFromDatabaseModel(dbVuln, true) - if err != nil { - return nil, err - } - var orderedLayers []*OrderedLayerName - - return &LayersIntroducingVulnerabilty{ - Vulnerability: vuln, - Layers: orderedLayers, - }, nil -} - -func VulnerabilityFromDatabaseModel(dbVuln database.Vulnerability, withFixedIn bool) (*Vulnerability, error) { - metaString := "" - if dbVuln.Metadata != nil { - metadataByte, err := json.Marshal(dbVuln.Metadata) - if err != nil { - return nil, err - } - metaString = string(metadataByte) - } - - vuln := Vulnerability{ - Name: dbVuln.Name, - NamespaceName: dbVuln.Namespace.Name, - Description: dbVuln.Description, - Link: dbVuln.Link, - Severity: string(dbVuln.Severity), - Metadata: metaString, - } - - if dbVuln.FixedBy != versionfmt.MaxVersion { - vuln.FixedBy = dbVuln.FixedBy - } - - if withFixedIn { - for _, dbFeatureVersion := range dbVuln.FixedIn { - f, err := FeatureFromDatabaseModel(dbFeatureVersion, false) - if err != nil { - return nil, err - } - - vuln.FixedInFeatures = append(vuln.FixedInFeatures, f) - } - } - - return &vuln, nil -} - -func LayerFromDatabaseModel(dbLayer database.Layer) *Layer { - layer := Layer{ - Name: dbLayer.Name, - } - for _, ns := range dbLayer.Namespaces { - layer.NamespaceNames = append(layer.NamespaceNames, ns.Name) - } - - return &layer -} - -func FeatureFromDatabaseModel(fv database.FeatureVersion, withVulnerabilities bool) (*Feature, error) { - version := fv.Version - if version == versionfmt.MaxVersion { - version = "None" - } - f := &Feature{ - Name: fv.Feature.Name, - NamespaceName: fv.Feature.Namespace.Name, - VersionFormat: fv.Feature.Namespace.VersionFormat, - Version: version, - AddedBy: fv.AddedBy.Name, - } - - if withVulnerabilities { - for _, dbVuln := range fv.AffectedBy { - // VulnerabilityFromDatabaseModel should be called without FixedIn, - // Otherwise it might cause infinite loop - vul, err := VulnerabilityFromDatabaseModel(dbVuln, false) - if err != nil { - return nil, err - } - - f.Vulnerabilities = append(f.Vulnerabilities, vul) - } - } - - return f, nil -} diff --git a/api/v2/rpc.go b/api/v2/rpc.go deleted file mode 100644 index 9ff840cb..00000000 --- a/api/v2/rpc.go +++ /dev/null @@ -1,256 +0,0 @@ -// Copyright 2017 clair authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package v2 - -import ( - "fmt" - - google_protobuf1 "github.com/golang/protobuf/ptypes/empty" - log "github.com/sirupsen/logrus" - "golang.org/x/net/context" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - - "github.com/coreos/clair" - "github.com/coreos/clair/api/token" - pb "github.com/coreos/clair/api/v2/clairpb" - "github.com/coreos/clair/database" - "github.com/coreos/clair/pkg/commonerr" - "github.com/coreos/clair/pkg/tarutil" -) - -// NotificationServer implements NotificationService interface for serving RPC. -type NotificationServer struct { - Store database.Datastore - PaginationKey string -} - -// AncestryServer implements AncestryService interface for serving RPC. -type AncestryServer struct { - Store database.Datastore -} - -// PostAncestry implements posting an ancestry via the Clair gRPC service. -func (s *AncestryServer) PostAncestry(ctx context.Context, req *pb.PostAncestryRequest) (*pb.PostAncestryResponse, error) { - ancestryName := req.GetAncestryName() - if ancestryName == "" { - return nil, status.Error(codes.InvalidArgument, "Failed to provide proper ancestry name") - } - - layers := req.GetLayers() - if len(layers) == 0 { - return nil, status.Error(codes.InvalidArgument, "At least one layer should be provided for an ancestry") - } - - var currentName, parentName, rootName string - for i, layer := range layers { - if layer == nil { - err := status.Error(codes.InvalidArgument, "Failed to provide layer") - return nil, s.rollBackOnError(err, currentName, rootName) - } - - // TODO(keyboardnerd): after altering the database to support ancestry, - // we should use the ancestry name and index as key instead of - // the amalgamation of ancestry name of index - // Hack: layer name is [ancestryName]-[index] except the tail layer, - // tail layer name is [ancestryName] - if i == len(layers)-1 { - currentName = ancestryName - } else { - currentName = fmt.Sprintf("%s-%d", ancestryName, i) - } - - // if rootName is unset, this is the first iteration over the layers and - // the current layer is the root of the ancestry - if rootName == "" { - rootName = currentName - } - - err := clair.ProcessLayer(s.Store, req.GetFormat(), currentName, parentName, layer.GetPath(), layer.GetHeaders()) - if err != nil { - return nil, s.rollBackOnError(err, currentName, rootName) - } - - // Now that the current layer is processed, set the parentName for the - // next iteration. - parentName = currentName - } - - return &pb.PostAncestryResponse{ - EngineVersion: clair.Version, - }, nil -} - -// GetAncestry implements retrieving an ancestry via the Clair gRPC service. -func (s *AncestryServer) GetAncestry(ctx context.Context, req *pb.GetAncestryRequest) (*pb.GetAncestryResponse, error) { - if req.GetAncestryName() == "" { - return nil, status.Errorf(codes.InvalidArgument, "invalid get ancestry request") - } - - // TODO(keyboardnerd): after altering the database to support ancestry, this - // function is iteratively querying for for r.GetIndex() th parent of the - // requested layer until the indexed layer is found or index is out of bound - // this is a hack and will be replaced with one query - ancestry, features, err := s.getAncestry(req.GetAncestryName(), req.GetWithFeatures(), req.GetWithVulnerabilities()) - if err == commonerr.ErrNotFound { - return nil, status.Error(codes.NotFound, err.Error()) - } else if err != nil { - return nil, status.Error(codes.Internal, err.Error()) - } - - return &pb.GetAncestryResponse{ - Ancestry: ancestry, - Features: features, - }, nil -} - -// GetNotification implements retrieving a notification via the Clair gRPC -// service. -func (s *NotificationServer) GetNotification(ctx context.Context, req *pb.GetNotificationRequest) (*pb.GetNotificationResponse, error) { - if req.GetName() == "" { - return nil, status.Error(codes.InvalidArgument, "Failed to provide notification name") - } - - if req.GetLimit() <= 0 { - return nil, status.Error(codes.InvalidArgument, "Failed to provide page limit") - } - - page := database.VulnerabilityNotificationFirstPage - pageToken := req.GetPage() - if pageToken != "" { - err := token.Unmarshal(pageToken, s.PaginationKey, &page) - if err != nil { - return nil, status.Errorf(codes.InvalidArgument, "Invalid page format %s", err.Error()) - } - } else { - pageTokenBytes, err := token.Marshal(page, s.PaginationKey) - if err != nil { - return nil, status.Errorf(codes.InvalidArgument, "Failed to marshal token: %s", err.Error()) - } - pageToken = string(pageTokenBytes) - } - - dbNotification, nextPage, err := s.Store.GetNotification(req.GetName(), int(req.GetLimit()), page) - if err == commonerr.ErrNotFound { - return nil, status.Error(codes.NotFound, err.Error()) - } else if err != nil { - return nil, status.Error(codes.Internal, err.Error()) - } - - notification, err := pb.NotificationFromDatabaseModel(dbNotification, int(req.GetLimit()), pageToken, nextPage, s.PaginationKey) - if err != nil { - return nil, status.Error(codes.Internal, err.Error()) - } - - return &pb.GetNotificationResponse{Notification: notification}, nil -} - -// DeleteNotification implements deleting a notification via the Clair gRPC -// service. -func (s *NotificationServer) DeleteNotification(ctx context.Context, req *pb.DeleteNotificationRequest) (*google_protobuf1.Empty, error) { - if req.GetName() == "" { - return nil, status.Error(codes.InvalidArgument, "Failed to provide notification name") - } - - err := s.Store.DeleteNotification(req.GetName()) - if err == commonerr.ErrNotFound { - return nil, status.Error(codes.NotFound, err.Error()) - } else if err != nil { - return nil, status.Error(codes.Internal, err.Error()) - } - - return &google_protobuf1.Empty{}, nil -} - -// rollBackOnError handles server error and rollback whole ancestry insertion if -// any layer failed to be inserted. -func (s *AncestryServer) rollBackOnError(err error, currentLayerName, rootLayerName string) error { - // if the current layer failed to be inserted and it's the root layer, - // then the ancestry is not yet in the database. - if currentLayerName != rootLayerName { - errrb := s.Store.DeleteLayer(rootLayerName) - if errrb != nil { - return status.Errorf(codes.Internal, errrb.Error()) - } - log.WithField("layer name", currentLayerName).Warnf("Can't process %s: roll back the ancestry", currentLayerName) - } - - if err == tarutil.ErrCouldNotExtract || - err == tarutil.ErrExtractedFileTooBig || - err == clair.ErrUnsupported { - return status.Errorf(codes.InvalidArgument, "unprocessable entity %s", err.Error()) - } - - if _, badreq := err.(*commonerr.ErrBadRequest); badreq { - return status.Error(codes.InvalidArgument, err.Error()) - } - - return status.Error(codes.Internal, err.Error()) -} - -// TODO(keyboardnerd): Remove this Legacy compability code once the database is -// revised. -// getAncestry returns an ancestry from database by getting all parents of a -// layer given the layer name, and the layer's feature list if -// withFeature/withVulnerability is turned on. -func (s *AncestryServer) getAncestry(name string, withFeature bool, withVulnerability bool) (ancestry *pb.Ancestry, features []*pb.Feature, err error) { - var ( - layers = []*pb.Layer{} - layer database.Layer - ) - ancestry = &pb.Ancestry{} - - layer, err = s.Store.FindLayer(name, withFeature, withVulnerability) - if err != nil { - return - } - - if withFeature { - for _, fv := range layer.Features { - f, e := pb.FeatureFromDatabaseModel(fv, withVulnerability) - if e != nil { - err = e - return - } - - features = append(features, f) - } - } - - ancestry.Name = name - ancestry.EngineVersion = int32(layer.EngineVersion) - for name != "" { - layer, err = s.Store.FindLayer(name, false, false) - if err != nil { - return - } - - if layer.Parent != nil { - name = layer.Parent.Name - } else { - name = "" - } - - layers = append(layers, pb.LayerFromDatabaseModel(layer)) - } - - // reverse layers to make the root layer at the top - for i, j := 0, len(layers)-1; i < j; i, j = i+1, j-1 { - layers[i], layers[j] = layers[j], layers[i] - } - - ancestry.Layers = layers - return -} diff --git a/api/v2/clairpb/Makefile b/api/v3/clairpb/Makefile similarity index 100% rename from api/v2/clairpb/Makefile rename to api/v3/clairpb/Makefile diff --git a/api/v2/clairpb/clair.pb.go b/api/v3/clairpb/clair.pb.go similarity index 52% rename from api/v2/clairpb/clair.pb.go rename to api/v3/clairpb/clair.pb.go index 6e5255d4..19816099 100644 --- a/api/v2/clairpb/clair.pb.go +++ b/api/v3/clairpb/clair.pb.go @@ -9,20 +9,20 @@ It is generated from these files: It has these top-level messages: Vulnerability + ClairStatus Feature Ancestry - LayersIntroducingVulnerabilty - OrderedLayerName Layer Notification - Page + IndexedAncestryName + PagedVulnerableAncestries PostAncestryRequest PostAncestryResponse GetAncestryRequest GetAncestryResponse GetNotificationRequest GetNotificationResponse - DeleteNotificationRequest + MarkNotificationAsReadRequest */ package clairpb @@ -31,6 +31,7 @@ import fmt "fmt" import math "math" import _ "google.golang.org/genproto/googleapis/api/annotations" import google_protobuf1 "github.com/golang/protobuf/ptypes/empty" +import google_protobuf2 "github.com/golang/protobuf/ptypes/timestamp" import ( context "golang.org/x/net/context" @@ -49,14 +50,16 @@ var _ = math.Inf const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package type Vulnerability struct { - Name string `protobuf:"bytes,1,opt,name=name" json:"name,omitempty"` - NamespaceName string `protobuf:"bytes,2,opt,name=namespace_name,json=namespaceName" json:"namespace_name,omitempty"` - Description string `protobuf:"bytes,3,opt,name=description" json:"description,omitempty"` - Link string `protobuf:"bytes,4,opt,name=link" json:"link,omitempty"` - Severity string `protobuf:"bytes,5,opt,name=severity" json:"severity,omitempty"` - Metadata string `protobuf:"bytes,6,opt,name=metadata" json:"metadata,omitempty"` - FixedBy string `protobuf:"bytes,7,opt,name=fixed_by,json=fixedBy" json:"fixed_by,omitempty"` - FixedInFeatures []*Feature `protobuf:"bytes,8,rep,name=fixed_in_features,json=fixedInFeatures" json:"fixed_in_features,omitempty"` + Name string `protobuf:"bytes,1,opt,name=name" json:"name,omitempty"` + NamespaceName string `protobuf:"bytes,2,opt,name=namespace_name,json=namespaceName" json:"namespace_name,omitempty"` + Description string `protobuf:"bytes,3,opt,name=description" json:"description,omitempty"` + Link string `protobuf:"bytes,4,opt,name=link" json:"link,omitempty"` + Severity string `protobuf:"bytes,5,opt,name=severity" json:"severity,omitempty"` + Metadata string `protobuf:"bytes,6,opt,name=metadata" json:"metadata,omitempty"` + // fixed_by exists when vulnerability is under feature. + FixedBy string `protobuf:"bytes,7,opt,name=fixed_by,json=fixedBy" json:"fixed_by,omitempty"` + // affected_versions exists when vulnerability is under notification. + AffectedVersions []*Feature `protobuf:"bytes,8,rep,name=affected_versions,json=affectedVersions" json:"affected_versions,omitempty"` } func (m *Vulnerability) Reset() { *m = Vulnerability{} } @@ -113,9 +116,43 @@ func (m *Vulnerability) GetFixedBy() string { return "" } -func (m *Vulnerability) GetFixedInFeatures() []*Feature { +func (m *Vulnerability) GetAffectedVersions() []*Feature { if m != nil { - return m.FixedInFeatures + return m.AffectedVersions + } + return nil +} + +type ClairStatus struct { + // listers and detectors are processors implemented in this Clair and used to + // scan ancestries + Listers []string `protobuf:"bytes,1,rep,name=listers" json:"listers,omitempty"` + Detectors []string `protobuf:"bytes,2,rep,name=detectors" json:"detectors,omitempty"` + LastUpdateTime *google_protobuf2.Timestamp `protobuf:"bytes,3,opt,name=last_update_time,json=lastUpdateTime" json:"last_update_time,omitempty"` +} + +func (m *ClairStatus) Reset() { *m = ClairStatus{} } +func (m *ClairStatus) String() string { return proto.CompactTextString(m) } +func (*ClairStatus) ProtoMessage() {} +func (*ClairStatus) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} } + +func (m *ClairStatus) GetListers() []string { + if m != nil { + return m.Listers + } + return nil +} + +func (m *ClairStatus) GetDetectors() []string { + if m != nil { + return m.Detectors + } + return nil +} + +func (m *ClairStatus) GetLastUpdateTime() *google_protobuf2.Timestamp { + if m != nil { + return m.LastUpdateTime } return nil } @@ -125,14 +162,13 @@ type Feature struct { NamespaceName string `protobuf:"bytes,2,opt,name=namespace_name,json=namespaceName" json:"namespace_name,omitempty"` Version string `protobuf:"bytes,3,opt,name=version" json:"version,omitempty"` VersionFormat string `protobuf:"bytes,4,opt,name=version_format,json=versionFormat" json:"version_format,omitempty"` - AddedBy string `protobuf:"bytes,5,opt,name=added_by,json=addedBy" json:"added_by,omitempty"` - Vulnerabilities []*Vulnerability `protobuf:"bytes,6,rep,name=vulnerabilities" json:"vulnerabilities,omitempty"` + Vulnerabilities []*Vulnerability `protobuf:"bytes,5,rep,name=vulnerabilities" json:"vulnerabilities,omitempty"` } func (m *Feature) Reset() { *m = Feature{} } func (m *Feature) String() string { return proto.CompactTextString(m) } func (*Feature) ProtoMessage() {} -func (*Feature) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} } +func (*Feature) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{2} } func (m *Feature) GetName() string { if m != nil { @@ -162,13 +198,6 @@ func (m *Feature) GetVersionFormat() string { return "" } -func (m *Feature) GetAddedBy() string { - if m != nil { - return m.AddedBy - } - return "" -} - func (m *Feature) GetVulnerabilities() []*Vulnerability { if m != nil { return m.Vulnerabilities @@ -177,15 +206,20 @@ func (m *Feature) GetVulnerabilities() []*Vulnerability { } type Ancestry struct { - Name string `protobuf:"bytes,1,opt,name=name" json:"name,omitempty"` - EngineVersion int32 `protobuf:"varint,2,opt,name=engine_version,json=engineVersion" json:"engine_version,omitempty"` - Layers []*Layer `protobuf:"bytes,3,rep,name=layers" json:"layers,omitempty"` + Name string `protobuf:"bytes,1,opt,name=name" json:"name,omitempty"` + Features []*Feature `protobuf:"bytes,2,rep,name=features" json:"features,omitempty"` + Layers []*Layer `protobuf:"bytes,3,rep,name=layers" json:"layers,omitempty"` + // scanned_listers and scanned_detectors are used to scan this ancestry, it + // may be different from listers and detectors in ClairStatus since the + // ancestry could be scanned by previous version of Clair. + ScannedListers []string `protobuf:"bytes,4,rep,name=scanned_listers,json=scannedListers" json:"scanned_listers,omitempty"` + ScannedDetectors []string `protobuf:"bytes,5,rep,name=scanned_detectors,json=scannedDetectors" json:"scanned_detectors,omitempty"` } func (m *Ancestry) Reset() { *m = Ancestry{} } func (m *Ancestry) String() string { return proto.CompactTextString(m) } func (*Ancestry) ProtoMessage() {} -func (*Ancestry) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{2} } +func (*Ancestry) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{3} } func (m *Ancestry) GetName() string { if m != nil { @@ -194,11 +228,11 @@ func (m *Ancestry) GetName() string { return "" } -func (m *Ancestry) GetEngineVersion() int32 { +func (m *Ancestry) GetFeatures() []*Feature { if m != nil { - return m.EngineVersion + return m.Features } - return 0 + return nil } func (m *Ancestry) GetLayers() []*Layer { @@ -208,91 +242,49 @@ func (m *Ancestry) GetLayers() []*Layer { return nil } -type LayersIntroducingVulnerabilty struct { - Vulnerability *Vulnerability `protobuf:"bytes,1,opt,name=vulnerability" json:"vulnerability,omitempty"` - Layers []*OrderedLayerName `protobuf:"bytes,2,rep,name=layers" json:"layers,omitempty"` -} - -func (m *LayersIntroducingVulnerabilty) Reset() { *m = LayersIntroducingVulnerabilty{} } -func (m *LayersIntroducingVulnerabilty) String() string { return proto.CompactTextString(m) } -func (*LayersIntroducingVulnerabilty) ProtoMessage() {} -func (*LayersIntroducingVulnerabilty) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{3} } - -func (m *LayersIntroducingVulnerabilty) GetVulnerability() *Vulnerability { +func (m *Ancestry) GetScannedListers() []string { if m != nil { - return m.Vulnerability + return m.ScannedListers } return nil } -func (m *LayersIntroducingVulnerabilty) GetLayers() []*OrderedLayerName { +func (m *Ancestry) GetScannedDetectors() []string { if m != nil { - return m.Layers + return m.ScannedDetectors } return nil } -type OrderedLayerName struct { - Index int32 `protobuf:"varint,1,opt,name=index" json:"index,omitempty"` - LayerName string `protobuf:"bytes,2,opt,name=layer_name,json=layerName" json:"layer_name,omitempty"` -} - -func (m *OrderedLayerName) Reset() { *m = OrderedLayerName{} } -func (m *OrderedLayerName) String() string { return proto.CompactTextString(m) } -func (*OrderedLayerName) ProtoMessage() {} -func (*OrderedLayerName) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{4} } - -func (m *OrderedLayerName) GetIndex() int32 { - if m != nil { - return m.Index - } - return 0 -} - -func (m *OrderedLayerName) GetLayerName() string { - if m != nil { - return m.LayerName - } - return "" -} - type Layer struct { - Name string `protobuf:"bytes,1,opt,name=name" json:"name,omitempty"` - NamespaceNames []string `protobuf:"bytes,2,rep,name=namespace_names,json=namespaceNames" json:"namespace_names,omitempty"` + Hash string `protobuf:"bytes,1,opt,name=hash" json:"hash,omitempty"` } func (m *Layer) Reset() { *m = Layer{} } func (m *Layer) String() string { return proto.CompactTextString(m) } func (*Layer) ProtoMessage() {} -func (*Layer) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{5} } +func (*Layer) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{4} } -func (m *Layer) GetName() string { +func (m *Layer) GetHash() string { if m != nil { - return m.Name + return m.Hash } return "" } -func (m *Layer) GetNamespaceNames() []string { - if m != nil { - return m.NamespaceNames - } - return nil -} - type Notification struct { - Name string `protobuf:"bytes,1,opt,name=name" json:"name,omitempty"` - Created string `protobuf:"bytes,2,opt,name=created" json:"created,omitempty"` - Notified string `protobuf:"bytes,3,opt,name=notified" json:"notified,omitempty"` - Deleted string `protobuf:"bytes,4,opt,name=deleted" json:"deleted,omitempty"` - Limit int32 `protobuf:"varint,5,opt,name=limit" json:"limit,omitempty"` - Page *Page `protobuf:"bytes,6,opt,name=page" json:"page,omitempty"` + Name string `protobuf:"bytes,1,opt,name=name" json:"name,omitempty"` + Created string `protobuf:"bytes,2,opt,name=created" json:"created,omitempty"` + Notified string `protobuf:"bytes,3,opt,name=notified" json:"notified,omitempty"` + Deleted string `protobuf:"bytes,4,opt,name=deleted" json:"deleted,omitempty"` + Old *PagedVulnerableAncestries `protobuf:"bytes,5,opt,name=old" json:"old,omitempty"` + New *PagedVulnerableAncestries `protobuf:"bytes,6,opt,name=new" json:"new,omitempty"` } func (m *Notification) Reset() { *m = Notification{} } func (m *Notification) String() string { return proto.CompactTextString(m) } func (*Notification) ProtoMessage() {} -func (*Notification) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{6} } +func (*Notification) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{5} } func (m *Notification) GetName() string { if m != nil { @@ -322,60 +314,95 @@ func (m *Notification) GetDeleted() string { return "" } -func (m *Notification) GetLimit() int32 { - if m != nil { - return m.Limit - } - return 0 -} - -func (m *Notification) GetPage() *Page { - if m != nil { - return m.Page - } - return nil -} - -type Page struct { - ThisToken string `protobuf:"bytes,1,opt,name=this_token,json=thisToken" json:"this_token,omitempty"` - NextToken string `protobuf:"bytes,2,opt,name=next_token,json=nextToken" json:"next_token,omitempty"` - Old *LayersIntroducingVulnerabilty `protobuf:"bytes,3,opt,name=old" json:"old,omitempty"` - New *LayersIntroducingVulnerabilty `protobuf:"bytes,4,opt,name=new" json:"new,omitempty"` -} - -func (m *Page) Reset() { *m = Page{} } -func (m *Page) String() string { return proto.CompactTextString(m) } -func (*Page) ProtoMessage() {} -func (*Page) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{7} } - -func (m *Page) GetThisToken() string { - if m != nil { - return m.ThisToken - } - return "" -} - -func (m *Page) GetNextToken() string { - if m != nil { - return m.NextToken - } - return "" -} - -func (m *Page) GetOld() *LayersIntroducingVulnerabilty { +func (m *Notification) GetOld() *PagedVulnerableAncestries { if m != nil { return m.Old } return nil } -func (m *Page) GetNew() *LayersIntroducingVulnerabilty { +func (m *Notification) GetNew() *PagedVulnerableAncestries { if m != nil { return m.New } return nil } +type IndexedAncestryName struct { + // index is unique to name in all streams simultaneously streamed, increasing + // and larger than all indexes in previous page in same stream. + Index int32 `protobuf:"varint,1,opt,name=index" json:"index,omitempty"` + Name string `protobuf:"bytes,2,opt,name=name" json:"name,omitempty"` +} + +func (m *IndexedAncestryName) Reset() { *m = IndexedAncestryName{} } +func (m *IndexedAncestryName) String() string { return proto.CompactTextString(m) } +func (*IndexedAncestryName) ProtoMessage() {} +func (*IndexedAncestryName) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{6} } + +func (m *IndexedAncestryName) GetIndex() int32 { + if m != nil { + return m.Index + } + return 0 +} + +func (m *IndexedAncestryName) GetName() string { + if m != nil { + return m.Name + } + return "" +} + +type PagedVulnerableAncestries struct { + CurrentPage string `protobuf:"bytes,1,opt,name=current_page,json=currentPage" json:"current_page,omitempty"` + // if next_page is empty, it signals the end of all pages. + NextPage string `protobuf:"bytes,2,opt,name=next_page,json=nextPage" json:"next_page,omitempty"` + Limit int32 `protobuf:"varint,3,opt,name=limit" json:"limit,omitempty"` + Vulnerability *Vulnerability `protobuf:"bytes,4,opt,name=vulnerability" json:"vulnerability,omitempty"` + Ancestries []*IndexedAncestryName `protobuf:"bytes,5,rep,name=ancestries" json:"ancestries,omitempty"` +} + +func (m *PagedVulnerableAncestries) Reset() { *m = PagedVulnerableAncestries{} } +func (m *PagedVulnerableAncestries) String() string { return proto.CompactTextString(m) } +func (*PagedVulnerableAncestries) ProtoMessage() {} +func (*PagedVulnerableAncestries) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{7} } + +func (m *PagedVulnerableAncestries) GetCurrentPage() string { + if m != nil { + return m.CurrentPage + } + return "" +} + +func (m *PagedVulnerableAncestries) GetNextPage() string { + if m != nil { + return m.NextPage + } + return "" +} + +func (m *PagedVulnerableAncestries) GetLimit() int32 { + if m != nil { + return m.Limit + } + return 0 +} + +func (m *PagedVulnerableAncestries) GetVulnerability() *Vulnerability { + if m != nil { + return m.Vulnerability + } + return nil +} + +func (m *PagedVulnerableAncestries) GetAncestries() []*IndexedAncestryName { + if m != nil { + return m.Ancestries + } + return nil +} + type PostAncestryRequest struct { AncestryName string `protobuf:"bytes,1,opt,name=ancestry_name,json=ancestryName" json:"ancestry_name,omitempty"` Format string `protobuf:"bytes,2,opt,name=format" json:"format,omitempty"` @@ -409,7 +436,7 @@ func (m *PostAncestryRequest) GetLayers() []*PostAncestryRequest_PostLayer { } type PostAncestryRequest_PostLayer struct { - Name string `protobuf:"bytes,1,opt,name=name" json:"name,omitempty"` + Hash string `protobuf:"bytes,1,opt,name=hash" json:"hash,omitempty"` Path string `protobuf:"bytes,2,opt,name=path" json:"path,omitempty"` Headers map[string]string `protobuf:"bytes,3,rep,name=headers" json:"headers,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` } @@ -421,9 +448,9 @@ func (*PostAncestryRequest_PostLayer) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{8, 0} } -func (m *PostAncestryRequest_PostLayer) GetName() string { +func (m *PostAncestryRequest_PostLayer) GetHash() string { if m != nil { - return m.Name + return m.Hash } return "" } @@ -443,7 +470,7 @@ func (m *PostAncestryRequest_PostLayer) GetHeaders() map[string]string { } type PostAncestryResponse struct { - EngineVersion int32 `protobuf:"varint,1,opt,name=engine_version,json=engineVersion" json:"engine_version,omitempty"` + Status *ClairStatus `protobuf:"bytes,1,opt,name=status" json:"status,omitempty"` } func (m *PostAncestryResponse) Reset() { *m = PostAncestryResponse{} } @@ -451,11 +478,11 @@ func (m *PostAncestryResponse) String() string { return proto.Compact func (*PostAncestryResponse) ProtoMessage() {} func (*PostAncestryResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{9} } -func (m *PostAncestryResponse) GetEngineVersion() int32 { +func (m *PostAncestryResponse) GetStatus() *ClairStatus { if m != nil { - return m.EngineVersion + return m.Status } - return 0 + return nil } type GetAncestryRequest struct { @@ -491,8 +518,8 @@ func (m *GetAncestryRequest) GetWithFeatures() bool { } type GetAncestryResponse struct { - Ancestry *Ancestry `protobuf:"bytes,1,opt,name=ancestry" json:"ancestry,omitempty"` - Features []*Feature `protobuf:"bytes,2,rep,name=features" json:"features,omitempty"` + Ancestry *Ancestry `protobuf:"bytes,1,opt,name=ancestry" json:"ancestry,omitempty"` + Status *ClairStatus `protobuf:"bytes,2,opt,name=status" json:"status,omitempty"` } func (m *GetAncestryResponse) Reset() { *m = GetAncestryResponse{} } @@ -507,17 +534,19 @@ func (m *GetAncestryResponse) GetAncestry() *Ancestry { return nil } -func (m *GetAncestryResponse) GetFeatures() []*Feature { +func (m *GetAncestryResponse) GetStatus() *ClairStatus { if m != nil { - return m.Features + return m.Status } return nil } type GetNotificationRequest struct { - Page string `protobuf:"bytes,1,opt,name=page" json:"page,omitempty"` - Limit int32 `protobuf:"varint,2,opt,name=limit" json:"limit,omitempty"` - Name string `protobuf:"bytes,3,opt,name=name" json:"name,omitempty"` + // if the vulnerability_page is empty, it implies the first page. + OldVulnerabilityPage string `protobuf:"bytes,1,opt,name=old_vulnerability_page,json=oldVulnerabilityPage" json:"old_vulnerability_page,omitempty"` + NewVulnerabilityPage string `protobuf:"bytes,2,opt,name=new_vulnerability_page,json=newVulnerabilityPage" json:"new_vulnerability_page,omitempty"` + Limit int32 `protobuf:"varint,3,opt,name=limit" json:"limit,omitempty"` + Name string `protobuf:"bytes,4,opt,name=name" json:"name,omitempty"` } func (m *GetNotificationRequest) Reset() { *m = GetNotificationRequest{} } @@ -525,9 +554,16 @@ func (m *GetNotificationRequest) String() string { return proto.Compa func (*GetNotificationRequest) ProtoMessage() {} func (*GetNotificationRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{12} } -func (m *GetNotificationRequest) GetPage() string { +func (m *GetNotificationRequest) GetOldVulnerabilityPage() string { if m != nil { - return m.Page + return m.OldVulnerabilityPage + } + return "" +} + +func (m *GetNotificationRequest) GetNewVulnerabilityPage() string { + if m != nil { + return m.NewVulnerabilityPage } return "" } @@ -562,16 +598,16 @@ func (m *GetNotificationResponse) GetNotification() *Notification { return nil } -type DeleteNotificationRequest struct { +type MarkNotificationAsReadRequest struct { Name string `protobuf:"bytes,1,opt,name=name" json:"name,omitempty"` } -func (m *DeleteNotificationRequest) Reset() { *m = DeleteNotificationRequest{} } -func (m *DeleteNotificationRequest) String() string { return proto.CompactTextString(m) } -func (*DeleteNotificationRequest) ProtoMessage() {} -func (*DeleteNotificationRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{14} } +func (m *MarkNotificationAsReadRequest) Reset() { *m = MarkNotificationAsReadRequest{} } +func (m *MarkNotificationAsReadRequest) String() string { return proto.CompactTextString(m) } +func (*MarkNotificationAsReadRequest) ProtoMessage() {} +func (*MarkNotificationAsReadRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{14} } -func (m *DeleteNotificationRequest) GetName() string { +func (m *MarkNotificationAsReadRequest) GetName() string { if m != nil { return m.Name } @@ -580,13 +616,13 @@ func (m *DeleteNotificationRequest) GetName() string { func init() { proto.RegisterType((*Vulnerability)(nil), "clairpb.Vulnerability") + proto.RegisterType((*ClairStatus)(nil), "clairpb.ClairStatus") proto.RegisterType((*Feature)(nil), "clairpb.Feature") proto.RegisterType((*Ancestry)(nil), "clairpb.Ancestry") - proto.RegisterType((*LayersIntroducingVulnerabilty)(nil), "clairpb.LayersIntroducingVulnerabilty") - proto.RegisterType((*OrderedLayerName)(nil), "clairpb.OrderedLayerName") proto.RegisterType((*Layer)(nil), "clairpb.Layer") proto.RegisterType((*Notification)(nil), "clairpb.Notification") - proto.RegisterType((*Page)(nil), "clairpb.Page") + proto.RegisterType((*IndexedAncestryName)(nil), "clairpb.IndexedAncestryName") + proto.RegisterType((*PagedVulnerableAncestries)(nil), "clairpb.PagedVulnerableAncestries") proto.RegisterType((*PostAncestryRequest)(nil), "clairpb.PostAncestryRequest") proto.RegisterType((*PostAncestryRequest_PostLayer)(nil), "clairpb.PostAncestryRequest.PostLayer") proto.RegisterType((*PostAncestryResponse)(nil), "clairpb.PostAncestryResponse") @@ -594,7 +630,7 @@ func init() { proto.RegisterType((*GetAncestryResponse)(nil), "clairpb.GetAncestryResponse") proto.RegisterType((*GetNotificationRequest)(nil), "clairpb.GetNotificationRequest") proto.RegisterType((*GetNotificationResponse)(nil), "clairpb.GetNotificationResponse") - proto.RegisterType((*DeleteNotificationRequest)(nil), "clairpb.DeleteNotificationRequest") + proto.RegisterType((*MarkNotificationAsReadRequest)(nil), "clairpb.MarkNotificationAsReadRequest") } // Reference imports to suppress errors if they are not otherwise used. @@ -706,7 +742,7 @@ var _AncestryService_serviceDesc = grpc.ServiceDesc{ type NotificationServiceClient interface { GetNotification(ctx context.Context, in *GetNotificationRequest, opts ...grpc.CallOption) (*GetNotificationResponse, error) - DeleteNotification(ctx context.Context, in *DeleteNotificationRequest, opts ...grpc.CallOption) (*google_protobuf1.Empty, error) + MarkNotificationAsRead(ctx context.Context, in *MarkNotificationAsReadRequest, opts ...grpc.CallOption) (*google_protobuf1.Empty, error) } type notificationServiceClient struct { @@ -726,9 +762,9 @@ func (c *notificationServiceClient) GetNotification(ctx context.Context, in *Get return out, nil } -func (c *notificationServiceClient) DeleteNotification(ctx context.Context, in *DeleteNotificationRequest, opts ...grpc.CallOption) (*google_protobuf1.Empty, error) { +func (c *notificationServiceClient) MarkNotificationAsRead(ctx context.Context, in *MarkNotificationAsReadRequest, opts ...grpc.CallOption) (*google_protobuf1.Empty, error) { out := new(google_protobuf1.Empty) - err := grpc.Invoke(ctx, "/clairpb.NotificationService/DeleteNotification", in, out, c.cc, opts...) + err := grpc.Invoke(ctx, "/clairpb.NotificationService/MarkNotificationAsRead", in, out, c.cc, opts...) if err != nil { return nil, err } @@ -739,7 +775,7 @@ func (c *notificationServiceClient) DeleteNotification(ctx context.Context, in * type NotificationServiceServer interface { GetNotification(context.Context, *GetNotificationRequest) (*GetNotificationResponse, error) - DeleteNotification(context.Context, *DeleteNotificationRequest) (*google_protobuf1.Empty, error) + MarkNotificationAsRead(context.Context, *MarkNotificationAsReadRequest) (*google_protobuf1.Empty, error) } func RegisterNotificationServiceServer(s *grpc.Server, srv NotificationServiceServer) { @@ -764,20 +800,20 @@ func _NotificationService_GetNotification_Handler(srv interface{}, ctx context.C return interceptor(ctx, in, info, handler) } -func _NotificationService_DeleteNotification_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(DeleteNotificationRequest) +func _NotificationService_MarkNotificationAsRead_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(MarkNotificationAsReadRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { - return srv.(NotificationServiceServer).DeleteNotification(ctx, in) + return srv.(NotificationServiceServer).MarkNotificationAsRead(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/clairpb.NotificationService/DeleteNotification", + FullMethod: "/clairpb.NotificationService/MarkNotificationAsRead", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(NotificationServiceServer).DeleteNotification(ctx, req.(*DeleteNotificationRequest)) + return srv.(NotificationServiceServer).MarkNotificationAsRead(ctx, req.(*MarkNotificationAsReadRequest)) } return interceptor(ctx, in, info, handler) } @@ -791,8 +827,8 @@ var _NotificationService_serviceDesc = grpc.ServiceDesc{ Handler: _NotificationService_GetNotification_Handler, }, { - MethodName: "DeleteNotification", - Handler: _NotificationService_DeleteNotification_Handler, + MethodName: "MarkNotificationAsRead", + Handler: _NotificationService_MarkNotificationAsRead_Handler, }, }, Streams: []grpc.StreamDesc{}, @@ -802,71 +838,78 @@ var _NotificationService_serviceDesc = grpc.ServiceDesc{ func init() { proto.RegisterFile("clair.proto", fileDescriptor0) } var fileDescriptor0 = []byte{ - // 1042 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xa4, 0x56, 0xdd, 0x6e, 0x1b, 0x45, - 0x14, 0xd6, 0xda, 0x71, 0x6c, 0x1f, 0xdb, 0x49, 0x3a, 0x49, 0xd3, 0x8d, 0x93, 0x88, 0x74, 0x11, - 0xa5, 0xaa, 0xc0, 0x56, 0xd3, 0x9b, 0x12, 0x01, 0x82, 0xa8, 0x6d, 0xa8, 0x04, 0xa5, 0x5a, 0xaa, - 0x5c, 0x70, 0x63, 0x4d, 0xbc, 0x27, 0xce, 0x28, 0xeb, 0x59, 0xb3, 0x3b, 0x76, 0x62, 0x55, 0xdc, - 0xf0, 0x04, 0x54, 0x3c, 0x06, 0x2f, 0xc0, 0x15, 0x2f, 0xd1, 0x27, 0x00, 0xf1, 0x16, 0xdc, 0xa0, - 0xf9, 0xf5, 0xae, 0x63, 0x23, 0x7e, 0xae, 0x3c, 0xe7, 0x7c, 0xe7, 0xe7, 0x3b, 0x3f, 0x33, 0x5e, - 0x68, 0xf4, 0x63, 0xca, 0xd2, 0xce, 0x28, 0x4d, 0x44, 0x42, 0xaa, 0x4a, 0x18, 0x9d, 0xb5, 0xf7, - 0x06, 0x49, 0x32, 0x88, 0xb1, 0x4b, 0x47, 0xac, 0x4b, 0x39, 0x4f, 0x04, 0x15, 0x2c, 0xe1, 0x99, - 0x36, 0x6b, 0xef, 0x1a, 0x54, 0x49, 0x67, 0xe3, 0xf3, 0x2e, 0x0e, 0x47, 0x62, 0xaa, 0xc1, 0xe0, - 0x4d, 0x09, 0x5a, 0xa7, 0xe3, 0x98, 0x63, 0x4a, 0xcf, 0x58, 0xcc, 0xc4, 0x94, 0x10, 0x58, 0xe1, - 0x74, 0x88, 0xbe, 0x77, 0xe0, 0xdd, 0xaf, 0x87, 0xea, 0x4c, 0xde, 0x83, 0x35, 0xf9, 0x9b, 0x8d, - 0x68, 0x1f, 0x7b, 0x0a, 0x2d, 0x29, 0xb4, 0xe5, 0xb4, 0x2f, 0xa4, 0xd9, 0x01, 0x34, 0x22, 0xcc, - 0xfa, 0x29, 0x1b, 0xc9, 0xfc, 0x7e, 0x59, 0xd9, 0xe4, 0x55, 0x32, 0x78, 0xcc, 0xf8, 0xa5, 0xbf, - 0xa2, 0x83, 0xcb, 0x33, 0x69, 0x43, 0x2d, 0xc3, 0x09, 0xa6, 0x4c, 0x4c, 0xfd, 0x8a, 0xd2, 0x3b, - 0x59, 0x62, 0x43, 0x14, 0x34, 0xa2, 0x82, 0xfa, 0xab, 0x1a, 0xb3, 0x32, 0xd9, 0x81, 0xda, 0x39, - 0xbb, 0xc6, 0xa8, 0x77, 0x36, 0xf5, 0xab, 0x0a, 0xab, 0x2a, 0xf9, 0x78, 0x4a, 0x3e, 0x86, 0x5b, - 0x1a, 0x62, 0xbc, 0x77, 0x8e, 0x54, 0x8c, 0x53, 0xcc, 0xfc, 0xda, 0x41, 0xf9, 0x7e, 0xe3, 0x70, - 0xa3, 0x63, 0xba, 0xd6, 0x79, 0xa6, 0x81, 0x70, 0x5d, 0x99, 0x3e, 0xe7, 0x46, 0xce, 0x82, 0xdf, - 0x3d, 0xa8, 0x1a, 0xe1, 0xff, 0x74, 0xc3, 0x87, 0xea, 0x04, 0xd3, 0x6c, 0xd6, 0x09, 0x2b, 0xca, - 0x00, 0xe6, 0xd8, 0x3b, 0x4f, 0xd2, 0x21, 0x15, 0xa6, 0x1f, 0x2d, 0xa3, 0x7d, 0xa6, 0x94, 0xb2, - 0x40, 0x1a, 0x45, 0xba, 0x40, 0xdd, 0x98, 0xaa, 0x92, 0x8f, 0xa7, 0xe4, 0x33, 0x58, 0x9f, 0xe4, - 0xa6, 0xc6, 0x30, 0xf3, 0x57, 0x55, 0x79, 0xdb, 0xae, 0xbc, 0xc2, 0x54, 0xc3, 0x79, 0xf3, 0x60, - 0x08, 0xb5, 0xcf, 0x79, 0x1f, 0x33, 0x91, 0x2e, 0x1d, 0x39, 0xf2, 0x01, 0xe3, 0xd8, 0xb3, 0x45, - 0xc8, 0x22, 0x2b, 0x61, 0x4b, 0x6b, 0x4f, 0x4d, 0x29, 0xf7, 0x60, 0x35, 0xa6, 0x53, 0x4c, 0x33, - 0xbf, 0xac, 0xf2, 0xaf, 0xb9, 0xfc, 0x5f, 0x4a, 0x75, 0x68, 0xd0, 0xe0, 0x47, 0x0f, 0xf6, 0x95, - 0x26, 0x7b, 0xce, 0x45, 0x9a, 0x44, 0xe3, 0x3e, 0xe3, 0x83, 0x19, 0x45, 0x21, 0x67, 0xd6, 0xca, - 0x73, 0x9c, 0x2a, 0x36, 0xcb, 0x0b, 0x2a, 0x1a, 0x93, 0x87, 0x8e, 0x47, 0x49, 0xf1, 0xd8, 0x71, - 0x6e, 0x5f, 0xa7, 0x11, 0xa6, 0x18, 0xa9, 0xe4, 0x72, 0x2e, 0x8e, 0xd2, 0x09, 0x6c, 0xcc, 0x63, - 0x64, 0x0b, 0x2a, 0x8c, 0x47, 0x78, 0xad, 0x92, 0x57, 0x42, 0x2d, 0x90, 0x7d, 0x00, 0xe5, 0x93, - 0x1f, 0x76, 0x3d, 0xb6, 0x4e, 0xc1, 0x13, 0xa8, 0xa8, 0x08, 0x0b, 0xfb, 0xf8, 0x3e, 0xac, 0x17, - 0x97, 0x45, 0x33, 0xac, 0x87, 0x6b, 0x85, 0x6d, 0xc9, 0x82, 0x9f, 0x3d, 0x68, 0xbe, 0x48, 0x04, - 0x3b, 0x67, 0x7d, 0x6a, 0xef, 0xca, 0x8d, 0x68, 0x3e, 0x54, 0xfb, 0x29, 0x52, 0x81, 0x91, 0xa1, - 0x61, 0x45, 0x79, 0x53, 0xb8, 0xf2, 0xc6, 0xc8, 0xac, 0x9b, 0x93, 0xa5, 0x57, 0x84, 0x31, 0x4a, - 0x2f, 0xbd, 0x68, 0x56, 0x94, 0xf5, 0xc6, 0x6c, 0xc8, 0x84, 0xda, 0xaf, 0x4a, 0xa8, 0x05, 0x72, - 0x17, 0x56, 0x46, 0x74, 0x80, 0xea, 0xc6, 0x35, 0x0e, 0x5b, 0xae, 0x95, 0x2f, 0xe9, 0x00, 0x43, - 0x05, 0x05, 0xbf, 0x78, 0xb0, 0x22, 0x45, 0xd9, 0x1b, 0x71, 0xc1, 0xb2, 0x9e, 0x48, 0x2e, 0x91, - 0x1b, 0xae, 0x75, 0xa9, 0x79, 0x25, 0x15, 0x12, 0xe6, 0x78, 0x2d, 0x0c, 0x6c, 0x5a, 0x27, 0x35, - 0x1a, 0x7e, 0x0c, 0xe5, 0x24, 0xd6, 0x84, 0x1b, 0x87, 0xf7, 0x8a, 0xbb, 0xb3, 0x6c, 0x53, 0x42, - 0xe9, 0x22, 0x3d, 0x39, 0x5e, 0xa9, 0x7a, 0xfe, 0x85, 0x27, 0xc7, 0xab, 0xe0, 0x6d, 0x09, 0x36, - 0x5f, 0x26, 0x99, 0xb0, 0xeb, 0x1f, 0xe2, 0x77, 0x63, 0xcc, 0x04, 0x79, 0x17, 0x5a, 0xd4, 0xa8, - 0x7a, 0xb9, 0xc6, 0x37, 0xad, 0x52, 0x2d, 0xc8, 0x36, 0xac, 0x9a, 0x2b, 0xab, 0x6b, 0x31, 0x12, - 0xf9, 0x74, 0xee, 0x1e, 0xcc, 0x18, 0x2d, 0x48, 0xa5, 0x74, 0x85, 0xfb, 0xd1, 0xfe, 0xd5, 0x83, - 0xba, 0xd3, 0x2e, 0x1c, 0x3d, 0x91, 0x43, 0x11, 0x17, 0x26, 0xaf, 0x3a, 0x93, 0xaf, 0xa0, 0x7a, - 0x81, 0x34, 0x9a, 0xa5, 0x7d, 0xf4, 0xcf, 0xd2, 0x76, 0xbe, 0xd0, 0x5e, 0x4f, 0xb9, 0x44, 0x6d, - 0x8c, 0xf6, 0x11, 0x34, 0xf3, 0x00, 0xd9, 0x80, 0xf2, 0x25, 0x4e, 0x0d, 0x0b, 0x79, 0x94, 0xfb, - 0x32, 0xa1, 0xf1, 0xd8, 0x5e, 0x02, 0x2d, 0x1c, 0x95, 0x1e, 0x7b, 0xc1, 0x27, 0xb0, 0x55, 0x4c, - 0x99, 0x8d, 0x12, 0x9e, 0x2d, 0x7a, 0x47, 0xbc, 0x05, 0xef, 0x48, 0xf0, 0xc6, 0x03, 0x72, 0x82, - 0xff, 0x6d, 0x26, 0x0f, 0x61, 0xeb, 0x8a, 0x89, 0x8b, 0xde, 0xfc, 0x8b, 0x28, 0x39, 0xd6, 0xc2, - 0x4d, 0x89, 0x9d, 0x16, 0x21, 0x19, 0x57, 0xb9, 0xb8, 0x3f, 0x87, 0xb2, 0xb2, 0x6d, 0x4a, 0xa5, - 0xfb, 0x1f, 0x48, 0x61, 0xb3, 0x40, 0xc9, 0x54, 0xf4, 0x21, 0xd4, 0x6c, 0x7a, 0xf3, 0x46, 0xdd, - 0x72, 0x5d, 0x77, 0xc6, 0xce, 0x84, 0x7c, 0x00, 0x35, 0x97, 0xa5, 0xb4, 0xe4, 0x2f, 0xc8, 0x59, - 0x04, 0xa7, 0xb0, 0x7d, 0x82, 0x22, 0xff, 0x0e, 0xd8, 0x56, 0x10, 0x73, 0x29, 0x3d, 0x3b, 0xff, - 0x01, 0xce, 0xae, 0x6f, 0x29, 0x7f, 0x7d, 0xed, 0xf6, 0x94, 0x67, 0xdb, 0x13, 0xbc, 0x82, 0x3b, - 0x37, 0xe2, 0x9a, 0x7a, 0x3e, 0x82, 0x26, 0xcf, 0xe9, 0x4d, 0x4d, 0xb7, 0x1d, 0xc9, 0x82, 0x53, - 0xc1, 0x34, 0xe8, 0xc2, 0xce, 0x13, 0xf5, 0x92, 0x2c, 0x21, 0x3c, 0xbf, 0xc4, 0x87, 0xbf, 0x79, - 0xb0, 0x6e, 0x7b, 0xf4, 0x0d, 0xa6, 0x13, 0xd6, 0x47, 0x42, 0xa1, 0x99, 0xdf, 0x1c, 0xb2, 0xf7, - 0x77, 0x3b, 0xdc, 0xde, 0x5f, 0x82, 0xea, 0x62, 0x82, 0xad, 0x1f, 0xde, 0xfe, 0xf1, 0x53, 0x69, - 0x2d, 0xa8, 0x77, 0xed, 0x00, 0x8e, 0xbc, 0x07, 0xe4, 0x12, 0x1a, 0xb9, 0x49, 0x92, 0x5d, 0x17, - 0xe3, 0xe6, 0xca, 0xb5, 0xf7, 0x16, 0x83, 0x26, 0xfe, 0x5d, 0x15, 0x7f, 0x97, 0xec, 0xb8, 0xf8, - 0xdd, 0xd7, 0x85, 0x0d, 0xfd, 0xfe, 0xf0, 0x4f, 0x0f, 0x36, 0xf3, 0xfd, 0xb0, 0x75, 0x66, 0xb0, - 0x3e, 0x37, 0x02, 0xf2, 0x4e, 0x3e, 0xd7, 0x82, 0x1e, 0xb6, 0x0f, 0x96, 0x1b, 0x18, 0x42, 0xfb, - 0x8a, 0xd0, 0x1d, 0x72, 0xbb, 0x9b, 0x9f, 0x4c, 0xd6, 0x7d, 0xad, 0xc8, 0x90, 0x04, 0xc8, 0xcd, - 0x09, 0x91, 0xc0, 0x85, 0x5d, 0x3a, 0xbe, 0xf6, 0x76, 0x47, 0x7f, 0x37, 0x76, 0xec, 0x77, 0x63, - 0xe7, 0xa9, 0xfc, 0x6e, 0xb4, 0x09, 0x1f, 0x2c, 0x4e, 0x78, 0x5c, 0xff, 0xd6, 0x7e, 0x96, 0x9e, - 0xad, 0x2a, 0xcf, 0x47, 0x7f, 0x05, 0x00, 0x00, 0xff, 0xff, 0x5b, 0x9c, 0x1d, 0xc4, 0xb5, 0x0a, - 0x00, 0x00, + // 1156 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xa4, 0x56, 0x4d, 0x6f, 0xdb, 0x46, + 0x13, 0x06, 0x25, 0xcb, 0x92, 0x46, 0xf2, 0xd7, 0x5a, 0x51, 0x68, 0xd9, 0x46, 0x1c, 0xbe, 0x78, + 0xd3, 0x20, 0x6d, 0x25, 0x54, 0xf6, 0xa1, 0x35, 0xd2, 0x8f, 0xa4, 0x4e, 0xd2, 0x02, 0x49, 0x10, + 0x30, 0xa9, 0x0f, 0xbd, 0x08, 0x6b, 0x72, 0x64, 0x13, 0xa6, 0x48, 0x96, 0xbb, 0xb2, 0x2c, 0x04, + 0xbd, 0xb4, 0xc7, 0x9e, 0xda, 0xfe, 0x8f, 0xfe, 0x84, 0x5e, 0x0b, 0xf4, 0x9a, 0x7b, 0x81, 0x02, + 0xbd, 0xf6, 0x3f, 0x14, 0xbb, 0xdc, 0xa5, 0x48, 0x89, 0x0e, 0x8c, 0xf6, 0x24, 0xce, 0xcc, 0x33, + 0xbb, 0x33, 0xcf, 0x33, 0x3b, 0x10, 0x34, 0x1c, 0x9f, 0x7a, 0x71, 0x37, 0x8a, 0x43, 0x1e, 0x92, + 0xaa, 0x34, 0xa2, 0x93, 0xce, 0xce, 0x69, 0x18, 0x9e, 0xfa, 0xd8, 0xa3, 0x91, 0xd7, 0xa3, 0x41, + 0x10, 0x72, 0xca, 0xbd, 0x30, 0x60, 0x09, 0xac, 0xb3, 0xad, 0xa2, 0xd2, 0x3a, 0x19, 0x0f, 0x7b, + 0x38, 0x8a, 0xf8, 0x54, 0x05, 0x6f, 0xcd, 0x07, 0xb9, 0x37, 0x42, 0xc6, 0xe9, 0x28, 0x4a, 0x00, + 0xd6, 0x4f, 0x25, 0x58, 0x39, 0x1e, 0xfb, 0x01, 0xc6, 0xf4, 0xc4, 0xf3, 0x3d, 0x3e, 0x25, 0x04, + 0x96, 0x02, 0x3a, 0x42, 0xd3, 0xd8, 0x33, 0xee, 0xd6, 0x6d, 0xf9, 0x4d, 0xfe, 0x0f, 0xab, 0xe2, + 0x97, 0x45, 0xd4, 0xc1, 0x81, 0x8c, 0x96, 0x64, 0x74, 0x25, 0xf5, 0x3e, 0x17, 0xb0, 0x3d, 0x68, + 0xb8, 0xc8, 0x9c, 0xd8, 0x8b, 0x44, 0x81, 0x66, 0x59, 0x62, 0xb2, 0x2e, 0x71, 0xb8, 0xef, 0x05, + 0xe7, 0xe6, 0x52, 0x72, 0xb8, 0xf8, 0x26, 0x1d, 0xa8, 0x31, 0xbc, 0xc0, 0xd8, 0xe3, 0x53, 0xb3, + 0x22, 0xfd, 0xa9, 0x2d, 0x62, 0x23, 0xe4, 0xd4, 0xa5, 0x9c, 0x9a, 0xcb, 0x49, 0x4c, 0xdb, 0x64, + 0x0b, 0x6a, 0x43, 0xef, 0x12, 0xdd, 0xc1, 0xc9, 0xd4, 0xac, 0xca, 0x58, 0x55, 0xda, 0x0f, 0xa7, + 0xe4, 0x63, 0xd8, 0xa0, 0xc3, 0x21, 0x3a, 0x1c, 0xdd, 0xc1, 0x05, 0xc6, 0x4c, 0xd0, 0x65, 0xd6, + 0xf6, 0xca, 0x77, 0x1b, 0xfd, 0xf5, 0xae, 0xa2, 0xb5, 0xfb, 0x18, 0x29, 0x1f, 0xc7, 0x68, 0xaf, + 0x6b, 0xe8, 0xb1, 0x42, 0x5a, 0x3f, 0x18, 0xd0, 0xf8, 0x5c, 0xa0, 0x5e, 0x72, 0xca, 0xc7, 0x8c, + 0x98, 0x50, 0xf5, 0x3d, 0xc6, 0x31, 0x66, 0xa6, 0xb1, 0x57, 0x16, 0x17, 0x29, 0x93, 0xec, 0x40, + 0xdd, 0x45, 0x8e, 0x0e, 0x0f, 0x63, 0x66, 0x96, 0x64, 0x6c, 0xe6, 0x20, 0x47, 0xb0, 0xee, 0x53, + 0xc6, 0x07, 0xe3, 0xc8, 0xa5, 0x1c, 0x07, 0x82, 0x7b, 0x49, 0x4a, 0xa3, 0xdf, 0xe9, 0x26, 0xc2, + 0x74, 0xb5, 0x30, 0xdd, 0x57, 0x5a, 0x18, 0x7b, 0x55, 0xe4, 0x7c, 0x25, 0x53, 0x84, 0xd3, 0xfa, + 0xcd, 0x80, 0xaa, 0xaa, 0xf5, 0xbf, 0x88, 0x63, 0x42, 0x55, 0x51, 0xa1, 0x84, 0xd1, 0xa6, 0x38, + 0x40, 0x7d, 0x0e, 0x86, 0x61, 0x3c, 0xa2, 0x5c, 0xc9, 0xb3, 0xa2, 0xbc, 0x8f, 0xa5, 0x93, 0x7c, + 0x06, 0x6b, 0x17, 0x99, 0x49, 0xf1, 0x90, 0x99, 0x15, 0x49, 0x69, 0x3b, 0xa5, 0x34, 0x37, 0x49, + 0xf6, 0x3c, 0xdc, 0xfa, 0xdd, 0x80, 0xda, 0x83, 0xc0, 0x41, 0xc6, 0xe3, 0xe2, 0x39, 0x7b, 0x0f, + 0x6a, 0xc3, 0xa4, 0xd3, 0x84, 0xcd, 0x22, 0xb9, 0x52, 0x04, 0xb9, 0x03, 0xcb, 0x3e, 0x9d, 0x0a, + 0x55, 0xca, 0x12, 0xbb, 0x9a, 0x62, 0x9f, 0x0a, 0xb7, 0xad, 0xa2, 0xe4, 0x1d, 0x58, 0x63, 0x0e, + 0x0d, 0x02, 0x74, 0x07, 0x5a, 0xc6, 0x25, 0x29, 0xd5, 0xaa, 0x72, 0x3f, 0x55, 0x6a, 0xbe, 0x0b, + 0x1b, 0x1a, 0x38, 0x53, 0xb5, 0x22, 0xa1, 0xeb, 0x2a, 0x70, 0xa4, 0xfd, 0xd6, 0x36, 0x54, 0xe4, + 0x35, 0xa2, 0x91, 0x33, 0xca, 0xce, 0x74, 0x23, 0xe2, 0xdb, 0xfa, 0xc3, 0x80, 0xe6, 0xf3, 0x90, + 0x7b, 0x43, 0xcf, 0xa1, 0x7a, 0xf0, 0x17, 0xba, 0x35, 0xa1, 0xea, 0xc4, 0x48, 0x39, 0xba, 0x4a, + 0x31, 0x6d, 0x8a, 0xb1, 0x0f, 0x64, 0x36, 0xba, 0x4a, 0xac, 0xd4, 0x16, 0x59, 0x2e, 0xfa, 0x28, + 0xb2, 0x12, 0x99, 0xb4, 0x49, 0x0e, 0xa0, 0x1c, 0xfa, 0xae, 0x7c, 0x43, 0x8d, 0xbe, 0x95, 0x92, + 0xf1, 0x82, 0x9e, 0xa2, 0xab, 0x95, 0xf1, 0x51, 0x09, 0xe0, 0x21, 0xb3, 0x05, 0x5c, 0x64, 0x05, + 0x38, 0x91, 0xaf, 0xeb, 0x9a, 0x59, 0x01, 0x4e, 0xac, 0x4f, 0x61, 0xf3, 0xcb, 0xc0, 0xc5, 0x4b, + 0x74, 0xb5, 0xa0, 0x72, 0xc8, 0x5a, 0x50, 0xf1, 0x84, 0x5b, 0xf6, 0x59, 0xb1, 0x13, 0x23, 0x6d, + 0xbe, 0x34, 0x6b, 0xde, 0xfa, 0xdb, 0x80, 0xad, 0x2b, 0xef, 0x20, 0xb7, 0xa1, 0xe9, 0x8c, 0xe3, + 0x18, 0x03, 0x3e, 0x88, 0xe8, 0xa9, 0xa6, 0xad, 0xa1, 0x7c, 0x22, 0x8f, 0x6c, 0x43, 0x3d, 0xc0, + 0x4b, 0x15, 0x2f, 0x29, 0x92, 0xf0, 0x32, 0x09, 0xb6, 0xa0, 0xe2, 0x7b, 0x23, 0x8f, 0x4b, 0xf6, + 0x2a, 0x76, 0x62, 0x90, 0xfb, 0xb0, 0x92, 0x1d, 0xc9, 0xa9, 0x24, 0xf0, 0xea, 0xf9, 0xcd, 0x83, + 0xc9, 0x7d, 0x00, 0x9a, 0x56, 0xa8, 0x46, 0x7f, 0x27, 0x4d, 0x2d, 0x60, 0xc3, 0xce, 0xe0, 0xad, + 0x37, 0x25, 0xd8, 0x7c, 0x11, 0x32, 0xae, 0x01, 0x36, 0x7e, 0x33, 0x46, 0xc6, 0xc9, 0xff, 0x60, + 0x45, 0xa1, 0xa6, 0x83, 0xcc, 0x84, 0x34, 0x69, 0x96, 0xd6, 0x36, 0x2c, 0xab, 0x97, 0x99, 0x34, + 0xaa, 0x2c, 0xf2, 0xc9, 0xdc, 0x0b, 0xb8, 0x33, 0x93, 0x6f, 0xf1, 0x2a, 0xe9, 0xcb, 0xbd, 0x8c, + 0xce, 0xaf, 0x06, 0xd4, 0x53, 0x6f, 0xd1, 0x20, 0x0b, 0x5f, 0x44, 0xf9, 0x99, 0x96, 0x4e, 0x7c, + 0x93, 0x67, 0x50, 0x3d, 0x43, 0xea, 0xce, 0xae, 0xdd, 0xbf, 0xde, 0xb5, 0xdd, 0x2f, 0x92, 0xac, + 0x47, 0x81, 0x88, 0xea, 0x33, 0x3a, 0x87, 0xd0, 0xcc, 0x06, 0xc8, 0x3a, 0x94, 0xcf, 0x71, 0xaa, + 0xaa, 0x10, 0x9f, 0x42, 0xcd, 0x0b, 0xea, 0x8f, 0xb5, 0xcc, 0x89, 0x71, 0x58, 0xfa, 0xd0, 0xb0, + 0x8e, 0xa0, 0x95, 0xbf, 0x92, 0x45, 0x61, 0xc0, 0xc4, 0x22, 0x59, 0x66, 0x72, 0x77, 0xcb, 0x63, + 0x1a, 0xfd, 0x56, 0x5a, 0x61, 0x66, 0xaf, 0xdb, 0x0a, 0x63, 0xfd, 0x68, 0x00, 0x79, 0x82, 0xff, + 0x4e, 0x9a, 0x0f, 0xa0, 0x35, 0xf1, 0xf8, 0xd9, 0x60, 0x7e, 0x35, 0x8a, 0x52, 0x6b, 0xf6, 0xa6, + 0x88, 0x1d, 0xe7, 0x43, 0xe2, 0x5c, 0x99, 0x92, 0xae, 0xba, 0xb2, 0xc4, 0x36, 0x85, 0x53, 0x6d, + 0x39, 0x66, 0xc5, 0xb0, 0x99, 0x2b, 0x49, 0x35, 0xf6, 0x3e, 0xd4, 0xf4, 0xf5, 0xaa, 0xb5, 0x8d, + 0xb4, 0xb5, 0x14, 0x9c, 0x42, 0x32, 0x3c, 0x94, 0xae, 0xc1, 0xc3, 0x2f, 0x06, 0xb4, 0x9f, 0x20, + 0xcf, 0x2e, 0x2e, 0xcd, 0xc5, 0x01, 0xb4, 0x43, 0xdf, 0xcd, 0x75, 0x39, 0xcd, 0x3e, 0xcd, 0x56, + 0xe8, 0xbb, 0xb9, 0xd7, 0x23, 0x9f, 0xe1, 0x01, 0xb4, 0x03, 0x9c, 0x14, 0x65, 0x25, 0x4a, 0xb6, + 0x02, 0x9c, 0x2c, 0x66, 0x15, 0x3f, 0x5e, 0xbd, 0x44, 0x96, 0x32, 0x4b, 0xe4, 0x15, 0xdc, 0x5c, + 0xa8, 0x57, 0x11, 0xf5, 0x11, 0x34, 0x83, 0x8c, 0x5f, 0x91, 0x75, 0x23, 0xed, 0x3f, 0x97, 0x94, + 0x83, 0x5a, 0xfb, 0xb0, 0xfb, 0x8c, 0xc6, 0xe7, 0x59, 0xc4, 0x03, 0x66, 0x23, 0x75, 0x35, 0x19, + 0x05, 0xcb, 0xbc, 0xff, 0xa7, 0x01, 0x6b, 0x5a, 0x80, 0x97, 0x18, 0x5f, 0x78, 0x0e, 0x12, 0x0a, + 0xcd, 0xec, 0x74, 0x92, 0x9d, 0xb7, 0xbd, 0x93, 0xce, 0xee, 0x15, 0xd1, 0xa4, 0x21, 0xab, 0xf5, + 0xdd, 0x9b, 0xbf, 0x7e, 0x2e, 0xad, 0x5a, 0xf5, 0x9e, 0x56, 0xf7, 0xd0, 0xb8, 0x47, 0xce, 0xa1, + 0x91, 0x19, 0x13, 0xb2, 0x9d, 0x9e, 0xb1, 0x38, 0xcf, 0x9d, 0x9d, 0xe2, 0xa0, 0x3a, 0xff, 0xb6, + 0x3c, 0x7f, 0x9b, 0x6c, 0xa5, 0xe7, 0xf7, 0x5e, 0xe7, 0xc6, 0xff, 0xdb, 0xfe, 0xf7, 0x25, 0xd8, + 0xcc, 0xb2, 0xa2, 0xfb, 0x64, 0xb0, 0x36, 0x27, 0x03, 0xb9, 0x95, 0xbd, 0xab, 0x60, 0xa0, 0x3a, + 0x7b, 0x57, 0x03, 0x54, 0x41, 0xbb, 0xb2, 0xa0, 0x9b, 0xe4, 0x46, 0x2f, 0xab, 0x0e, 0xeb, 0xbd, + 0x96, 0xc5, 0x90, 0x09, 0xb4, 0x8b, 0x55, 0x22, 0xb3, 0x2d, 0xf8, 0x56, 0x19, 0x3b, 0xed, 0x85, + 0x3f, 0x61, 0x8f, 0xc4, 0x5f, 0x67, 0x7d, 0xf1, 0xbd, 0xe2, 0x8b, 0x1f, 0xd6, 0xbf, 0xd6, 0xff, + 0xcc, 0x4f, 0x96, 0x65, 0xe6, 0xfe, 0x3f, 0x01, 0x00, 0x00, 0xff, 0xff, 0x6e, 0xc1, 0x1d, 0xff, + 0xb8, 0x0b, 0x00, 0x00, } diff --git a/api/v2/clairpb/clair.pb.gw.go b/api/v3/clairpb/clair.pb.gw.go similarity index 90% rename from api/v2/clairpb/clair.pb.gw.go rename to api/v3/clairpb/clair.pb.gw.go index f45c4f86..21bfdf12 100644 --- a/api/v2/clairpb/clair.pb.gw.go +++ b/api/v3/clairpb/clair.pb.gw.go @@ -112,8 +112,8 @@ func request_NotificationService_GetNotification_0(ctx context.Context, marshale } -func request_NotificationService_DeleteNotification_0(ctx context.Context, marshaler runtime.Marshaler, client NotificationServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq DeleteNotificationRequest +func request_NotificationService_MarkNotificationAsRead_0(ctx context.Context, marshaler runtime.Marshaler, client NotificationServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var protoReq MarkNotificationAsReadRequest var metadata runtime.ServerMetadata var ( @@ -134,7 +134,7 @@ func request_NotificationService_DeleteNotification_0(ctx context.Context, marsh return nil, metadata, err } - msg, err := client.DeleteNotification(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) + msg, err := client.MarkNotificationAsRead(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) return msg, metadata, err } @@ -301,7 +301,7 @@ func RegisterNotificationServiceHandler(ctx context.Context, mux *runtime.ServeM }) - mux.Handle("DELETE", pattern_NotificationService_DeleteNotification_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + mux.Handle("DELETE", pattern_NotificationService_MarkNotificationAsRead_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { ctx, cancel := context.WithCancel(ctx) defer cancel() if cn, ok := w.(http.CloseNotifier); ok { @@ -319,14 +319,14 @@ func RegisterNotificationServiceHandler(ctx context.Context, mux *runtime.ServeM runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) return } - resp, md, err := request_NotificationService_DeleteNotification_0(rctx, inboundMarshaler, client, req, pathParams) + resp, md, err := request_NotificationService_MarkNotificationAsRead_0(rctx, inboundMarshaler, client, req, pathParams) ctx = runtime.NewServerMetadataContext(ctx, md) if err != nil { runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) return } - forward_NotificationService_DeleteNotification_0(ctx, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + forward_NotificationService_MarkNotificationAsRead_0(ctx, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) }) @@ -336,11 +336,11 @@ func RegisterNotificationServiceHandler(ctx context.Context, mux *runtime.ServeM var ( pattern_NotificationService_GetNotification_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 1, 0, 4, 1, 5, 1}, []string{"notifications", "name"}, "")) - pattern_NotificationService_DeleteNotification_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 1, 0, 4, 1, 5, 1}, []string{"notifications", "name"}, "")) + pattern_NotificationService_MarkNotificationAsRead_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 1, 0, 4, 1, 5, 1}, []string{"notifications", "name"}, "")) ) var ( forward_NotificationService_GetNotification_0 = runtime.ForwardResponseMessage - forward_NotificationService_DeleteNotification_0 = runtime.ForwardResponseMessage + forward_NotificationService_MarkNotificationAsRead_0 = runtime.ForwardResponseMessage ) diff --git a/api/v2/clairpb/clair.proto b/api/v3/clairpb/clair.proto similarity index 57% rename from api/v2/clairpb/clair.proto rename to api/v3/clairpb/clair.proto index c2e8fb06..8d704230 100644 --- a/api/v2/clairpb/clair.proto +++ b/api/v3/clairpb/clair.proto @@ -18,6 +18,7 @@ option go_package = "clairpb"; package clairpb; import "google/api/annotations.proto"; import "google/protobuf/empty.proto"; +import "google/protobuf/timestamp.proto"; message Vulnerability { string name = 1; @@ -26,60 +27,72 @@ message Vulnerability { string link = 4; string severity = 5; string metadata = 6; + // fixed_by exists when vulnerability is under feature. string fixed_by = 7; - repeated Feature fixed_in_features = 8; + // affected_versions exists when vulnerability is under notification. + repeated Feature affected_versions = 8; } -message Feature { +message ClairStatus { + // listers and detectors are processors implemented in this Clair and used to + // scan ancestries + repeated string listers = 1; + repeated string detectors = 2; + google.protobuf.Timestamp last_update_time = 3; +} + +message Feature{ string name = 1; string namespace_name = 2; string version = 3; string version_format = 4; - string added_by = 5; - repeated Vulnerability vulnerabilities = 6; + repeated Vulnerability vulnerabilities = 5; } message Ancestry { string name = 1; - int32 engine_version = 2; + repeated Feature features = 2; repeated Layer layers = 3; -} -message LayersIntroducingVulnerabilty { - Vulnerability vulnerability = 1; - repeated OrderedLayerName layers = 2; -} - -message OrderedLayerName { - int32 index = 1; - string layer_name = 2; + // scanned_listers and scanned_detectors are used to scan this ancestry, it + // may be different from listers and detectors in ClairStatus since the + // ancestry could be scanned by previous version of Clair. + repeated string scanned_listers = 4; + repeated string scanned_detectors = 5; } message Layer { - string name = 1; - repeated string namespace_names = 2; + string hash = 1; } message Notification { string name = 1; string created = 2; string notified = 3; - string deleted = 4; - int32 limit = 5; - Page page = 6; + string deleted = 4; + PagedVulnerableAncestries old = 5; + PagedVulnerableAncestries new = 6; } -message Page { - string this_token = 1; - string next_token = 2; - LayersIntroducingVulnerabilty old = 3; - LayersIntroducingVulnerabilty new = 4; +message IndexedAncestryName { + // index is unique to name in all streams simultaneously streamed, increasing + // and larger than all indexes in previous page in same stream. + int32 index = 1; + string name = 2; } +message PagedVulnerableAncestries { + string current_page = 1; + // if next_page is empty, it signals the end of all pages. + string next_page = 2; + int32 limit = 3; + Vulnerability vulnerability = 4; + repeated IndexedAncestryName ancestries = 5; +} message PostAncestryRequest { message PostLayer { - string name = 1; + string hash = 1; string path = 2; map headers = 3; } @@ -89,7 +102,7 @@ message PostAncestryRequest { } message PostAncestryResponse { - int32 engine_version = 1; + ClairStatus status = 1; } message GetAncestryRequest { @@ -100,25 +113,25 @@ message GetAncestryRequest { message GetAncestryResponse { Ancestry ancestry = 1; - repeated Feature features = 2; + ClairStatus status = 2; } message GetNotificationRequest { - string page = 1; - int32 limit = 2; - string name = 3; + // if the vulnerability_page is empty, it implies the first page. + string old_vulnerability_page = 1; + string new_vulnerability_page = 2; + int32 limit = 3; + string name = 4; } message GetNotificationResponse { Notification notification = 1; } -message DeleteNotificationRequest { +message MarkNotificationAsReadRequest { string name = 1; } - - service AncestryService{ rpc PostAncestry(PostAncestryRequest) returns (PostAncestryResponse) { option (google.api.http) = { @@ -141,7 +154,7 @@ service NotificationService{ }; } - rpc DeleteNotification(DeleteNotificationRequest) returns (google.protobuf.Empty) { + rpc MarkNotificationAsRead(MarkNotificationAsReadRequest) returns (google.protobuf.Empty) { option (google.api.http) = { delete: "/notifications/{name}" }; diff --git a/api/v2/clairpb/clair.swagger.json b/api/v3/clairpb/clair.swagger.json similarity index 76% rename from api/v2/clairpb/clair.swagger.json rename to api/v3/clairpb/clair.swagger.json index 99623387..3e54a8a2 100644 --- a/api/v2/clairpb/clair.swagger.json +++ b/api/v3/clairpb/clair.swagger.json @@ -98,7 +98,14 @@ "type": "string" }, { - "name": "page", + "name": "old_vulnerability_page", + "description": "if the vulnerability_page is empty, it implies the first page.", + "in": "query", + "required": false, + "type": "string" + }, + { + "name": "new_vulnerability_page", "in": "query", "required": false, "type": "string" @@ -116,7 +123,7 @@ ] }, "delete": { - "operationId": "DeleteNotification", + "operationId": "MarkNotificationAsRead", "responses": { "200": { "description": "", @@ -143,7 +150,7 @@ "PostAncestryRequestPostLayer": { "type": "object", "properties": { - "name": { + "hash": { "type": "string" }, "path": { @@ -163,15 +170,52 @@ "name": { "type": "string" }, - "engine_version": { - "type": "integer", - "format": "int32" + "features": { + "type": "array", + "items": { + "$ref": "#/definitions/clairpbFeature" + } }, "layers": { "type": "array", "items": { "$ref": "#/definitions/clairpbLayer" } + }, + "scanned_listers": { + "type": "array", + "items": { + "type": "string" + }, + "description": "scanned_listers and scanned_detectors are used to scan this ancestry, it\nmay be different from listers and detectors in ClairStatus since the\nancestry could be scanned by previous version of Clair." + }, + "scanned_detectors": { + "type": "array", + "items": { + "type": "string" + } + } + } + }, + "clairpbClairStatus": { + "type": "object", + "properties": { + "listers": { + "type": "array", + "items": { + "type": "string" + }, + "title": "listers and detectors are processors implemented in this Clair and used to\nscan ancestries" + }, + "detectors": { + "type": "array", + "items": { + "type": "string" + } + }, + "last_update_time": { + "type": "string", + "format": "date-time" } } }, @@ -190,9 +234,6 @@ "version_format": { "type": "string" }, - "added_by": { - "type": "string" - }, "vulnerabilities": { "type": "array", "items": { @@ -207,11 +248,8 @@ "ancestry": { "$ref": "#/definitions/clairpbAncestry" }, - "features": { - "type": "array", - "items": { - "$ref": "#/definitions/clairpbFeature" - } + "status": { + "$ref": "#/definitions/clairpbClairStatus" } } }, @@ -223,31 +261,24 @@ } } }, - "clairpbLayer": { + "clairpbIndexedAncestryName": { "type": "object", "properties": { + "index": { + "type": "integer", + "format": "int32", + "description": "index is unique to name in all streams simultaneously streamed, increasing\nand larger than all indexes in previous page in same stream." + }, "name": { "type": "string" - }, - "namespace_names": { - "type": "array", - "items": { - "type": "string" - } } } }, - "clairpbLayersIntroducingVulnerabilty": { + "clairpbLayer": { "type": "object", "properties": { - "vulnerability": { - "$ref": "#/definitions/clairpbVulnerability" - }, - "layers": { - "type": "array", - "items": { - "$ref": "#/definitions/clairpbOrderedLayerName" - } + "hash": { + "type": "string" } } }, @@ -266,41 +297,36 @@ "deleted": { "type": "string" }, + "old": { + "$ref": "#/definitions/clairpbPagedVulnerableAncestries" + }, + "new": { + "$ref": "#/definitions/clairpbPagedVulnerableAncestries" + } + } + }, + "clairpbPagedVulnerableAncestries": { + "type": "object", + "properties": { + "current_page": { + "type": "string" + }, + "next_page": { + "type": "string", + "description": "if next_page is empty, it signals the end of all pages." + }, "limit": { "type": "integer", "format": "int32" }, - "page": { - "$ref": "#/definitions/clairpbPage" - } - } - }, - "clairpbOrderedLayerName": { - "type": "object", - "properties": { - "index": { - "type": "integer", - "format": "int32" + "vulnerability": { + "$ref": "#/definitions/clairpbVulnerability" }, - "layer_name": { - "type": "string" - } - } - }, - "clairpbPage": { - "type": "object", - "properties": { - "this_token": { - "type": "string" - }, - "next_token": { - "type": "string" - }, - "old": { - "$ref": "#/definitions/clairpbLayersIntroducingVulnerabilty" - }, - "new": { - "$ref": "#/definitions/clairpbLayersIntroducingVulnerabilty" + "ancestries": { + "type": "array", + "items": { + "$ref": "#/definitions/clairpbIndexedAncestryName" + } } } }, @@ -324,9 +350,8 @@ "clairpbPostAncestryResponse": { "type": "object", "properties": { - "engine_version": { - "type": "integer", - "format": "int32" + "status": { + "$ref": "#/definitions/clairpbClairStatus" } } }, @@ -352,13 +377,15 @@ "type": "string" }, "fixed_by": { - "type": "string" + "type": "string", + "description": "fixed_by exists when vulnerability is under feature." }, - "fixed_in_features": { + "affected_versions": { "type": "array", "items": { "$ref": "#/definitions/clairpbFeature" - } + }, + "description": "affected_versions exists when vulnerability is under notification." } } }, diff --git a/api/v3/clairpb/convert.go b/api/v3/clairpb/convert.go new file mode 100644 index 00000000..a3584587 --- /dev/null +++ b/api/v3/clairpb/convert.go @@ -0,0 +1,155 @@ +// Copyright 2017 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package clairpb + +import ( + "encoding/json" + "fmt" + + "github.com/coreos/clair/database" + "github.com/coreos/clair/ext/versionfmt" +) + +// PagedVulnerableAncestriesFromDatabaseModel converts database +// PagedVulnerableAncestries to api PagedVulnerableAncestries and assigns +// indexes to ancestries. +func PagedVulnerableAncestriesFromDatabaseModel(dbVuln *database.PagedVulnerableAncestries) (*PagedVulnerableAncestries, error) { + if dbVuln == nil { + return nil, nil + } + + vuln, err := VulnerabilityFromDatabaseModel(dbVuln.Vulnerability) + if err != nil { + return nil, err + } + + next := "" + if !dbVuln.End { + next = string(dbVuln.Next) + } + + vulnAncestry := PagedVulnerableAncestries{ + Vulnerability: vuln, + CurrentPage: string(dbVuln.Current), + NextPage: next, + Limit: int32(dbVuln.Limit), + } + + for index, ancestryName := range dbVuln.Affected { + indexedAncestry := IndexedAncestryName{ + Name: ancestryName, + Index: int32(index), + } + vulnAncestry.Ancestries = append(vulnAncestry.Ancestries, &indexedAncestry) + } + + return &vulnAncestry, nil +} + +// NotificationFromDatabaseModel converts database notification, old and new +// vulnerabilities' paged vulnerable ancestries to be api notification. +func NotificationFromDatabaseModel(dbNotification database.VulnerabilityNotificationWithVulnerable) (*Notification, error) { + var ( + noti Notification + err error + ) + + noti.Name = dbNotification.Name + if !dbNotification.Created.IsZero() { + noti.Created = fmt.Sprintf("%d", dbNotification.Created.Unix()) + } + + if !dbNotification.Notified.IsZero() { + noti.Notified = fmt.Sprintf("%d", dbNotification.Notified.Unix()) + } + + if !dbNotification.Deleted.IsZero() { + noti.Deleted = fmt.Sprintf("%d", dbNotification.Deleted.Unix()) + } + + noti.Old, err = PagedVulnerableAncestriesFromDatabaseModel(dbNotification.Old) + if err != nil { + return nil, err + } + + noti.New, err = PagedVulnerableAncestriesFromDatabaseModel(dbNotification.New) + if err != nil { + return nil, err + } + + return ¬i, nil +} + +func VulnerabilityFromDatabaseModel(dbVuln database.Vulnerability) (*Vulnerability, error) { + metaString := "" + if dbVuln.Metadata != nil { + metadataByte, err := json.Marshal(dbVuln.Metadata) + if err != nil { + return nil, err + } + metaString = string(metadataByte) + } + + return &Vulnerability{ + Name: dbVuln.Name, + NamespaceName: dbVuln.Namespace.Name, + Description: dbVuln.Description, + Link: dbVuln.Link, + Severity: string(dbVuln.Severity), + Metadata: metaString, + }, nil +} + +func VulnerabilityWithFixedInFromDatabaseModel(dbVuln database.VulnerabilityWithFixedIn) (*Vulnerability, error) { + vuln, err := VulnerabilityFromDatabaseModel(dbVuln.Vulnerability) + if err != nil { + return nil, err + } + + vuln.FixedBy = dbVuln.FixedInVersion + return vuln, nil +} + +// AncestryFromDatabaseModel converts database ancestry to api ancestry. +func AncestryFromDatabaseModel(dbAncestry database.Ancestry) *Ancestry { + ancestry := &Ancestry{ + Name: dbAncestry.Name, + } + for _, layer := range dbAncestry.Layers { + ancestry.Layers = append(ancestry.Layers, LayerFromDatabaseModel(layer)) + } + return ancestry +} + +// LayerFromDatabaseModel converts database layer to api layer. +func LayerFromDatabaseModel(dbLayer database.Layer) *Layer { + layer := Layer{Hash: dbLayer.Hash} + return &layer +} + +// NamespacedFeatureFromDatabaseModel converts database namespacedFeature to api Feature. +func NamespacedFeatureFromDatabaseModel(feature database.NamespacedFeature) *Feature { + version := feature.Feature.Version + if version == versionfmt.MaxVersion { + version = "None" + } + + return &Feature{ + Name: feature.Feature.Name, + NamespaceName: feature.Namespace.Name, + VersionFormat: feature.Namespace.VersionFormat, + Version: version, + } +} diff --git a/api/v3/rpc.go b/api/v3/rpc.go new file mode 100644 index 00000000..109bf17a --- /dev/null +++ b/api/v3/rpc.go @@ -0,0 +1,253 @@ +// Copyright 2017 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package v3 + +import ( + "fmt" + + "github.com/golang/protobuf/ptypes" + google_protobuf1 "github.com/golang/protobuf/ptypes/empty" + "golang.org/x/net/context" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "github.com/coreos/clair" + pb "github.com/coreos/clair/api/v3/clairpb" + "github.com/coreos/clair/database" + "github.com/coreos/clair/pkg/commonerr" +) + +// NotificationServer implements NotificationService interface for serving RPC. +type NotificationServer struct { + Store database.Datastore +} + +// AncestryServer implements AncestryService interface for serving RPC. +type AncestryServer struct { + Store database.Datastore +} + +// PostAncestry implements posting an ancestry via the Clair gRPC service. +func (s *AncestryServer) PostAncestry(ctx context.Context, req *pb.PostAncestryRequest) (*pb.PostAncestryResponse, error) { + ancestryName := req.GetAncestryName() + if ancestryName == "" { + return nil, status.Error(codes.InvalidArgument, "ancestry name should not be empty") + } + + layers := req.GetLayers() + if len(layers) == 0 { + return nil, status.Error(codes.InvalidArgument, "ancestry should have at least one layer") + } + + ancestryFormat := req.GetFormat() + if ancestryFormat == "" { + return nil, status.Error(codes.InvalidArgument, "ancestry format should not be empty") + } + + ancestryLayers := []clair.LayerRequest{} + for _, layer := range layers { + if layer == nil { + err := status.Error(codes.InvalidArgument, "ancestry layer is invalid") + return nil, err + } + + if layer.GetHash() == "" { + return nil, status.Error(codes.InvalidArgument, "ancestry layer hash should not be empty") + } + + if layer.GetPath() == "" { + return nil, status.Error(codes.InvalidArgument, "ancestry layer path should not be empty") + } + + ancestryLayers = append(ancestryLayers, clair.LayerRequest{ + Hash: layer.Hash, + Headers: layer.Headers, + Path: layer.Path, + }) + } + + err := clair.ProcessAncestry(s.Store, ancestryFormat, ancestryName, ancestryLayers) + if err != nil { + return nil, status.Error(codes.Internal, "ancestry is failed to be processed: "+err.Error()) + } + + clairStatus, err := s.getClairStatus() + if err != nil { + return nil, status.Error(codes.Internal, err.Error()) + } + + return &pb.PostAncestryResponse{Status: clairStatus}, nil +} + +func (s *AncestryServer) getClairStatus() (*pb.ClairStatus, error) { + status := &pb.ClairStatus{ + Listers: clair.Processors.Listers, + Detectors: clair.Processors.Detectors, + } + + t, firstUpdate, err := clair.GetLastUpdateTime(s.Store) + if err != nil { + return nil, err + } + if firstUpdate { + return status, nil + } + + status.LastUpdateTime, err = ptypes.TimestampProto(t) + if err != nil { + return nil, err + } + return status, nil +} + +// GetAncestry implements retrieving an ancestry via the Clair gRPC service. +func (s *AncestryServer) GetAncestry(ctx context.Context, req *pb.GetAncestryRequest) (*pb.GetAncestryResponse, error) { + if req.GetAncestryName() == "" { + return nil, status.Errorf(codes.InvalidArgument, "ancestry name should not be empty") + } + + tx, err := s.Store.Begin() + if err != nil { + return nil, status.Error(codes.Internal, err.Error()) + } + defer tx.Rollback() + + ancestry, _, ok, err := tx.FindAncestry(req.GetAncestryName()) + if err != nil { + return nil, status.Error(codes.Internal, err.Error()) + } else if !ok { + return nil, status.Error(codes.NotFound, fmt.Sprintf("requested ancestry '%s' is not found", req.GetAncestryName())) + } + + pbAncestry := pb.AncestryFromDatabaseModel(ancestry) + if req.GetWithFeatures() || req.GetWithVulnerabilities() { + ancestryWFeature, ok, err := tx.FindAncestryFeatures(ancestry.Name) + if err != nil { + return nil, status.Error(codes.Internal, err.Error()) + } + + if !ok { + return nil, status.Error(codes.NotFound, fmt.Sprintf("requested ancestry '%s' is not found", req.GetAncestryName())) + } + pbAncestry.ScannedDetectors = ancestryWFeature.ProcessedBy.Detectors + pbAncestry.ScannedListers = ancestryWFeature.ProcessedBy.Listers + + if req.GetWithVulnerabilities() { + featureVulnerabilities, err := tx.FindAffectedNamespacedFeatures(ancestryWFeature.Features) + if err != nil { + return nil, status.Error(codes.Internal, err.Error()) + } + + for _, fv := range featureVulnerabilities { + // Ensure that every feature can be found. + if !fv.Valid { + return nil, status.Error(codes.Internal, "ancestry feature is not found") + } + + pbFeature := pb.NamespacedFeatureFromDatabaseModel(fv.NamespacedFeature) + for _, v := range fv.AffectedBy { + pbVuln, err := pb.VulnerabilityWithFixedInFromDatabaseModel(v) + if err != nil { + return nil, status.Error(codes.Internal, err.Error()) + } + pbFeature.Vulnerabilities = append(pbFeature.Vulnerabilities, pbVuln) + } + + pbAncestry.Features = append(pbAncestry.Features, pbFeature) + } + } else { + for _, f := range ancestryWFeature.Features { + pbAncestry.Features = append(pbAncestry.Features, pb.NamespacedFeatureFromDatabaseModel(f)) + } + } + } + + clairStatus, err := s.getClairStatus() + if err != nil { + return nil, status.Error(codes.Internal, err.Error()) + } + + return &pb.GetAncestryResponse{ + Status: clairStatus, + Ancestry: pbAncestry, + }, nil +} + +// GetNotification implements retrieving a notification via the Clair gRPC +// service. +func (s *NotificationServer) GetNotification(ctx context.Context, req *pb.GetNotificationRequest) (*pb.GetNotificationResponse, error) { + if req.GetName() == "" { + return nil, status.Error(codes.InvalidArgument, "notification name should not be empty") + } + + if req.GetLimit() <= 0 { + return nil, status.Error(codes.InvalidArgument, "notification page limit should not be empty or less than 1") + } + + tx, err := s.Store.Begin() + if err != nil { + return nil, status.Error(codes.Internal, err.Error()) + } + defer tx.Rollback() + + dbNotification, ok, err := tx.FindVulnerabilityNotification( + req.GetName(), + int(req.GetLimit()), + database.PageNumber(req.GetOldVulnerabilityPage()), + database.PageNumber(req.GetNewVulnerabilityPage()), + ) + + if err != nil { + return nil, status.Error(codes.Internal, err.Error()) + } + + if !ok { + return nil, status.Error(codes.NotFound, fmt.Sprintf("requested notification '%s' is not found", req.GetName())) + } + + notification, err := pb.NotificationFromDatabaseModel(dbNotification) + if err != nil { + return nil, status.Error(codes.Internal, err.Error()) + } + + return &pb.GetNotificationResponse{Notification: notification}, nil +} + +// MarkNotificationAsRead implements deleting a notification via the Clair gRPC +// service. +func (s *NotificationServer) MarkNotificationAsRead(ctx context.Context, req *pb.MarkNotificationAsReadRequest) (*google_protobuf1.Empty, error) { + if req.GetName() == "" { + return nil, status.Error(codes.InvalidArgument, "notification name should not be empty") + } + + tx, err := s.Store.Begin() + if err != nil { + return nil, status.Error(codes.Internal, err.Error()) + } + + defer tx.Rollback() + err = tx.DeleteNotification(req.GetName()) + if err == commonerr.ErrNotFound { + return nil, status.Error(codes.NotFound, "requested notification \""+req.GetName()+"\" is not found") + } else if err != nil { + return nil, status.Error(codes.Internal, err.Error()) + } + + if err := tx.Commit(); err != nil { + return nil, status.Error(codes.Internal, err.Error()) + } + + return &google_protobuf1.Empty{}, nil +} diff --git a/api/v2/server.go b/api/v3/server.go similarity index 88% rename from api/v2/server.go rename to api/v3/server.go index 8b153680..e9267eb0 100644 --- a/api/v2/server.go +++ b/api/v3/server.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package v2 +package v3 import ( "context" @@ -32,7 +32,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/credentials" - pb "github.com/coreos/clair/api/v2/clairpb" + pb "github.com/coreos/clair/api/v3/clairpb" "github.com/coreos/clair/database" ) @@ -47,8 +47,8 @@ func handleShutdown(err error) { var ( promResponseDurationMilliseconds = prometheus.NewHistogramVec(prometheus.HistogramOpts{ - Name: "clair_v2_api_response_duration_milliseconds", - Help: "The duration of time it takes to receieve and write a response to an V2 API request", + Name: "clair_v3_api_response_duration_milliseconds", + Help: "The duration of time it takes to receive and write a response to an V2 API request", Buckets: prometheus.ExponentialBuckets(9.375, 2, 10), }, []string{"route", "code"}) ) @@ -57,7 +57,7 @@ func init() { prometheus.MustRegister(promResponseDurationMilliseconds) } -func newGrpcServer(paginationKey string, store database.Datastore, tlsConfig *tls.Config) *grpc.Server { +func newGrpcServer(store database.Datastore, tlsConfig *tls.Config) *grpc.Server { grpcOpts := []grpc.ServerOption{ grpc.UnaryInterceptor(grpc_prometheus.UnaryServerInterceptor), grpc.StreamInterceptor(grpc_prometheus.StreamServerInterceptor), @@ -69,7 +69,7 @@ func newGrpcServer(paginationKey string, store database.Datastore, tlsConfig *tl grpcServer := grpc.NewServer(grpcOpts...) pb.RegisterAncestryServiceServer(grpcServer, &AncestryServer{Store: store}) - pb.RegisterNotificationServiceServer(grpcServer, &NotificationServer{PaginationKey: paginationKey, Store: store}) + pb.RegisterNotificationServiceServer(grpcServer, &NotificationServer{Store: store}) return grpcServer } @@ -98,11 +98,11 @@ func logHandler(handler http.Handler) http.Handler { } log.WithFields(log.Fields{ - "remote addr": r.RemoteAddr, - "method": r.Method, - "request uri": r.RequestURI, - "status": statusStr, - "elapsed time": time.Since(start), + "remote addr": r.RemoteAddr, + "method": r.Method, + "request uri": r.RequestURI, + "status": statusStr, + "elapsed time (ms)": float64(time.Since(start).Nanoseconds()) * 1e-6, }).Info("Handled HTTP request") }) } @@ -148,7 +148,7 @@ func servePrometheus(mux *http.ServeMux) { } // Run initializes grpc and grpc gateway api services on the same port -func Run(GrpcPort int, tlsConfig *tls.Config, PaginationKey, CertFile, KeyFile string, store database.Datastore) { +func Run(GrpcPort int, tlsConfig *tls.Config, CertFile, KeyFile string, store database.Datastore) { l, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", GrpcPort)) if err != nil { log.WithError(err).Fatalf("could not bind to port %d", GrpcPort) @@ -175,7 +175,7 @@ func Run(GrpcPort int, tlsConfig *tls.Config, PaginationKey, CertFile, KeyFile s apiListener = tls.NewListener(tcpMux.Match(cmux.Any()), tlsConfig) go func() { handleShutdown(tcpMux.Serve()) }() - grpcServer := newGrpcServer(PaginationKey, store, tlsConfig) + grpcServer := newGrpcServer(store, tlsConfig) gwmux := newGrpcGatewayServer(ctx, apiListener.Addr().String(), tlsConfig) httpMux.Handle("/", gwmux) @@ -188,7 +188,7 @@ func Run(GrpcPort int, tlsConfig *tls.Config, PaginationKey, CertFile, KeyFile s apiListener = tcpMux.Match(cmux.Any()) go func() { handleShutdown(tcpMux.Serve()) }() - grpcServer := newGrpcServer(PaginationKey, store, nil) + grpcServer := newGrpcServer(store, nil) go func() { handleShutdown(grpcServer.Serve(grpcL)) }() gwmux := newGrpcGatewayServer(ctx, apiListener.Addr().String(), nil) diff --git a/cmd/clair/config.go b/cmd/clair/config.go index cccbebfb..e1628f05 100644 --- a/cmd/clair/config.go +++ b/cmd/clair/config.go @@ -20,13 +20,17 @@ import ( "os" "time" + "github.com/fernet/fernet-go" + log "github.com/sirupsen/logrus" "gopkg.in/yaml.v2" "github.com/coreos/clair" "github.com/coreos/clair/api" "github.com/coreos/clair/database" + "github.com/coreos/clair/ext/featurefmt" + "github.com/coreos/clair/ext/featurens" "github.com/coreos/clair/ext/notification" - "github.com/fernet/fernet-go" + "github.com/coreos/clair/ext/vulnsrc" ) // ErrDatasourceNotLoaded is returned when the datasource variable in the @@ -43,6 +47,7 @@ type File struct { type Config struct { Database database.RegistrableComponentConfig Updater *clair.UpdaterConfig + Worker *clair.WorkerConfig Notifier *notification.Config API *api.Config } @@ -54,12 +59,16 @@ func DefaultConfig() Config { Type: "pgsql", }, Updater: &clair.UpdaterConfig{ - Interval: 1 * time.Hour, + EnabledUpdaters: vulnsrc.ListUpdaters(), + Interval: 1 * time.Hour, + }, + Worker: &clair.WorkerConfig{ + EnabledDetectors: featurens.ListDetectors(), + EnabledListers: featurefmt.ListListers(), }, API: &api.Config{ - Port: 6060, HealthPort: 6061, - GrpcPort: 6070, + GrpcPort: 6060, Timeout: 900 * time.Second, }, Notifier: ¬ification.Config{ @@ -97,14 +106,15 @@ func LoadConfig(path string) (config *Config, err error) { config = &cfgFile.Clair // Generate a pagination key if none is provided. - if config.API.PaginationKey == "" { + if v, ok := config.Database.Options["paginationkey"]; !ok || v == nil || v.(string) == "" { + log.Warn("pagination key is empty, generating...") var key fernet.Key if err = key.Generate(); err != nil { return } - config.API.PaginationKey = key.Encode() + config.Database.Options["paginationkey"] = key.Encode() } else { - _, err = fernet.DecodeKey(config.API.PaginationKey) + _, err = fernet.DecodeKey(config.Database.Options["paginationkey"].(string)) if err != nil { err = errors.New("Invalid Pagination key; must be 32-bit URL-safe base64") return diff --git a/cmd/clair/main.go b/cmd/clair/main.go index 0408a732..fbf5d256 100644 --- a/cmd/clair/main.go +++ b/cmd/clair/main.go @@ -30,9 +30,13 @@ import ( "github.com/coreos/clair" "github.com/coreos/clair/api" "github.com/coreos/clair/database" + "github.com/coreos/clair/ext/featurefmt" + "github.com/coreos/clair/ext/featurens" "github.com/coreos/clair/ext/imagefmt" + "github.com/coreos/clair/ext/vulnsrc" "github.com/coreos/clair/pkg/formatter" "github.com/coreos/clair/pkg/stopper" + "github.com/coreos/clair/pkg/strutil" // Register database driver. _ "github.com/coreos/clair/database/pgsql" @@ -85,6 +89,43 @@ func stopCPUProfiling(f *os.File) { log.Info("stopped CPU profiling") } +func configClairVersion(config *Config) { + listers := featurefmt.ListListers() + detectors := featurens.ListDetectors() + updaters := vulnsrc.ListUpdaters() + + log.WithFields(log.Fields{ + "Listers": strings.Join(listers, ","), + "Detectors": strings.Join(detectors, ","), + "Updaters": strings.Join(updaters, ","), + }).Info("Clair registered components") + + unregDetectors := strutil.CompareStringLists(config.Worker.EnabledDetectors, detectors) + unregListers := strutil.CompareStringLists(config.Worker.EnabledListers, listers) + unregUpdaters := strutil.CompareStringLists(config.Updater.EnabledUpdaters, updaters) + if len(unregDetectors) != 0 || len(unregListers) != 0 || len(unregUpdaters) != 0 { + log.WithFields(log.Fields{ + "Unknown Detectors": strings.Join(unregDetectors, ","), + "Unknown Listers": strings.Join(unregListers, ","), + "Unknown Updaters": strings.Join(unregUpdaters, ","), + "Available Listers": strings.Join(featurefmt.ListListers(), ","), + "Available Detectors": strings.Join(featurens.ListDetectors(), ","), + "Available Updaters": strings.Join(vulnsrc.ListUpdaters(), ","), + }).Fatal("Unknown or unregistered components are configured") + } + + // verify the user specified detectors/listers/updaters are implemented. If + // some are not registered, it logs warning and won't use the unregistered + // extensions. + + clair.Processors = database.Processors{ + Detectors: strutil.CompareStringListsInBoth(config.Worker.EnabledDetectors, detectors), + Listers: strutil.CompareStringListsInBoth(config.Worker.EnabledListers, listers), + } + + clair.EnabledUpdaters = strutil.CompareStringListsInBoth(config.Updater.EnabledUpdaters, updaters) +} + // Boot starts Clair instance with the provided config. func Boot(config *Config) { rand.Seed(time.Now().UnixNano()) @@ -102,9 +143,8 @@ func Boot(config *Config) { go clair.RunNotifier(config.Notifier, db, st) // Start API - st.Begin() - go api.Run(config.API, db, st) go api.RunV2(config.API, db) + st.Begin() go api.RunHealth(config.API, db, st) @@ -135,19 +175,17 @@ func main() { } } - // Load configuration - config, err := LoadConfig(*flagConfigPath) - if err != nil { - log.WithError(err).Fatal("failed to load configuration") - } - // Initialize logging system - logLevel, err := log.ParseLevel(strings.ToUpper(*flagLogLevel)) log.SetLevel(logLevel) log.SetOutput(os.Stdout) log.SetFormatter(&formatter.JSONExtendedFormatter{ShowLn: true}) + config, err := LoadConfig(*flagConfigPath) + if err != nil { + log.WithError(err).Fatal("failed to load configuration") + } + // Enable CPU Profiling if specified if *flagCPUProfilePath != "" { defer stopCPUProfiling(startCPUProfiling(*flagCPUProfilePath)) @@ -159,5 +197,8 @@ func main() { imagefmt.SetInsecureTLS(*flagInsecureTLS) } + // configure updater and worker + configClairVersion(config) + Boot(config) } diff --git a/config.example.yaml b/config.example.yaml index ab47886c..63714bab 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -25,11 +25,15 @@ clair: # Number of elements kept in the cache # Values unlikely to change (e.g. namespaces) are cached in order to save prevent needless roundtrips to the database. cachesize: 16384 + # 32-bit URL-safe base64 key used to encrypt pagination tokens + # If one is not provided, it will be generated. + # Multiple clair instances in the same cluster need the same value. + paginationkey: api: - # API server port - port: 6060 - grpcPort: 6070 + # v3 grpc/RESTful API server port + grpcport : 6060 + # Health server port # This is an unencrypted endpoint useful for load balancers to check to healthiness of the clair server. healthport: 6061 @@ -37,11 +41,6 @@ clair: # Deadline before an API request will respond with a 503 timeout: 900s - # 32-bit URL-safe base64 key used to encrypt pagination tokens - # If one is not provided, it will be generated. - # Multiple clair instances in the same cluster need the same value. - paginationkey: - # Optional PKI configuration # If you want to easily generate client certificates and CAs, try the following projects: # https://github.com/coreos/etcd-ca @@ -51,10 +50,29 @@ clair: keyfile: certfile: + worker: + namespace_detectors: + - os-release + - lsb-release + - apt-sources + - alpine-release + - redhat-release + + feature_listers: + - apk + - dpkg + - rpm + updater: # Frequency the database will be updated with vulnerabilities from the default data sources # The value 0 disables the updater entirely. interval: 2h + enabledupdaters: + - debian + - ubuntu + - rhel + - oracle + - alpine notifier: # Number of attempts before the notification is marked as failed to be sent @@ -72,9 +90,9 @@ clair: # https://github.com/cloudflare/cfssl # https://github.com/coreos/etcd-ca servername: - cafile: - keyfile: - certfile: + cafile: + keyfile: + certfile: # Optional HTTP Proxy: must be a valid URL (including the scheme). proxy: diff --git a/database/database.go b/database/database.go index d3f7fc0f..16925bb1 100644 --- a/database/database.go +++ b/database/database.go @@ -23,9 +23,9 @@ import ( ) var ( - // ErrBackendException is an error that occurs when the database backend does - // not work properly (ie. unreachable). - ErrBackendException = errors.New("database: an error occured when querying the backend") + // ErrBackendException is an error that occurs when the database backend + // does not work properly (ie. unreachable). + ErrBackendException = errors.New("database: an error occurred when querying the backend") // ErrInconsistent is an error that occurs when a database consistency check // fails (i.e. when an entity which is supposed to be unique is detected @@ -43,8 +43,8 @@ type RegistrableComponentConfig struct { var drivers = make(map[string]Driver) -// Driver is a function that opens a Datastore specified by its database driver type and specific -// configuration. +// Driver is a function that opens a Datastore specified by its database driver +// type and specific configuration. type Driver func(RegistrableComponentConfig) (Datastore, error) // Register makes a Constructor available by the provided name. @@ -70,130 +70,127 @@ func Open(cfg RegistrableComponentConfig) (Datastore, error) { return driver(cfg) } -// Datastore represents the required operations on a persistent data store for -// a Clair deployment. -type Datastore interface { - // ListNamespaces returns the entire list of known Namespaces. - ListNamespaces() ([]Namespace, error) - - // InsertLayer stores a Layer in the database. +// Session contains the required operations on a persistent data store for a +// Clair deployment. +// +// Session is started by Datastore.Begin and terminated with Commit or Rollback. +// Besides Commit and Rollback, other functions cannot be called after the +// session is terminated. +// Any function is not guaranteed to be called successfully if there's a session +// failure. +type Session interface { + // Commit commits changes to datastore. // - // A Layer is uniquely identified by its Name. - // The Name and EngineVersion fields are mandatory. - // If a Parent is specified, it is expected that it has been retrieved using - // FindLayer. - // If a Layer that already exists is inserted and the EngineVersion of the - // given Layer is higher than the stored one, the stored Layer should be - // updated. - // The function has to be idempotent, inserting a layer that already exists - // shouldn't return an error. - InsertLayer(Layer) error + // Commit call after Rollback does no-op. + Commit() error - // FindLayer retrieves a Layer from the database. + // Rollback drops changes to datastore. // - // When `withFeatures` is true, the Features field should be filled. - // When `withVulnerabilities` is true, the Features field should be filled - // and their AffectedBy fields should contain every vulnerabilities that - // affect them. - FindLayer(name string, withFeatures, withVulnerabilities bool) (Layer, error) + // Rollback call after Commit does no-op. + Rollback() error - // DeleteLayer deletes a Layer from the database and every layers that are - // based on it, recursively. - DeleteLayer(name string) error + // UpsertAncestry inserts or replaces an ancestry and its namespaced + // features and processors used to scan the ancestry. + UpsertAncestry(ancestry Ancestry, features []NamespacedFeature, processedBy Processors) error - // ListVulnerabilities returns the list of vulnerabilities of a particular - // Namespace. + // FindAncestry retrieves an ancestry with processors used to scan the + // ancestry. If the ancestry is not found, return false. // - // The Limit and page parameters are used to paginate the return list. - // The first given page should be 0. - // The function should return the next available page. If there are no more - // pages, -1 has to be returned. - ListVulnerabilities(namespaceName string, limit int, page int) ([]Vulnerability, int, error) + // The ancestry's processors are returned to short cut processing ancestry + // if it has been processed by all processors in the current Clair instance. + FindAncestry(name string) (ancestry Ancestry, processedBy Processors, found bool, err error) - // InsertVulnerabilities stores the given Vulnerabilities in the database, - // updating them if necessary. + // FindAncestryFeatures retrieves an ancestry with all detected namespaced + // features. If the ancestry is not found, return false. + FindAncestryFeatures(name string) (ancestry AncestryWithFeatures, found bool, err error) + + // PersistFeatures inserts a set of features if not in the database. + PersistFeatures(features []Feature) error + + // PersistNamespacedFeatures inserts a set of namespaced features if not in + // the database. + PersistNamespacedFeatures([]NamespacedFeature) error + + // CacheAffectedNamespacedFeatures relates the namespaced features with the + // vulnerabilities affecting these features. // - // A vulnerability is uniquely identified by its Namespace and its Name. - // The FixedIn field may only contain a partial list of Features that are - // affected by the Vulnerability, along with the version in which the - // vulnerability is fixed. It is the responsibility of the implementation to - // update the list properly. - // A version equals to versionfmt.MinVersion means that the given Feature is - // not being affected by the Vulnerability at all and thus, should be removed - // from the list. - // It is important that Features should be unique in the FixedIn list. For - // example, it doesn't make sense to have two `openssl` Feature listed as a - // Vulnerability can only be fixed in one Version. This is true because - // Vulnerabilities and Features are namespaced (i.e. specific to one - // operating system). - // Each vulnerability insertion or update has to create a Notification that - // will contain the old and the updated Vulnerability, unless - // createNotification equals to true. - InsertVulnerabilities(vulnerabilities []Vulnerability, createNotification bool) error + // NOTE(Sida): it's not necessary for every database implementation and so + // this function may have a better home. + CacheAffectedNamespacedFeatures([]NamespacedFeature) error - // FindVulnerability retrieves a Vulnerability from the database, including - // the FixedIn list. - FindVulnerability(namespaceName, name string) (Vulnerability, error) + // FindAffectedNamespacedFeatures retrieves a set of namespaced features + // with affecting vulnerabilities. + FindAffectedNamespacedFeatures(features []NamespacedFeature) ([]NullableAffectedNamespacedFeature, error) - // DeleteVulnerability removes a Vulnerability from the database. + // PersistNamespaces inserts a set of namespaces if not in the database. + PersistNamespaces([]Namespace) error + + // PersistLayer inserts a layer if not in the datastore. + PersistLayer(Layer) error + + // PersistLayerContent persists a layer's content in the database. The given + // namespaces and features can be partial content of this layer. // - // It has to create a Notification that will contain the old Vulnerability. - DeleteVulnerability(namespaceName, name string) error + // The layer, namespaces and features are expected to be already existing + // in the database. + PersistLayerContent(hash string, namespaces []Namespace, features []Feature, processedBy Processors) error - // InsertVulnerabilityFixes adds new FixedIn Feature or update the Versions - // of existing ones to the specified Vulnerability in the database. + // FindLayer retrieves a layer and the processors scanned the layer. + FindLayer(hash string) (layer Layer, processedBy Processors, found bool, err error) + + // FindLayerWithContent returns a layer with all detected features and + // namespaces. + FindLayerWithContent(hash string) (layer LayerWithContent, found bool, err error) + + // InsertVulnerabilities inserts a set of UNIQUE vulnerabilities with + // affected features into database, assuming that all vulnerabilities + // provided are NOT in database and all vulnerabilities' namespaces are + // already in the database. + InsertVulnerabilities([]VulnerabilityWithAffected) error + + // FindVulnerability retrieves a set of Vulnerabilities with affected + // features. + FindVulnerabilities([]VulnerabilityID) ([]NullableVulnerability, error) + + // DeleteVulnerability removes a set of Vulnerabilities assuming that the + // requested vulnerabilities are in the database. + DeleteVulnerabilities([]VulnerabilityID) error + + // InsertVulnerabilityNotifications inserts a set of unique vulnerability + // notifications into datastore, assuming that they are not in the database. + InsertVulnerabilityNotifications([]VulnerabilityNotification) error + + // FindNewNotification retrieves a notification, which has never been + // notified or notified before a certain time. + FindNewNotification(notifiedBefore time.Time) (hook NotificationHook, found bool, err error) + + // FindVulnerabilityNotification retrieves a vulnerability notification with + // affected ancestries affected by old or new vulnerability. // - // It has has to create a Notification that will contain the old and the - // updated Vulnerability. - InsertVulnerabilityFixes(vulnerabilityNamespace, vulnerabilityName string, fixes []FeatureVersion) error - - // DeleteVulnerabilityFix removes a FixedIn Feature from the specified - // Vulnerability in the database. It can be used to store the fact that a - // Vulnerability no longer affects the given Feature in any Version. + // Because the number of affected ancestries maybe large, they are paginated + // and their pages are specified by the given encrypted PageNumbers, which, + // if empty, are always considered first page. // - // It has has to create a Notification that will contain the old and the - // updated Vulnerability. - DeleteVulnerabilityFix(vulnerabilityNamespace, vulnerabilityName, featureName string) error + // Session interface implementation should have encrypt and decrypt + // functions for PageNumber. + FindVulnerabilityNotification(name string, limit int, + oldVulnerabilityPage PageNumber, + newVulnerabilityPage PageNumber) ( + noti VulnerabilityNotificationWithVulnerable, + found bool, err error) - // GetAvailableNotification returns the Name, Created, Notified and Deleted - // fields of a Notification that should be handled. - // - // The renotify interval defines how much time after being marked as Notified - // by SetNotificationNotified, a Notification that hasn't been deleted should - // be returned again by this function. - // A Notification for which there is a valid Lock with the same Name should - // not be returned. - GetAvailableNotification(renotifyInterval time.Duration) (VulnerabilityNotification, error) + // MarkNotificationNotified marks a Notification as notified now, assuming + // the requested notification is in the database. + MarkNotificationNotified(name string) error - // GetNotification returns a Notification, including its OldVulnerability and - // NewVulnerability fields. - // - // On these Vulnerabilities, LayersIntroducingVulnerability should be filled - // with every Layer that introduces the Vulnerability (i.e. adds at least one - // affected FeatureVersion). - // The Limit and page parameters are used to paginate - // LayersIntroducingVulnerability. The first given page should be - // VulnerabilityNotificationFirstPage. The function will then return the next - // available page. If there is no more page, NoVulnerabilityNotificationPage - // has to be returned. - GetNotification(name string, limit int, page VulnerabilityNotificationPageNumber) (VulnerabilityNotification, VulnerabilityNotificationPageNumber, error) - - // SetNotificationNotified marks a Notification as notified and thus, makes - // it unavailable for GetAvailableNotification, until the renotify duration - // is elapsed. - SetNotificationNotified(name string) error - - // DeleteNotification marks a Notification as deleted, and thus, makes it - // unavailable for GetAvailableNotification. + // DeleteNotification removes a Notification in the database. DeleteNotification(name string) error - // InsertKeyValue stores or updates a simple key/value pair in the database. - InsertKeyValue(key, value string) error + // UpdateKeyValue stores or updates a simple key/value pair. + UpdateKeyValue(key, value string) error - // GetKeyValue retrieves a value from the database from the given key. - // - // It returns an empty string if there is no such key. - GetKeyValue(key string) (string, error) + // FindKeyValue retrieves a value from the given key. + FindKeyValue(key string) (value string, found bool, err error) // Lock creates or renew a Lock in the database with the given name, owner // and duration. @@ -204,14 +201,20 @@ type Datastore interface { // Lock should not block, it should instead returns whether the Lock has been // successfully acquired/renewed. If it's the case, the expiration time of // that Lock is returned as well. - Lock(name string, owner string, duration time.Duration, renew bool) (bool, time.Time) + Lock(name string, owner string, duration time.Duration, renew bool) (success bool, expiration time.Time, err error) // Unlock releases an existing Lock. - Unlock(name, owner string) + Unlock(name, owner string) error // FindLock returns the owner of a Lock specified by the name, and its // expiration time if it exists. - FindLock(name string) (string, time.Time, error) + FindLock(name string) (owner string, expiration time.Time, found bool, err error) +} + +// Datastore represents a persistent data store +type Datastore interface { + // Begin starts a session to change. + Begin() (Session, error) // Ping returns the health status of the database. Ping() bool diff --git a/database/mock.go b/database/mock.go index 9a0963c8..966e9c88 100644 --- a/database/mock.go +++ b/database/mock.go @@ -16,161 +16,240 @@ package database import "time" +// MockSession implements Session and enables overriding each available method. +// The default behavior of each method is to simply panic. +type MockSession struct { + FctCommit func() error + FctRollback func() error + FctUpsertAncestry func(Ancestry, []NamespacedFeature, Processors) error + FctFindAncestry func(name string) (Ancestry, Processors, bool, error) + FctFindAncestryFeatures func(name string) (AncestryWithFeatures, bool, error) + FctFindAffectedNamespacedFeatures func(features []NamespacedFeature) ([]NullableAffectedNamespacedFeature, error) + FctPersistNamespaces func([]Namespace) error + FctPersistFeatures func([]Feature) error + FctPersistNamespacedFeatures func([]NamespacedFeature) error + FctCacheAffectedNamespacedFeatures func([]NamespacedFeature) error + FctPersistLayer func(Layer) error + FctPersistLayerContent func(hash string, namespaces []Namespace, features []Feature, processedBy Processors) error + FctFindLayer func(name string) (Layer, Processors, bool, error) + FctFindLayerWithContent func(name string) (LayerWithContent, bool, error) + FctInsertVulnerabilities func([]VulnerabilityWithAffected) error + FctFindVulnerabilities func([]VulnerabilityID) ([]NullableVulnerability, error) + FctDeleteVulnerabilities func([]VulnerabilityID) error + FctInsertVulnerabilityNotifications func([]VulnerabilityNotification) error + FctFindNewNotification func(lastNotified time.Time) (NotificationHook, bool, error) + FctFindVulnerabilityNotification func(name string, limit int, oldPage PageNumber, newPage PageNumber) ( + vuln VulnerabilityNotificationWithVulnerable, ok bool, err error) + FctMarkNotificationNotified func(name string) error + FctDeleteNotification func(name string) error + FctUpdateKeyValue func(key, value string) error + FctFindKeyValue func(key string) (string, bool, error) + FctLock func(name string, owner string, duration time.Duration, renew bool) (bool, time.Time, error) + FctUnlock func(name, owner string) error + FctFindLock func(name string) (string, time.Time, bool, error) +} + +func (ms *MockSession) Commit() error { + if ms.FctCommit != nil { + return ms.FctCommit() + } + panic("required mock function not implemented") +} + +func (ms *MockSession) Rollback() error { + if ms.FctRollback != nil { + return ms.FctRollback() + } + panic("required mock function not implemented") +} + +func (ms *MockSession) UpsertAncestry(ancestry Ancestry, features []NamespacedFeature, processedBy Processors) error { + if ms.FctUpsertAncestry != nil { + return ms.FctUpsertAncestry(ancestry, features, processedBy) + } + panic("required mock function not implemented") +} + +func (ms *MockSession) FindAncestry(name string) (Ancestry, Processors, bool, error) { + if ms.FctFindAncestry != nil { + return ms.FctFindAncestry(name) + } + panic("required mock function not implemented") +} + +func (ms *MockSession) FindAncestryFeatures(name string) (AncestryWithFeatures, bool, error) { + if ms.FctFindAncestryFeatures != nil { + return ms.FctFindAncestryFeatures(name) + } + panic("required mock function not implemented") +} + +func (ms *MockSession) FindAffectedNamespacedFeatures(features []NamespacedFeature) ([]NullableAffectedNamespacedFeature, error) { + if ms.FctFindAffectedNamespacedFeatures != nil { + return ms.FctFindAffectedNamespacedFeatures(features) + } + panic("required mock function not implemented") +} + +func (ms *MockSession) PersistNamespaces(namespaces []Namespace) error { + if ms.FctPersistNamespaces != nil { + return ms.FctPersistNamespaces(namespaces) + } + panic("required mock function not implemented") +} + +func (ms *MockSession) PersistFeatures(features []Feature) error { + if ms.FctPersistFeatures != nil { + return ms.FctPersistFeatures(features) + } + panic("required mock function not implemented") +} + +func (ms *MockSession) PersistNamespacedFeatures(namespacedFeatures []NamespacedFeature) error { + if ms.FctPersistNamespacedFeatures != nil { + return ms.FctPersistNamespacedFeatures(namespacedFeatures) + } + panic("required mock function not implemented") +} + +func (ms *MockSession) CacheAffectedNamespacedFeatures(namespacedFeatures []NamespacedFeature) error { + if ms.FctCacheAffectedNamespacedFeatures != nil { + return ms.FctCacheAffectedNamespacedFeatures(namespacedFeatures) + } + panic("required mock function not implemented") +} + +func (ms *MockSession) PersistLayer(layer Layer) error { + if ms.FctPersistLayer != nil { + return ms.FctPersistLayer(layer) + } + panic("required mock function not implemented") +} + +func (ms *MockSession) PersistLayerContent(hash string, namespaces []Namespace, features []Feature, processedBy Processors) error { + if ms.FctPersistLayerContent != nil { + return ms.FctPersistLayerContent(hash, namespaces, features, processedBy) + } + panic("required mock function not implemented") +} + +func (ms *MockSession) FindLayer(name string) (Layer, Processors, bool, error) { + if ms.FctFindLayer != nil { + return ms.FctFindLayer(name) + } + panic("required mock function not implemented") +} + +func (ms *MockSession) FindLayerWithContent(name string) (LayerWithContent, bool, error) { + if ms.FctFindLayerWithContent != nil { + return ms.FctFindLayerWithContent(name) + } + panic("required mock function not implemented") +} + +func (ms *MockSession) InsertVulnerabilities(vulnerabilities []VulnerabilityWithAffected) error { + if ms.FctInsertVulnerabilities != nil { + return ms.FctInsertVulnerabilities(vulnerabilities) + } + panic("required mock function not implemented") +} + +func (ms *MockSession) FindVulnerabilities(vulnerabilityIDs []VulnerabilityID) ([]NullableVulnerability, error) { + if ms.FctFindVulnerabilities != nil { + return ms.FctFindVulnerabilities(vulnerabilityIDs) + } + panic("required mock function not implemented") +} + +func (ms *MockSession) DeleteVulnerabilities(VulnerabilityIDs []VulnerabilityID) error { + if ms.FctDeleteVulnerabilities != nil { + return ms.FctDeleteVulnerabilities(VulnerabilityIDs) + } + panic("required mock function not implemented") +} + +func (ms *MockSession) InsertVulnerabilityNotifications(vulnerabilityNotifications []VulnerabilityNotification) error { + if ms.FctInsertVulnerabilityNotifications != nil { + return ms.FctInsertVulnerabilityNotifications(vulnerabilityNotifications) + } + panic("required mock function not implemented") +} + +func (ms *MockSession) FindNewNotification(lastNotified time.Time) (NotificationHook, bool, error) { + if ms.FctFindNewNotification != nil { + return ms.FctFindNewNotification(lastNotified) + } + panic("required mock function not implemented") +} + +func (ms *MockSession) FindVulnerabilityNotification(name string, limit int, oldPage PageNumber, newPage PageNumber) ( + VulnerabilityNotificationWithVulnerable, bool, error) { + if ms.FctFindVulnerabilityNotification != nil { + return ms.FctFindVulnerabilityNotification(name, limit, oldPage, newPage) + } + panic("required mock function not implemented") +} + +func (ms *MockSession) MarkNotificationNotified(name string) error { + if ms.FctMarkNotificationNotified != nil { + return ms.FctMarkNotificationNotified(name) + } + panic("required mock function not implemented") +} + +func (ms *MockSession) DeleteNotification(name string) error { + if ms.FctDeleteNotification != nil { + return ms.FctDeleteNotification(name) + } + panic("required mock function not implemented") +} + +func (ms *MockSession) UpdateKeyValue(key, value string) error { + if ms.FctUpdateKeyValue != nil { + return ms.FctUpdateKeyValue(key, value) + } + panic("required mock function not implemented") +} + +func (ms *MockSession) FindKeyValue(key string) (string, bool, error) { + if ms.FctFindKeyValue != nil { + return ms.FctFindKeyValue(key) + } + panic("required mock function not implemented") +} + +func (ms *MockSession) Lock(name string, owner string, duration time.Duration, renew bool) (bool, time.Time, error) { + if ms.FctLock != nil { + return ms.FctLock(name, owner, duration, renew) + } + panic("required mock function not implemented") +} + +func (ms *MockSession) Unlock(name, owner string) error { + if ms.FctUnlock != nil { + return ms.FctUnlock(name, owner) + } + panic("required mock function not implemented") +} + +func (ms *MockSession) FindLock(name string) (string, time.Time, bool, error) { + if ms.FctFindLock != nil { + return ms.FctFindLock(name) + } + panic("required mock function not implemented") +} + // MockDatastore implements Datastore and enables overriding each available method. // The default behavior of each method is to simply panic. type MockDatastore struct { - FctListNamespaces func() ([]Namespace, error) - FctInsertLayer func(Layer) error - FctFindLayer func(name string, withFeatures, withVulnerabilities bool) (Layer, error) - FctDeleteLayer func(name string) error - FctListVulnerabilities func(namespaceName string, limit int, page int) ([]Vulnerability, int, error) - FctInsertVulnerabilities func(vulnerabilities []Vulnerability, createNotification bool) error - FctFindVulnerability func(namespaceName, name string) (Vulnerability, error) - FctDeleteVulnerability func(namespaceName, name string) error - FctInsertVulnerabilityFixes func(vulnerabilityNamespace, vulnerabilityName string, fixes []FeatureVersion) error - FctDeleteVulnerabilityFix func(vulnerabilityNamespace, vulnerabilityName, featureName string) error - FctGetAvailableNotification func(renotifyInterval time.Duration) (VulnerabilityNotification, error) - FctGetNotification func(name string, limit int, page VulnerabilityNotificationPageNumber) (VulnerabilityNotification, VulnerabilityNotificationPageNumber, error) - FctSetNotificationNotified func(name string) error - FctDeleteNotification func(name string) error - FctInsertKeyValue func(key, value string) error - FctGetKeyValue func(key string) (string, error) - FctLock func(name string, owner string, duration time.Duration, renew bool) (bool, time.Time) - FctUnlock func(name, owner string) - FctFindLock func(name string) (string, time.Time, error) - FctPing func() bool - FctClose func() + FctBegin func() (Session, error) + FctPing func() bool + FctClose func() } -func (mds *MockDatastore) ListNamespaces() ([]Namespace, error) { - if mds.FctListNamespaces != nil { - return mds.FctListNamespaces() - } - panic("required mock function not implemented") -} - -func (mds *MockDatastore) InsertLayer(layer Layer) error { - if mds.FctInsertLayer != nil { - return mds.FctInsertLayer(layer) - } - panic("required mock function not implemented") -} - -func (mds *MockDatastore) FindLayer(name string, withFeatures, withVulnerabilities bool) (Layer, error) { - if mds.FctFindLayer != nil { - return mds.FctFindLayer(name, withFeatures, withVulnerabilities) - } - panic("required mock function not implemented") -} - -func (mds *MockDatastore) DeleteLayer(name string) error { - if mds.FctDeleteLayer != nil { - return mds.FctDeleteLayer(name) - } - panic("required mock function not implemented") -} - -func (mds *MockDatastore) ListVulnerabilities(namespaceName string, limit int, page int) ([]Vulnerability, int, error) { - if mds.FctListVulnerabilities != nil { - return mds.FctListVulnerabilities(namespaceName, limit, page) - } - panic("required mock function not implemented") -} - -func (mds *MockDatastore) InsertVulnerabilities(vulnerabilities []Vulnerability, createNotification bool) error { - if mds.FctInsertVulnerabilities != nil { - return mds.FctInsertVulnerabilities(vulnerabilities, createNotification) - } - panic("required mock function not implemented") -} - -func (mds *MockDatastore) FindVulnerability(namespaceName, name string) (Vulnerability, error) { - if mds.FctFindVulnerability != nil { - return mds.FctFindVulnerability(namespaceName, name) - } - panic("required mock function not implemented") -} - -func (mds *MockDatastore) DeleteVulnerability(namespaceName, name string) error { - if mds.FctDeleteVulnerability != nil { - return mds.FctDeleteVulnerability(namespaceName, name) - } - panic("required mock function not implemented") -} - -func (mds *MockDatastore) InsertVulnerabilityFixes(vulnerabilityNamespace, vulnerabilityName string, fixes []FeatureVersion) error { - if mds.FctInsertVulnerabilityFixes != nil { - return mds.FctInsertVulnerabilityFixes(vulnerabilityNamespace, vulnerabilityName, fixes) - } - panic("required mock function not implemented") -} - -func (mds *MockDatastore) DeleteVulnerabilityFix(vulnerabilityNamespace, vulnerabilityName, featureName string) error { - if mds.FctDeleteVulnerabilityFix != nil { - return mds.FctDeleteVulnerabilityFix(vulnerabilityNamespace, vulnerabilityName, featureName) - } - panic("required mock function not implemented") -} - -func (mds *MockDatastore) GetAvailableNotification(renotifyInterval time.Duration) (VulnerabilityNotification, error) { - if mds.FctGetAvailableNotification != nil { - return mds.FctGetAvailableNotification(renotifyInterval) - } - panic("required mock function not implemented") -} - -func (mds *MockDatastore) GetNotification(name string, limit int, page VulnerabilityNotificationPageNumber) (VulnerabilityNotification, VulnerabilityNotificationPageNumber, error) { - if mds.FctGetNotification != nil { - return mds.FctGetNotification(name, limit, page) - } - panic("required mock function not implemented") -} - -func (mds *MockDatastore) SetNotificationNotified(name string) error { - if mds.FctSetNotificationNotified != nil { - return mds.FctSetNotificationNotified(name) - } - panic("required mock function not implemented") -} - -func (mds *MockDatastore) DeleteNotification(name string) error { - if mds.FctDeleteNotification != nil { - return mds.FctDeleteNotification(name) - } - panic("required mock function not implemented") -} -func (mds *MockDatastore) InsertKeyValue(key, value string) error { - if mds.FctInsertKeyValue != nil { - return mds.FctInsertKeyValue(key, value) - } - panic("required mock function not implemented") -} - -func (mds *MockDatastore) GetKeyValue(key string) (string, error) { - if mds.FctGetKeyValue != nil { - return mds.FctGetKeyValue(key) - } - panic("required mock function not implemented") -} - -func (mds *MockDatastore) Lock(name string, owner string, duration time.Duration, renew bool) (bool, time.Time) { - if mds.FctLock != nil { - return mds.FctLock(name, owner, duration, renew) - } - panic("required mock function not implemented") -} - -func (mds *MockDatastore) Unlock(name, owner string) { - if mds.FctUnlock != nil { - mds.FctUnlock(name, owner) - return - } - panic("required mock function not implemented") -} - -func (mds *MockDatastore) FindLock(name string) (string, time.Time, error) { - if mds.FctFindLock != nil { - return mds.FctFindLock(name) +func (mds *MockDatastore) Begin() (Session, error) { + if mds.FctBegin != nil { + return mds.FctBegin() } panic("required mock function not implemented") } diff --git a/database/models.go b/database/models.go index 608f2449..fe36fbfc 100644 --- a/database/models.go +++ b/database/models.go @@ -20,49 +20,115 @@ import ( "time" ) -// ID is only meant to be used by database implementations and should never be used for anything else. -type Model struct { - ID int +// Processors are extentions to scan layer's content. +type Processors struct { + Listers []string + Detectors []string } +// Ancestry is a manifest that keeps all layers in an image in order. +type Ancestry struct { + Name string + // Layers should be ordered and i_th layer is the parent of i+1_th layer in + // the slice. + Layers []Layer +} + +// AncestryWithFeatures is an ancestry with namespaced features detected in the +// ancestry, which is processed by `ProcessedBy`. +type AncestryWithFeatures struct { + Ancestry + + ProcessedBy Processors + Features []NamespacedFeature +} + +// Layer corresponds to a layer in an image processed by `ProcessedBy`. type Layer struct { - Model - - Name string - EngineVersion int - Parent *Layer - Namespaces []Namespace - Features []FeatureVersion + // Hash is content hash of the layer. + Hash string } -type Namespace struct { - Model +// LayerWithContent is a layer with its detected namespaces and features by +// ProcessedBy. +type LayerWithContent struct { + Layer + ProcessedBy Processors + Namespaces []Namespace + Features []Feature +} + +// Namespace is the contextual information around features. +// +// e.g. Debian:7, NodeJS. +type Namespace struct { Name string VersionFormat string } +// Feature represents a package detected in a layer but the namespace is not +// determined. +// +// e.g. Name: OpenSSL, Version: 1.0, VersionFormat: dpkg. +// dpkg implies the installer package manager but the namespace (might be +// debian:7, debian:8, ...) could not be determined. type Feature struct { - Model + Name string + Version string + VersionFormat string +} + +// NamespacedFeature is a feature with determined namespace and can be affected +// by vulnerabilities. +// +// e.g. OpenSSL 1.0 dpkg Debian:7. +type NamespacedFeature struct { + Feature - Name string Namespace Namespace } -type FeatureVersion struct { - Model +// AffectedNamespacedFeature is a namespaced feature affected by the +// vulnerabilities with fixed-in versions for this feature. +type AffectedNamespacedFeature struct { + NamespacedFeature - Feature Feature - Version string - AffectedBy []Vulnerability - - // For output purposes. Only make sense when the feature version is in the context of an image. - AddedBy Layer + AffectedBy []VulnerabilityWithFixedIn } -type Vulnerability struct { - Model +// VulnerabilityWithFixedIn is used for AffectedNamespacedFeature to retrieve +// the affecting vulnerabilities and the fixed-in versions for the feature. +type VulnerabilityWithFixedIn struct { + Vulnerability + FixedInVersion string +} + +// AffectedFeature is used to determine whether a namespaced feature is affected +// by a Vulnerability. Namespace and Feature Name is unique. Affected Feature is +// bound to vulnerability. +type AffectedFeature struct { + Namespace Namespace + FeatureName string + // FixedInVersion is known next feature version that's not affected by the + // vulnerability. Empty FixedInVersion means the unaffected version is + // unknown. + FixedInVersion string + // AffectedVersion contains the version range to determine whether or not a + // feature is affected. + AffectedVersion string +} + +// VulnerabilityID is an identifier for every vulnerability. Every vulnerability +// has unique namespace and name. +type VulnerabilityID struct { + Name string + Namespace string +} + +// Vulnerability represents CVE or similar vulnerability reports. +type Vulnerability struct { Name string Namespace Namespace @@ -71,17 +137,85 @@ type Vulnerability struct { Severity Severity Metadata MetadataMap - - FixedIn []FeatureVersion - LayersIntroducingVulnerability []Layer - - // For output purposes. Only make sense when the vulnerability - // is already about a specific Feature/FeatureVersion. - FixedBy string `json:",omitempty"` } +// VulnerabilityWithAffected is an vulnerability with all known affected +// features. +type VulnerabilityWithAffected struct { + Vulnerability + + Affected []AffectedFeature +} + +// PagedVulnerableAncestries is a vulnerability with a page of affected +// ancestries each with a special index attached for streaming purpose. The +// current page number and next page number are for navigate. +type PagedVulnerableAncestries struct { + Vulnerability + + // Affected is a map of special indexes to Ancestries, which the pair + // should be unique in a stream. Every indexes in the map should be larger + // than previous page. + Affected map[int]string + + Limit int + Current PageNumber + Next PageNumber + + // End signals the end of the pages. + End bool +} + +// NotificationHook is a message sent to another service to inform of a change +// to a Vulnerability or the Ancestries affected by a Vulnerability. It contains +// the name of a notification that should be read and marked as read via the +// API. +type NotificationHook struct { + Name string + + Created time.Time + Notified time.Time + Deleted time.Time +} + +// VulnerabilityNotification is a notification for vulnerability changes. +type VulnerabilityNotification struct { + NotificationHook + + Old *Vulnerability + New *Vulnerability +} + +// VulnerabilityNotificationWithVulnerable is a notification for vulnerability +// changes with vulnerable ancestries. +type VulnerabilityNotificationWithVulnerable struct { + NotificationHook + + Old *PagedVulnerableAncestries + New *PagedVulnerableAncestries +} + +// PageNumber is used to do pagination. +type PageNumber string + type MetadataMap map[string]interface{} +// NullableAffectedNamespacedFeature is an affectednamespacedfeature with +// whether it's found in datastore. +type NullableAffectedNamespacedFeature struct { + AffectedNamespacedFeature + + Valid bool +} + +// NullableVulnerability is a vulnerability with whether the vulnerability is +// found in datastore. +type NullableVulnerability struct { + VulnerabilityWithAffected + + Valid bool +} + func (mm *MetadataMap) Scan(value interface{}) error { if value == nil { return nil @@ -99,25 +233,3 @@ func (mm *MetadataMap) Value() (driver.Value, error) { json, err := json.Marshal(*mm) return string(json), err } - -type VulnerabilityNotification struct { - Model - - Name string - - Created time.Time - Notified time.Time - Deleted time.Time - - OldVulnerability *Vulnerability - NewVulnerability *Vulnerability -} - -type VulnerabilityNotificationPageNumber struct { - // -1 means that we reached the end already. - OldVulnerability int - NewVulnerability int -} - -var VulnerabilityNotificationFirstPage = VulnerabilityNotificationPageNumber{0, 0} -var NoVulnerabilityNotificationPage = VulnerabilityNotificationPageNumber{-1, -1} diff --git a/database/pgsql/ancestry.go b/database/pgsql/ancestry.go new file mode 100644 index 00000000..17033144 --- /dev/null +++ b/database/pgsql/ancestry.go @@ -0,0 +1,261 @@ +package pgsql + +import ( + "database/sql" + "errors" + "fmt" + "strings" + + "github.com/lib/pq" + log "github.com/sirupsen/logrus" + + "github.com/coreos/clair/database" + "github.com/coreos/clair/pkg/commonerr" +) + +func (tx *pgSession) UpsertAncestry(ancestry database.Ancestry, features []database.NamespacedFeature, processedBy database.Processors) error { + if ancestry.Name == "" { + log.Warning("Empty ancestry name is not allowed") + return commonerr.NewBadRequestError("could not insert an ancestry with empty name") + } + + if len(ancestry.Layers) == 0 { + log.Warning("Empty ancestry is not allowed") + return commonerr.NewBadRequestError("could not insert an ancestry with 0 layers") + } + + err := tx.deleteAncestry(ancestry.Name) + if err != nil { + return err + } + + var ancestryID int64 + err = tx.QueryRow(insertAncestry, ancestry.Name).Scan(&ancestryID) + if err != nil { + if isErrUniqueViolation(err) { + return handleError("insertAncestry", errors.New("Other Go-routine is processing this ancestry (skip).")) + } + return handleError("insertAncestry", err) + } + + err = tx.insertAncestryLayers(ancestryID, ancestry.Layers) + if err != nil { + return err + } + + err = tx.insertAncestryFeatures(ancestryID, features) + if err != nil { + return err + } + + return tx.persistProcessors(persistAncestryLister, + "persistAncestryLister", + persistAncestryDetector, + "persistAncestryDetector", + ancestryID, processedBy) +} + +func (tx *pgSession) FindAncestry(name string) (database.Ancestry, database.Processors, bool, error) { + ancestry := database.Ancestry{Name: name} + processed := database.Processors{} + + var ancestryID int64 + err := tx.QueryRow(searchAncestry, name).Scan(&ancestryID) + if err != nil { + if err == sql.ErrNoRows { + return ancestry, processed, false, nil + } + return ancestry, processed, false, handleError("searchAncestry", err) + } + + ancestry.Layers, err = tx.findAncestryLayers(ancestryID) + if err != nil { + return ancestry, processed, false, err + } + + processed.Detectors, err = tx.findProcessors(searchAncestryDetectors, "searchAncestryDetectors", "detector", ancestryID) + if err != nil { + return ancestry, processed, false, err + } + + processed.Listers, err = tx.findProcessors(searchAncestryListers, "searchAncestryListers", "lister", ancestryID) + if err != nil { + return ancestry, processed, false, err + } + + return ancestry, processed, true, nil +} + +func (tx *pgSession) FindAncestryFeatures(name string) (database.AncestryWithFeatures, bool, error) { + var ( + awf database.AncestryWithFeatures + ok bool + err error + ) + awf.Ancestry, awf.ProcessedBy, ok, err = tx.FindAncestry(name) + if err != nil { + return awf, false, err + } + + if !ok { + return awf, false, nil + } + + rows, err := tx.Query(searchAncestryFeatures, name) + if err != nil { + return awf, false, handleError("searchAncestryFeatures", err) + } + + for rows.Next() { + nf := database.NamespacedFeature{} + err := rows.Scan(&nf.Namespace.Name, &nf.Namespace.VersionFormat, &nf.Feature.Name, &nf.Feature.Version) + if err != nil { + return awf, false, handleError("searchAncestryFeatures", err) + } + nf.Feature.VersionFormat = nf.Namespace.VersionFormat + awf.Features = append(awf.Features, nf) + } + + return awf, true, nil +} + +func (tx *pgSession) deleteAncestry(name string) error { + result, err := tx.Exec(removeAncestry, name) + if err != nil { + return handleError("removeAncestry", err) + } + + _, err = result.RowsAffected() + if err != nil { + return handleError("removeAncestry", err) + } + + return nil +} + +func (tx *pgSession) findProcessors(query, queryName, processorType string, id int64) ([]string, error) { + rows, err := tx.Query(query, id) + if err != nil { + if err == sql.ErrNoRows { + log.Warning("No " + processorType + " are used") + return nil, nil + } + return nil, handleError(queryName, err) + } + + var ( + processors []string + processor string + ) + + for rows.Next() { + err := rows.Scan(&processor) + if err != nil { + return nil, handleError(queryName, err) + } + processors = append(processors, processor) + } + + return processors, nil +} + +func (tx *pgSession) findAncestryLayers(ancestryID int64) ([]database.Layer, error) { + rows, err := tx.Query(searchAncestryLayer, ancestryID) + if err != nil { + return nil, handleError("searchAncestryLayer", err) + } + layers := []database.Layer{} + for rows.Next() { + var layer database.Layer + err := rows.Scan(&layer.Hash) + if err != nil { + return nil, handleError("searchAncestryLayer", err) + } + layers = append(layers, layer) + } + return layers, nil +} + +func (tx *pgSession) insertAncestryLayers(ancestryID int64, layers []database.Layer) error { + layerIDs := map[string]sql.NullInt64{} + for _, l := range layers { + layerIDs[l.Hash] = sql.NullInt64{} + } + + layerHashes := []string{} + for hash := range layerIDs { + layerHashes = append(layerHashes, hash) + } + + rows, err := tx.Query(searchLayerIDs, pq.Array(layerHashes)) + if err != nil { + return handleError("searchLayerIDs", err) + } + + for rows.Next() { + var ( + layerID sql.NullInt64 + layerName string + ) + err := rows.Scan(&layerID, &layerName) + if err != nil { + return handleError("searchLayerIDs", err) + } + layerIDs[layerName] = layerID + } + + notFound := []string{} + for hash, id := range layerIDs { + if !id.Valid { + notFound = append(notFound, hash) + } + } + + if len(notFound) > 0 { + return handleError("searchLayerIDs", fmt.Errorf("Layer %s is not found in database", strings.Join(notFound, ","))) + } + + //TODO(Sida): use bulk insert. + stmt, err := tx.Prepare(insertAncestryLayer) + if err != nil { + return handleError("insertAncestryLayer", err) + } + + defer stmt.Close() + for index, layer := range layers { + _, err := stmt.Exec(ancestryID, index, layerIDs[layer.Hash].Int64) + if err != nil { + return handleError("insertAncestryLayer", commonerr.CombineErrors(err, stmt.Close())) + } + } + + return nil +} + +func (tx *pgSession) insertAncestryFeatures(ancestryID int64, features []database.NamespacedFeature) error { + featureIDs, err := tx.findNamespacedFeatureIDs(features) + if err != nil { + return err + } + + //TODO(Sida): use bulk insert. + stmtFeatures, err := tx.Prepare(insertAncestryFeature) + if err != nil { + return handleError("insertAncestryFeature", err) + } + + defer stmtFeatures.Close() + + for _, id := range featureIDs { + if !id.Valid { + return errors.New("requested namespaced feature is not in database") + } + + _, err := stmtFeatures.Exec(ancestryID, id) + if err != nil { + return handleError("insertAncestryFeature", err) + } + } + + return nil +} diff --git a/database/pgsql/ancestry_test.go b/database/pgsql/ancestry_test.go new file mode 100644 index 00000000..7851163c --- /dev/null +++ b/database/pgsql/ancestry_test.go @@ -0,0 +1,207 @@ +// Copyright 2017 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pgsql + +import ( + "sort" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/coreos/clair/database" +) + +func TestUpsertAncestry(t *testing.T) { + store, tx := openSessionForTest(t, "UpsertAncestry", true) + defer closeTest(t, store, tx) + a1 := database.Ancestry{ + Name: "a1", + Layers: []database.Layer{ + {Hash: "layer-N"}, + }, + } + + a2 := database.Ancestry{} + + a3 := database.Ancestry{ + Name: "a", + Layers: []database.Layer{ + {Hash: "layer-0"}, + }, + } + + a4 := database.Ancestry{ + Name: "a", + Layers: []database.Layer{ + {Hash: "layer-1"}, + }, + } + + f1 := database.Feature{ + Name: "wechat", + Version: "0.5", + VersionFormat: "dpkg", + } + + // not in database + f2 := database.Feature{ + Name: "wechat", + Version: "0.6", + VersionFormat: "dpkg", + } + + n1 := database.Namespace{ + Name: "debian:7", + VersionFormat: "dpkg", + } + + p := database.Processors{ + Listers: []string{"dpkg", "non-existing"}, + Detectors: []string{"os-release", "non-existing"}, + } + + nsf1 := database.NamespacedFeature{ + Namespace: n1, + Feature: f1, + } + + // not in database + nsf2 := database.NamespacedFeature{ + Namespace: n1, + Feature: f2, + } + + // invalid case + assert.NotNil(t, tx.UpsertAncestry(a1, nil, database.Processors{})) + assert.NotNil(t, tx.UpsertAncestry(a2, nil, database.Processors{})) + // valid case + assert.Nil(t, tx.UpsertAncestry(a3, nil, database.Processors{})) + // replace invalid case + assert.NotNil(t, tx.UpsertAncestry(a4, []database.NamespacedFeature{nsf1, nsf2}, p)) + // replace valid case + assert.Nil(t, tx.UpsertAncestry(a4, []database.NamespacedFeature{nsf1}, p)) + // validate + ancestry, ok, err := tx.FindAncestryFeatures("a") + assert.Nil(t, err) + assert.True(t, ok) + assert.Equal(t, a4, ancestry.Ancestry) +} + +func assertProcessorsEqual(t *testing.T, expected database.Processors, actual database.Processors) bool { + sort.Strings(expected.Detectors) + sort.Strings(actual.Detectors) + sort.Strings(expected.Listers) + sort.Strings(actual.Listers) + return assert.Equal(t, expected.Detectors, actual.Detectors) && assert.Equal(t, expected.Listers, actual.Listers) +} + +func TestFindAncestry(t *testing.T) { + store, tx := openSessionForTest(t, "FindAncestry", true) + defer closeTest(t, store, tx) + + // not found + _, _, ok, err := tx.FindAncestry("ancestry-non") + assert.Nil(t, err) + assert.False(t, ok) + + expected := database.Ancestry{ + Name: "ancestry-1", + Layers: []database.Layer{ + {Hash: "layer-0"}, + {Hash: "layer-1"}, + {Hash: "layer-2"}, + {Hash: "layer-3a"}, + }, + } + + expectedProcessors := database.Processors{ + Detectors: []string{"os-release"}, + Listers: []string{"dpkg"}, + } + + // found + a, p, ok2, err := tx.FindAncestry("ancestry-1") + if assert.Nil(t, err) && assert.True(t, ok2) { + assertAncestryEqual(t, expected, a) + assertProcessorsEqual(t, expectedProcessors, p) + } +} + +func assertAncestryWithFeatureEqual(t *testing.T, expected database.AncestryWithFeatures, actual database.AncestryWithFeatures) bool { + return assertAncestryEqual(t, expected.Ancestry, actual.Ancestry) && + assertNamespacedFeatureEqual(t, expected.Features, actual.Features) && + assertProcessorsEqual(t, expected.ProcessedBy, actual.ProcessedBy) +} +func assertAncestryEqual(t *testing.T, expected database.Ancestry, actual database.Ancestry) bool { + return assert.Equal(t, expected.Name, actual.Name) && assert.Equal(t, expected.Layers, actual.Layers) +} + +func TestFindAncestryFeatures(t *testing.T) { + store, tx := openSessionForTest(t, "FindAncestryFeatures", true) + defer closeTest(t, store, tx) + + // invalid + _, ok, err := tx.FindAncestryFeatures("ancestry-non") + if assert.Nil(t, err) { + assert.False(t, ok) + } + + expected := database.AncestryWithFeatures{ + Ancestry: database.Ancestry{ + Name: "ancestry-2", + Layers: []database.Layer{ + {Hash: "layer-0"}, + {Hash: "layer-1"}, + {Hash: "layer-2"}, + {Hash: "layer-3b"}, + }, + }, + ProcessedBy: database.Processors{ + Detectors: []string{"os-release"}, + Listers: []string{"dpkg"}, + }, + Features: []database.NamespacedFeature{ + { + Namespace: database.Namespace{ + Name: "debian:7", + VersionFormat: "dpkg", + }, + Feature: database.Feature{ + Name: "wechat", + Version: "0.5", + VersionFormat: "dpkg", + }, + }, + { + Namespace: database.Namespace{ + Name: "debian:8", + VersionFormat: "dpkg", + }, + Feature: database.Feature{ + Name: "openssl", + Version: "1.0", + VersionFormat: "dpkg", + }, + }, + }, + } + // valid + ancestry, ok, err := tx.FindAncestryFeatures("ancestry-2") + if assert.Nil(t, err) && assert.True(t, ok) { + assertAncestryEqual(t, expected.Ancestry, ancestry.Ancestry) + assertNamespacedFeatureEqual(t, expected.Features, ancestry.Features) + assertProcessorsEqual(t, expected.ProcessedBy, ancestry.ProcessedBy) + } +} diff --git a/database/pgsql/complex_test.go b/database/pgsql/complex_test.go index ed038b4e..07d6f55f 100644 --- a/database/pgsql/complex_test.go +++ b/database/pgsql/complex_test.go @@ -27,135 +27,200 @@ import ( "github.com/stretchr/testify/assert" "github.com/coreos/clair/database" + "github.com/coreos/clair/ext/versionfmt" "github.com/coreos/clair/ext/versionfmt/dpkg" + "github.com/coreos/clair/pkg/strutil" ) const ( numVulnerabilities = 100 - numFeatureVersions = 100 + numFeatures = 100 ) -func TestRaceAffects(t *testing.T) { - datastore, err := openDatabaseForTest("RaceAffects", false) - if err != nil { - t.Error(err) - return +func testGenRandomVulnerabilityAndNamespacedFeature(t *testing.T, store database.Datastore) ([]database.NamespacedFeature, []database.VulnerabilityWithAffected) { + tx, err := store.Begin() + if !assert.Nil(t, err) { + t.FailNow() } - defer datastore.Close() - // Insert the Feature on which we'll work. - feature := database.Feature{ - Namespace: database.Namespace{ - Name: "TestRaceAffectsFeatureNamespace1", - VersionFormat: dpkg.ParserName, - }, - Name: "TestRaceAffecturesFeature1", + featureName := "TestFeature" + featureVersionFormat := dpkg.ParserName + // Insert the namespace on which we'll work. + namespace := database.Namespace{ + Name: "TestRaceAffectsFeatureNamespace1", + VersionFormat: dpkg.ParserName, } - _, err = datastore.insertFeature(feature) - if err != nil { - t.Error(err) - return + + if !assert.Nil(t, tx.PersistNamespaces([]database.Namespace{namespace})) { + t.FailNow() } // Initialize random generator and enforce max procs. rand.Seed(time.Now().UnixNano()) runtime.GOMAXPROCS(runtime.NumCPU()) - // Generate FeatureVersions. - featureVersions := make([]database.FeatureVersion, numFeatureVersions) - for i := 0; i < numFeatureVersions; i++ { - version := rand.Intn(numFeatureVersions) + // Generate Distinct random features + features := make([]database.Feature, numFeatures) + nsFeatures := make([]database.NamespacedFeature, numFeatures) + for i := 0; i < numFeatures; i++ { + version := rand.Intn(numFeatures) - featureVersions[i] = database.FeatureVersion{ - Feature: feature, - Version: strconv.Itoa(version), + features[i] = database.Feature{ + Name: featureName, + VersionFormat: featureVersionFormat, + Version: strconv.Itoa(version), } + + nsFeatures[i] = database.NamespacedFeature{ + Namespace: namespace, + Feature: features[i], + } + } + + // insert features + if !assert.Nil(t, tx.PersistFeatures(features)) { + t.FailNow() } // Generate vulnerabilities. - // They are mapped by fixed version, which will make verification really easy afterwards. - vulnerabilities := make(map[int][]database.Vulnerability) + vulnerabilities := []database.VulnerabilityWithAffected{} for i := 0; i < numVulnerabilities; i++ { - version := rand.Intn(numFeatureVersions) + 1 + // any version less than this is vulnerable + version := rand.Intn(numFeatures) + 1 - // if _, ok := vulnerabilities[version]; !ok { - // vulnerabilities[version] = make([]database.Vulnerability) - // } - - vulnerability := database.Vulnerability{ - Name: uuid.New(), - Namespace: feature.Namespace, - FixedIn: []database.FeatureVersion{ + vulnerability := database.VulnerabilityWithAffected{ + Vulnerability: database.Vulnerability{ + Name: uuid.New(), + Namespace: namespace, + Severity: database.UnknownSeverity, + }, + Affected: []database.AffectedFeature{ { - Feature: feature, - Version: strconv.Itoa(version), + Namespace: namespace, + FeatureName: featureName, + AffectedVersion: strconv.Itoa(version), + FixedInVersion: strconv.Itoa(version), }, }, - Severity: database.UnknownSeverity, } - vulnerabilities[version] = append(vulnerabilities[version], vulnerability) + vulnerabilities = append(vulnerabilities, vulnerability) } + tx.Commit() + + return nsFeatures, vulnerabilities +} + +func TestConcurrency(t *testing.T) { + store, err := openDatabaseForTest("Concurrency", false) + if !assert.Nil(t, err) { + t.FailNow() + } + defer store.Close() + + start := time.Now() + var wg sync.WaitGroup + wg.Add(100) + for i := 0; i < 100; i++ { + go func() { + defer wg.Done() + nsNamespaces := genRandomNamespaces(t, 100) + tx, err := store.Begin() + if !assert.Nil(t, err) { + t.FailNow() + } + assert.Nil(t, tx.PersistNamespaces(nsNamespaces)) + tx.Commit() + }() + } + wg.Wait() + fmt.Println("total", time.Since(start)) +} + +func genRandomNamespaces(t *testing.T, count int) []database.Namespace { + r := make([]database.Namespace, count) + for i := 0; i < count; i++ { + r[i] = database.Namespace{ + Name: uuid.New(), + VersionFormat: "dpkg", + } + } + return r +} + +func TestCaching(t *testing.T) { + store, err := openDatabaseForTest("Caching", false) + if !assert.Nil(t, err) { + t.FailNow() + } + defer store.Close() + + nsFeatures, vulnerabilities := testGenRandomVulnerabilityAndNamespacedFeature(t, store) + + fmt.Printf("%d features, %d vulnerabilities are generated", len(nsFeatures), len(vulnerabilities)) - // Insert featureversions and vulnerabilities in parallel. var wg sync.WaitGroup wg.Add(2) - go func() { defer wg.Done() - for _, vulnerabilitiesM := range vulnerabilities { - for _, vulnerability := range vulnerabilitiesM { - err = datastore.InsertVulnerabilities([]database.Vulnerability{vulnerability}, true) - assert.Nil(t, err) - } + tx, err := store.Begin() + if !assert.Nil(t, err) { + t.FailNow() } - fmt.Println("finished to insert vulnerabilities") + + assert.Nil(t, tx.PersistNamespacedFeatures(nsFeatures)) + fmt.Println("finished to insert namespaced features") + + tx.Commit() }() go func() { defer wg.Done() - for i := 0; i < len(featureVersions); i++ { - featureVersions[i].ID, err = datastore.insertFeatureVersion(featureVersions[i]) - assert.Nil(t, err) + tx, err := store.Begin() + if !assert.Nil(t, err) { + t.FailNow() } - fmt.Println("finished to insert featureVersions") + + assert.Nil(t, tx.InsertVulnerabilities(vulnerabilities)) + fmt.Println("finished to insert vulnerabilities") + tx.Commit() + }() wg.Wait() + tx, err := store.Begin() + if !assert.Nil(t, err) { + t.FailNow() + } + defer tx.Rollback() + // Verify consistency now. - var actualAffectedNames []string - var expectedAffectedNames []string + affected, err := tx.FindAffectedNamespacedFeatures(nsFeatures) + if !assert.Nil(t, err) { + t.FailNow() + } - for _, featureVersion := range featureVersions { - featureVersionVersion, _ := strconv.Atoi(featureVersion.Version) - - // Get actual affects. - rows, err := datastore.Query(searchComplexTestFeatureVersionAffects, - featureVersion.ID) - assert.Nil(t, err) - defer rows.Close() - - var vulnName string - for rows.Next() { - err = rows.Scan(&vulnName) - if !assert.Nil(t, err) { - continue - } - actualAffectedNames = append(actualAffectedNames, vulnName) - } - if assert.Nil(t, rows.Err()) { - rows.Close() + for _, ansf := range affected { + if !assert.True(t, ansf.Valid) { + t.FailNow() } - // Get expected affects. - for i := numVulnerabilities; i > featureVersionVersion; i-- { - for _, vulnerability := range vulnerabilities[i] { - expectedAffectedNames = append(expectedAffectedNames, vulnerability.Name) + expectedAffectedNames := []string{} + for _, vuln := range vulnerabilities { + if ok, err := versionfmt.InRange(dpkg.ParserName, ansf.Version, vuln.Affected[0].AffectedVersion); err == nil { + if ok { + expectedAffectedNames = append(expectedAffectedNames, vuln.Name) + } } } - assert.Len(t, compareStringLists(expectedAffectedNames, actualAffectedNames), 0) - assert.Len(t, compareStringLists(actualAffectedNames, expectedAffectedNames), 0) + actualAffectedNames := []string{} + for _, s := range ansf.AffectedBy { + actualAffectedNames = append(actualAffectedNames, s.Name) + } + + assert.Len(t, strutil.CompareStringLists(expectedAffectedNames, actualAffectedNames), 0) + assert.Len(t, strutil.CompareStringLists(actualAffectedNames, expectedAffectedNames), 0) } } diff --git a/database/pgsql/feature.go b/database/pgsql/feature.go index c39bd5b7..81ef857d 100644 --- a/database/pgsql/feature.go +++ b/database/pgsql/feature.go @@ -16,230 +16,366 @@ package pgsql import ( "database/sql" - "strings" - "time" + "errors" + "sort" + + "github.com/lib/pq" + log "github.com/sirupsen/logrus" "github.com/coreos/clair/database" "github.com/coreos/clair/ext/versionfmt" "github.com/coreos/clair/pkg/commonerr" ) -func (pgSQL *pgSQL) insertFeature(feature database.Feature) (int, error) { - if feature.Name == "" { - return 0, commonerr.NewBadRequestError("could not find/insert invalid Feature") - } +var ( + errFeatureNotFound = errors.New("Feature not found") +) - // Do cache lookup. - if pgSQL.cache != nil { - promCacheQueriesTotal.WithLabelValues("feature").Inc() - id, found := pgSQL.cache.Get("feature:" + feature.Namespace.Name + ":" + feature.Name) - if found { - promCacheHitsTotal.WithLabelValues("feature").Inc() - return id.(int), nil - } - } - - // We do `defer observeQueryTime` here because we don't want to observe cached features. - defer observeQueryTime("insertFeature", "all", time.Now()) - - // Find or create Namespace. - namespaceID, err := pgSQL.insertNamespace(feature.Namespace) - if err != nil { - return 0, err - } - - // Find or create Feature. - var id int - err = pgSQL.QueryRow(soiFeature, feature.Name, namespaceID).Scan(&id) - if err != nil { - return 0, handleError("soiFeature", err) - } - - if pgSQL.cache != nil { - pgSQL.cache.Add("feature:"+feature.Namespace.Name+":"+feature.Name, id) - } - - return id, nil +type vulnerabilityAffecting struct { + vulnerabilityID int64 + addedByID int64 } -func (pgSQL *pgSQL) insertFeatureVersion(fv database.FeatureVersion) (id int, err error) { - err = versionfmt.Valid(fv.Feature.Namespace.VersionFormat, fv.Version) - if err != nil { - return 0, commonerr.NewBadRequestError("could not find/insert invalid FeatureVersion") +func (tx *pgSession) PersistFeatures(features []database.Feature) error { + if len(features) == 0 { + return nil } - // Do cache lookup. - cacheIndex := strings.Join([]string{"featureversion", fv.Feature.Namespace.Name, fv.Feature.Name, fv.Version}, ":") - if pgSQL.cache != nil { - promCacheQueriesTotal.WithLabelValues("featureversion").Inc() - id, found := pgSQL.cache.Get(cacheIndex) - if found { - promCacheHitsTotal.WithLabelValues("featureversion").Inc() - return id.(int), nil + // Sorting is needed before inserting into database to prevent deadlock. + sort.Slice(features, func(i, j int) bool { + return features[i].Name < features[j].Name || + features[i].Version < features[j].Version || + features[i].VersionFormat < features[j].VersionFormat + }) + + // TODO(Sida): A better interface for bulk insertion is needed. + keys := make([]interface{}, len(features)*3) + for i, f := range features { + keys[i*3] = f.Name + keys[i*3+1] = f.Version + keys[i*3+2] = f.VersionFormat + if f.Name == "" || f.Version == "" || f.VersionFormat == "" { + return commonerr.NewBadRequestError("Empty feature name, version or version format is not allowed") } } - // We do `defer observeQueryTime` here because we don't want to observe cached featureversions. - defer observeQueryTime("insertFeatureVersion", "all", time.Now()) - - // Find or create Feature first. - t := time.Now() - featureID, err := pgSQL.insertFeature(fv.Feature) - observeQueryTime("insertFeatureVersion", "insertFeature", t) - - if err != nil { - return 0, err - } - - fv.Feature.ID = featureID - - // Try to find the FeatureVersion. - // - // In a populated database, the likelihood of the FeatureVersion already being there is high. - // If we can find it here, we then avoid using a transaction and locking the database. - err = pgSQL.QueryRow(searchFeatureVersion, featureID, fv.Version).Scan(&fv.ID) - if err != nil && err != sql.ErrNoRows { - return 0, handleError("searchFeatureVersion", err) - } - if err == nil { - if pgSQL.cache != nil { - pgSQL.cache.Add(cacheIndex, fv.ID) - } - - return fv.ID, nil - } - - // Begin transaction. - tx, err := pgSQL.Begin() - if err != nil { - tx.Rollback() - return 0, handleError("insertFeatureVersion.Begin()", err) - } - - // Lock Vulnerability_Affects_FeatureVersion exclusively. - // We want to prevent InsertVulnerability to modify it. - promConcurrentLockVAFV.Inc() - defer promConcurrentLockVAFV.Dec() - t = time.Now() - _, err = tx.Exec(lockVulnerabilityAffects) - observeQueryTime("insertFeatureVersion", "lock", t) - - if err != nil { - tx.Rollback() - return 0, handleError("insertFeatureVersion.lockVulnerabilityAffects", err) - } - - // Find or create FeatureVersion. - var created bool - - t = time.Now() - err = tx.QueryRow(soiFeatureVersion, featureID, fv.Version).Scan(&created, &fv.ID) - observeQueryTime("insertFeatureVersion", "soiFeatureVersion", t) - - if err != nil { - tx.Rollback() - return 0, handleError("soiFeatureVersion", err) - } - - if !created { - // The featureVersion already existed, no need to link it to - // vulnerabilities. - tx.Commit() - - if pgSQL.cache != nil { - pgSQL.cache.Add(cacheIndex, fv.ID) - } - - return fv.ID, nil - } - - // Link the new FeatureVersion with every vulnerabilities that affect it, by inserting in - // Vulnerability_Affects_FeatureVersion. - t = time.Now() - err = linkFeatureVersionToVulnerabilities(tx, fv) - observeQueryTime("insertFeatureVersion", "linkFeatureVersionToVulnerabilities", t) - - if err != nil { - tx.Rollback() - return 0, err - } - - // Commit transaction. - err = tx.Commit() - if err != nil { - return 0, handleError("insertFeatureVersion.Commit()", err) - } - - if pgSQL.cache != nil { - pgSQL.cache.Add(cacheIndex, fv.ID) - } - - return fv.ID, nil + _, err := tx.Exec(queryPersistFeature(len(features)), keys...) + return handleError("queryPersistFeature", err) } -// TODO(Quentin-M): Batch me -func (pgSQL *pgSQL) insertFeatureVersions(featureVersions []database.FeatureVersion) ([]int, error) { - IDs := make([]int, 0, len(featureVersions)) +type namespacedFeatureWithID struct { + database.NamespacedFeature - for i := 0; i < len(featureVersions); i++ { - id, err := pgSQL.insertFeatureVersion(featureVersions[i]) - if err != nil { - return IDs, err - } - IDs = append(IDs, id) + ID int64 +} + +type vulnerabilityCache struct { + nsFeatureID int64 + vulnID int64 + vulnAffectingID int64 +} + +func (tx *pgSession) searchAffectingVulnerabilities(features []database.NamespacedFeature) ([]vulnerabilityCache, error) { + if len(features) == 0 { + return nil, nil } - return IDs, nil -} - -type vulnerabilityAffectsFeatureVersion struct { - vulnerabilityID int - fixedInID int - fixedInVersion string -} - -func linkFeatureVersionToVulnerabilities(tx *sql.Tx, featureVersion database.FeatureVersion) error { - // Select every vulnerability and the fixed version that affect this Feature. - // TODO(Quentin-M): LIMIT - rows, err := tx.Query(searchVulnerabilityFixedInFeature, featureVersion.Feature.ID) + ids, err := tx.findNamespacedFeatureIDs(features) if err != nil { - return handleError("searchVulnerabilityFixedInFeature", err) + return nil, err } + + fMap := map[int64]database.NamespacedFeature{} + for i, f := range features { + if !ids[i].Valid { + return nil, errFeatureNotFound + } + fMap[ids[i].Int64] = f + } + + cacheTable := []vulnerabilityCache{} + rows, err := tx.Query(searchPotentialAffectingVulneraibilities, pq.Array(ids)) + if err != nil { + return nil, handleError("searchPotentialAffectingVulneraibilities", err) + } + defer rows.Close() - - var affects []vulnerabilityAffectsFeatureVersion for rows.Next() { - var affect vulnerabilityAffectsFeatureVersion + var ( + cache vulnerabilityCache + affected string + ) - err := rows.Scan(&affect.fixedInID, &affect.vulnerabilityID, &affect.fixedInVersion) + err := rows.Scan(&cache.nsFeatureID, &cache.vulnID, &affected, &cache.vulnAffectingID) if err != nil { - return handleError("searchVulnerabilityFixedInFeature.Scan()", err) + return nil, err } - cmp, err := versionfmt.Compare(featureVersion.Feature.Namespace.VersionFormat, featureVersion.Version, affect.fixedInVersion) - if err != nil { - return err - } - if cmp < 0 { - // The version of the FeatureVersion we are inserting is lower than the fixed version on this - // Vulnerability, thus, this FeatureVersion is affected by it. - affects = append(affects, affect) + if ok, err := versionfmt.InRange(fMap[cache.nsFeatureID].VersionFormat, fMap[cache.nsFeatureID].Version, affected); err != nil { + return nil, err + } else if ok { + cacheTable = append(cacheTable, cache) } } - if err = rows.Err(); err != nil { - return handleError("searchVulnerabilityFixedInFeature.Rows()", err) - } - rows.Close() - // Insert into Vulnerability_Affects_FeatureVersion. - for _, affect := range affects { - // TODO(Quentin-M): Batch me. - _, err := tx.Exec(insertVulnerabilityAffectsFeatureVersion, affect.vulnerabilityID, - featureVersion.ID, affect.fixedInID) - if err != nil { - return handleError("insertVulnerabilityAffectsFeatureVersion", err) + return cacheTable, nil +} + +func (tx *pgSession) CacheAffectedNamespacedFeatures(features []database.NamespacedFeature) error { + if len(features) == 0 { + return nil + } + + _, err := tx.Exec(lockVulnerabilityAffects) + if err != nil { + return handleError("lockVulnerabilityAffects", err) + } + + cache, err := tx.searchAffectingVulnerabilities(features) + + keys := make([]interface{}, len(cache)*3) + for i, c := range cache { + keys[i*3] = c.vulnID + keys[i*3+1] = c.nsFeatureID + keys[i*3+2] = c.vulnAffectingID + } + + if len(cache) == 0 { + return nil + } + + affected, err := tx.Exec(queryPersistVulnerabilityAffectedNamespacedFeature(len(cache)), keys...) + if err != nil { + return handleError("persistVulnerabilityAffectedNamespacedFeature", err) + } + if count, err := affected.RowsAffected(); err != nil { + log.Debugf("Cached %d features in vulnerability_affected_namespaced_feature", count) + } + return nil +} + +func (tx *pgSession) PersistNamespacedFeatures(features []database.NamespacedFeature) error { + if len(features) == 0 { + return nil + } + + nsIDs := map[database.Namespace]sql.NullInt64{} + fIDs := map[database.Feature]sql.NullInt64{} + for _, f := range features { + nsIDs[f.Namespace] = sql.NullInt64{} + fIDs[f.Feature] = sql.NullInt64{} + } + + fToFind := []database.Feature{} + for f := range fIDs { + fToFind = append(fToFind, f) + } + + sort.Slice(fToFind, func(i, j int) bool { + return fToFind[i].Name < fToFind[j].Name || + fToFind[i].Version < fToFind[j].Version || + fToFind[i].VersionFormat < fToFind[j].VersionFormat + }) + + if ids, err := tx.findFeatureIDs(fToFind); err == nil { + for i, id := range ids { + if !id.Valid { + return errFeatureNotFound + } + fIDs[fToFind[i]] = id } + } else { + return err + } + + nsToFind := []database.Namespace{} + for ns := range nsIDs { + nsToFind = append(nsToFind, ns) + } + + if ids, err := tx.findNamespaceIDs(nsToFind); err == nil { + for i, id := range ids { + if !id.Valid { + return errNamespaceNotFound + } + nsIDs[nsToFind[i]] = id + } + } else { + return err + } + + keys := make([]interface{}, len(features)*2) + for i, f := range features { + keys[i*2] = fIDs[f.Feature] + keys[i*2+1] = nsIDs[f.Namespace] + } + + _, err := tx.Exec(queryPersistNamespacedFeature(len(features)), keys...) + if err != nil { + return err } return nil } + +// FindAffectedNamespacedFeatures looks up cache table and retrieves all +// vulnerabilities associated with the features. +func (tx *pgSession) FindAffectedNamespacedFeatures(features []database.NamespacedFeature) ([]database.NullableAffectedNamespacedFeature, error) { + if len(features) == 0 { + return nil, nil + } + + returnFeatures := make([]database.NullableAffectedNamespacedFeature, len(features)) + + // featureMap is used to keep track of duplicated features. + featureMap := map[database.NamespacedFeature][]*database.NullableAffectedNamespacedFeature{} + // initialize return value and generate unique feature request queries. + for i, f := range features { + returnFeatures[i] = database.NullableAffectedNamespacedFeature{ + AffectedNamespacedFeature: database.AffectedNamespacedFeature{ + NamespacedFeature: f, + }, + } + + featureMap[f] = append(featureMap[f], &returnFeatures[i]) + } + + // query unique namespaced features + distinctFeatures := []database.NamespacedFeature{} + for f := range featureMap { + distinctFeatures = append(distinctFeatures, f) + } + + nsFeatureIDs, err := tx.findNamespacedFeatureIDs(distinctFeatures) + if err != nil { + return nil, err + } + + toQuery := []int64{} + featureIDMap := map[int64][]*database.NullableAffectedNamespacedFeature{} + for i, id := range nsFeatureIDs { + if id.Valid { + toQuery = append(toQuery, id.Int64) + for _, f := range featureMap[distinctFeatures[i]] { + f.Valid = id.Valid + featureIDMap[id.Int64] = append(featureIDMap[id.Int64], f) + } + } + } + + rows, err := tx.Query(searchNamespacedFeaturesVulnerabilities, pq.Array(toQuery)) + if err != nil { + return nil, handleError("searchNamespacedFeaturesVulnerabilities", err) + } + defer rows.Close() + + for rows.Next() { + var ( + featureID int64 + vuln database.VulnerabilityWithFixedIn + ) + err := rows.Scan(&featureID, + &vuln.Name, + &vuln.Description, + &vuln.Link, + &vuln.Severity, + &vuln.Metadata, + &vuln.FixedInVersion, + &vuln.Namespace.Name, + &vuln.Namespace.VersionFormat, + ) + if err != nil { + return nil, handleError("searchNamespacedFeaturesVulnerabilities", err) + } + + for _, f := range featureIDMap[featureID] { + f.AffectedBy = append(f.AffectedBy, vuln) + } + } + + return returnFeatures, nil +} + +func (tx *pgSession) findNamespacedFeatureIDs(nfs []database.NamespacedFeature) ([]sql.NullInt64, error) { + if len(nfs) == 0 { + return nil, nil + } + + nfsMap := map[database.NamespacedFeature]sql.NullInt64{} + keys := make([]interface{}, len(nfs)*4) + for i, nf := range nfs { + keys[i*4] = nfs[i].Name + keys[i*4+1] = nfs[i].Version + keys[i*4+2] = nfs[i].VersionFormat + keys[i*4+3] = nfs[i].Namespace.Name + nfsMap[nf] = sql.NullInt64{} + } + + rows, err := tx.Query(querySearchNamespacedFeature(len(nfs)), keys...) + if err != nil { + return nil, handleError("searchNamespacedFeature", err) + } + + defer rows.Close() + var ( + id sql.NullInt64 + nf database.NamespacedFeature + ) + + for rows.Next() { + err := rows.Scan(&id, &nf.Name, &nf.Version, &nf.VersionFormat, &nf.Namespace.Name) + nf.Namespace.VersionFormat = nf.VersionFormat + if err != nil { + return nil, handleError("searchNamespacedFeature", err) + } + nfsMap[nf] = id + } + + ids := make([]sql.NullInt64, len(nfs)) + for i, nf := range nfs { + ids[i] = nfsMap[nf] + } + + return ids, nil +} + +func (tx *pgSession) findFeatureIDs(fs []database.Feature) ([]sql.NullInt64, error) { + if len(fs) == 0 { + return nil, nil + } + + fMap := map[database.Feature]sql.NullInt64{} + + keys := make([]interface{}, len(fs)*3) + for i, f := range fs { + keys[i*3] = f.Name + keys[i*3+1] = f.Version + keys[i*3+2] = f.VersionFormat + fMap[f] = sql.NullInt64{} + } + + rows, err := tx.Query(querySearchFeatureID(len(fs)), keys...) + if err != nil { + return nil, handleError("querySearchFeatureID", err) + } + defer rows.Close() + + var ( + id sql.NullInt64 + f database.Feature + ) + for rows.Next() { + err := rows.Scan(&id, &f.Name, &f.Version, &f.VersionFormat) + if err != nil { + return nil, handleError("querySearchFeatureID", err) + } + fMap[f] = id + } + + ids := make([]sql.NullInt64, len(fs)) + for i, f := range fs { + ids[i] = fMap[f] + } + + return ids, nil +} diff --git a/database/pgsql/feature_test.go b/database/pgsql/feature_test.go index 5b7f8078..934b8cc1 100644 --- a/database/pgsql/feature_test.go +++ b/database/pgsql/feature_test.go @@ -20,96 +20,237 @@ import ( "github.com/stretchr/testify/assert" "github.com/coreos/clair/database" - "github.com/coreos/clair/ext/versionfmt/dpkg" + + // register dpkg feature lister for testing + _ "github.com/coreos/clair/ext/featurefmt/dpkg" ) -func TestInsertFeature(t *testing.T) { - datastore, err := openDatabaseForTest("InsertFeature", false) +func TestPersistFeatures(t *testing.T) { + datastore, tx := openSessionForTest(t, "PersistFeatures", false) + defer closeTest(t, datastore, tx) + + f1 := database.Feature{} + f2 := database.Feature{Name: "n", Version: "v", VersionFormat: "vf"} + + // empty + assert.Nil(t, tx.PersistFeatures([]database.Feature{})) + // invalid + assert.NotNil(t, tx.PersistFeatures([]database.Feature{f1})) + // duplicated + assert.Nil(t, tx.PersistFeatures([]database.Feature{f2, f2})) + // existing + assert.Nil(t, tx.PersistFeatures([]database.Feature{f2})) + + fs := listFeatures(t, tx) + assert.Len(t, fs, 1) + assert.Equal(t, f2, fs[0]) +} + +func TestPersistNamespacedFeatures(t *testing.T) { + datastore, tx := openSessionForTest(t, "PersistNamespacedFeatures", true) + defer closeTest(t, datastore, tx) + + // existing features + f1 := database.Feature{ + Name: "wechat", + Version: "0.5", + VersionFormat: "dpkg", + } + + // non-existing features + f2 := database.Feature{ + Name: "fake!", + } + + f3 := database.Feature{ + Name: "openssl", + Version: "2.0", + VersionFormat: "dpkg", + } + + // exising namespace + n1 := database.Namespace{ + Name: "debian:7", + VersionFormat: "dpkg", + } + + n3 := database.Namespace{ + Name: "debian:8", + VersionFormat: "dpkg", + } + + // non-existing namespace + n2 := database.Namespace{ + Name: "debian:non", + VersionFormat: "dpkg", + } + + // existing namespaced feature + nf1 := database.NamespacedFeature{ + Namespace: n1, + Feature: f1, + } + + // invalid namespaced feature + nf2 := database.NamespacedFeature{ + Namespace: n2, + Feature: f2, + } + + // new namespaced feature affected by vulnerability + nf3 := database.NamespacedFeature{ + Namespace: n3, + Feature: f3, + } + + // namespaced features with namespaces or features not in the database will + // generate error. + assert.Nil(t, tx.PersistNamespacedFeatures([]database.NamespacedFeature{})) + + assert.NotNil(t, tx.PersistNamespacedFeatures([]database.NamespacedFeature{nf1, nf2})) + // valid case: insert nf3 + assert.Nil(t, tx.PersistNamespacedFeatures([]database.NamespacedFeature{nf1, nf3})) + + all := listNamespacedFeatures(t, tx) + assert.Contains(t, all, nf1) + assert.Contains(t, all, nf3) +} + +func TestVulnerableFeature(t *testing.T) { + datastore, tx := openSessionForTest(t, "VulnerableFeature", true) + defer closeTest(t, datastore, tx) + + f1 := database.Feature{ + Name: "openssl", + Version: "1.3", + VersionFormat: "dpkg", + } + + n1 := database.Namespace{ + Name: "debian:7", + VersionFormat: "dpkg", + } + + nf1 := database.NamespacedFeature{ + Namespace: n1, + Feature: f1, + } + assert.Nil(t, tx.PersistFeatures([]database.Feature{f1})) + assert.Nil(t, tx.PersistNamespacedFeatures([]database.NamespacedFeature{nf1})) + assert.Nil(t, tx.CacheAffectedNamespacedFeatures([]database.NamespacedFeature{nf1})) + // ensure the namespaced feature is affected correctly + anf, err := tx.FindAffectedNamespacedFeatures([]database.NamespacedFeature{nf1}) + if assert.Nil(t, err) && + assert.Len(t, anf, 1) && + assert.True(t, anf[0].Valid) && + assert.Len(t, anf[0].AffectedBy, 1) { + assert.Equal(t, "CVE-OPENSSL-1-DEB7", anf[0].AffectedBy[0].Name) + } +} + +func TestFindAffectedNamespacedFeatures(t *testing.T) { + datastore, tx := openSessionForTest(t, "FindAffectedNamespacedFeatures", true) + defer closeTest(t, datastore, tx) + ns := database.NamespacedFeature{ + Feature: database.Feature{ + Name: "openssl", + Version: "1.0", + VersionFormat: "dpkg", + }, + Namespace: database.Namespace{ + Name: "debian:7", + VersionFormat: "dpkg", + }, + } + + ans, err := tx.FindAffectedNamespacedFeatures([]database.NamespacedFeature{ns}) + if assert.Nil(t, err) && + assert.Len(t, ans, 1) && + assert.True(t, ans[0].Valid) && + assert.Len(t, ans[0].AffectedBy, 1) { + assert.Equal(t, "CVE-OPENSSL-1-DEB7", ans[0].AffectedBy[0].Name) + } +} + +func listNamespacedFeatures(t *testing.T, tx *pgSession) []database.NamespacedFeature { + rows, err := tx.Query(`SELECT f.name, f.version, f.version_format, n.name, n.version_format + FROM feature AS f, namespace AS n, namespaced_feature AS nf + WHERE nf.feature_id = f.id AND nf.namespace_id = n.id`) if err != nil { t.Error(err) - return - } - defer datastore.Close() - - // Invalid Feature. - id0, err := datastore.insertFeature(database.Feature{}) - assert.NotNil(t, err) - assert.Zero(t, id0) - - id0, err = datastore.insertFeature(database.Feature{ - Namespace: database.Namespace{}, - Name: "TestInsertFeature0", - }) - assert.NotNil(t, err) - assert.Zero(t, id0) - - // Insert Feature and ensure we can find it. - feature := database.Feature{ - Namespace: database.Namespace{ - Name: "TestInsertFeatureNamespace1", - VersionFormat: dpkg.ParserName, - }, - Name: "TestInsertFeature1", - } - id1, err := datastore.insertFeature(feature) - assert.Nil(t, err) - id2, err := datastore.insertFeature(feature) - assert.Nil(t, err) - assert.Equal(t, id1, id2) - - // Insert invalid FeatureVersion. - for _, invalidFeatureVersion := range []database.FeatureVersion{ - { - Feature: database.Feature{}, - Version: "1.0", - }, - { - Feature: database.Feature{ - Namespace: database.Namespace{}, - Name: "TestInsertFeature2", - }, - Version: "1.0", - }, - { - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "TestInsertFeatureNamespace2", - VersionFormat: dpkg.ParserName, - }, - Name: "TestInsertFeature2", - }, - Version: "", - }, - { - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "TestInsertFeatureNamespace2", - VersionFormat: dpkg.ParserName, - }, - Name: "TestInsertFeature2", - }, - Version: "bad version", - }, - } { - id3, err := datastore.insertFeatureVersion(invalidFeatureVersion) - assert.Error(t, err) - assert.Zero(t, id3) + t.FailNow() } - // Insert FeatureVersion and ensure we can find it. - featureVersion := database.FeatureVersion{ - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "TestInsertFeatureNamespace1", - VersionFormat: dpkg.ParserName, - }, - Name: "TestInsertFeature1", - }, - Version: "2:3.0-imba", + nf := []database.NamespacedFeature{} + for rows.Next() { + f := database.NamespacedFeature{} + err := rows.Scan(&f.Name, &f.Version, &f.VersionFormat, &f.Namespace.Name, &f.Namespace.VersionFormat) + if err != nil { + t.Error(err) + t.FailNow() + } + nf = append(nf, f) } - id4, err := datastore.insertFeatureVersion(featureVersion) - assert.Nil(t, err) - id5, err := datastore.insertFeatureVersion(featureVersion) - assert.Nil(t, err) - assert.Equal(t, id4, id5) + + return nf +} + +func listFeatures(t *testing.T, tx *pgSession) []database.Feature { + rows, err := tx.Query("SELECT name, version, version_format FROM feature") + if err != nil { + t.FailNow() + } + + fs := []database.Feature{} + for rows.Next() { + f := database.Feature{} + err := rows.Scan(&f.Name, &f.Version, &f.VersionFormat) + if err != nil { + t.FailNow() + } + fs = append(fs, f) + } + return fs +} + +func assertFeaturesEqual(t *testing.T, expected []database.Feature, actual []database.Feature) bool { + if assert.Len(t, actual, len(expected)) { + has := map[database.Feature]bool{} + for _, nf := range expected { + has[nf] = false + } + + for _, nf := range actual { + has[nf] = true + } + + for nf, visited := range has { + if !assert.True(t, visited, nf.Name+" is expected") { + return false + } + return true + } + } + return false +} + +func assertNamespacedFeatureEqual(t *testing.T, expected []database.NamespacedFeature, actual []database.NamespacedFeature) bool { + if assert.Len(t, actual, len(expected)) { + has := map[database.NamespacedFeature]bool{} + for _, nf := range expected { + has[nf] = false + } + + for _, nf := range actual { + has[nf] = true + } + + for nf, visited := range has { + if !assert.True(t, visited, nf.Namespace.Name+":"+nf.Name+" is expected") { + return false + } + } + return true + } + return false } diff --git a/database/pgsql/keyvalue.go b/database/pgsql/keyvalue.go index ab599588..1f85fab5 100644 --- a/database/pgsql/keyvalue.go +++ b/database/pgsql/keyvalue.go @@ -23,63 +23,35 @@ import ( "github.com/coreos/clair/pkg/commonerr" ) -// InsertKeyValue stores (or updates) a single key / value tuple. -func (pgSQL *pgSQL) InsertKeyValue(key, value string) (err error) { +func (tx *pgSession) UpdateKeyValue(key, value string) (err error) { if key == "" || value == "" { log.Warning("could not insert a flag which has an empty name or value") return commonerr.NewBadRequestError("could not insert a flag which has an empty name or value") } - defer observeQueryTime("InsertKeyValue", "all", time.Now()) + defer observeQueryTime("PersistKeyValue", "all", time.Now()) - // Upsert. - // - // Note: UPSERT works only on >= PostgreSQL 9.5 which is not yet supported by AWS RDS. - // The best solution is currently the use of http://dba.stackexchange.com/a/13477 - // but the key/value storage doesn't need to be super-efficient and super-safe at the - // moment so we can just use a client-side solution with transactions, based on - // http://postgresql.org/docs/current/static/plpgsql-control-structures.html. - // TODO(Quentin-M): Enable Upsert as soon as 9.5 is stable. - - for { - // First, try to update. - r, err := pgSQL.Exec(updateKeyValue, value, key) - if err != nil { - return handleError("updateKeyValue", err) - } - if n, _ := r.RowsAffected(); n > 0 { - // Updated successfully. - return nil - } - - // Try to insert the key. - // If someone else inserts the same key concurrently, we could get a unique-key violation error. - _, err = pgSQL.Exec(insertKeyValue, key, value) - if err != nil { - if isErrUniqueViolation(err) { - // Got unique constraint violation, retry. - continue - } - return handleError("insertKeyValue", err) - } - - return nil + _, err = tx.Exec(upsertKeyValue, key, value) + if err != nil { + return handleError("insertKeyValue", err) } + + return nil } -// GetValue reads a single key / value tuple and returns an empty string if the key doesn't exist. -func (pgSQL *pgSQL) GetKeyValue(key string) (string, error) { - defer observeQueryTime("GetKeyValue", "all", time.Now()) +func (tx *pgSession) FindKeyValue(key string) (string, bool, error) { + defer observeQueryTime("FindKeyValue", "all", time.Now()) var value string - err := pgSQL.QueryRow(searchKeyValue, key).Scan(&value) + err := tx.QueryRow(searchKeyValue, key).Scan(&value) if err == sql.ErrNoRows { - return "", nil - } - if err != nil { - return "", handleError("searchKeyValue", err) + return "", false, nil } - return value, nil + if err != nil { + return "", false, handleError("searchKeyValue", err) + } + + return value, true, nil } diff --git a/database/pgsql/keyvalue_test.go b/database/pgsql/keyvalue_test.go index 4a8b6593..9991bf48 100644 --- a/database/pgsql/keyvalue_test.go +++ b/database/pgsql/keyvalue_test.go @@ -21,32 +21,30 @@ import ( ) func TestKeyValue(t *testing.T) { - datastore, err := openDatabaseForTest("KeyValue", false) - if err != nil { - t.Error(err) - return - } - defer datastore.Close() + datastore, tx := openSessionForTest(t, "KeyValue", true) + defer closeTest(t, datastore, tx) // Get non-existing key/value - f, err := datastore.GetKeyValue("test") + f, ok, err := tx.FindKeyValue("test") assert.Nil(t, err) - assert.Empty(t, "", f) + assert.False(t, ok) // Try to insert invalid key/value. - assert.Error(t, datastore.InsertKeyValue("test", "")) - assert.Error(t, datastore.InsertKeyValue("", "test")) - assert.Error(t, datastore.InsertKeyValue("", "")) + assert.Error(t, tx.UpdateKeyValue("test", "")) + assert.Error(t, tx.UpdateKeyValue("", "test")) + assert.Error(t, tx.UpdateKeyValue("", "")) // Insert and verify. - assert.Nil(t, datastore.InsertKeyValue("test", "test1")) - f, err = datastore.GetKeyValue("test") + assert.Nil(t, tx.UpdateKeyValue("test", "test1")) + f, ok, err = tx.FindKeyValue("test") assert.Nil(t, err) + assert.True(t, ok) assert.Equal(t, "test1", f) // Update and verify. - assert.Nil(t, datastore.InsertKeyValue("test", "test2")) - f, err = datastore.GetKeyValue("test") + assert.Nil(t, tx.UpdateKeyValue("test", "test2")) + f, ok, err = tx.FindKeyValue("test") assert.Nil(t, err) + assert.True(t, ok) assert.Equal(t, "test2", f) } diff --git a/database/pgsql/layer.go b/database/pgsql/layer.go index 64e9a475..c7cd5ce2 100644 --- a/database/pgsql/layer.go +++ b/database/pgsql/layer.go @@ -16,464 +16,293 @@ package pgsql import ( "database/sql" - "strings" - "time" - - "github.com/guregu/null/zero" - log "github.com/sirupsen/logrus" + "sort" "github.com/coreos/clair/database" "github.com/coreos/clair/pkg/commonerr" ) -func (pgSQL *pgSQL) FindLayer(name string, withFeatures, withVulnerabilities bool) (database.Layer, error) { - subquery := "all" - if withFeatures { - subquery += "/features" - } else if withVulnerabilities { - subquery += "/features+vulnerabilities" - } - defer observeQueryTime("FindLayer", subquery, time.Now()) +func (tx *pgSession) FindLayer(hash string) (database.Layer, database.Processors, bool, error) { + l, p, _, ok, err := tx.findLayer(hash) + return l, p, ok, err +} - // Find the layer +func (tx *pgSession) FindLayerWithContent(hash string) (database.LayerWithContent, bool, error) { var ( - layer database.Layer - parentID zero.Int - parentName zero.String - nsID zero.Int - nsName sql.NullString - nsVersionFormat sql.NullString + layer database.LayerWithContent + layerID int64 + ok bool + err error ) - t := time.Now() - err := pgSQL.QueryRow(searchLayer, name).Scan( - &layer.ID, - &layer.Name, - &layer.EngineVersion, - &parentID, - &parentName, - ) - observeQueryTime("FindLayer", "searchLayer", t) - + layer.Layer, layer.ProcessedBy, layerID, ok, err = tx.findLayer(hash) if err != nil { - return layer, handleError("searchLayer", err) + return layer, false, err } - if !parentID.IsZero() { - layer.Parent = &database.Layer{ - Model: database.Model{ID: int(parentID.Int64)}, - Name: parentName.String, - } + if !ok { + return layer, false, nil } - rows, err := pgSQL.Query(searchLayerNamespace, layer.ID) - defer rows.Close() - if err != nil { - return layer, handleError("searchLayerNamespace", err) - } - for rows.Next() { - err = rows.Scan(&nsID, &nsName, &nsVersionFormat) - if err != nil { - return layer, handleError("searchLayerNamespace", err) - } - if !nsID.IsZero() { - layer.Namespaces = append(layer.Namespaces, database.Namespace{ - Model: database.Model{ID: int(nsID.Int64)}, - Name: nsName.String, - VersionFormat: nsVersionFormat.String, - }) - } - } - - // Find its features - if withFeatures || withVulnerabilities { - // Create a transaction to disable hash/merge joins as our experiments have shown that - // PostgreSQL 9.4 makes bad planning decisions about: - // - joining the layer tree to feature versions and feature - // - joining the feature versions to affected/fixed feature version and vulnerabilities - // It would for instance do a merge join between affected feature versions (300 rows, estimated - // 3000 rows) and fixed in feature version (100k rows). In this case, it is much more - // preferred to use a nested loop. - tx, err := pgSQL.Begin() - if err != nil { - return layer, handleError("FindLayer.Begin()", err) - } - defer tx.Commit() - - _, err = tx.Exec(disableHashJoin) - if err != nil { - log.WithError(err).Warningf("FindLayer: could not disable hash join") - } - _, err = tx.Exec(disableMergeJoin) - if err != nil { - log.WithError(err).Warningf("FindLayer: could not disable merge join") - } - - t = time.Now() - featureVersions, err := getLayerFeatureVersions(tx, layer.ID) - observeQueryTime("FindLayer", "getLayerFeatureVersions", t) - - if err != nil { - return layer, err - } - - layer.Features = featureVersions - - if withVulnerabilities { - // Load the vulnerabilities that affect the FeatureVersions. - t = time.Now() - err := loadAffectedBy(tx, layer.Features) - observeQueryTime("FindLayer", "loadAffectedBy", t) - - if err != nil { - return layer, err - } - } - } - - return layer, nil + layer.Features, err = tx.findLayerFeatures(layerID) + layer.Namespaces, err = tx.findLayerNamespaces(layerID) + return layer, true, nil } -// getLayerFeatureVersions returns list of database.FeatureVersion that a database.Layer has. -func getLayerFeatureVersions(tx *sql.Tx, layerID int) ([]database.FeatureVersion, error) { - var featureVersions []database.FeatureVersion +func (tx *pgSession) PersistLayer(layer database.Layer) error { + if layer.Hash == "" { + return commonerr.NewBadRequestError("Empty Layer Hash is not allowed") + } - // Query. - rows, err := tx.Query(searchLayerFeatureVersion, layerID) + _, err := tx.Exec(queryPersistLayer(1), layer.Hash) if err != nil { - return featureVersions, handleError("searchLayerFeatureVersion", err) - } - defer rows.Close() - - // Scan query. - var modification string - mapFeatureVersions := make(map[int]database.FeatureVersion) - for rows.Next() { - var fv database.FeatureVersion - err = rows.Scan( - &fv.ID, - &modification, - &fv.Feature.Namespace.ID, - &fv.Feature.Namespace.Name, - &fv.Feature.Namespace.VersionFormat, - &fv.Feature.ID, - &fv.Feature.Name, - &fv.ID, - &fv.Version, - &fv.AddedBy.ID, - &fv.AddedBy.Name, - ) - if err != nil { - return featureVersions, handleError("searchLayerFeatureVersion.Scan()", err) - } - - // Do transitive closure. - switch modification { - case "add": - mapFeatureVersions[fv.ID] = fv - case "del": - delete(mapFeatureVersions, fv.ID) - default: - log.WithField("modification", modification).Warning("unknown Layer_diff_FeatureVersion's modification") - return featureVersions, database.ErrInconsistent - } - } - if err = rows.Err(); err != nil { - return featureVersions, handleError("searchLayerFeatureVersion.Rows()", err) + return handleError("queryPersistLayer", err) } - // Build result by converting our map to a slice. - for _, featureVersion := range mapFeatureVersions { - featureVersions = append(featureVersions, featureVersion) - } - - return featureVersions, nil + return nil } -// loadAffectedBy returns the list of database.Vulnerability that affect the given -// FeatureVersion. -func loadAffectedBy(tx *sql.Tx, featureVersions []database.FeatureVersion) error { - if len(featureVersions) == 0 { +// PersistLayerContent relates layer identified by hash with namespaces, +// features and processors provided. If the layer, namespaces, features are not +// in database, the function returns an error. +func (tx *pgSession) PersistLayerContent(hash string, namespaces []database.Namespace, features []database.Feature, processedBy database.Processors) error { + if hash == "" { + return commonerr.NewBadRequestError("Empty layer hash is not allowed") + } + + var layerID int64 + err := tx.QueryRow(searchLayer, hash).Scan(&layerID) + if err != nil { + return err + } + + if err = tx.persistLayerNamespace(layerID, namespaces); err != nil { + return err + } + + if err = tx.persistLayerFeatures(layerID, features); err != nil { + return err + } + + if err = tx.persistLayerDetectors(layerID, processedBy.Detectors); err != nil { + return err + } + + if err = tx.persistLayerListers(layerID, processedBy.Listers); err != nil { + return err + } + + return nil +} + +func (tx *pgSession) persistLayerDetectors(id int64, detectors []string) error { + if len(detectors) == 0 { return nil } - // Construct list of FeatureVersion IDs, we will do a single query - featureVersionIDs := make([]int, 0, len(featureVersions)) - for i := 0; i < len(featureVersions); i++ { - featureVersionIDs = append(featureVersionIDs, featureVersions[i].ID) + // Sorting is needed before inserting into database to prevent deadlock. + sort.Strings(detectors) + keys := make([]interface{}, len(detectors)*2) + for i, d := range detectors { + keys[i*2] = id + keys[i*2+1] = d + } + _, err := tx.Exec(queryPersistLayerDetectors(len(detectors)), keys...) + if err != nil { + return handleError("queryPersistLayerDetectors", err) + } + return nil +} + +func (tx *pgSession) persistLayerListers(id int64, listers []string) error { + if len(listers) == 0 { + return nil } - rows, err := tx.Query(searchFeatureVersionVulnerability, - buildInputArray(featureVersionIDs)) - if err != nil && err != sql.ErrNoRows { - return handleError("searchFeatureVersionVulnerability", err) + sort.Strings(listers) + keys := make([]interface{}, len(listers)*2) + for i, d := range listers { + keys[i*2] = id + keys[i*2+1] = d + } + + _, err := tx.Exec(queryPersistLayerListers(len(listers)), keys...) + if err != nil { + return handleError("queryPersistLayerDetectors", err) + } + return nil +} + +func (tx *pgSession) persistLayerFeatures(id int64, features []database.Feature) error { + if len(features) == 0 { + return nil + } + + fIDs, err := tx.findFeatureIDs(features) + if err != nil { + return err + } + + ids := make([]int, len(fIDs)) + for i, fID := range fIDs { + if !fID.Valid { + return errNamespaceNotFound + } + ids[i] = int(fID.Int64) + } + + sort.IntSlice(ids).Sort() + keys := make([]interface{}, len(features)*2) + for i, fID := range ids { + keys[i*2] = id + keys[i*2+1] = fID + } + + _, err = tx.Exec(queryPersistLayerFeature(len(features)), keys...) + if err != nil { + return handleError("queryPersistLayerFeature", err) + } + return nil +} + +func (tx *pgSession) persistLayerNamespace(id int64, namespaces []database.Namespace) error { + if len(namespaces) == 0 { + return nil + } + + nsIDs, err := tx.findNamespaceIDs(namespaces) + if err != nil { + return err + } + + // for every bulk persist operation, the input data should be sorted. + ids := make([]int, len(nsIDs)) + for i, nsID := range nsIDs { + if !nsID.Valid { + panic(errNamespaceNotFound) + } + ids[i] = int(nsID.Int64) + } + + sort.IntSlice(ids).Sort() + + keys := make([]interface{}, len(namespaces)*2) + for i, nsID := range ids { + keys[i*2] = id + keys[i*2+1] = nsID + } + + _, err = tx.Exec(queryPersistLayerNamespace(len(namespaces)), keys...) + if err != nil { + return handleError("queryPersistLayerNamespace", err) + } + return nil +} + +func (tx *pgSession) persistProcessors(listerQuery, listerQueryName, detectorQuery, detectorQueryName string, id int64, processors database.Processors) error { + stmt, err := tx.Prepare(listerQuery) + if err != nil { + return handleError(listerQueryName, err) + } + + for _, l := range processors.Listers { + _, err := stmt.Exec(id, l) + if err != nil { + stmt.Close() + return handleError(listerQueryName, err) + } + } + + if err := stmt.Close(); err != nil { + return handleError(listerQueryName, err) + } + + stmt, err = tx.Prepare(detectorQuery) + if err != nil { + return handleError(detectorQueryName, err) + } + + for _, d := range processors.Detectors { + _, err := stmt.Exec(id, d) + if err != nil { + stmt.Close() + return handleError(detectorQueryName, err) + } + } + + if err := stmt.Close(); err != nil { + return handleError(detectorQueryName, err) + } + + return nil +} + +func (tx *pgSession) findLayerNamespaces(layerID int64) ([]database.Namespace, error) { + var namespaces []database.Namespace + + rows, err := tx.Query(searchLayerNamespaces, layerID) + if err != nil { + return nil, handleError("searchLayerFeatures", err) } - defer rows.Close() - vulnerabilities := make(map[int][]database.Vulnerability, len(featureVersions)) - var featureversionID int for rows.Next() { - var vulnerability database.Vulnerability - err := rows.Scan( - &featureversionID, - &vulnerability.ID, - &vulnerability.Name, - &vulnerability.Description, - &vulnerability.Link, - &vulnerability.Severity, - &vulnerability.Metadata, - &vulnerability.Namespace.Name, - &vulnerability.Namespace.VersionFormat, - &vulnerability.FixedBy, - ) + ns := database.Namespace{} + err := rows.Scan(&ns.Name, &ns.VersionFormat) if err != nil { - return handleError("searchFeatureVersionVulnerability.Scan()", err) + return nil, err } - vulnerabilities[featureversionID] = append(vulnerabilities[featureversionID], vulnerability) + namespaces = append(namespaces, ns) } - if err = rows.Err(); err != nil { - return handleError("searchFeatureVersionVulnerability.Rows()", err) - } - - // Assign vulnerabilities to every FeatureVersions - for i := 0; i < len(featureVersions); i++ { - featureVersions[i].AffectedBy = vulnerabilities[featureVersions[i].ID] - } - - return nil + return namespaces, nil } -// Internally, only Feature additions/removals are stored for each layer. If a layer has a parent, -// the Feature list will be compared to the parent's Feature list and the difference will be stored. -// Note that when the Namespace of a layer differs from its parent, it is expected that several -// Feature that were already included a parent will have their Namespace updated as well -// (happens when Feature detectors relies on the detected layer Namespace). However, if the listed -// Feature has the same Name/Version as its parent, InsertLayer considers that the Feature hasn't -// been modified. -func (pgSQL *pgSQL) InsertLayer(layer database.Layer) error { - tf := time.Now() +func (tx *pgSession) findLayerFeatures(layerID int64) ([]database.Feature, error) { + var features []database.Feature - // Verify parameters - if layer.Name == "" { - log.Warning("could not insert a layer which has an empty Name") - return commonerr.NewBadRequestError("could not insert a layer which has an empty Name") - } - - // Get a potentially existing layer. - existingLayer, err := pgSQL.FindLayer(layer.Name, true, false) - if err != nil && err != commonerr.ErrNotFound { - return err - } else if err == nil { - if existingLayer.EngineVersion >= layer.EngineVersion { - // The layer exists and has an equal or higher engine version, do nothing. - return nil - } - - layer.ID = existingLayer.ID - } - - // We do `defer observeQueryTime` here because we don't want to observe existing layers. - defer observeQueryTime("InsertLayer", "all", tf) - - // Get parent ID. - var parentID zero.Int - if layer.Parent != nil { - if layer.Parent.ID == 0 { - log.Warning("Parent is expected to be retrieved from database when inserting a layer.") - return commonerr.NewBadRequestError("Parent is expected to be retrieved from database when inserting a layer.") - } - - parentID = zero.IntFrom(int64(layer.Parent.ID)) - } - - // namespaceIDs will contain inherited and new namespaces - namespaceIDs := make(map[int]struct{}) - - // try to insert the new namespaces - for _, ns := range layer.Namespaces { - n, err := pgSQL.insertNamespace(ns) - if err != nil { - return handleError("pgSQL.insertNamespace", err) - } - namespaceIDs[n] = struct{}{} - } - - // inherit namespaces from parent layer - if layer.Parent != nil { - for _, ns := range layer.Parent.Namespaces { - namespaceIDs[ns.ID] = struct{}{} - } - } - - // Begin transaction. - tx, err := pgSQL.Begin() + rows, err := tx.Query(searchLayerFeatures, layerID) if err != nil { - tx.Rollback() - return handleError("InsertLayer.Begin()", err) + return nil, handleError("searchLayerFeatures", err) } - if layer.ID == 0 { - // Insert a new layer. - err = tx.QueryRow(insertLayer, layer.Name, layer.EngineVersion, parentID). - Scan(&layer.ID) + for rows.Next() { + f := database.Feature{} + err := rows.Scan(&f.Name, &f.Version, &f.VersionFormat) if err != nil { - tx.Rollback() - - if isErrUniqueViolation(err) { - // Ignore this error, another process collided. - log.Debug("Attempted to insert duplicate layer.") - return nil - } - return handleError("insertLayer", err) - } - } else { - // Update an existing layer. - _, err = tx.Exec(updateLayer, layer.ID, layer.EngineVersion) - if err != nil { - tx.Rollback() - return handleError("updateLayer", err) - } - - // replace the old namespace in the database - _, err := tx.Exec(removeLayerNamespace, layer.ID) - if err != nil { - tx.Rollback() - return handleError("removeLayerNamespace", err) - } - // Remove all existing Layer_diff_FeatureVersion. - _, err = tx.Exec(removeLayerDiffFeatureVersion, layer.ID) - if err != nil { - tx.Rollback() - return handleError("removeLayerDiffFeatureVersion", err) + return nil, err } + features = append(features, f) } - - // insert the layer's namespaces - stmt, err := tx.Prepare(insertLayerNamespace) - - if err != nil { - tx.Rollback() - return handleError("failed to prepare statement", err) - } - - defer func() { - err = stmt.Close() - if err != nil { - tx.Rollback() - log.WithError(err).Error("failed to close prepared statement") - } - }() - - for nsid := range namespaceIDs { - _, err := stmt.Exec(layer.ID, nsid) - if err != nil { - tx.Rollback() - return handleError("insertLayerNamespace", err) - } - } - - // Update Layer_diff_FeatureVersion now. - err = pgSQL.updateDiffFeatureVersions(tx, &layer, &existingLayer) - if err != nil { - tx.Rollback() - return err - } - - // Commit transaction. - err = tx.Commit() - if err != nil { - tx.Rollback() - return handleError("InsertLayer.Commit()", err) - } - - return nil + return features, nil } -func (pgSQL *pgSQL) updateDiffFeatureVersions(tx *sql.Tx, layer, existingLayer *database.Layer) error { - // add and del are the FeatureVersion diff we should insert. - var add []database.FeatureVersion - var del []database.FeatureVersion +func (tx *pgSession) findLayer(hash string) (database.Layer, database.Processors, int64, bool, error) { + var ( + layerID int64 + layer = database.Layer{Hash: hash} + processors database.Processors + ) - if layer.Parent == nil { - // There is no parent, every Features are added. - add = append(add, layer.Features...) - } else if layer.Parent != nil { - // There is a parent, we need to diff the Features with it. - - // Build name:version structures. - layerFeaturesMapNV, layerFeaturesNV := createNV(layer.Features) - parentLayerFeaturesMapNV, parentLayerFeaturesNV := createNV(layer.Parent.Features) - - // Calculate the added and deleted FeatureVersions name:version. - addNV := compareStringLists(layerFeaturesNV, parentLayerFeaturesNV) - delNV := compareStringLists(parentLayerFeaturesNV, layerFeaturesNV) - - // Fill the structures containing the added and deleted FeatureVersions. - for _, nv := range addNV { - add = append(add, *layerFeaturesMapNV[nv]) - } - for _, nv := range delNV { - del = append(del, *parentLayerFeaturesMapNV[nv]) - } + if hash == "" { + return layer, processors, layerID, false, commonerr.NewBadRequestError("Empty Layer Hash is not allowed") } - // Insert FeatureVersions in the database. - addIDs, err := pgSQL.insertFeatureVersions(add) + err := tx.QueryRow(searchLayer, hash).Scan(&layerID) if err != nil { - return err + if err == sql.ErrNoRows { + return layer, processors, layerID, false, nil + } + return layer, processors, layerID, false, err } - delIDs, err := pgSQL.insertFeatureVersions(del) + + processors.Detectors, err = tx.findProcessors(searchLayerDetectors, "searchLayerDetectors", "detector", layerID) if err != nil { - return err + return layer, processors, layerID, false, err } - // Insert diff in the database. - if len(addIDs) > 0 { - _, err = tx.Exec(insertLayerDiffFeatureVersion, layer.ID, "add", buildInputArray(addIDs)) - if err != nil { - return handleError("insertLayerDiffFeatureVersion.Add", err) - } - } - if len(delIDs) > 0 { - _, err = tx.Exec(insertLayerDiffFeatureVersion, layer.ID, "del", buildInputArray(delIDs)) - if err != nil { - return handleError("insertLayerDiffFeatureVersion.Del", err) - } + processors.Listers, err = tx.findProcessors(searchLayerListers, "searchLayerListers", "lister", layerID) + if err != nil { + return layer, processors, layerID, false, err } - return nil -} - -func createNV(features []database.FeatureVersion) (map[string]*database.FeatureVersion, []string) { - mapNV := make(map[string]*database.FeatureVersion, 0) - sliceNV := make([]string, 0, len(features)) - - for i := 0; i < len(features); i++ { - fv := &features[i] - nv := strings.Join([]string{fv.Feature.Namespace.Name, fv.Feature.Name, fv.Version}, ":") - mapNV[nv] = fv - sliceNV = append(sliceNV, nv) - } - - return mapNV, sliceNV -} - -func (pgSQL *pgSQL) DeleteLayer(name string) error { - defer observeQueryTime("DeleteLayer", "all", time.Now()) - - result, err := pgSQL.Exec(removeLayer, name) - if err != nil { - return handleError("removeLayer", err) - } - - affected, err := result.RowsAffected() - if err != nil { - return handleError("removeLayer.RowsAffected()", err) - } - - if affected <= 0 { - return commonerr.ErrNotFound - } - - return nil + return layer, processors, layerID, true, nil } diff --git a/database/pgsql/layer_test.go b/database/pgsql/layer_test.go index 6f35bbde..e823a048 100644 --- a/database/pgsql/layer_test.go +++ b/database/pgsql/layer_test.go @@ -15,423 +15,100 @@ package pgsql import ( - "fmt" "testing" "github.com/stretchr/testify/assert" "github.com/coreos/clair/database" - "github.com/coreos/clair/ext/versionfmt/dpkg" - "github.com/coreos/clair/pkg/commonerr" ) +func TestPersistLayer(t *testing.T) { + datastore, tx := openSessionForTest(t, "PersistLayer", false) + defer closeTest(t, datastore, tx) + + l1 := database.Layer{} + l2 := database.Layer{Hash: "HESOYAM"} + + // invalid + assert.NotNil(t, tx.PersistLayer(l1)) + // valid + assert.Nil(t, tx.PersistLayer(l2)) + // duplicated + assert.Nil(t, tx.PersistLayer(l2)) +} + +func TestPersistLayerProcessors(t *testing.T) { + datastore, tx := openSessionForTest(t, "PersistLayerProcessors", true) + defer closeTest(t, datastore, tx) + + // invalid + assert.NotNil(t, tx.PersistLayerContent("hash", []database.Namespace{}, []database.Feature{}, database.Processors{})) + // valid + assert.Nil(t, tx.PersistLayerContent("layer-4", []database.Namespace{}, []database.Feature{}, database.Processors{Detectors: []string{"new detector!"}})) +} + func TestFindLayer(t *testing.T) { - datastore, err := openDatabaseForTest("FindLayer", true) - if err != nil { - t.Error(err) - return - } - defer datastore.Close() + datastore, tx := openSessionForTest(t, "FindLayer", true) + defer closeTest(t, datastore, tx) - // Layer-0: no parent, no namespace, no feature, no vulnerability - layer, err := datastore.FindLayer("layer-0", false, false) - if assert.Nil(t, err) && assert.NotNil(t, layer) { - assert.Equal(t, "layer-0", layer.Name) - assert.Len(t, layer.Namespaces, 0) - assert.Nil(t, layer.Parent) - assert.Equal(t, 1, layer.EngineVersion) - assert.Len(t, layer.Features, 0) + expected := database.Layer{Hash: "layer-4"} + expectedProcessors := database.Processors{ + Detectors: []string{"os-release", "apt-sources"}, + Listers: []string{"dpkg", "rpm"}, } - layer, err = datastore.FindLayer("layer-0", true, false) - if assert.Nil(t, err) && assert.NotNil(t, layer) { - assert.Len(t, layer.Features, 0) - } - - // Layer-1: one parent, adds two features, one vulnerability - layer, err = datastore.FindLayer("layer-1", false, false) - if assert.Nil(t, err) && assert.NotNil(t, layer) { - assert.Equal(t, layer.Name, "layer-1") - assertExpectedNamespaceName(t, &layer, []string{"debian:7"}) - if assert.NotNil(t, layer.Parent) { - assert.Equal(t, "layer-0", layer.Parent.Name) - } - assert.Equal(t, 1, layer.EngineVersion) - assert.Len(t, layer.Features, 0) - } - - layer, err = datastore.FindLayer("layer-1", true, false) - if assert.Nil(t, err) && assert.NotNil(t, layer) && assert.Len(t, layer.Features, 2) { - for _, featureVersion := range layer.Features { - assert.Equal(t, "debian:7", featureVersion.Feature.Namespace.Name) - - switch featureVersion.Feature.Name { - case "wechat": - assert.Equal(t, "0.5", featureVersion.Version) - case "openssl": - assert.Equal(t, "1.0", featureVersion.Version) - default: - t.Errorf("unexpected package %s for layer-1", featureVersion.Feature.Name) - } - } - } - - layer, err = datastore.FindLayer("layer-1", true, true) - if assert.Nil(t, err) && assert.NotNil(t, layer) && assert.Len(t, layer.Features, 2) { - for _, featureVersion := range layer.Features { - assert.Equal(t, "debian:7", featureVersion.Feature.Namespace.Name) - - switch featureVersion.Feature.Name { - case "wechat": - assert.Equal(t, "0.5", featureVersion.Version) - case "openssl": - assert.Equal(t, "1.0", featureVersion.Version) - - if assert.Len(t, featureVersion.AffectedBy, 1) { - assert.Equal(t, "debian:7", featureVersion.AffectedBy[0].Namespace.Name) - assert.Equal(t, "CVE-OPENSSL-1-DEB7", featureVersion.AffectedBy[0].Name) - assert.Equal(t, database.HighSeverity, featureVersion.AffectedBy[0].Severity) - assert.Equal(t, "A vulnerability affecting OpenSSL < 2.0 on Debian 7.0", featureVersion.AffectedBy[0].Description) - assert.Equal(t, "http://google.com/#q=CVE-OPENSSL-1-DEB7", featureVersion.AffectedBy[0].Link) - assert.Equal(t, "2.0", featureVersion.AffectedBy[0].FixedBy) - } - default: - t.Errorf("unexpected package %s for layer-1", featureVersion.Feature.Name) - } - } - } - - // Testing Multiple namespaces layer-3b has debian:7 and debian:8 namespaces - layer, err = datastore.FindLayer("layer-3b", true, true) - - if assert.Nil(t, err) && assert.NotNil(t, layer) && assert.Len(t, layer.Features, 2) { - assert.Equal(t, "layer-3b", layer.Name) - // validate the namespace - assertExpectedNamespaceName(t, &layer, []string{"debian:7", "debian:8"}) - for _, featureVersion := range layer.Features { - switch featureVersion.Feature.Namespace.Name { - case "debian:7": - assert.Equal(t, "wechat", featureVersion.Feature.Name) - assert.Equal(t, "0.5", featureVersion.Version) - case "debian:8": - assert.Equal(t, "openssl", featureVersion.Feature.Name) - assert.Equal(t, "1.0", featureVersion.Version) - default: - t.Errorf("unexpected package %s for layer-3b", featureVersion.Feature.Name) - } - } - } -} - -func TestInsertLayer(t *testing.T) { - datastore, err := openDatabaseForTest("InsertLayer", false) - if err != nil { - t.Error(err) - return - } - defer datastore.Close() - - // Insert invalid layer. - testInsertLayerInvalid(t, datastore) - - // Insert a layer tree. - testInsertLayerTree(t, datastore) - - // Update layer. - testInsertLayerUpdate(t, datastore) - - // Delete layer. - testInsertLayerDelete(t, datastore) -} - -func testInsertLayerInvalid(t *testing.T, datastore database.Datastore) { - invalidLayers := []database.Layer{ - {}, - {Name: "layer0", Parent: &database.Layer{}}, - {Name: "layer0", Parent: &database.Layer{Name: "UnknownLayer"}}, - } - - for _, invalidLayer := range invalidLayers { - err := datastore.InsertLayer(invalidLayer) - assert.Error(t, err) - } -} - -func testInsertLayerTree(t *testing.T, datastore database.Datastore) { - f1 := database.FeatureVersion{ - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "TestInsertLayerNamespace2", - VersionFormat: dpkg.ParserName, - }, - Name: "TestInsertLayerFeature1", - }, - Version: "1.0", - } - f2 := database.FeatureVersion{ - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "TestInsertLayerNamespace2", - VersionFormat: dpkg.ParserName, - }, - Name: "TestInsertLayerFeature2", - }, - Version: "0.34", - } - f3 := database.FeatureVersion{ - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "TestInsertLayerNamespace2", - VersionFormat: dpkg.ParserName, - }, - Name: "TestInsertLayerFeature3", - }, - Version: "0.56", - } - f4 := database.FeatureVersion{ - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "TestInsertLayerNamespace3", - VersionFormat: dpkg.ParserName, - }, - Name: "TestInsertLayerFeature2", - }, - Version: "0.34", - } - f5 := database.FeatureVersion{ - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "TestInsertLayerNamespace3", - VersionFormat: dpkg.ParserName, - }, - Name: "TestInsertLayerFeature3", - }, - Version: "0.56", - } - f6 := database.FeatureVersion{ - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "TestInsertLayerNamespace3", - VersionFormat: dpkg.ParserName, - }, - Name: "TestInsertLayerFeature4", - }, - Version: "0.666", - } - - layers := []database.Layer{ - { - Name: "TestInsertLayer1", - }, - { - Name: "TestInsertLayer2", - Parent: &database.Layer{Name: "TestInsertLayer1"}, - Namespaces: []database.Namespace{database.Namespace{ - Name: "TestInsertLayerNamespace1", - VersionFormat: dpkg.ParserName, - }}, - }, - // This layer changes the namespace and adds Features. - { - Name: "TestInsertLayer3", - Parent: &database.Layer{Name: "TestInsertLayer2"}, - Namespaces: []database.Namespace{database.Namespace{ - Name: "TestInsertLayerNamespace2", - VersionFormat: dpkg.ParserName, - }}, - Features: []database.FeatureVersion{f1, f2, f3}, - }, - // This layer covers the case where the last layer doesn't provide any new Feature. - { - Name: "TestInsertLayer4a", - Parent: &database.Layer{Name: "TestInsertLayer3"}, - Features: []database.FeatureVersion{f1, f2, f3}, - }, - // This layer covers the case where the last layer provides Features. - // It also modifies the Namespace ("upgrade") but keeps some Features not upgraded, their - // Namespaces should then remain unchanged. - { - Name: "TestInsertLayer4b", - Parent: &database.Layer{Name: "TestInsertLayer3"}, - Namespaces: []database.Namespace{database.Namespace{ - Name: "TestInsertLayerNamespace3", - VersionFormat: dpkg.ParserName, - }}, - Features: []database.FeatureVersion{ - // Deletes TestInsertLayerFeature1. - // Keep TestInsertLayerFeature2 (old Namespace should be kept): - f4, - // Upgrades TestInsertLayerFeature3 (with new Namespace): - f5, - // Adds TestInsertLayerFeature4: - f6, - }, - }, - } - - var err error - retrievedLayers := make(map[string]database.Layer) - for _, layer := range layers { - if layer.Parent != nil { - // Retrieve from database its parent and assign. - parent := retrievedLayers[layer.Parent.Name] - layer.Parent = &parent - } - - err = datastore.InsertLayer(layer) - assert.Nil(t, err) - - retrievedLayers[layer.Name], err = datastore.FindLayer(layer.Name, true, false) - assert.Nil(t, err) - } - - // layer inherits all namespaces from its ancestries - l4a := retrievedLayers["TestInsertLayer4a"] - assertExpectedNamespaceName(t, &l4a, []string{"TestInsertLayerNamespace2", "TestInsertLayerNamespace1"}) - assert.Len(t, l4a.Features, 3) - for _, featureVersion := range l4a.Features { - if cmpFV(featureVersion, f1) && cmpFV(featureVersion, f2) && cmpFV(featureVersion, f3) { - assert.Error(t, fmt.Errorf("TestInsertLayer4a contains an unexpected package: %#v. Should contain %#v and %#v and %#v.", featureVersion, f1, f2, f3)) - } - } - - l4b := retrievedLayers["TestInsertLayer4b"] - assertExpectedNamespaceName(t, &l4b, []string{"TestInsertLayerNamespace1", "TestInsertLayerNamespace2", "TestInsertLayerNamespace3"}) - assert.Len(t, l4b.Features, 3) - for _, featureVersion := range l4b.Features { - if cmpFV(featureVersion, f2) && cmpFV(featureVersion, f5) && cmpFV(featureVersion, f6) { - assert.Error(t, fmt.Errorf("TestInsertLayer4a contains an unexpected package: %#v. Should contain %#v and %#v and %#v.", featureVersion, f2, f4, f6)) - } - } -} - -func testInsertLayerUpdate(t *testing.T, datastore database.Datastore) { - f7 := database.FeatureVersion{ - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "TestInsertLayerNamespace3", - VersionFormat: dpkg.ParserName, - }, - Name: "TestInsertLayerFeature7", - }, - Version: "0.01", - } - - l3, _ := datastore.FindLayer("TestInsertLayer3", true, false) - l3u := database.Layer{ - Name: l3.Name, - Parent: l3.Parent, - Namespaces: []database.Namespace{database.Namespace{ - Name: "TestInsertLayerNamespaceUpdated1", - VersionFormat: dpkg.ParserName, - }}, - Features: []database.FeatureVersion{f7}, - } - - l4u := database.Layer{ - Name: "TestInsertLayer4", - Parent: &database.Layer{Name: "TestInsertLayer3"}, - Features: []database.FeatureVersion{f7}, - EngineVersion: 2, - } - - // Try to re-insert without increasing the EngineVersion. - err := datastore.InsertLayer(l3u) + // invalid + _, _, _, err := tx.FindLayer("") + assert.NotNil(t, err) + _, _, ok, err := tx.FindLayer("layer-non") assert.Nil(t, err) + assert.False(t, ok) - l3uf, err := datastore.FindLayer(l3u.Name, true, false) - if assert.Nil(t, err) { - assertSameNamespaceName(t, &l3, &l3uf) - assert.Equal(t, l3.EngineVersion, l3uf.EngineVersion) - assert.Len(t, l3uf.Features, len(l3.Features)) + // valid + layer, processors, ok2, err := tx.FindLayer("layer-4") + if assert.Nil(t, err) && assert.True(t, ok2) { + assert.Equal(t, expected, layer) + assertProcessorsEqual(t, expectedProcessors, processors) } +} - // Update layer l3. - // Verify that the Namespace, EngineVersion and FeatureVersions got updated. - l3u.EngineVersion = 2 - err = datastore.InsertLayer(l3u) +func TestFindLayerWithContent(t *testing.T) { + datastore, tx := openSessionForTest(t, "FindLayerWithContent", true) + defer closeTest(t, datastore, tx) + + _, _, err := tx.FindLayerWithContent("") + assert.NotNil(t, err) + _, ok, err := tx.FindLayerWithContent("layer-non") assert.Nil(t, err) + assert.False(t, ok) - l3uf, err = datastore.FindLayer(l3u.Name, true, false) - if assert.Nil(t, err) { - assertSameNamespaceName(t, &l3u, &l3uf) - assert.Equal(t, l3u.EngineVersion, l3uf.EngineVersion) - if assert.Len(t, l3uf.Features, 1) { - assert.True(t, cmpFV(l3uf.Features[0], f7), "Updated layer should have %#v but actually have %#v", f7, l3uf.Features[0]) - } + expectedL := database.LayerWithContent{ + Layer: database.Layer{ + Hash: "layer-4", + }, + Features: []database.Feature{ + {Name: "fake", Version: "2.0", VersionFormat: "rpm"}, + {Name: "openssl", Version: "2.0", VersionFormat: "dpkg"}, + }, + Namespaces: []database.Namespace{ + {Name: "debian:7", VersionFormat: "dpkg"}, + {Name: "fake:1.0", VersionFormat: "rpm"}, + }, + ProcessedBy: database.Processors{ + Detectors: []string{"os-release", "apt-sources"}, + Listers: []string{"dpkg", "rpm"}, + }, } - // Update layer l4. - // Verify that the Namespace got updated from its new Parent's, and also verify the - // EnginVersion and FeatureVersions. - l4u.Parent = &l3uf - err = datastore.InsertLayer(l4u) - assert.Nil(t, err) - - l4uf, err := datastore.FindLayer(l3u.Name, true, false) - if assert.Nil(t, err) { - assertSameNamespaceName(t, &l3u, &l4uf) - assert.Equal(t, l4u.EngineVersion, l4uf.EngineVersion) - if assert.Len(t, l4uf.Features, 1) { - assert.True(t, cmpFV(l3uf.Features[0], f7), "Updated layer should have %#v but actually have %#v", f7, l4uf.Features[0]) - } + layer, ok2, err := tx.FindLayerWithContent("layer-4") + if assert.Nil(t, err) && assert.True(t, ok2) { + assertLayerWithContentEqual(t, expectedL, layer) } } -func assertSameNamespaceName(t *testing.T, layer1 *database.Layer, layer2 *database.Layer) { - assert.Len(t, compareStringLists(extractNamespaceName(layer1), extractNamespaceName(layer2)), 0) -} - -func assertExpectedNamespaceName(t *testing.T, layer *database.Layer, expectedNames []string) { - assert.Len(t, compareStringLists(extractNamespaceName(layer), expectedNames), 0) -} - -func extractNamespaceName(layer *database.Layer) []string { - slist := make([]string, 0, len(layer.Namespaces)) - for _, ns := range layer.Namespaces { - slist = append(slist, ns.Name) - } - return slist -} - -func testInsertLayerDelete(t *testing.T, datastore database.Datastore) { - err := datastore.DeleteLayer("TestInsertLayerX") - assert.Equal(t, commonerr.ErrNotFound, err) - - // ensure layer_namespace table is cleaned up once a layer is removed - layer3, err := datastore.FindLayer("TestInsertLayer3", false, false) - layer4a, err := datastore.FindLayer("TestInsertLayer4a", false, false) - layer4b, err := datastore.FindLayer("TestInsertLayer4b", false, false) - - err = datastore.DeleteLayer("TestInsertLayer3") - assert.Nil(t, err) - - _, err = datastore.FindLayer("TestInsertLayer3", false, false) - assert.Equal(t, commonerr.ErrNotFound, err) - assertNotInLayerNamespace(t, layer3.ID, datastore) - _, err = datastore.FindLayer("TestInsertLayer4a", false, false) - assert.Equal(t, commonerr.ErrNotFound, err) - assertNotInLayerNamespace(t, layer4a.ID, datastore) - _, err = datastore.FindLayer("TestInsertLayer4b", true, false) - assert.Equal(t, commonerr.ErrNotFound, err) - assertNotInLayerNamespace(t, layer4b.ID, datastore) -} - -func assertNotInLayerNamespace(t *testing.T, layerID int, datastore database.Datastore) { - pg, ok := datastore.(*pgSQL) - if !assert.True(t, ok) { - return - } - tx, err := pg.Begin() - if !assert.Nil(t, err) { - return - } - rows, err := tx.Query(searchLayerNamespace, layerID) - assert.False(t, rows.Next()) -} - -func cmpFV(a, b database.FeatureVersion) bool { - return a.Feature.Name == b.Feature.Name && - a.Feature.Namespace.Name == b.Feature.Namespace.Name && - a.Version == b.Version +func assertLayerWithContentEqual(t *testing.T, expected database.LayerWithContent, actual database.LayerWithContent) bool { + return assert.Equal(t, expected.Layer, actual.Layer) && + assertFeaturesEqual(t, expected.Features, actual.Features) && + assertProcessorsEqual(t, expected.ProcessedBy, actual.ProcessedBy) && + assertNamespacesEqual(t, expected.Namespaces, actual.Namespaces) } diff --git a/database/pgsql/lock.go b/database/pgsql/lock.go index d3521b75..c8918ebc 100644 --- a/database/pgsql/lock.go +++ b/database/pgsql/lock.go @@ -15,6 +15,7 @@ package pgsql import ( + "errors" "time" log "github.com/sirupsen/logrus" @@ -22,86 +23,91 @@ import ( "github.com/coreos/clair/pkg/commonerr" ) +var ( + errLockNotFound = errors.New("lock is not in database") +) + // Lock tries to set a temporary lock in the database. // // Lock does not block, instead, it returns true and its expiration time -// is the lock has been successfully acquired or false otherwise -func (pgSQL *pgSQL) Lock(name string, owner string, duration time.Duration, renew bool) (bool, time.Time) { +// is the lock has been successfully acquired or false otherwise. +func (tx *pgSession) Lock(name string, owner string, duration time.Duration, renew bool) (bool, time.Time, error) { if name == "" || owner == "" || duration == 0 { log.Warning("could not create an invalid lock") - return false, time.Time{} + return false, time.Time{}, commonerr.NewBadRequestError("Invalid Lock Parameters") } - defer observeQueryTime("Lock", "all", time.Now()) - - // Compute expiration. until := time.Now().Add(duration) - if renew { + defer observeQueryTime("Lock", "update", time.Now()) // Renew lock. - r, err := pgSQL.Exec(updateLock, name, owner, until) + r, err := tx.Exec(updateLock, name, owner, until) if err != nil { - handleError("updateLock", err) - return false, until + return false, until, handleError("updateLock", err) } - if n, _ := r.RowsAffected(); n > 0 { - // Updated successfully. - return true, until + + if n, err := r.RowsAffected(); err == nil { + return n > 0, until, nil } - } else { - // Prune locks. - pgSQL.pruneLocks() + return false, until, handleError("updateLock", err) + } else if err := tx.pruneLocks(); err != nil { + return false, until, err } // Lock. - _, err := pgSQL.Exec(insertLock, name, owner, until) + defer observeQueryTime("Lock", "soiLock", time.Now()) + _, err := tx.Exec(soiLock, name, owner, until) if err != nil { - if !isErrUniqueViolation(err) { - handleError("insertLock", err) + if isErrUniqueViolation(err) { + return false, until, nil } - return false, until + return false, until, handleError("insertLock", err) } - - return true, until + return true, until, nil } // Unlock unlocks a lock specified by its name if I own it -func (pgSQL *pgSQL) Unlock(name, owner string) { +func (tx *pgSession) Unlock(name, owner string) error { if name == "" || owner == "" { - log.Warning("could not delete an invalid lock") - return + return commonerr.NewBadRequestError("Invalid Lock Parameters") } defer observeQueryTime("Unlock", "all", time.Now()) - pgSQL.Exec(removeLock, name, owner) + _, err := tx.Exec(removeLock, name, owner) + return err } // FindLock returns the owner of a lock specified by its name and its // expiration time. -func (pgSQL *pgSQL) FindLock(name string) (string, time.Time, error) { +func (tx *pgSession) FindLock(name string) (string, time.Time, bool, error) { if name == "" { - log.Warning("could not find an invalid lock") - return "", time.Time{}, commonerr.NewBadRequestError("could not find an invalid lock") + return "", time.Time{}, false, commonerr.NewBadRequestError("could not find an invalid lock") } defer observeQueryTime("FindLock", "all", time.Now()) var owner string var until time.Time - err := pgSQL.QueryRow(searchLock, name).Scan(&owner, &until) + err := tx.QueryRow(searchLock, name).Scan(&owner, &until) if err != nil { - return owner, until, handleError("searchLock", err) + return owner, until, false, handleError("searchLock", err) } - return owner, until, nil + return owner, until, true, nil } // pruneLocks removes every expired locks from the database -func (pgSQL *pgSQL) pruneLocks() { +func (tx *pgSession) pruneLocks() error { defer observeQueryTime("pruneLocks", "all", time.Now()) - if _, err := pgSQL.Exec(removeLockExpired); err != nil { - handleError("removeLockExpired", err) + if r, err := tx.Exec(removeLockExpired); err != nil { + return handleError("removeLockExpired", err) + } else if affected, err := r.RowsAffected(); err != nil { + return handleError("removeLockExpired", err) + } else { + log.Debugf("Pruned %d Locks", affected) } + + return nil } diff --git a/database/pgsql/lock_test.go b/database/pgsql/lock_test.go index cbd2d999..19a5a934 100644 --- a/database/pgsql/lock_test.go +++ b/database/pgsql/lock_test.go @@ -22,48 +22,72 @@ import ( ) func TestLock(t *testing.T) { - datastore, err := openDatabaseForTest("InsertNamespace", false) - if err != nil { - t.Error(err) - return - } + datastore, tx := openSessionForTest(t, "Lock", true) defer datastore.Close() var l bool var et time.Time // Create a first lock. - l, _ = datastore.Lock("test1", "owner1", time.Minute, false) + l, _, err := tx.Lock("test1", "owner1", time.Minute, false) + assert.Nil(t, err) assert.True(t, l) + tx = restartSession(t, datastore, tx, true) - // Try to lock the same lock with another owner. - l, _ = datastore.Lock("test1", "owner2", time.Minute, true) + // lock again by itself, the previous lock is not expired yet. + l, _, err = tx.Lock("test1", "owner1", time.Minute, false) + assert.Nil(t, err) assert.False(t, l) + tx = restartSession(t, datastore, tx, false) - l, _ = datastore.Lock("test1", "owner2", time.Minute, false) + // Try to renew the same lock with another owner. + l, _, err = tx.Lock("test1", "owner2", time.Minute, true) + assert.Nil(t, err) assert.False(t, l) + tx = restartSession(t, datastore, tx, false) + + l, _, err = tx.Lock("test1", "owner2", time.Minute, false) + assert.Nil(t, err) + assert.False(t, l) + tx = restartSession(t, datastore, tx, false) // Renew the lock. - l, _ = datastore.Lock("test1", "owner1", 2*time.Minute, true) + l, _, err = tx.Lock("test1", "owner1", 2*time.Minute, true) + assert.Nil(t, err) assert.True(t, l) + tx = restartSession(t, datastore, tx, true) // Unlock and then relock by someone else. - datastore.Unlock("test1", "owner1") + err = tx.Unlock("test1", "owner1") + assert.Nil(t, err) + tx = restartSession(t, datastore, tx, true) - l, et = datastore.Lock("test1", "owner2", time.Minute, false) + l, et, err = tx.Lock("test1", "owner2", time.Minute, false) + assert.Nil(t, err) assert.True(t, l) + tx = restartSession(t, datastore, tx, true) // LockInfo - o, et2, err := datastore.FindLock("test1") + o, et2, ok, err := tx.FindLock("test1") + assert.True(t, ok) assert.Nil(t, err) assert.Equal(t, "owner2", o) assert.Equal(t, et.Second(), et2.Second()) + tx = restartSession(t, datastore, tx, true) // Create a second lock which is actually already expired ... - l, _ = datastore.Lock("test2", "owner1", -time.Minute, false) + l, _, err = tx.Lock("test2", "owner1", -time.Minute, false) + assert.Nil(t, err) assert.True(t, l) + tx = restartSession(t, datastore, tx, true) // Take over the lock - l, _ = datastore.Lock("test2", "owner2", time.Minute, false) + l, _, err = tx.Lock("test2", "owner2", time.Minute, false) + assert.Nil(t, err) assert.True(t, l) + tx = restartSession(t, datastore, tx, true) + + if !assert.Nil(t, tx.Rollback()) { + t.FailNow() + } } diff --git a/database/pgsql/migrations/00001_change_migrator.go b/database/pgsql/migrations/00001_change_migrator.go deleted file mode 100644 index 8fef9ea0..00000000 --- a/database/pgsql/migrations/00001_change_migrator.go +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright 2016 clair authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package migrations - -import ( - "database/sql" - - "github.com/remind101/migrate" -) - -func init() { - // This migration removes the data maintained by the previous migration tool - // (liamstask/goose), and if it was present, mark the 00002_initial_schema - // migration as done. - RegisterMigration(migrate.Migration{ - ID: 1, - Up: func(tx *sql.Tx) error { - // Verify that goose was in use before, otherwise skip this migration. - var e bool - err := tx.QueryRow("SELECT true FROM pg_class WHERE relname = $1", "goose_db_version").Scan(&e) - if err == sql.ErrNoRows { - return nil - } - if err != nil { - return err - } - - // Delete goose's data. - _, err = tx.Exec("DROP TABLE goose_db_version CASCADE") - if err != nil { - return err - } - - // Mark the '00002_initial_schema' as done. - _, err = tx.Exec("INSERT INTO schema_migrations (version) VALUES (2)") - - return err - }, - Down: migrate.Queries([]string{}), - }) -} diff --git a/database/pgsql/migrations/00001_initial_schema.go b/database/pgsql/migrations/00001_initial_schema.go new file mode 100644 index 00000000..14fff7d4 --- /dev/null +++ b/database/pgsql/migrations/00001_initial_schema.go @@ -0,0 +1,192 @@ +// Copyright 2016 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package migrations + +import "github.com/remind101/migrate" + +func init() { + RegisterMigration(migrate.Migration{ + ID: 1, + Up: migrate.Queries([]string{ + // namespaces + `CREATE TABLE IF NOT EXISTS namespace ( + id SERIAL PRIMARY KEY, + name TEXT NULL, + version_format TEXT, + UNIQUE (name, version_format));`, + `CREATE INDEX ON namespace(name);`, + + // features + `CREATE TABLE IF NOT EXISTS feature ( + id SERIAL PRIMARY KEY, + name TEXT NOT NULL, + version TEXT NOT NULL, + version_format TEXT NOT NULL, + UNIQUE (name, version, version_format));`, + `CREATE INDEX ON feature(name);`, + + `CREATE TABLE IF NOT EXISTS namespaced_feature ( + id SERIAL PRIMARY KEY, + namespace_id INT REFERENCES namespace, + feature_id INT REFERENCES feature, + UNIQUE (namespace_id, feature_id));`, + + // layers + `CREATE TABLE IF NOT EXISTS layer( + id SERIAL PRIMARY KEY, + hash TEXT NOT NULL UNIQUE);`, + + `CREATE TABLE IF NOT EXISTS layer_feature ( + id SERIAL PRIMARY KEY, + layer_id INT REFERENCES layer ON DELETE CASCADE, + feature_id INT REFERENCES feature ON DELETE CASCADE, + UNIQUE (layer_id, feature_id));`, + `CREATE INDEX ON layer_feature(layer_id);`, + + `CREATE TABLE IF NOT EXISTS layer_lister ( + id SERIAL PRIMARY KEY, + layer_id INT REFERENCES layer ON DELETE CASCADE, + lister TEXT NOT NULL, + UNIQUE (layer_id, lister));`, + `CREATE INDEX ON layer_lister(layer_id);`, + + `CREATE TABLE IF NOT EXISTS layer_detector ( + id SERIAL PRIMARY KEY, + layer_id INT REFERENCES layer ON DELETE CASCADE, + detector TEXT, + UNIQUE (layer_id, detector));`, + `CREATE INDEX ON layer_detector(layer_id);`, + + `CREATE TABLE IF NOT EXISTS layer_namespace ( + id SERIAL PRIMARY KEY, + layer_id INT REFERENCES layer ON DELETE CASCADE, + namespace_id INT REFERENCES namespace ON DELETE CASCADE, + UNIQUE (layer_id, namespace_id));`, + `CREATE INDEX ON layer_namespace(layer_id);`, + + // ancestry + `CREATE TABLE IF NOT EXISTS ancestry ( + id SERIAL PRIMARY KEY, + name TEXT NOT NULL UNIQUE);`, + + `CREATE TABLE IF NOT EXISTS ancestry_layer ( + id SERIAL PRIMARY KEY, + ancestry_id INT REFERENCES ancestry ON DELETE CASCADE, + ancestry_index INT NOT NULL, + layer_id INT REFERENCES layer ON DELETE RESTRICT, + UNIQUE (ancestry_id, ancestry_index));`, + `CREATE INDEX ON ancestry_layer(ancestry_id);`, + + `CREATE TABLE IF NOT EXISTS ancestry_feature ( + id SERIAL PRIMARY KEY, + ancestry_id INT REFERENCES ancestry ON DELETE CASCADE, + namespaced_feature_id INT REFERENCES namespaced_feature ON DELETE CASCADE, + UNIQUE (ancestry_id, namespaced_feature_id));`, + + `CREATE TABLE IF NOT EXISTS ancestry_lister ( + id SERIAL PRIMARY KEY, + ancestry_id INT REFERENCES ancestry ON DELETE CASCADE, + lister TEXT, + UNIQUE (ancestry_id, lister));`, + `CREATE INDEX ON ancestry_lister(ancestry_id);`, + + `CREATE TABLE IF NOT EXISTS ancestry_detector ( + id SERIAL PRIMARY KEY, + ancestry_id INT REFERENCES ancestry ON DELETE CASCADE, + detector TEXT, + UNIQUE (ancestry_id, detector));`, + `CREATE INDEX ON ancestry_detector(ancestry_id);`, + + `CREATE TYPE severity AS ENUM ('Unknown', 'Negligible', 'Low', 'Medium', 'High', 'Critical', 'Defcon1');`, + + // vulnerability + `CREATE TABLE IF NOT EXISTS vulnerability ( + id SERIAL PRIMARY KEY, + namespace_id INT NOT NULL REFERENCES Namespace, + name TEXT NOT NULL, + description TEXT NULL, + link TEXT NULL, + severity severity NOT NULL, + metadata TEXT NULL, + created_at TIMESTAMP WITH TIME ZONE, + deleted_at TIMESTAMP WITH TIME ZONE NULL);`, + `CREATE INDEX ON vulnerability(namespace_id, name);`, + `CREATE INDEX ON vulnerability(namespace_id);`, + + `CREATE TABLE IF NOT EXISTS vulnerability_affected_feature ( + id SERIAL PRIMARY KEY, + vulnerability_id INT NOT NULL REFERENCES vulnerability ON DELETE CASCADE, + feature_name TEXT NOT NULL, + affected_version TEXT, + fixedin TEXT);`, + `CREATE INDEX ON vulnerability_affected_feature(vulnerability_id, feature_name);`, + + `CREATE TABLE IF NOT EXISTS vulnerability_affected_namespaced_feature( + id SERIAL PRIMARY KEY, + vulnerability_id INT NOT NULL REFERENCES vulnerability ON DELETE CASCADE, + namespaced_feature_id INT NOT NULL REFERENCES namespaced_feature ON DELETE CASCADE, + added_by INT NOT NULL REFERENCES vulnerability_affected_feature ON DELETE CASCADE, + UNIQUE (vulnerability_id, namespaced_feature_id));`, + `CREATE INDEX ON vulnerability_affected_namespaced_feature(namespaced_feature_id);`, + + `CREATE TABLE IF NOT EXISTS KeyValue ( + id SERIAL PRIMARY KEY, + key TEXT NOT NULL UNIQUE, + value TEXT);`, + + `CREATE TABLE IF NOT EXISTS Lock ( + id SERIAL PRIMARY KEY, + name VARCHAR(64) NOT NULL UNIQUE, + owner VARCHAR(64) NOT NULL, + until TIMESTAMP WITH TIME ZONE);`, + `CREATE INDEX ON Lock (owner);`, + + // Notification + `CREATE TABLE IF NOT EXISTS Vulnerability_Notification ( + id SERIAL PRIMARY KEY, + name VARCHAR(64) NOT NULL UNIQUE, + created_at TIMESTAMP WITH TIME ZONE, + notified_at TIMESTAMP WITH TIME ZONE NULL, + deleted_at TIMESTAMP WITH TIME ZONE NULL, + old_vulnerability_id INT NULL REFERENCES Vulnerability ON DELETE CASCADE, + new_vulnerability_id INT NULL REFERENCES Vulnerability ON DELETE CASCADE);`, + `CREATE INDEX ON Vulnerability_Notification (notified_at);`, + }), + Down: migrate.Queries([]string{ + `DROP TABLE IF EXISTS + ancestry, + ancestry_layer, + ancestry_feature, + ancestry_detector, + ancestry_lister, + feature, + namespaced_feature, + keyvalue, + layer, + layer_detector, + layer_feature, + layer_lister, + layer_namespace, + lock, + namespace, + vulnerability, + vulnerability_affected_feature, + vulnerability_affected_namespaced_feature, + vulnerability_notification + CASCADE;`, + `DROP TYPE IF EXISTS severity;`, + }), + }) +} diff --git a/database/pgsql/migrations/00002_initial_schema.go b/database/pgsql/migrations/00002_initial_schema.go deleted file mode 100644 index f7cc17e6..00000000 --- a/database/pgsql/migrations/00002_initial_schema.go +++ /dev/null @@ -1,128 +0,0 @@ -// Copyright 2016 clair authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package migrations - -import "github.com/remind101/migrate" - -func init() { - // This migration creates the initial Clair's schema. - RegisterMigration(migrate.Migration{ - ID: 2, - Up: migrate.Queries([]string{ - `CREATE TABLE IF NOT EXISTS Namespace ( - id SERIAL PRIMARY KEY, - name VARCHAR(128) NULL);`, - - `CREATE TABLE IF NOT EXISTS Layer ( - id SERIAL PRIMARY KEY, - name VARCHAR(128) NOT NULL UNIQUE, - engineversion SMALLINT NOT NULL, - parent_id INT NULL REFERENCES Layer ON DELETE CASCADE, - namespace_id INT NULL REFERENCES Namespace, - created_at TIMESTAMP WITH TIME ZONE);`, - `CREATE INDEX ON Layer (parent_id);`, - `CREATE INDEX ON Layer (namespace_id);`, - - `CREATE TABLE IF NOT EXISTS Feature ( - id SERIAL PRIMARY KEY, - namespace_id INT NOT NULL REFERENCES Namespace, - name VARCHAR(128) NOT NULL, - UNIQUE (namespace_id, name));`, - - `CREATE TABLE IF NOT EXISTS FeatureVersion ( - id SERIAL PRIMARY KEY, - feature_id INT NOT NULL REFERENCES Feature, - version VARCHAR(128) NOT NULL);`, - `CREATE INDEX ON FeatureVersion (feature_id);`, - - `CREATE TYPE modification AS ENUM ('add', 'del');`, - `CREATE TABLE IF NOT EXISTS Layer_diff_FeatureVersion ( - id SERIAL PRIMARY KEY, - layer_id INT NOT NULL REFERENCES Layer ON DELETE CASCADE, - featureversion_id INT NOT NULL REFERENCES FeatureVersion, - modification modification NOT NULL, - UNIQUE (layer_id, featureversion_id));`, - `CREATE INDEX ON Layer_diff_FeatureVersion (layer_id);`, - `CREATE INDEX ON Layer_diff_FeatureVersion (featureversion_id);`, - `CREATE INDEX ON Layer_diff_FeatureVersion (featureversion_id, layer_id);`, - - `CREATE TYPE severity AS ENUM ('Unknown', 'Negligible', 'Low', 'Medium', 'High', 'Critical', 'Defcon1');`, - `CREATE TABLE IF NOT EXISTS Vulnerability ( - id SERIAL PRIMARY KEY, - namespace_id INT NOT NULL REFERENCES Namespace, - name VARCHAR(128) NOT NULL, - description TEXT NULL, - link VARCHAR(128) NULL, - severity severity NOT NULL, - metadata TEXT NULL, - created_at TIMESTAMP WITH TIME ZONE, - deleted_at TIMESTAMP WITH TIME ZONE NULL);`, - - `CREATE TABLE IF NOT EXISTS Vulnerability_FixedIn_Feature ( - id SERIAL PRIMARY KEY, - vulnerability_id INT NOT NULL REFERENCES Vulnerability ON DELETE CASCADE, - feature_id INT NOT NULL REFERENCES Feature, - version VARCHAR(128) NOT NULL, - UNIQUE (vulnerability_id, feature_id));`, - `CREATE INDEX ON Vulnerability_FixedIn_Feature (feature_id, vulnerability_id);`, - - `CREATE TABLE IF NOT EXISTS Vulnerability_Affects_FeatureVersion ( - id SERIAL PRIMARY KEY, - vulnerability_id INT NOT NULL REFERENCES Vulnerability ON DELETE CASCADE, - featureversion_id INT NOT NULL REFERENCES FeatureVersion, - fixedin_id INT NOT NULL REFERENCES Vulnerability_FixedIn_Feature ON DELETE CASCADE, - UNIQUE (vulnerability_id, featureversion_id));`, - `CREATE INDEX ON Vulnerability_Affects_FeatureVersion (fixedin_id);`, - `CREATE INDEX ON Vulnerability_Affects_FeatureVersion (featureversion_id, vulnerability_id);`, - - `CREATE TABLE IF NOT EXISTS KeyValue ( - id SERIAL PRIMARY KEY, - key VARCHAR(128) NOT NULL UNIQUE, - value TEXT);`, - - `CREATE TABLE IF NOT EXISTS Lock ( - id SERIAL PRIMARY KEY, - name VARCHAR(64) NOT NULL UNIQUE, - owner VARCHAR(64) NOT NULL, - until TIMESTAMP WITH TIME ZONE);`, - `CREATE INDEX ON Lock (owner);`, - - `CREATE TABLE IF NOT EXISTS Vulnerability_Notification ( - id SERIAL PRIMARY KEY, - name VARCHAR(64) NOT NULL UNIQUE, - created_at TIMESTAMP WITH TIME ZONE, - notified_at TIMESTAMP WITH TIME ZONE NULL, - deleted_at TIMESTAMP WITH TIME ZONE NULL, - old_vulnerability_id INT NULL REFERENCES Vulnerability ON DELETE CASCADE, - new_vulnerability_id INT NULL REFERENCES Vulnerability ON DELETE CASCADE);`, - `CREATE INDEX ON Vulnerability_Notification (notified_at);`, - }), - Down: migrate.Queries([]string{ - `DROP TABLE IF EXISTS - Namespace, - Layer, - Feature, - FeatureVersion, - Layer_diff_FeatureVersion, - Vulnerability, - Vulnerability_FixedIn_Feature, - Vulnerability_Affects_FeatureVersion, - Vulnerability_Notification, - KeyValue, - Lock - CASCADE;`, - }), - }) -} diff --git a/database/pgsql/migrations/00003_add_indexes.go b/database/pgsql/migrations/00003_add_indexes.go deleted file mode 100644 index 78ccaba2..00000000 --- a/database/pgsql/migrations/00003_add_indexes.go +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright 2016 clair authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package migrations - -import "github.com/remind101/migrate" - -func init() { - RegisterMigration(migrate.Migration{ - ID: 3, - Up: migrate.Queries([]string{ - `CREATE UNIQUE INDEX namespace_name_key ON Namespace (name);`, - `CREATE INDEX vulnerability_name_idx ON Vulnerability (name);`, - `CREATE INDEX vulnerability_namespace_id_name_idx ON Vulnerability (namespace_id, name);`, - `CREATE UNIQUE INDEX featureversion_feature_id_version_key ON FeatureVersion (feature_id, version);`, - }), - Down: migrate.Queries([]string{ - `DROP INDEX namespace_name_key;`, - `DROP INDEX vulnerability_name_idx;`, - `DROP INDEX vulnerability_namespace_id_name_idx;`, - `DROP INDEX featureversion_feature_id_version_key;`, - }), - }) -} diff --git a/database/pgsql/migrations/00004_add_index_notification_deleted_at.go b/database/pgsql/migrations/00004_add_index_notification_deleted_at.go deleted file mode 100644 index 12f38ab2..00000000 --- a/database/pgsql/migrations/00004_add_index_notification_deleted_at.go +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2016 clair authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package migrations - -import "github.com/remind101/migrate" - -func init() { - RegisterMigration(migrate.Migration{ - ID: 4, - Up: migrate.Queries([]string{ - `CREATE INDEX vulnerability_notification_deleted_at_idx ON Vulnerability_Notification (deleted_at);`, - }), - Down: migrate.Queries([]string{ - `DROP INDEX vulnerability_notification_deleted_at_idx;`, - }), - }) -} diff --git a/database/pgsql/migrations/00005_ldfv_index.go b/database/pgsql/migrations/00005_ldfv_index.go deleted file mode 100644 index ec8e7137..00000000 --- a/database/pgsql/migrations/00005_ldfv_index.go +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2016 clair authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package migrations - -import "github.com/remind101/migrate" - -func init() { - RegisterMigration(migrate.Migration{ - ID: 5, - Up: migrate.Queries([]string{ - `CREATE INDEX layer_diff_featureversion_layer_id_modification_idx ON Layer_diff_FeatureVersion (layer_id, modification);`, - }), - Down: migrate.Queries([]string{ - `DROP INDEX layer_diff_featureversion_layer_id_modification_idx;`, - }), - }) -} diff --git a/database/pgsql/migrations/00006_add_version_format.go b/database/pgsql/migrations/00006_add_version_format.go deleted file mode 100644 index 3a08f6f0..00000000 --- a/database/pgsql/migrations/00006_add_version_format.go +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright 2016 clair authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package migrations - -import "github.com/remind101/migrate" - -func init() { - RegisterMigration(migrate.Migration{ - ID: 6, - Up: migrate.Queries([]string{ - `ALTER TABLE Namespace ADD COLUMN version_format varchar(128);`, - `UPDATE Namespace SET version_format = 'rpm' WHERE name LIKE 'rhel%' OR name LIKE 'centos%' OR name LIKE 'fedora%' OR name LIKE 'amzn%' OR name LIKE 'scientific%' OR name LIKE 'ol%' OR name LIKE 'oracle%';`, - `UPDATE Namespace SET version_format = 'dpkg' WHERE version_format is NULL;`, - }), - Down: migrate.Queries([]string{ - `ALTER TABLE Namespace DROP COLUMN version_format;`, - }), - }) -} diff --git a/database/pgsql/migrations/00008_add_multiplens.go b/database/pgsql/migrations/00008_add_multiplens.go deleted file mode 100644 index ecfb4762..00000000 --- a/database/pgsql/migrations/00008_add_multiplens.go +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright 2016 clair authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package migrations - -import "github.com/remind101/migrate" - -func init() { - RegisterMigration(migrate.Migration{ - ID: 8, - Up: migrate.Queries([]string{ - // set on deletion, remove the corresponding rows in database - `CREATE TABLE IF NOT EXISTS Layer_Namespace( - id SERIAL PRIMARY KEY, - layer_id INT REFERENCES Layer(id) ON DELETE CASCADE, - namespace_id INT REFERENCES Namespace(id) ON DELETE CASCADE, - unique(layer_id, namespace_id) - );`, - `CREATE INDEX ON Layer_Namespace (namespace_id);`, - `CREATE INDEX ON Layer_Namespace (layer_id);`, - // move the namespace_id to the table - `INSERT INTO Layer_Namespace (layer_id, namespace_id) SELECT id, namespace_id FROM Layer;`, - // alter the Layer table to remove the column - `ALTER TABLE IF EXISTS Layer DROP namespace_id;`, - }), - Down: migrate.Queries([]string{ - `ALTER TABLE IF EXISTS Layer ADD namespace_id INT NULL REFERENCES Namespace;`, - `CREATE INDEX ON Layer (namespace_id);`, - `UPDATE IF EXISTS Layer SET namespace_id = (SELECT lns.namespace_id FROM Layer_Namespace lns WHERE Layer.id = lns.layer_id LIMIT 1);`, - `DROP TABLE IF EXISTS Layer_Namespace;`, - }), - }) -} diff --git a/database/pgsql/namespace.go b/database/pgsql/namespace.go index 8d4b304b..1a78f837 100644 --- a/database/pgsql/namespace.go +++ b/database/pgsql/namespace.go @@ -15,61 +15,82 @@ package pgsql import ( - "time" + "database/sql" + "errors" + "sort" "github.com/coreos/clair/database" "github.com/coreos/clair/pkg/commonerr" ) -func (pgSQL *pgSQL) insertNamespace(namespace database.Namespace) (int, error) { - if namespace.Name == "" { - return 0, commonerr.NewBadRequestError("could not find/insert invalid Namespace") +var ( + errNamespaceNotFound = errors.New("Requested Namespace is not in database") +) + +// PersistNamespaces soi namespaces into database. +func (tx *pgSession) PersistNamespaces(namespaces []database.Namespace) error { + if len(namespaces) == 0 { + return nil } - if pgSQL.cache != nil { - promCacheQueriesTotal.WithLabelValues("namespace").Inc() - if id, found := pgSQL.cache.Get("namespace:" + namespace.Name); found { - promCacheHitsTotal.WithLabelValues("namespace").Inc() - return id.(int), nil + // Sorting is needed before inserting into database to prevent deadlock. + sort.Slice(namespaces, func(i, j int) bool { + return namespaces[i].Name < namespaces[j].Name && + namespaces[i].VersionFormat < namespaces[j].VersionFormat + }) + + keys := make([]interface{}, len(namespaces)*2) + for i, ns := range namespaces { + if ns.Name == "" || ns.VersionFormat == "" { + return commonerr.NewBadRequestError("Empty namespace name or version format is not allowed") } + keys[i*2] = ns.Name + keys[i*2+1] = ns.VersionFormat } - // We do `defer observeQueryTime` here because we don't want to observe cached namespaces. - defer observeQueryTime("insertNamespace", "all", time.Now()) - - var id int - err := pgSQL.QueryRow(soiNamespace, namespace.Name, namespace.VersionFormat).Scan(&id) + _, err := tx.Exec(queryPersistNamespace(len(namespaces)), keys...) if err != nil { - return 0, handleError("soiNamespace", err) + return handleError("queryPersistNamespace", err) } - - if pgSQL.cache != nil { - pgSQL.cache.Add("namespace:"+namespace.Name, id) - } - - return id, nil + return nil } -func (pgSQL *pgSQL) ListNamespaces() (namespaces []database.Namespace, err error) { - rows, err := pgSQL.Query(listNamespace) - if err != nil { - return namespaces, handleError("listNamespace", err) +func (tx *pgSession) findNamespaceIDs(namespaces []database.Namespace) ([]sql.NullInt64, error) { + if len(namespaces) == 0 { + return nil, nil } + + keys := make([]interface{}, len(namespaces)*2) + nsMap := map[database.Namespace]sql.NullInt64{} + for i, n := range namespaces { + keys[i*2] = n.Name + keys[i*2+1] = n.VersionFormat + nsMap[n] = sql.NullInt64{} + } + + rows, err := tx.Query(querySearchNamespace(len(namespaces)), keys...) + if err != nil { + return nil, handleError("searchNamespace", err) + } + defer rows.Close() + var ( + id sql.NullInt64 + ns database.Namespace + ) for rows.Next() { - var ns database.Namespace - - err = rows.Scan(&ns.ID, &ns.Name, &ns.VersionFormat) + err := rows.Scan(&id, &ns.Name, &ns.VersionFormat) if err != nil { - return namespaces, handleError("listNamespace.Scan()", err) + return nil, handleError("searchNamespace", err) } - - namespaces = append(namespaces, ns) - } - if err = rows.Err(); err != nil { - return namespaces, handleError("listNamespace.Rows()", err) + nsMap[ns] = id } - return namespaces, err + ids := make([]sql.NullInt64, len(namespaces)) + for i, ns := range namespaces { + ids[i] = nsMap[ns] + } + + return ids, nil } diff --git a/database/pgsql/namespace_test.go b/database/pgsql/namespace_test.go index 0990b6f4..27ceefef 100644 --- a/database/pgsql/namespace_test.go +++ b/database/pgsql/namespace_test.go @@ -15,60 +15,69 @@ package pgsql import ( - "fmt" "testing" "github.com/stretchr/testify/assert" "github.com/coreos/clair/database" - "github.com/coreos/clair/ext/versionfmt/dpkg" ) -func TestInsertNamespace(t *testing.T) { - datastore, err := openDatabaseForTest("InsertNamespace", false) - if err != nil { - t.Error(err) - return - } - defer datastore.Close() +func TestPersistNamespaces(t *testing.T) { + datastore, tx := openSessionForTest(t, "PersistNamespaces", false) + defer closeTest(t, datastore, tx) - // Invalid Namespace. - id0, err := datastore.insertNamespace(database.Namespace{}) - assert.NotNil(t, err) - assert.Zero(t, id0) + ns1 := database.Namespace{} + ns2 := database.Namespace{Name: "t", VersionFormat: "b"} - // Insert Namespace and ensure we can find it. - id1, err := datastore.insertNamespace(database.Namespace{ - Name: "TestInsertNamespace1", - VersionFormat: dpkg.ParserName, - }) - assert.Nil(t, err) - id2, err := datastore.insertNamespace(database.Namespace{ - Name: "TestInsertNamespace1", - VersionFormat: dpkg.ParserName, - }) - assert.Nil(t, err) - assert.Equal(t, id1, id2) + // Empty Case + assert.Nil(t, tx.PersistNamespaces([]database.Namespace{})) + // Invalid Case + assert.NotNil(t, tx.PersistNamespaces([]database.Namespace{ns1})) + // Duplicated Case + assert.Nil(t, tx.PersistNamespaces([]database.Namespace{ns2, ns2})) + // Existing Case + assert.Nil(t, tx.PersistNamespaces([]database.Namespace{ns2})) + + nsList := listNamespaces(t, tx) + assert.Len(t, nsList, 1) + assert.Equal(t, ns2, nsList[0]) } -func TestListNamespace(t *testing.T) { - datastore, err := openDatabaseForTest("ListNamespaces", true) - if err != nil { - t.Error(err) - return - } - defer datastore.Close() - - namespaces, err := datastore.ListNamespaces() - assert.Nil(t, err) - if assert.Len(t, namespaces, 2) { - for _, namespace := range namespaces { - switch namespace.Name { - case "debian:7", "debian:8": - continue - default: - assert.Error(t, fmt.Errorf("ListNamespaces should not have returned '%s'", namespace.Name)) +func assertNamespacesEqual(t *testing.T, expected []database.Namespace, actual []database.Namespace) bool { + if assert.Len(t, actual, len(expected)) { + has := map[database.Namespace]bool{} + for _, i := range expected { + has[i] = false + } + for _, i := range actual { + has[i] = true + } + for key, v := range has { + if !assert.True(t, v, key.Name+"is expected") { + return false } } + return true } + return false +} + +func listNamespaces(t *testing.T, tx *pgSession) []database.Namespace { + rows, err := tx.Query("SELECT name, version_format FROM namespace") + if err != nil { + t.FailNow() + } + defer rows.Close() + + namespaces := []database.Namespace{} + for rows.Next() { + var ns database.Namespace + err := rows.Scan(&ns.Name, &ns.VersionFormat) + if err != nil { + t.FailNow() + } + namespaces = append(namespaces, ns) + } + + return namespaces } diff --git a/database/pgsql/notification.go b/database/pgsql/notification.go index f8c6960d..ebc346d3 100644 --- a/database/pgsql/notification.go +++ b/database/pgsql/notification.go @@ -16,235 +16,320 @@ package pgsql import ( "database/sql" + "errors" "time" "github.com/guregu/null/zero" - "github.com/pborman/uuid" - log "github.com/sirupsen/logrus" "github.com/coreos/clair/database" "github.com/coreos/clair/pkg/commonerr" ) -// do it in tx so we won't insert/update a vuln without notification and vice-versa. -// name and created doesn't matter. -func createNotification(tx *sql.Tx, oldVulnerabilityID, newVulnerabilityID int) error { - defer observeQueryTime("createNotification", "all", time.Now()) +var ( + errNotificationNotFound = errors.New("requested notification is not found") +) - // Insert Notification. - oldVulnerabilityNullableID := sql.NullInt64{Int64: int64(oldVulnerabilityID), Valid: oldVulnerabilityID != 0} - newVulnerabilityNullableID := sql.NullInt64{Int64: int64(newVulnerabilityID), Valid: newVulnerabilityID != 0} - _, err := tx.Exec(insertNotification, uuid.New(), oldVulnerabilityNullableID, newVulnerabilityNullableID) +func (tx *pgSession) InsertVulnerabilityNotifications(notifications []database.VulnerabilityNotification) error { + if len(notifications) == 0 { + return nil + } + + var ( + newVulnIDMap = make(map[database.VulnerabilityID]sql.NullInt64) + oldVulnIDMap = make(map[database.VulnerabilityID]sql.NullInt64) + ) + + invalidCreationTime := time.Time{} + for _, noti := range notifications { + if noti.Name == "" { + return commonerr.NewBadRequestError("notification should not have empty name") + } + if noti.Created == invalidCreationTime { + return commonerr.NewBadRequestError("notification should not have empty created time") + } + + if noti.New != nil { + key := database.VulnerabilityID{ + Name: noti.New.Name, + Namespace: noti.New.Namespace.Name, + } + newVulnIDMap[key] = sql.NullInt64{} + } + + if noti.Old != nil { + key := database.VulnerabilityID{ + Name: noti.Old.Name, + Namespace: noti.Old.Namespace.Name, + } + oldVulnIDMap[key] = sql.NullInt64{} + } + } + + var ( + newVulnIDs = make([]database.VulnerabilityID, 0, len(newVulnIDMap)) + oldVulnIDs = make([]database.VulnerabilityID, 0, len(oldVulnIDMap)) + ) + + for vulnID := range newVulnIDMap { + newVulnIDs = append(newVulnIDs, vulnID) + } + + for vulnID := range oldVulnIDMap { + oldVulnIDs = append(oldVulnIDs, vulnID) + } + + ids, err := tx.findNotDeletedVulnerabilityIDs(newVulnIDs) if err != nil { - tx.Rollback() - return handleError("insertNotification", err) + return err + } + + for i, id := range ids { + if !id.Valid { + return handleError("findNotDeletedVulnerabilityIDs", errVulnerabilityNotFound) + } + newVulnIDMap[newVulnIDs[i]] = id + } + + ids, err = tx.findLatestDeletedVulnerabilityIDs(oldVulnIDs) + if err != nil { + return err + } + + for i, id := range ids { + if !id.Valid { + return handleError("findLatestDeletedVulnerabilityIDs", errVulnerabilityNotFound) + } + oldVulnIDMap[oldVulnIDs[i]] = id + } + + var ( + newVulnID sql.NullInt64 + oldVulnID sql.NullInt64 + ) + + keys := make([]interface{}, len(notifications)*4) + for i, noti := range notifications { + if noti.New != nil { + newVulnID = newVulnIDMap[database.VulnerabilityID{ + Name: noti.New.Name, + Namespace: noti.New.Namespace.Name, + }] + } + + if noti.Old != nil { + oldVulnID = oldVulnIDMap[database.VulnerabilityID{ + Name: noti.Old.Name, + Namespace: noti.Old.Namespace.Name, + }] + } + + keys[4*i] = noti.Name + keys[4*i+1] = noti.Created + keys[4*i+2] = oldVulnID + keys[4*i+3] = newVulnID + } + + // NOTE(Sida): The data is not sorted before inserting into database under + // the fact that there's only one updater running at a time. If there are + // multiple updaters, deadlock may happen. + _, err = tx.Exec(queryInsertNotifications(len(notifications)), keys...) + if err != nil { + return handleError("queryInsertNotifications", err) } return nil } -// Get one available notification name (!locked && !deleted && (!notified || notified_but_timed-out)). -// Does not fill new/old vuln. -func (pgSQL *pgSQL) GetAvailableNotification(renotifyInterval time.Duration) (database.VulnerabilityNotification, error) { - defer observeQueryTime("GetAvailableNotification", "all", time.Now()) - - before := time.Now().Add(-renotifyInterval) - row := pgSQL.QueryRow(searchNotificationAvailable, before) - notification, err := pgSQL.scanNotification(row, false) - - return notification, handleError("searchNotificationAvailable", err) -} - -func (pgSQL *pgSQL) GetNotification(name string, limit int, page database.VulnerabilityNotificationPageNumber) (database.VulnerabilityNotification, database.VulnerabilityNotificationPageNumber, error) { - defer observeQueryTime("GetNotification", "all", time.Now()) - - // Get Notification. - notification, err := pgSQL.scanNotification(pgSQL.QueryRow(searchNotification, name), true) - if err != nil { - return notification, page, handleError("searchNotification", err) - } - - // Load vulnerabilities' LayersIntroducingVulnerability. - page.OldVulnerability, err = pgSQL.loadLayerIntroducingVulnerability( - notification.OldVulnerability, - limit, - page.OldVulnerability, +func (tx *pgSession) FindNewNotification(notifiedBefore time.Time) (database.NotificationHook, bool, error) { + var ( + notification database.NotificationHook + created zero.Time + notified zero.Time + deleted zero.Time ) + err := tx.QueryRow(searchNotificationAvailable, notifiedBefore).Scan(¬ification.Name, &created, ¬ified, &deleted) if err != nil { - return notification, page, err - } - - page.NewVulnerability, err = pgSQL.loadLayerIntroducingVulnerability( - notification.NewVulnerability, - limit, - page.NewVulnerability, - ) - - if err != nil { - return notification, page, err - } - - return notification, page, nil -} - -func (pgSQL *pgSQL) scanNotification(row *sql.Row, hasVulns bool) (database.VulnerabilityNotification, error) { - var notification database.VulnerabilityNotification - var created zero.Time - var notified zero.Time - var deleted zero.Time - var oldVulnerabilityNullableID sql.NullInt64 - var newVulnerabilityNullableID sql.NullInt64 - - // Scan notification. - if hasVulns { - err := row.Scan( - ¬ification.ID, - ¬ification.Name, - &created, - ¬ified, - &deleted, - &oldVulnerabilityNullableID, - &newVulnerabilityNullableID, - ) - - if err != nil { - return notification, err - } - } else { - err := row.Scan(¬ification.ID, ¬ification.Name, &created, ¬ified, &deleted) - - if err != nil { - return notification, err + if err == sql.ErrNoRows { + return notification, false, nil } + return notification, false, handleError("searchNotificationAvailable", err) } notification.Created = created.Time notification.Notified = notified.Time notification.Deleted = deleted.Time - if hasVulns { - if oldVulnerabilityNullableID.Valid { - vulnerability, err := pgSQL.findVulnerabilityByIDWithDeleted(int(oldVulnerabilityNullableID.Int64)) - if err != nil { - return notification, err - } - - notification.OldVulnerability = &vulnerability - } - - if newVulnerabilityNullableID.Valid { - vulnerability, err := pgSQL.findVulnerabilityByIDWithDeleted(int(newVulnerabilityNullableID.Int64)) - if err != nil { - return notification, err - } - - notification.NewVulnerability = &vulnerability - } - } - - return notification, nil + return notification, true, nil } -// Fills Vulnerability.LayersIntroducingVulnerability. -// limit -1: won't do anything -// limit 0: will just get the startID of the second page -func (pgSQL *pgSQL) loadLayerIntroducingVulnerability(vulnerability *database.Vulnerability, limit, startID int) (int, error) { - tf := time.Now() - - if vulnerability == nil { - return -1, nil +func (tx *pgSession) findPagedVulnerableAncestries(vulnID int64, limit int, currentPage database.PageNumber) (database.PagedVulnerableAncestries, error) { + vulnPage := database.PagedVulnerableAncestries{Limit: limit} + current := idPageNumber{0} + if currentPage != "" { + var err error + current, err = decryptPage(currentPage, tx.paginationKey) + if err != nil { + return vulnPage, err + } } - // A startID equals to -1 means that we reached the end already. - if startID == -1 || limit == -1 { - return -1, nil - } - - // Create a transaction to disable hash joins as our experience shows that - // PostgreSQL plans in certain cases a sequential scan and a hash on - // Layer_diff_FeatureVersion for the condition `ldfv.layer_id >= $2 AND - // ldfv.modification = 'add'` before realizing a hash inner join with - // Vulnerability_Affects_FeatureVersion. By disabling explictly hash joins, - // we force PostgreSQL to perform a bitmap index scan with - // `ldfv.featureversion_id = fv.id` on Layer_diff_FeatureVersion, followed by - // a bitmap heap scan on `ldfv.layer_id >= $2 AND ldfv.modification = 'add'`, - // thus avoiding a sequential scan on the biggest database table and - // allowing a small nested loop join instead. - tx, err := pgSQL.Begin() + err := tx.QueryRow(searchVulnerabilityByID, vulnID).Scan( + &vulnPage.Name, + &vulnPage.Description, + &vulnPage.Link, + &vulnPage.Severity, + &vulnPage.Metadata, + &vulnPage.Namespace.Name, + &vulnPage.Namespace.VersionFormat, + ) if err != nil { - return -1, handleError("searchNotificationLayerIntroducingVulnerability.Begin()", err) - } - defer tx.Commit() - - _, err = tx.Exec(disableHashJoin) - if err != nil { - log.WithError(err).Warning("searchNotificationLayerIntroducingVulnerability: could not disable hash join") + return vulnPage, handleError("searchVulnerabilityByID", err) } - // We do `defer observeQueryTime` here because we don't want to observe invalid calls. - defer observeQueryTime("loadLayerIntroducingVulnerability", "all", tf) - - // Query with limit + 1, the last item will be used to know the next starting ID. - rows, err := tx.Query(searchNotificationLayerIntroducingVulnerability, - vulnerability.ID, startID, limit+1) + // the last result is used for the next page's startID + rows, err := tx.Query(searchNotificationVulnerableAncestry, vulnID, current.StartID, limit+1) if err != nil { - return 0, handleError("searchNotificationLayerIntroducingVulnerability", err) + return vulnPage, handleError("searchNotificationVulnerableAncestry", err) } defer rows.Close() - var layers []database.Layer + ancestries := []affectedAncestry{} for rows.Next() { - var layer database.Layer - - if err := rows.Scan(&layer.ID, &layer.Name); err != nil { - return -1, handleError("searchNotificationLayerIntroducingVulnerability.Scan()", err) + var ancestry affectedAncestry + err := rows.Scan(&ancestry.id, &ancestry.name) + if err != nil { + return vulnPage, handleError("searchNotificationVulnerableAncestry", err) } - - layers = append(layers, layer) - } - if err = rows.Err(); err != nil { - return -1, handleError("searchNotificationLayerIntroducingVulnerability.Rows()", err) + ancestries = append(ancestries, ancestry) } - size := limit - if len(layers) < limit { - size = len(layers) - } - vulnerability.LayersIntroducingVulnerability = layers[:size] + lastIndex := 0 + if len(ancestries)-1 < limit { + lastIndex = len(ancestries) + vulnPage.End = true + } else { + // Use the last ancestry's ID as the next PageNumber. + lastIndex = len(ancestries) - 1 + vulnPage.Next, err = encryptPage( + idPageNumber{ + ancestries[len(ancestries)-1].id, + }, tx.paginationKey) - nextID := -1 - if len(layers) > limit { - nextID = layers[limit].ID + if err != nil { + return vulnPage, err + } } - return nextID, nil + vulnPage.Affected = map[int]string{} + for _, ancestry := range ancestries[0:lastIndex] { + vulnPage.Affected[int(ancestry.id)] = ancestry.name + } + + vulnPage.Current, err = encryptPage(current, tx.paginationKey) + if err != nil { + return vulnPage, err + } + + return vulnPage, nil } -func (pgSQL *pgSQL) SetNotificationNotified(name string) error { - defer observeQueryTime("SetNotificationNotified", "all", time.Now()) +func (tx *pgSession) FindVulnerabilityNotification(name string, limit int, oldPage database.PageNumber, newPage database.PageNumber) ( + database.VulnerabilityNotificationWithVulnerable, bool, error) { + var ( + noti database.VulnerabilityNotificationWithVulnerable + oldVulnID sql.NullInt64 + newVulnID sql.NullInt64 + created zero.Time + notified zero.Time + deleted zero.Time + ) - if _, err := pgSQL.Exec(updatedNotificationNotified, name); err != nil { + if name == "" { + return noti, false, commonerr.NewBadRequestError("Empty notification name is not allowed") + } + + noti.Name = name + + err := tx.QueryRow(searchNotification, name).Scan(&created, ¬ified, + &deleted, &oldVulnID, &newVulnID) + + if err != nil { + if err == sql.ErrNoRows { + return noti, false, nil + } + return noti, false, handleError("searchNotification", err) + } + + if created.Valid { + noti.Created = created.Time + } + + if notified.Valid { + noti.Notified = notified.Time + } + + if deleted.Valid { + noti.Deleted = deleted.Time + } + + if oldVulnID.Valid { + page, err := tx.findPagedVulnerableAncestries(oldVulnID.Int64, limit, oldPage) + if err != nil { + return noti, false, err + } + noti.Old = &page + } + + if newVulnID.Valid { + page, err := tx.findPagedVulnerableAncestries(newVulnID.Int64, limit, newPage) + if err != nil { + return noti, false, err + } + noti.New = &page + } + + return noti, true, nil +} + +func (tx *pgSession) MarkNotificationNotified(name string) error { + if name == "" { + return commonerr.NewBadRequestError("Empty notification name is not allowed") + } + + r, err := tx.Exec(updatedNotificationNotified, name) + if err != nil { return handleError("updatedNotificationNotified", err) } + + affected, err := r.RowsAffected() + if err != nil { + return handleError("updatedNotificationNotified", err) + } + + if affected <= 0 { + return handleError("updatedNotificationNotified", errNotificationNotFound) + } return nil } -func (pgSQL *pgSQL) DeleteNotification(name string) error { - defer observeQueryTime("DeleteNotification", "all", time.Now()) +func (tx *pgSession) DeleteNotification(name string) error { + if name == "" { + return commonerr.NewBadRequestError("Empty notification name is not allowed") + } - result, err := pgSQL.Exec(removeNotification, name) + result, err := tx.Exec(removeNotification, name) if err != nil { return handleError("removeNotification", err) } affected, err := result.RowsAffected() if err != nil { - return handleError("removeNotification.RowsAffected()", err) + return handleError("removeNotification", err) } if affected <= 0 { - return commonerr.ErrNotFound + return handleError("removeNotification", commonerr.ErrNotFound) } return nil diff --git a/database/pgsql/notification_test.go b/database/pgsql/notification_test.go index 24e79246..0d930d08 100644 --- a/database/pgsql/notification_test.go +++ b/database/pgsql/notification_test.go @@ -21,211 +21,225 @@ import ( "github.com/stretchr/testify/assert" "github.com/coreos/clair/database" - "github.com/coreos/clair/ext/versionfmt" - "github.com/coreos/clair/ext/versionfmt/dpkg" - "github.com/coreos/clair/pkg/commonerr" ) -func TestNotification(t *testing.T) { - datastore, err := openDatabaseForTest("Notification", false) - if err != nil { - t.Error(err) - return - } - defer datastore.Close() +func TestPagination(t *testing.T) { + datastore, tx := openSessionForTest(t, "Pagination", true) + defer closeTest(t, datastore, tx) - // Try to get a notification when there is none. - _, err = datastore.GetAvailableNotification(time.Second) - assert.Equal(t, commonerr.ErrNotFound, err) - - // Create some data. - f1 := database.Feature{ - Name: "TestNotificationFeature1", - Namespace: database.Namespace{ - Name: "TestNotificationNamespace1", - VersionFormat: dpkg.ParserName, - }, + ns := database.Namespace{ + Name: "debian:7", + VersionFormat: "dpkg", } - f2 := database.Feature{ - Name: "TestNotificationFeature2", - Namespace: database.Namespace{ - Name: "TestNotificationNamespace1", - VersionFormat: dpkg.ParserName, - }, + vNew := database.Vulnerability{ + Namespace: ns, + Name: "CVE-OPENSSL-1-DEB7", + Description: "A vulnerability affecting OpenSSL < 2.0 on Debian 7.0", + Link: "http://google.com/#q=CVE-OPENSSL-1-DEB7", + Severity: database.HighSeverity, } - l1 := database.Layer{ - Name: "TestNotificationLayer1", - Features: []database.FeatureVersion{ - { - Feature: f1, - Version: "0.1", - }, - }, + vOld := database.Vulnerability{ + Namespace: ns, + Name: "CVE-NOPE", + Description: "A vulnerability affecting nothing", + Severity: database.UnknownSeverity, } - l2 := database.Layer{ - Name: "TestNotificationLayer2", - Features: []database.FeatureVersion{ - { - Feature: f1, - Version: "0.2", - }, - }, + noti, ok, err := tx.FindVulnerabilityNotification("test", 1, "", "") + oldPage := database.PagedVulnerableAncestries{ + Vulnerability: vOld, + Limit: 1, + Affected: make(map[int]string), + End: true, } - l3 := database.Layer{ - Name: "TestNotificationLayer3", - Features: []database.FeatureVersion{ - { - Feature: f1, - Version: "0.3", - }, - }, + newPage1 := database.PagedVulnerableAncestries{ + Vulnerability: vNew, + Limit: 1, + Affected: map[int]string{3: "ancestry-3"}, + End: false, } - l4 := database.Layer{ - Name: "TestNotificationLayer4", - Features: []database.FeatureVersion{ - { - Feature: f2, - Version: "0.1", - }, - }, + newPage2 := database.PagedVulnerableAncestries{ + Vulnerability: vNew, + Limit: 1, + Affected: map[int]string{4: "ancestry-4"}, + Next: "", + End: true, } - if !assert.Nil(t, datastore.InsertLayer(l1)) || - !assert.Nil(t, datastore.InsertLayer(l2)) || - !assert.Nil(t, datastore.InsertLayer(l3)) || - !assert.Nil(t, datastore.InsertLayer(l4)) { - return - } - - // Insert a new vulnerability that is introduced by three layers. - v1 := database.Vulnerability{ - Name: "TestNotificationVulnerability1", - Namespace: f1.Namespace, - Description: "TestNotificationDescription1", - Link: "TestNotificationLink1", - Severity: "Unknown", - FixedIn: []database.FeatureVersion{ - { - Feature: f1, - Version: "1.0", - }, - }, - } - assert.Nil(t, datastore.insertVulnerability(v1, false, true)) - - // Get the notification associated to the previously inserted vulnerability. - notification, err := datastore.GetAvailableNotification(time.Second) - - if assert.Nil(t, err) && assert.NotEmpty(t, notification.Name) { - // Verify the renotify behaviour. - if assert.Nil(t, datastore.SetNotificationNotified(notification.Name)) { - _, err := datastore.GetAvailableNotification(time.Second) - assert.Equal(t, commonerr.ErrNotFound, err) - - time.Sleep(50 * time.Millisecond) - notificationB, err := datastore.GetAvailableNotification(20 * time.Millisecond) - assert.Nil(t, err) - assert.Equal(t, notification.Name, notificationB.Name) - - datastore.SetNotificationNotified(notification.Name) - } - - // Get notification. - filledNotification, nextPage, err := datastore.GetNotification(notification.Name, 2, database.VulnerabilityNotificationFirstPage) - if assert.Nil(t, err) { - assert.NotEqual(t, database.NoVulnerabilityNotificationPage, nextPage) - assert.Nil(t, filledNotification.OldVulnerability) - - if assert.NotNil(t, filledNotification.NewVulnerability) { - assert.Equal(t, v1.Name, filledNotification.NewVulnerability.Name) - assert.Len(t, filledNotification.NewVulnerability.LayersIntroducingVulnerability, 2) - } - } - - // Get second page. - filledNotification, nextPage, err = datastore.GetNotification(notification.Name, 2, nextPage) - if assert.Nil(t, err) { - assert.Equal(t, database.NoVulnerabilityNotificationPage, nextPage) - assert.Nil(t, filledNotification.OldVulnerability) - - if assert.NotNil(t, filledNotification.NewVulnerability) { - assert.Equal(t, v1.Name, filledNotification.NewVulnerability.Name) - assert.Len(t, filledNotification.NewVulnerability.LayersIntroducingVulnerability, 1) - } - } - - // Delete notification. - assert.Nil(t, datastore.DeleteNotification(notification.Name)) - - _, err = datastore.GetAvailableNotification(time.Millisecond) - assert.Equal(t, commonerr.ErrNotFound, err) - } - - // Update a vulnerability and ensure that the old/new vulnerabilities are correct. - v1b := v1 - v1b.Severity = database.HighSeverity - v1b.FixedIn = []database.FeatureVersion{ - { - Feature: f1, - Version: versionfmt.MinVersion, - }, - { - Feature: f2, - Version: versionfmt.MaxVersion, - }, - } - - if assert.Nil(t, datastore.insertVulnerability(v1b, false, true)) { - notification, err = datastore.GetAvailableNotification(time.Second) - assert.Nil(t, err) - assert.NotEmpty(t, notification.Name) - - if assert.Nil(t, err) && assert.NotEmpty(t, notification.Name) { - filledNotification, nextPage, err := datastore.GetNotification(notification.Name, 2, database.VulnerabilityNotificationFirstPage) - if assert.Nil(t, err) { - if assert.NotNil(t, filledNotification.OldVulnerability) { - assert.Equal(t, v1.Name, filledNotification.OldVulnerability.Name) - assert.Equal(t, v1.Severity, filledNotification.OldVulnerability.Severity) - assert.Len(t, filledNotification.OldVulnerability.LayersIntroducingVulnerability, 2) - } - - if assert.NotNil(t, filledNotification.NewVulnerability) { - assert.Equal(t, v1b.Name, filledNotification.NewVulnerability.Name) - assert.Equal(t, v1b.Severity, filledNotification.NewVulnerability.Severity) - assert.Len(t, filledNotification.NewVulnerability.LayersIntroducingVulnerability, 1) - } - - assert.Equal(t, -1, nextPage.NewVulnerability) + if assert.Nil(t, err) && assert.True(t, ok) { + assert.Equal(t, "test", noti.Name) + if assert.NotNil(t, noti.Old) && assert.NotNil(t, noti.New) { + oldPageNum, err := decryptPage(noti.Old.Current, tx.paginationKey) + if !assert.Nil(t, err) { + assert.FailNow(t, "") } - assert.Nil(t, datastore.DeleteNotification(notification.Name)) + assert.Equal(t, int64(0), oldPageNum.StartID) + newPageNum, err := decryptPage(noti.New.Current, tx.paginationKey) + if !assert.Nil(t, err) { + assert.FailNow(t, "") + } + newPageNextNum, err := decryptPage(noti.New.Next, tx.paginationKey) + if !assert.Nil(t, err) { + assert.FailNow(t, "") + } + assert.Equal(t, int64(0), newPageNum.StartID) + assert.Equal(t, int64(4), newPageNextNum.StartID) + + noti.Old.Current = "" + noti.New.Current = "" + noti.New.Next = "" + assert.Equal(t, oldPage, *noti.Old) + assert.Equal(t, newPage1, *noti.New) } } - // Delete a vulnerability and verify the notification. - if assert.Nil(t, datastore.DeleteVulnerability(v1b.Namespace.Name, v1b.Name)) { - notification, err = datastore.GetAvailableNotification(time.Second) - assert.Nil(t, err) - assert.NotEmpty(t, notification.Name) + page1, err := encryptPage(idPageNumber{0}, tx.paginationKey) + if !assert.Nil(t, err) { + assert.FailNow(t, "") + } - if assert.Nil(t, err) && assert.NotEmpty(t, notification.Name) { - filledNotification, _, err := datastore.GetNotification(notification.Name, 2, database.VulnerabilityNotificationFirstPage) - if assert.Nil(t, err) { - assert.Nil(t, filledNotification.NewVulnerability) + page2, err := encryptPage(idPageNumber{4}, tx.paginationKey) + if !assert.Nil(t, err) { + assert.FailNow(t, "") + } - if assert.NotNil(t, filledNotification.OldVulnerability) { - assert.Equal(t, v1b.Name, filledNotification.OldVulnerability.Name) - assert.Equal(t, v1b.Severity, filledNotification.OldVulnerability.Severity) - assert.Len(t, filledNotification.OldVulnerability.LayersIntroducingVulnerability, 1) - } + noti, ok, err = tx.FindVulnerabilityNotification("test", 1, page1, page2) + if assert.Nil(t, err) && assert.True(t, ok) { + assert.Equal(t, "test", noti.Name) + if assert.NotNil(t, noti.Old) && assert.NotNil(t, noti.New) { + oldCurrentPage, err := decryptPage(noti.Old.Current, tx.paginationKey) + if !assert.Nil(t, err) { + assert.FailNow(t, "") } - assert.Nil(t, datastore.DeleteNotification(notification.Name)) + newCurrentPage, err := decryptPage(noti.New.Current, tx.paginationKey) + if !assert.Nil(t, err) { + assert.FailNow(t, "") + } + + assert.Equal(t, int64(0), oldCurrentPage.StartID) + assert.Equal(t, int64(4), newCurrentPage.StartID) + noti.Old.Current = "" + noti.New.Current = "" + assert.Equal(t, oldPage, *noti.Old) + assert.Equal(t, newPage2, *noti.New) } } } + +func TestInsertVulnerabilityNotifications(t *testing.T) { + datastore, tx := openSessionForTest(t, "InsertVulnerabilityNotifications", true) + + n1 := database.VulnerabilityNotification{} + n3 := database.VulnerabilityNotification{ + NotificationHook: database.NotificationHook{ + Name: "random name", + Created: time.Now(), + }, + Old: nil, + New: &database.Vulnerability{}, + } + n4 := database.VulnerabilityNotification{ + NotificationHook: database.NotificationHook{ + Name: "random name", + Created: time.Now(), + }, + Old: nil, + New: &database.Vulnerability{ + Name: "CVE-OPENSSL-1-DEB7", + Namespace: database.Namespace{ + Name: "debian:7", + VersionFormat: "dpkg", + }, + }, + } + + // invalid case + err := tx.InsertVulnerabilityNotifications([]database.VulnerabilityNotification{n1}) + assert.NotNil(t, err) + + // invalid case: unknown vulnerability + err = tx.InsertVulnerabilityNotifications([]database.VulnerabilityNotification{n3}) + assert.NotNil(t, err) + + // invalid case: duplicated input notification + err = tx.InsertVulnerabilityNotifications([]database.VulnerabilityNotification{n4, n4}) + assert.NotNil(t, err) + tx = restartSession(t, datastore, tx, false) + + // valid case + err = tx.InsertVulnerabilityNotifications([]database.VulnerabilityNotification{n4}) + assert.Nil(t, err) + // invalid case: notification is already in database + err = tx.InsertVulnerabilityNotifications([]database.VulnerabilityNotification{n4}) + assert.NotNil(t, err) + + closeTest(t, datastore, tx) +} + +func TestFindNewNotification(t *testing.T) { + datastore, tx := openSessionForTest(t, "FindNewNotification", true) + defer closeTest(t, datastore, tx) + + noti, ok, err := tx.FindNewNotification(time.Now()) + if assert.Nil(t, err) && assert.True(t, ok) { + assert.Equal(t, "test", noti.Name) + assert.Equal(t, time.Time{}, noti.Notified) + assert.Equal(t, time.Time{}, noti.Created) + assert.Equal(t, time.Time{}, noti.Deleted) + } + + // can't find the notified + assert.Nil(t, tx.MarkNotificationNotified("test")) + // if the notified time is before + noti, ok, err = tx.FindNewNotification(time.Now().Add(-time.Duration(10 * time.Second))) + assert.Nil(t, err) + assert.False(t, ok) + // can find the notified after a period of time + noti, ok, err = tx.FindNewNotification(time.Now().Add(time.Duration(1000))) + if assert.Nil(t, err) && assert.True(t, ok) { + assert.Equal(t, "test", noti.Name) + assert.NotEqual(t, time.Time{}, noti.Notified) + assert.Equal(t, time.Time{}, noti.Created) + assert.Equal(t, time.Time{}, noti.Deleted) + } + + assert.Nil(t, tx.DeleteNotification("test")) + // can't find in any time + noti, ok, err = tx.FindNewNotification(time.Now().Add(-time.Duration(1000))) + assert.Nil(t, err) + assert.False(t, ok) + + noti, ok, err = tx.FindNewNotification(time.Now().Add(time.Duration(1000))) + assert.Nil(t, err) + assert.False(t, ok) +} + +func TestMarkNotificationNotified(t *testing.T) { + datastore, tx := openSessionForTest(t, "MarkNotificationNotified", true) + defer closeTest(t, datastore, tx) + + // invalid case: notification doesn't exist + assert.NotNil(t, tx.MarkNotificationNotified("non-existing")) + // valid case + assert.Nil(t, tx.MarkNotificationNotified("test")) + // valid case + assert.Nil(t, tx.MarkNotificationNotified("test")) +} + +func TestDeleteNotification(t *testing.T) { + datastore, tx := openSessionForTest(t, "DeleteNotification", true) + defer closeTest(t, datastore, tx) + + // invalid case: notification doesn't exist + assert.NotNil(t, tx.DeleteNotification("non-existing")) + // valid case + assert.Nil(t, tx.DeleteNotification("test")) + // invalid case: notification is already deleted + assert.NotNil(t, tx.DeleteNotification("test")) +} diff --git a/database/pgsql/pgsql.go b/database/pgsql/pgsql.go index 34504a9a..8815fabb 100644 --- a/database/pgsql/pgsql.go +++ b/database/pgsql/pgsql.go @@ -31,6 +31,7 @@ import ( "github.com/remind101/migrate" log "github.com/sirupsen/logrus" + "github.com/coreos/clair/api/token" "github.com/coreos/clair/database" "github.com/coreos/clair/database/pgsql/migrations" "github.com/coreos/clair/pkg/commonerr" @@ -59,7 +60,7 @@ var ( promConcurrentLockVAFV = prometheus.NewGauge(prometheus.GaugeOpts{ Name: "clair_pgsql_concurrent_lock_vafv_total", - Help: "Number of transactions trying to hold the exclusive Vulnerability_Affects_FeatureVersion lock.", + Help: "Number of transactions trying to hold the exclusive Vulnerability_Affects_Feature lock.", }) ) @@ -73,17 +74,65 @@ func init() { database.Register("pgsql", openDatabase) } -type Queryer interface { - Query(query string, args ...interface{}) (*sql.Rows, error) - QueryRow(query string, args ...interface{}) *sql.Row +// pgSessionCache is the session's cache, which holds the pgSQL's cache and the +// individual session's cache. Only when session.Commit is called, all the +// changes to pgSQL cache will be applied. +type pgSessionCache struct { + c *lru.ARCCache } type pgSQL struct { *sql.DB + cache *lru.ARCCache config Config } +type pgSession struct { + *sql.Tx + + paginationKey string +} + +type idPageNumber struct { + // StartID is an implementation detail for paginating by an ID required to + // be unique to every ancestry and always increasing. + // + // StartID is used to search for ancestry with ID >= StartID + StartID int64 +} + +func encryptPage(page idPageNumber, paginationKey string) (result database.PageNumber, err error) { + resultBytes, err := token.Marshal(page, paginationKey) + if err != nil { + return result, err + } + result = database.PageNumber(resultBytes) + return result, nil +} + +func decryptPage(page database.PageNumber, paginationKey string) (result idPageNumber, err error) { + err = token.Unmarshal(string(page), paginationKey, &result) + return +} + +// Begin initiates a transaction to database. The expected transaction isolation +// level in this implementation is "Read Committed". +func (pgSQL *pgSQL) Begin() (database.Session, error) { + tx, err := pgSQL.DB.Begin() + if err != nil { + return nil, err + } + return &pgSession{ + Tx: tx, + paginationKey: pgSQL.config.PaginationKey, + }, nil +} + +func (tx *pgSession) Commit() error { + return tx.Tx.Commit() +} + // Close closes the database and destroys if ManageDatabaseLifecycle has been specified in // the configuration. func (pgSQL *pgSQL) Close() { @@ -109,6 +158,7 @@ type Config struct { ManageDatabaseLifecycle bool FixturePath string + PaginationKey string } // openDatabase opens a PostgresSQL-backed Datastore using the given @@ -134,6 +184,10 @@ func openDatabase(registrableComponentConfig database.RegistrableComponentConfig return nil, fmt.Errorf("pgsql: could not load configuration: %v", err) } + if pg.config.PaginationKey == "" { + panic("pagination key should be given") + } + dbName, pgSourceURL, err := parseConnectionString(pg.config.Source) if err != nil { return nil, err @@ -179,7 +233,7 @@ func openDatabase(registrableComponentConfig database.RegistrableComponentConfig _, err = pg.DB.Exec(string(d)) if err != nil { pg.Close() - return nil, fmt.Errorf("pgsql: an error occured while importing fixtures: %v", err) + return nil, fmt.Errorf("pgsql: an error occurred while importing fixtures: %v", err) } } @@ -217,7 +271,7 @@ func migrateDatabase(db *sql.DB) error { err := migrate.NewPostgresMigrator(db).Exec(migrate.Up, migrations.Migrations...) if err != nil { - return fmt.Errorf("pgsql: an error occured while running migrations: %v", err) + return fmt.Errorf("pgsql: an error occurred while running migrations: %v", err) } log.Info("database migration ran successfully") @@ -271,7 +325,8 @@ func dropDatabase(source, dbName string) error { } // handleError logs an error with an extra description and masks the error if it's an SQL one. -// This ensures we never return plain SQL errors and leak anything. +// The function ensures we never return plain SQL errors and leak anything. +// The function should be used for every database query error. func handleError(desc string, err error) error { if err == nil { return nil @@ -297,6 +352,11 @@ func isErrUniqueViolation(err error) bool { return ok && pqErr.Code == "23505" } +// observeQueryTime computes the time elapsed since `start` to represent the +// query time. +// 1. `query` is a pgSession function name. +// 2. `subquery` is a specific query or a batched query. +// 3. `start` is the time right before query is executed. func observeQueryTime(query, subquery string, start time.Time) { promQueryDurationMilliseconds. WithLabelValues(query, subquery). diff --git a/database/pgsql/pgsql_test.go b/database/pgsql/pgsql_test.go index 93f53144..96241666 100644 --- a/database/pgsql/pgsql_test.go +++ b/database/pgsql/pgsql_test.go @@ -15,27 +15,193 @@ package pgsql import ( + "database/sql" "fmt" + "io/ioutil" "os" "path/filepath" "runtime" "strings" + "testing" + fernet "github.com/fernet/fernet-go" "github.com/pborman/uuid" + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + yaml "gopkg.in/yaml.v2" "github.com/coreos/clair/database" ) -func openDatabaseForTest(testName string, loadFixture bool) (*pgSQL, error) { - ds, err := openDatabase(generateTestConfig(testName, loadFixture)) +var ( + withFixtureName, withoutFixtureName string +) + +func genTemplateDatabase(name string, loadFixture bool) (sourceURL string, dbName string) { + config := generateTestConfig(name, loadFixture, false) + source := config.Options["source"].(string) + name, url, err := parseConnectionString(source) + if err != nil { + panic(err) + } + + fixturePath := config.Options["fixturepath"].(string) + + if err := createDatabase(url, name); err != nil { + panic(err) + } + + // migration and fixture + db, err := sql.Open("postgres", source) + if err != nil { + panic(err) + } + + // Verify database state. + if err := db.Ping(); err != nil { + panic(err) + } + + // Run migrations. + if err := migrateDatabase(db); err != nil { + panic(err) + } + + if loadFixture { + log.Info("pgsql: loading fixtures") + + d, err := ioutil.ReadFile(fixturePath) + if err != nil { + panic(err) + } + + _, err = db.Exec(string(d)) + if err != nil { + panic(err) + } + } + + db.Exec("UPDATE pg_database SET datistemplate=True WHERE datname=$1", name) + db.Close() + + log.Info("Generated Template database ", name) + return url, name +} + +func dropTemplateDatabase(url string, name string) { + db, err := sql.Open("postgres", url) + if err != nil { + panic(err) + } + + if _, err := db.Exec("UPDATE pg_database SET datistemplate=False WHERE datname=$1", name); err != nil { + panic(err) + } + + if err := db.Close(); err != nil { + panic(err) + } + + if err := dropDatabase(url, name); err != nil { + panic(err) + } + +} +func TestMain(m *testing.M) { + fURL, fName := genTemplateDatabase("fixture", true) + nfURL, nfName := genTemplateDatabase("nonfixture", false) + + withFixtureName = fName + withoutFixtureName = nfName + + m.Run() + + dropTemplateDatabase(fURL, fName) + dropTemplateDatabase(nfURL, nfName) +} + +func openCopiedDatabase(testConfig database.RegistrableComponentConfig, fixture bool) (database.Datastore, error) { + var fixtureName string + if fixture { + fixtureName = withFixtureName + } else { + fixtureName = withoutFixtureName + } + + // copy the database into new database + var pg pgSQL + // Parse configuration. + pg.config = Config{ + CacheSize: 16384, + } + + bytes, err := yaml.Marshal(testConfig.Options) + if err != nil { + return nil, fmt.Errorf("pgsql: could not load configuration: %v", err) + } + err = yaml.Unmarshal(bytes, &pg.config) + if err != nil { + return nil, fmt.Errorf("pgsql: could not load configuration: %v", err) + } + + dbName, pgSourceURL, err := parseConnectionString(pg.config.Source) if err != nil { return nil, err } - datastore := ds.(*pgSQL) + + // Create database. + if pg.config.ManageDatabaseLifecycle { + if err = copyDatabase(pgSourceURL, dbName, fixtureName); err != nil { + return nil, err + } + } + + // Open database. + pg.DB, err = sql.Open("postgres", pg.config.Source) + fmt.Println("database", pg.config.Source) + if err != nil { + pg.Close() + return nil, fmt.Errorf("pgsql: could not open database: %v", err) + } + + return &pg, nil +} + +// copyDatabase creates a new database with +func copyDatabase(url, name string, templateName string) error { + // Open database. + db, err := sql.Open("postgres", url) + if err != nil { + return fmt.Errorf("pgsql: could not open 'postgres' database for creation: %v", err) + } + defer db.Close() + + // Create database with copy + _, err = db.Exec("CREATE DATABASE " + name + " WITH TEMPLATE " + templateName) + if err != nil { + return fmt.Errorf("pgsql: could not create database: %v", err) + } + + return nil +} + +func openDatabaseForTest(testName string, loadFixture bool) (*pgSQL, error) { + var ( + db database.Datastore + err error + testConfig = generateTestConfig(testName, loadFixture, true) + ) + + db, err = openCopiedDatabase(testConfig, loadFixture) + + if err != nil { + return nil, err + } + datastore := db.(*pgSQL) return datastore, nil } -func generateTestConfig(testName string, loadFixture bool) database.RegistrableComponentConfig { +func generateTestConfig(testName string, loadFixture bool, manageLife bool) database.RegistrableComponentConfig { dbName := "test_" + strings.ToLower(testName) + "_" + strings.Replace(uuid.New(), "-", "_", -1) var fixturePath string @@ -49,12 +215,60 @@ func generateTestConfig(testName string, loadFixture bool) database.RegistrableC source = fmt.Sprintf(sourceEnv, dbName) } + var key fernet.Key + if err := key.Generate(); err != nil { + panic("failed to generate pagination key" + err.Error()) + } + return database.RegistrableComponentConfig{ Options: map[string]interface{}{ "source": source, "cachesize": 0, - "managedatabaselifecycle": true, + "managedatabaselifecycle": manageLife, "fixturepath": fixturePath, + "paginationkey": key.Encode(), }, } } + +func closeTest(t *testing.T, store database.Datastore, session database.Session) { + err := session.Rollback() + if err != nil { + t.Error(err) + t.FailNow() + } + + store.Close() +} + +func openSessionForTest(t *testing.T, name string, loadFixture bool) (*pgSQL, *pgSession) { + store, err := openDatabaseForTest(name, loadFixture) + if err != nil { + t.Error(err) + t.FailNow() + } + tx, err := store.Begin() + if err != nil { + t.Error(err) + t.FailNow() + } + return store, tx.(*pgSession) +} + +func restartSession(t *testing.T, datastore *pgSQL, tx *pgSession, commit bool) *pgSession { + var err error + if !commit { + err = tx.Rollback() + } else { + err = tx.Commit() + } + + if assert.Nil(t, err) { + session, err := datastore.Begin() + if assert.Nil(t, err) { + return session.(*pgSession) + } + } + t.FailNow() + return nil +} diff --git a/database/pgsql/queries.go b/database/pgsql/queries.go index 3fedf8d0..c7bd689b 100644 --- a/database/pgsql/queries.go +++ b/database/pgsql/queries.go @@ -14,185 +14,159 @@ package pgsql -import "strconv" +import ( + "fmt" + "strings" + + "github.com/lib/pq" +) const ( - lockVulnerabilityAffects = `LOCK Vulnerability_Affects_FeatureVersion IN SHARE ROW EXCLUSIVE MODE` - disableHashJoin = `SET LOCAL enable_hashjoin = off` - disableMergeJoin = `SET LOCAL enable_mergejoin = off` + lockVulnerabilityAffects = `LOCK vulnerability_affected_namespaced_feature IN SHARE ROW EXCLUSIVE MODE` // keyvalue.go - updateKeyValue = `UPDATE KeyValue SET value = $1 WHERE key = $2` - insertKeyValue = `INSERT INTO KeyValue(key, value) VALUES($1, $2)` searchKeyValue = `SELECT value FROM KeyValue WHERE key = $1` + upsertKeyValue = ` + INSERT INTO KeyValue(key, value) + VALUES ($1, $2) + ON CONFLICT ON CONSTRAINT keyvalue_key_key + DO UPDATE SET key=$1, value=$2` // namespace.go - soiNamespace = ` - WITH new_namespace AS ( - INSERT INTO Namespace(name, version_format) - SELECT CAST($1 AS VARCHAR), CAST($2 AS VARCHAR) - WHERE NOT EXISTS (SELECT name FROM Namespace WHERE name = $1) - RETURNING id - ) - SELECT id FROM Namespace WHERE name = $1 - UNION - SELECT id FROM new_namespace` - searchNamespace = `SELECT id FROM Namespace WHERE name = $1` - listNamespace = `SELECT id, name, version_format FROM Namespace` + searchNamespaceID = `SELECT id FROM Namespace WHERE name = $1 AND version_format = $2` // feature.go - soiFeature = ` - WITH new_feature AS ( - INSERT INTO Feature(name, namespace_id) - SELECT CAST($1 AS VARCHAR), CAST($2 AS INTEGER) - WHERE NOT EXISTS (SELECT id FROM Feature WHERE name = $1 AND namespace_id = $2) + soiNamespacedFeature = ` + WITH new_feature_ns AS ( + INSERT INTO namespaced_feature(feature_id, namespace_id) + SELECT CAST ($1 AS INTEGER), CAST ($2 AS INTEGER) + WHERE NOT EXISTS ( SELECT id FROM namespaced_feature WHERE namespaced_feature.feature_id = $1 AND namespaced_feature.namespace_id = $2) RETURNING id ) - SELECT id FROM Feature WHERE name = $1 AND namespace_id = $2 + SELECT id FROM namespaced_feature WHERE namespaced_feature.feature_id = $1 AND namespaced_feature.namespace_id = $2 UNION - SELECT id FROM new_feature` + SELECT id FROM new_feature_ns` - searchFeatureVersion = ` - SELECT id FROM FeatureVersion WHERE feature_id = $1 AND version = $2` + searchPotentialAffectingVulneraibilities = ` + SELECT nf.id, v.id, vaf.affected_version, vaf.id + FROM vulnerability_affected_feature AS vaf, vulnerability AS v, + namespaced_feature AS nf, feature AS f + WHERE nf.id = ANY($1) + AND nf.feature_id = f.id + AND nf.namespace_id = v.namespace_id + AND vaf.feature_name = f.name + AND vaf.vulnerability_id = v.id + AND v.deleted_at IS NULL` - soiFeatureVersion = ` - WITH new_featureversion AS ( - INSERT INTO FeatureVersion(feature_id, version) - SELECT CAST($1 AS INTEGER), CAST($2 AS VARCHAR) - WHERE NOT EXISTS (SELECT id FROM FeatureVersion WHERE feature_id = $1 AND version = $2) - RETURNING id - ) - SELECT false, id FROM FeatureVersion WHERE feature_id = $1 AND version = $2 - UNION - SELECT true, id FROM new_featureversion` - - searchVulnerabilityFixedInFeature = ` - SELECT id, vulnerability_id, version FROM Vulnerability_FixedIn_Feature - WHERE feature_id = $1` - - insertVulnerabilityAffectsFeatureVersion = ` - INSERT INTO Vulnerability_Affects_FeatureVersion(vulnerability_id, featureversion_id, fixedin_id) - VALUES($1, $2, $3)` + searchNamespacedFeaturesVulnerabilities = ` + SELECT vanf.namespaced_feature_id, v.name, v.description, v.link, + v.severity, v.metadata, vaf.fixedin, n.name, n.version_format + FROM vulnerability_affected_namespaced_feature AS vanf, + Vulnerability AS v, + vulnerability_affected_feature AS vaf, + namespace AS n + WHERE vanf.namespaced_feature_id = ANY($1) + AND vaf.id = vanf.added_by + AND v.id = vanf.vulnerability_id + AND n.id = v.namespace_id + AND v.deleted_at IS NULL` // layer.go - searchLayer = ` - SELECT l.id, l.name, l.engineversion, p.id, p.name - FROM Layer l - LEFT JOIN Layer p ON l.parent_id = p.id - WHERE l.name = $1;` + searchLayerIDs = `SELECT id, hash FROM layer WHERE hash = ANY($1);` - searchLayerNamespace = ` - SELECT n.id, n.name, n.version_format - FROM Namespace n - JOIN Layer_Namespace lns ON lns.namespace_id = n.id - WHERE lns.layer_id = $1` + searchLayerFeatures = ` + SELECT feature.Name, feature.Version, feature.version_format + FROM feature, layer_feature + WHERE layer_feature.layer_id = $1 + AND layer_feature.feature_id = feature.id` - searchLayerFeatureVersion = ` - WITH RECURSIVE layer_tree(id, name, parent_id, depth, path, cycle) AS( - SELECT l.id, l.name, l.parent_id, 1, ARRAY[l.id], false - FROM Layer l - WHERE l.id = $1 - UNION ALL - SELECT l.id, l.name, l.parent_id, lt.depth + 1, path || l.id, l.id = ANY(path) - FROM Layer l, layer_tree lt - WHERE l.id = lt.parent_id - ) - SELECT ldf.featureversion_id, ldf.modification, fn.id, fn.name, fn.version_format, f.id, f.name, fv.id, fv.version, ltree.id, ltree.name - FROM Layer_diff_FeatureVersion ldf - JOIN ( - SELECT row_number() over (ORDER BY depth DESC), id, name FROM layer_tree - ) AS ltree (ordering, id, name) ON ldf.layer_id = ltree.id, FeatureVersion fv, Feature f, Namespace fn - WHERE ldf.featureversion_id = fv.id AND fv.feature_id = f.id AND f.namespace_id = fn.id - ORDER BY ltree.ordering` + searchLayerNamespaces = ` + SELECT namespace.Name, namespace.version_format + FROM namespace, layer_namespace + WHERE layer_namespace.layer_id = $1 + AND layer_namespace.namespace_id = namespace.id` - searchFeatureVersionVulnerability = ` - SELECT vafv.featureversion_id, v.id, v.name, v.description, v.link, v.severity, v.metadata, - vn.name, vn.version_format, vfif.version - FROM Vulnerability_Affects_FeatureVersion vafv, Vulnerability v, - Namespace vn, Vulnerability_FixedIn_Feature vfif - WHERE vafv.featureversion_id = ANY($1::integer[]) - AND vfif.vulnerability_id = v.id - AND vafv.fixedin_id = vfif.id - AND v.namespace_id = vn.id - AND v.deleted_at IS NULL` - - insertLayer = ` - INSERT INTO Layer(name, engineversion, parent_id, created_at) - VALUES($1, $2, $3, CURRENT_TIMESTAMP) - RETURNING id` - - insertLayerNamespace = `INSERT INTO Layer_Namespace(layer_id, namespace_id) VALUES($1, $2)` - removeLayerNamespace = `DELETE FROM Layer_Namespace WHERE layer_id = $1` - - updateLayer = `UPDATE LAYER SET engineversion = $2 WHERE id = $1` - - removeLayerDiffFeatureVersion = ` - DELETE FROM Layer_diff_FeatureVersion - WHERE layer_id = $1` - - insertLayerDiffFeatureVersion = ` - INSERT INTO Layer_diff_FeatureVersion(layer_id, featureversion_id, modification) - SELECT $1, fv.id, $2 - FROM FeatureVersion fv - WHERE fv.id = ANY($3::integer[])` - - removeLayer = `DELETE FROM Layer WHERE name = $1` + searchLayer = `SELECT id FROM layer WHERE hash = $1` + searchLayerDetectors = `SELECT detector FROM layer_detector WHERE layer_id = $1` + searchLayerListers = `SELECT lister FROM layer_lister WHERE layer_id = $1` // lock.go - insertLock = `INSERT INTO Lock(name, owner, until) VALUES($1, $2, $3)` + soiLock = `INSERT INTO lock(name, owner, until) VALUES ($1, $2, $3)` + searchLock = `SELECT owner, until FROM Lock WHERE name = $1` updateLock = `UPDATE Lock SET until = $3 WHERE name = $1 AND owner = $2` removeLock = `DELETE FROM Lock WHERE name = $1 AND owner = $2` removeLockExpired = `DELETE FROM LOCK WHERE until < CURRENT_TIMESTAMP` // vulnerability.go - searchVulnerabilityBase = ` - SELECT v.id, v.name, n.id, n.name, n.version_format, v.description, v.link, v.severity, v.metadata - FROM Vulnerability v JOIN Namespace n ON v.namespace_id = n.id` - searchVulnerabilityForUpdate = ` FOR UPDATE OF v` - searchVulnerabilityByNamespaceAndName = ` WHERE n.name = $1 AND v.name = $2 AND v.deleted_at IS NULL` - searchVulnerabilityByID = ` WHERE v.id = $1` - searchVulnerabilityByNamespace = ` WHERE n.name = $1 AND v.deleted_at IS NULL - AND v.id >= $2 - ORDER BY v.id - LIMIT $3` + searchVulnerability = ` + SELECT v.id, v.description, v.link, v.severity, v.metadata, n.version_format + FROM vulnerability AS v, namespace AS n + WHERE v.namespace_id = n.id + AND v.name = $1 + AND n.name = $2 + AND v.deleted_at IS NULL + ` - searchVulnerabilityFixedIn = ` - SELECT vfif.version, f.id, f.Name - FROM Vulnerability_FixedIn_Feature vfif JOIN Feature f ON vfif.feature_id = f.id - WHERE vfif.vulnerability_id = $1` + insertVulnerabilityAffected = ` + INSERT INTO vulnerability_affected_feature(vulnerability_id, feature_name, affected_version, fixedin) + VALUES ($1, $2, $3, $4) + RETURNING ID + ` + + searchVulnerabilityAffected = ` + SELECT vulnerability_id, feature_name, affected_version, fixedin + FROM vulnerability_affected_feature + WHERE vulnerability_id = ANY($1) + ` + + searchVulnerabilityByID = ` + SELECT v.name, v.description, v.link, v.severity, v.metadata, n.name, n.version_format + FROM vulnerability AS v, namespace AS n + WHERE v.namespace_id = n.id + AND v.id = $1` + + searchVulnerabilityPotentialAffected = ` + WITH req AS ( + SELECT vaf.id AS vaf_id, n.id AS n_id, vaf.feature_name AS name, v.id AS vulnerability_id + FROM vulnerability_affected_feature AS vaf, + vulnerability AS v, + namespace AS n + WHERE vaf.vulnerability_id = ANY($1) + AND v.id = vaf.vulnerability_id + AND n.id = v.namespace_id + ) + SELECT req.vulnerability_id, nf.id, f.version, req.vaf_id AS added_by + FROM feature AS f, namespaced_feature AS nf, req + WHERE f.name = req.name + AND nf.namespace_id = req.n_id + AND nf.feature_id = f.id` + + insertVulnerabilityAffectedNamespacedFeature = ` + INSERT INTO vulnerability_affected_namespaced_feature(vulnerability_id, namespaced_feature_id, added_by) + VALUES ($1, $2, $3)` insertVulnerability = ` - INSERT INTO Vulnerability(namespace_id, name, description, link, severity, metadata, created_at) - VALUES($1, $2, $3, $4, $5, $6, CURRENT_TIMESTAMP) - RETURNING id` - - soiVulnerabilityFixedInFeature = ` - WITH new_fixedinfeature AS ( - INSERT INTO Vulnerability_FixedIn_Feature(vulnerability_id, feature_id, version) - SELECT CAST($1 AS INTEGER), CAST($2 AS INTEGER), CAST($3 AS VARCHAR) - WHERE NOT EXISTS (SELECT id FROM Vulnerability_FixedIn_Feature WHERE vulnerability_id = $1 AND feature_id = $2) - RETURNING id + WITH ns AS ( + SELECT id FROM namespace WHERE name = $6 AND version_format = $7 ) - SELECT false, id FROM Vulnerability_FixedIn_Feature WHERE vulnerability_id = $1 AND feature_id = $2 - UNION - SELECT true, id FROM new_fixedinfeature` - - searchFeatureVersionByFeature = `SELECT id, version FROM FeatureVersion WHERE feature_id = $1` + INSERT INTO Vulnerability(namespace_id, name, description, link, severity, metadata, created_at) + VALUES((SELECT id FROM ns), $1, $2, $3, $4, $5, CURRENT_TIMESTAMP) + RETURNING id` removeVulnerability = ` UPDATE Vulnerability - SET deleted_at = CURRENT_TIMESTAMP - WHERE namespace_id = (SELECT id FROM Namespace WHERE name = $1) - AND name = $2 - AND deleted_at IS NULL - RETURNING id` + SET deleted_at = CURRENT_TIMESTAMP + WHERE namespace_id = (SELECT id FROM Namespace WHERE name = $1) + AND name = $2 + AND deleted_at IS NULL + RETURNING id` // notification.go insertNotification = ` INSERT INTO Vulnerability_Notification(name, created_at, old_vulnerability_id, new_vulnerability_id) - VALUES($1, CURRENT_TIMESTAMP, $2, $3)` + VALUES ($1, $2, $3, $4)` updatedNotificationNotified = ` UPDATE Vulnerability_Notification @@ -202,10 +176,10 @@ const ( removeNotification = ` UPDATE Vulnerability_Notification SET deleted_at = CURRENT_TIMESTAMP - WHERE name = $1` + WHERE name = $1 AND deleted_at IS NULL` searchNotificationAvailable = ` - SELECT id, name, created_at, notified_at, deleted_at + SELECT name, created_at, notified_at, deleted_at FROM Vulnerability_Notification WHERE (notified_at IS NULL OR notified_at < $1) AND deleted_at IS NULL @@ -214,43 +188,231 @@ const ( LIMIT 1` searchNotification = ` - SELECT id, name, created_at, notified_at, deleted_at, old_vulnerability_id, new_vulnerability_id + SELECT created_at, notified_at, deleted_at, old_vulnerability_id, new_vulnerability_id FROM Vulnerability_Notification WHERE name = $1` - searchNotificationLayerIntroducingVulnerability = ` - WITH LDFV AS ( - SELECT DISTINCT ldfv.layer_id - FROM Vulnerability_Affects_FeatureVersion vafv, FeatureVersion fv, Layer_diff_FeatureVersion ldfv - WHERE ldfv.layer_id >= $2 - AND vafv.vulnerability_id = $1 - AND vafv.featureversion_id = fv.id - AND ldfv.featureversion_id = fv.id - AND ldfv.modification = 'add' - ORDER BY ldfv.layer_id - ) - SELECT l.id, l.name - FROM LDFV, Layer l - WHERE LDFV.layer_id = l.id - LIMIT $3` + searchNotificationVulnerableAncestry = ` + SELECT DISTINCT ON (a.id) + a.id, a.name + FROM vulnerability_affected_namespaced_feature AS vanf, + ancestry AS a, ancestry_feature AS af + WHERE vanf.vulnerability_id = $1 + AND a.id >= $2 + AND a.id = af.ancestry_id + AND af.namespaced_feature_id = vanf.namespaced_feature_id + ORDER BY a.id ASC + LIMIT $3;` - // complex_test.go - searchComplexTestFeatureVersionAffects = ` - SELECT v.name - FROM FeatureVersion fv - LEFT JOIN Vulnerability_Affects_FeatureVersion vaf ON fv.id = vaf.featureversion_id - JOIN Vulnerability v ON vaf.vulnerability_id = v.id - WHERE featureversion_id = $1` + // ancestry.go + persistAncestryLister = ` + INSERT INTO ancestry_lister (ancestry_id, lister) + SELECT CAST ($1 AS INTEGER), CAST ($2 AS TEXT) + WHERE NOT EXISTS (SELECT id FROM ancestry_lister WHERE ancestry_id = $1 AND lister = $2) ON CONFLICT DO NOTHING` + + persistAncestryDetector = ` + INSERT INTO ancestry_detector (ancestry_id, detector) + SELECT CAST ($1 AS INTEGER), CAST ($2 AS TEXT) + WHERE NOT EXISTS (SELECT id FROM ancestry_detector WHERE ancestry_id = $1 AND detector = $2) ON CONFLICT DO NOTHING` + + insertAncestry = `INSERT INTO ancestry (name) VALUES ($1) RETURNING id` + + searchAncestryLayer = ` + SELECT layer.hash + FROM layer, ancestry_layer + WHERE ancestry_layer.ancestry_id = $1 + AND ancestry_layer.layer_id = layer.id + ORDER BY ancestry_layer.ancestry_index ASC` + + searchAncestryFeatures = ` + SELECT namespace.name, namespace.version_format, feature.name, feature.version + FROM namespace, feature, ancestry, namespaced_feature, ancestry_feature + WHERE ancestry.name = $1 + AND ancestry.id = ancestry_feature.ancestry_id + AND ancestry_feature.namespaced_feature_id = namespaced_feature.id + AND namespaced_feature.feature_id = feature.id + AND namespaced_feature.namespace_id = namespace.id` + + searchAncestry = `SELECT id FROM ancestry WHERE name = $1` + searchAncestryDetectors = `SELECT detector FROM ancestry_detector WHERE ancestry_id = $1` + searchAncestryListers = `SELECT lister FROM ancestry_lister WHERE ancestry_id = $1` + removeAncestry = `DELETE FROM ancestry WHERE name = $1` + insertAncestryLayer = `INSERT INTO ancestry_layer(ancestry_id, ancestry_index, layer_id) VALUES($1,$2,$3)` + insertAncestryFeature = `INSERT INTO ancestry_feature(ancestry_id, namespaced_feature_id) VALUES ($1, $2)` ) -// buildInputArray constructs a PostgreSQL input array from the specified integers. -// Useful to use the `= ANY($1::integer[])` syntax that let us use a IN clause while using -// a single placeholder. -func buildInputArray(ints []int) string { - str := "{" - for i := 0; i < len(ints)-1; i++ { - str = str + strconv.Itoa(ints[i]) + "," - } - str = str + strconv.Itoa(ints[len(ints)-1]) + "}" - return str +// NOTE(Sida): Every search query can only have count less than postgres set +// stack depth. IN will be resolved to nested OR_s and the parser might exceed +// stack depth. TODO(Sida): Generate different queries for different count: if +// count < 5120, use IN; for count > 5120 and < 65536, use temporary table; for +// count > 65535, use is expected to split data into batches. +func querySearchLastDeletedVulnerabilityID(count int) string { + return fmt.Sprintf(` + SELECT vid, vname, nname FROM ( + SELECT v.id AS vid, v.name AS vname, n.name AS nname, + row_number() OVER ( + PARTITION by (v.name, n.name) + ORDER BY v.deleted_at DESC + ) AS rownum + FROM vulnerability AS v, namespace AS n + WHERE v.namespace_id = n.id + AND (v.name, n.name) IN ( %s ) + AND v.deleted_at IS NOT NULL + ) tmp WHERE rownum <= 1`, + queryString(2, count)) +} + +func querySearchNotDeletedVulnerabilityID(count int) string { + return fmt.Sprintf(` + SELECT v.id, v.name, n.name FROM vulnerability AS v, namespace AS n + WHERE v.namespace_id = n.id AND (v.name, n.name) IN (%s) + AND v.deleted_at IS NULL`, + queryString(2, count)) +} + +func querySearchFeatureID(featureCount int) string { + return fmt.Sprintf(` + SELECT id, name, version, version_format + FROM Feature WHERE (name, version, version_format) IN (%s)`, + queryString(3, featureCount), + ) +} + +func querySearchNamespacedFeature(nsfCount int) string { + return fmt.Sprintf(` + SELECT nf.id, f.name, f.version, f.version_format, n.name + FROM namespaced_feature AS nf, feature AS f, namespace AS n + WHERE nf.feature_id = f.id + AND nf.namespace_id = n.id + AND n.version_format = f.version_format + AND (f.name, f.version, f.version_format, n.name) IN (%s)`, + queryString(4, nsfCount), + ) +} + +func querySearchNamespace(nsCount int) string { + return fmt.Sprintf( + `SELECT id, name, version_format + FROM namespace WHERE (name, version_format) IN (%s)`, + queryString(2, nsCount), + ) +} + +func queryInsert(count int, table string, columns ...string) string { + base := `INSERT INTO %s (%s) VALUES %s` + t := pq.QuoteIdentifier(table) + cols := make([]string, len(columns)) + for i, c := range columns { + cols[i] = pq.QuoteIdentifier(c) + } + colsQuoted := strings.Join(cols, ",") + return fmt.Sprintf(base, t, colsQuoted, queryString(len(columns), count)) +} + +func queryPersist(count int, table, constraint string, columns ...string) string { + ct := "" + if constraint != "" { + ct = fmt.Sprintf("ON CONSTRAINT %s", constraint) + } + return fmt.Sprintf("%s ON CONFLICT %s DO NOTHING", queryInsert(count, table, columns...), ct) +} + +func queryInsertNotifications(count int) string { + return queryInsert(count, + "vulnerability_notification", + "name", + "created_at", + "old_vulnerability_id", + "new_vulnerability_id", + ) +} + +func queryPersistFeature(count int) string { + return queryPersist(count, + "feature", + "feature_name_version_version_format_key", + "name", + "version", + "version_format") +} + +func queryPersistLayerFeature(count int) string { + return queryPersist(count, + "layer_feature", + "layer_feature_layer_id_feature_id_key", + "layer_id", + "feature_id") +} + +func queryPersistNamespace(count int) string { + return queryPersist(count, + "namespace", + "namespace_name_version_format_key", + "name", + "version_format") +} + +func queryPersistLayerListers(count int) string { + return queryPersist(count, + "layer_lister", + "layer_lister_layer_id_lister_key", + "layer_id", + "lister") +} + +func queryPersistLayerDetectors(count int) string { + return queryPersist(count, + "layer_detector", + "layer_detector_layer_id_detector_key", + "layer_id", + "detector") +} + +func queryPersistLayerNamespace(count int) string { + return queryPersist(count, + "layer_namespace", + "layer_namespace_layer_id_namespace_id_key", + "layer_id", + "namespace_id") +} + +// size of key and array should be both greater than 0 +func queryString(keySize, arraySize int) string { + if arraySize <= 0 || keySize <= 0 { + panic("Bulk Query requires size of element tuple and number of elements to be greater than 0") + } + keys := make([]string, 0, arraySize) + for i := 0; i < arraySize; i++ { + key := make([]string, keySize) + for j := 0; j < keySize; j++ { + key[j] = fmt.Sprintf("$%d", i*keySize+j+1) + } + keys = append(keys, fmt.Sprintf("(%s)", strings.Join(key, ","))) + } + return strings.Join(keys, ",") +} + +func queryPersistNamespacedFeature(count int) string { + return queryPersist(count, "namespaced_feature", + "namespaced_feature_namespace_id_feature_id_key", + "feature_id", + "namespace_id") +} + +func queryPersistVulnerabilityAffectedNamespacedFeature(count int) string { + return queryPersist(count, "vulnerability_affected_namespaced_feature", + "vulnerability_affected_namesp_vulnerability_id_namespaced_f_key", + "vulnerability_id", + "namespaced_feature_id", + "added_by") +} + +func queryPersistLayer(count int) string { + return queryPersist(count, "layer", "", "hash") +} + +func queryInvalidateVulnerabilityCache(count int) string { + return fmt.Sprintf(`DELETE FROM vulnerability_affected_feature + WHERE vulnerability_id IN (%s)`, + queryString(1, count)) } diff --git a/database/pgsql/testdata/data.sql b/database/pgsql/testdata/data.sql index b01e170e..a4ccd31c 100644 --- a/database/pgsql/testdata/data.sql +++ b/database/pgsql/testdata/data.sql @@ -1,73 +1,117 @@ --- Copyright 2015 clair authors --- --- Licensed under the Apache License, Version 2.0 (the "License"); --- you may not use this file except in compliance with the License. --- You may obtain a copy of the License at --- --- http://www.apache.org/licenses/LICENSE-2.0 --- --- Unless required by applicable law or agreed to in writing, software --- distributed under the License is distributed on an "AS IS" BASIS, --- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. --- See the License for the specific language governing permissions and --- limitations under the License. - INSERT INTO namespace (id, name, version_format) VALUES - (1, 'debian:7', 'dpkg'), - (2, 'debian:8', 'dpkg'); +(1, 'debian:7', 'dpkg'), +(2, 'debian:8', 'dpkg'), +(3, 'fake:1.0', 'rpm'); -INSERT INTO feature (id, namespace_id, name) VALUES - (1, 1, 'wechat'), - (2, 1, 'openssl'), - (4, 1, 'libssl'), - (3, 2, 'openssl'); +INSERT INTO feature (id, name, version, version_format) VALUES +(1, 'wechat', '0.5', 'dpkg'), +(2, 'openssl', '1.0', 'dpkg'), +(3, 'openssl', '2.0', 'dpkg'), +(4, 'fake', '2.0', 'rpm'); -INSERT INTO featureversion (id, feature_id, version) VALUES - (1, 1, '0.5'), - (2, 2, '1.0'), - (3, 2, '2.0'), - (4, 3, '1.0'); +INSERT INTO layer (id, hash) VALUES + (1, 'layer-0'), -- blank + (2, 'layer-1'), -- debian:7; wechat 0.5, openssl 1.0 + (3, 'layer-2'), -- debian:7; wechat 0.5, openssl 2.0 + (4, 'layer-3a'),-- debian:7; + (5, 'layer-3b'),-- debian:8; wechat 0.5, openssl 1.0 + (6, 'layer-4'); -- debian:7, fake:1.0; openssl 2.0 (debian), fake 2.0 (fake) -INSERT INTO layer (id, name, engineversion, parent_id) VALUES - (1, 'layer-0', 1, NULL), - (2, 'layer-1', 1, 1), - (3, 'layer-2', 1, 2), - (4, 'layer-3a', 1, 3), - (5, 'layer-3b', 1, 3); - -INSERT INTO layer_namespace (id, layer_id, namespace_id) VALUES +INSERT INTO layer_namespace(id, layer_id, namespace_id) VALUES (1, 2, 1), (2, 3, 1), (3, 4, 1), (4, 5, 2), - (5, 5, 1); + (5, 6, 1), + (6, 6, 3); -INSERT INTO layer_diff_featureversion (id, layer_id, featureversion_id, modification) VALUES - (1, 2, 1, 'add'), - (2, 2, 2, 'add'), - (3, 3, 2, 'del'), -- layer-2: Update Debian:7 OpenSSL 1.0 -> 2.0 - (4, 3, 3, 'add'), -- ^ - (5, 5, 3, 'del'), -- layer-3b: Delete Debian:7 OpenSSL 2.0 - (6, 5, 4, 'add'); -- layer-3b: Add Debian:8 OpenSSL 1.0 +INSERT INTO layer_feature(id, layer_id, feature_id) VALUES + (1, 2, 1), + (2, 2, 2), + (3, 3, 1), + (4, 3, 3), + (5, 5, 1), + (6, 5, 2), + (7, 6, 4), + (8, 6, 3); + +INSERT INTO layer_lister(id, layer_id, lister) VALUES + (1, 1, 'dpkg'), + (2, 2, 'dpkg'), + (3, 3, 'dpkg'), + (4, 4, 'dpkg'), + (5, 5, 'dpkg'), + (6, 6, 'dpkg'), + (7, 6, 'rpm'); + +INSERT INTO layer_detector(id, layer_id, detector) VALUES + (1, 1, 'os-release'), + (2, 2, 'os-release'), + (3, 3, 'os-release'), + (4, 4, 'os-release'), + (5, 5, 'os-release'), + (6, 6, 'os-release'), + (7, 6, 'apt-sources'); + +INSERT INTO ancestry (id, name) VALUES + (1, 'ancestry-1'), -- layer-0, layer-1, layer-2, layer-3a + (2, 'ancestry-2'), -- layer-0, layer-1, layer-2, layer-3b + (3, 'ancestry-3'), -- empty; just for testing the vulnerable ancestry + (4, 'ancestry-4'); -- empty; just for testing the vulnerable ancestry + +INSERT INTO ancestry_lister (id, ancestry_id, lister) VALUES + (1, 1, 'dpkg'), + (2, 2, 'dpkg'); + +INSERT INTO ancestry_detector (id, ancestry_id, detector) VALUES + (1, 1, 'os-release'), + (2, 2, 'os-release'); + +INSERT INTO ancestry_layer (id, ancestry_id, layer_id, ancestry_index) VALUES + (1, 1, 1, 0),(2, 1, 2, 1),(3, 1, 3, 2),(4, 1, 4, 3), + (5, 2, 1, 0),(6, 2, 2, 1),(7, 2, 3, 2),(8, 2, 5, 3); + +INSERT INTO namespaced_feature(id, feature_id, namespace_id) VALUES + (1, 1, 1), -- wechat 0.5, debian:7 + (2, 2, 1), -- openssl 1.0, debian:7 + (3, 2, 2), -- openssl 1.0, debian:8 + (4, 3, 1); -- openssl 2.0, debian:7 + +INSERT INTO ancestry_feature (id, ancestry_id, namespaced_feature_id) VALUES + (1, 1, 1), (2, 1, 4), + (3, 2, 1), (4, 2, 3), + (5, 3, 2), (6, 4, 2); -- assume that ancestry-3 and ancestry-4 are vulnerable. INSERT INTO vulnerability (id, namespace_id, name, description, link, severity) VALUES (1, 1, 'CVE-OPENSSL-1-DEB7', 'A vulnerability affecting OpenSSL < 2.0 on Debian 7.0', 'http://google.com/#q=CVE-OPENSSL-1-DEB7', 'High'), (2, 1, 'CVE-NOPE', 'A vulnerability affecting nothing', '', 'Unknown'); -INSERT INTO vulnerability_fixedin_feature (id, vulnerability_id, feature_id, version) VALUES - (1, 1, 2, '2.0'), - (2, 1, 4, '1.9-abc'); +INSERT INTO vulnerability (id, namespace_id, name, description, link, severity, deleted_at) VALUES + (3, 1, 'CVE-DELETED', '', '', 'Unknown', '2017-08-08 17:49:31.668483'); + +INSERT INTO vulnerability_affected_feature(id, vulnerability_id, feature_name, affected_version, fixedin) VALUES +(1, 1, 'openssl', '2.0', '2.0'), +(2, 1, 'libssl', '1.9-abc', '1.9-abc'); -INSERT INTO vulnerability_affects_featureversion (id, vulnerability_id, featureversion_id, fixedin_id) VALUES - (1, 1, 2, 1); -- CVE-OPENSSL-1-DEB7 affects Debian:7 OpenSSL 1.0 +INSERT INTO vulnerability_affected_namespaced_feature(id, vulnerability_id, namespaced_feature_id, added_by) VALUES + (1, 1, 2, 1); + +INSERT INTO vulnerability_notification(id, name, created_at, notified_at, deleted_at, old_vulnerability_id, new_vulnerability_id) VALUES + (1, 'test', NULL, NULL, NULL, 2, 1); -- 'CVE-NOPE' -> 'CVE-OPENSSL-1-DEB7' SELECT pg_catalog.setval(pg_get_serial_sequence('namespace', 'id'), (SELECT MAX(id) FROM namespace)+1); +SELECT pg_catalog.setval(pg_get_serial_sequence('ancestry', 'id'), (SELECT MAX(id) FROM ancestry)+1); +SELECT pg_catalog.setval(pg_get_serial_sequence('ancestry_layer', 'id'), (SELECT MAX(id) FROM ancestry_layer)+1); +SELECT pg_catalog.setval(pg_get_serial_sequence('ancestry_feature', 'id'), (SELECT MAX(id) FROM ancestry_feature)+1); +SELECT pg_catalog.setval(pg_get_serial_sequence('ancestry_detector', 'id'), (SELECT MAX(id) FROM ancestry_detector)+1); +SELECT pg_catalog.setval(pg_get_serial_sequence('ancestry_lister', 'id'), (SELECT MAX(id) FROM ancestry_lister)+1); SELECT pg_catalog.setval(pg_get_serial_sequence('feature', 'id'), (SELECT MAX(id) FROM feature)+1); -SELECT pg_catalog.setval(pg_get_serial_sequence('featureversion', 'id'), (SELECT MAX(id) FROM featureversion)+1); +SELECT pg_catalog.setval(pg_get_serial_sequence('namespaced_feature', 'id'), (SELECT MAX(id) FROM namespaced_feature)+1); SELECT pg_catalog.setval(pg_get_serial_sequence('layer', 'id'), (SELECT MAX(id) FROM layer)+1); SELECT pg_catalog.setval(pg_get_serial_sequence('layer_namespace', 'id'), (SELECT MAX(id) FROM layer_namespace)+1); -SELECT pg_catalog.setval(pg_get_serial_sequence('layer_diff_featureversion', 'id'), (SELECT MAX(id) FROM layer_diff_featureversion)+1); -SELECT pg_catalog.setval(pg_get_serial_sequence('vulnerability', 'id'), (SELECT MAX(id) FROM vulnerability)+1); -SELECT pg_catalog.setval(pg_get_serial_sequence('vulnerability_fixedin_feature', 'id'), (SELECT MAX(id) FROM vulnerability_fixedin_feature)+1); -SELECT pg_catalog.setval(pg_get_serial_sequence('vulnerability_affects_featureversion', 'id'), (SELECT MAX(id) FROM vulnerability_affects_featureversion)+1); +SELECT pg_catalog.setval(pg_get_serial_sequence('layer_detector', 'id'), (SELECT MAX(id) FROM layer_detector)+1); +SELECT pg_catalog.setval(pg_get_serial_sequence('layer_lister', 'id'), (SELECT MAX(id) FROM layer_lister)+1); SELECT pg_catalog.setval(pg_get_serial_sequence('vulnerability', 'id'), (SELECT MAX(id) FROM vulnerability)+1); +SELECT pg_catalog.setval(pg_get_serial_sequence('vulnerability_affected_feature', 'id'), (SELECT MAX(id) FROM vulnerability_affected_feature)+1); +SELECT pg_catalog.setval(pg_get_serial_sequence('vulnerability_affected_namespaced_feature', 'id'), (SELECT MAX(id) FROM vulnerability_affected_namespaced_feature)+1); +SELECT pg_catalog.setval(pg_get_serial_sequence('vulnerability_notification', 'id'), (SELECT MAX(id) FROM vulnerability_notification)+1); diff --git a/database/pgsql/vulnerability.go b/database/pgsql/vulnerability.go index efb57392..ab92c0e9 100644 --- a/database/pgsql/vulnerability.go +++ b/database/pgsql/vulnerability.go @@ -17,352 +17,207 @@ package pgsql import ( "database/sql" "encoding/json" - "reflect" + "errors" "time" - "github.com/guregu/null/zero" + "github.com/lib/pq" log "github.com/sirupsen/logrus" "github.com/coreos/clair/database" "github.com/coreos/clair/ext/versionfmt" - "github.com/coreos/clair/pkg/commonerr" ) -// compareStringLists returns the strings that are present in X but not in Y. -func compareStringLists(X, Y []string) []string { - m := make(map[string]bool) +var ( + errVulnerabilityNotFound = errors.New("vulnerability is not in database") +) - for _, y := range Y { - m[y] = true - } - - diff := []string{} - for _, x := range X { - if m[x] { - continue - } - - diff = append(diff, x) - m[x] = true - } - - return diff +type affectedAncestry struct { + name string + id int64 } -func compareStringListsInBoth(X, Y []string) []string { - m := make(map[string]struct{}) - - for _, y := range Y { - m[y] = struct{}{} - } - - diff := []string{} - for _, x := range X { - if _, e := m[x]; e { - diff = append(diff, x) - delete(m, x) - } - } - - return diff +type affectRelation struct { + vulnerabilityID int64 + namespacedFeatureID int64 + addedBy int64 } -func (pgSQL *pgSQL) ListVulnerabilities(namespaceName string, limit int, startID int) ([]database.Vulnerability, int, error) { - defer observeQueryTime("listVulnerabilities", "all", time.Now()) +type affectedFeatureRows struct { + rows map[int64]database.AffectedFeature +} - // Query Namespace. - var id int - err := pgSQL.QueryRow(searchNamespace, namespaceName).Scan(&id) +func (tx *pgSession) FindVulnerabilities(vulnerabilities []database.VulnerabilityID) ([]database.NullableVulnerability, error) { + defer observeQueryTime("findVulnerabilities", "", time.Now()) + resultVuln := make([]database.NullableVulnerability, len(vulnerabilities)) + vulnIDMap := map[int64][]*database.NullableVulnerability{} + + //TODO(Sida): Change to bulk search. + stmt, err := tx.Prepare(searchVulnerability) if err != nil { - return nil, -1, handleError("searchNamespace", err) - } else if id == 0 { - return nil, -1, commonerr.ErrNotFound + return nil, err } - // Query. - query := searchVulnerabilityBase + searchVulnerabilityByNamespace - rows, err := pgSQL.Query(query, namespaceName, startID, limit+1) - if err != nil { - return nil, -1, handleError("searchVulnerabilityByNamespace", err) - } - defer rows.Close() - - var vulns []database.Vulnerability - nextID := -1 - size := 0 - // Scan query. - for rows.Next() { - var vulnerability database.Vulnerability - - err := rows.Scan( - &vulnerability.ID, - &vulnerability.Name, - &vulnerability.Namespace.ID, - &vulnerability.Namespace.Name, - &vulnerability.Namespace.VersionFormat, - &vulnerability.Description, - &vulnerability.Link, - &vulnerability.Severity, - &vulnerability.Metadata, + // load vulnerabilities + for i, key := range vulnerabilities { + var ( + id sql.NullInt64 + vuln = database.NullableVulnerability{ + VulnerabilityWithAffected: database.VulnerabilityWithAffected{ + Vulnerability: database.Vulnerability{ + Name: key.Name, + Namespace: database.Namespace{ + Name: key.Namespace, + }, + }, + }, + } ) + + err := stmt.QueryRow(key.Name, key.Namespace).Scan( + &id, + &vuln.Description, + &vuln.Link, + &vuln.Severity, + &vuln.Metadata, + &vuln.Namespace.VersionFormat, + ) + + if err != nil && err != sql.ErrNoRows { + stmt.Close() + return nil, handleError("searchVulnerability", err) + } + vuln.Valid = id.Valid + resultVuln[i] = vuln + if id.Valid { + vulnIDMap[id.Int64] = append(vulnIDMap[id.Int64], &resultVuln[i]) + } + } + + if err := stmt.Close(); err != nil { + return nil, handleError("searchVulnerability", err) + } + + toQuery := make([]int64, 0, len(vulnIDMap)) + for id := range vulnIDMap { + toQuery = append(toQuery, id) + } + + // load vulnerability affected features + rows, err := tx.Query(searchVulnerabilityAffected, pq.Array(toQuery)) + if err != nil { + return nil, handleError("searchVulnerabilityAffected", err) + } + + for rows.Next() { + var ( + id int64 + f database.AffectedFeature + ) + + err := rows.Scan(&id, &f.FeatureName, &f.AffectedVersion, &f.FixedInVersion) if err != nil { - return nil, -1, handleError("searchVulnerabilityByNamespace.Scan()", err) + return nil, handleError("searchVulnerabilityAffected", err) } - size++ - if size > limit { - nextID = vulnerability.ID - } else { - vulns = append(vulns, vulnerability) + + for _, vuln := range vulnIDMap[id] { + f.Namespace = vuln.Namespace + vuln.Affected = append(vuln.Affected, f) } } - if err := rows.Err(); err != nil { - return nil, -1, handleError("searchVulnerabilityByNamespace.Rows()", err) + return resultVuln, nil +} + +func (tx *pgSession) InsertVulnerabilities(vulnerabilities []database.VulnerabilityWithAffected) error { + defer observeQueryTime("insertVulnerabilities", "all", time.Now()) + // bulk insert vulnerabilities + vulnIDs, err := tx.insertVulnerabilities(vulnerabilities) + if err != nil { + return err } - return vulns, nextID, nil -} - -func (pgSQL *pgSQL) FindVulnerability(namespaceName, name string) (database.Vulnerability, error) { - return findVulnerability(pgSQL, namespaceName, name, false) -} - -func findVulnerability(queryer Queryer, namespaceName, name string, forUpdate bool) (database.Vulnerability, error) { - defer observeQueryTime("findVulnerability", "all", time.Now()) - - queryName := "searchVulnerabilityBase+searchVulnerabilityByNamespaceAndName" - query := searchVulnerabilityBase + searchVulnerabilityByNamespaceAndName - if forUpdate { - queryName = queryName + "+searchVulnerabilityForUpdate" - query = query + searchVulnerabilityForUpdate + // bulk insert vulnerability affected features + vulnFeatureMap, err := tx.insertVulnerabilityAffected(vulnIDs, vulnerabilities) + if err != nil { + return err } - return scanVulnerability(queryer, queryName, queryer.QueryRow(query, namespaceName, name)) + return tx.cacheVulnerabiltyAffectedNamespacedFeature(vulnFeatureMap) } -func (pgSQL *pgSQL) findVulnerabilityByIDWithDeleted(id int) (database.Vulnerability, error) { - defer observeQueryTime("findVulnerabilityByIDWithDeleted", "all", time.Now()) - - queryName := "searchVulnerabilityBase+searchVulnerabilityByID" - query := searchVulnerabilityBase + searchVulnerabilityByID - - return scanVulnerability(pgSQL, queryName, pgSQL.QueryRow(query, id)) -} - -func scanVulnerability(queryer Queryer, queryName string, vulnerabilityRow *sql.Row) (database.Vulnerability, error) { - var vulnerability database.Vulnerability - - err := vulnerabilityRow.Scan( - &vulnerability.ID, - &vulnerability.Name, - &vulnerability.Namespace.ID, - &vulnerability.Namespace.Name, - &vulnerability.Namespace.VersionFormat, - &vulnerability.Description, - &vulnerability.Link, - &vulnerability.Severity, - &vulnerability.Metadata, +// insertVulnerabilityAffected inserts a set of vulnerability affected features for each vulnerability provided. +// +// i_th vulnerabilityIDs corresponds to i_th vulnerabilities provided. +func (tx *pgSession) insertVulnerabilityAffected(vulnerabilityIDs []int64, vulnerabilities []database.VulnerabilityWithAffected) (map[int64]affectedFeatureRows, error) { + var ( + vulnFeature = map[int64]affectedFeatureRows{} + affectedID int64 ) + //TODO(Sida): Change to bulk insert. + stmt, err := tx.Prepare(insertVulnerabilityAffected) if err != nil { - return vulnerability, handleError(queryName+".Scan()", err) + return nil, handleError("insertVulnerabilityAffected", err) } - if vulnerability.ID == 0 { - return vulnerability, commonerr.ErrNotFound - } - - // Query the FixedIn FeatureVersion now. - rows, err := queryer.Query(searchVulnerabilityFixedIn, vulnerability.ID) - if err != nil { - return vulnerability, handleError("searchVulnerabilityFixedIn.Scan()", err) - } - defer rows.Close() - - for rows.Next() { - var featureVersionID zero.Int - var featureVersionVersion zero.String - var featureVersionFeatureName zero.String - - err := rows.Scan( - &featureVersionVersion, - &featureVersionID, - &featureVersionFeatureName, - ) - - if err != nil { - return vulnerability, handleError("searchVulnerabilityFixedIn.Scan()", err) - } - - if !featureVersionID.IsZero() { - // Note that the ID we fill in featureVersion is actually a Feature ID, and not - // a FeatureVersion ID. - featureVersion := database.FeatureVersion{ - Model: database.Model{ID: int(featureVersionID.Int64)}, - Feature: database.Feature{ - Model: database.Model{ID: int(featureVersionID.Int64)}, - Namespace: vulnerability.Namespace, - Name: featureVersionFeatureName.String, - }, - Version: featureVersionVersion.String, + defer stmt.Close() + for i, vuln := range vulnerabilities { + // affected feature row ID -> affected feature + affectedFeatures := map[int64]database.AffectedFeature{} + for _, f := range vuln.Affected { + err := stmt.QueryRow(vulnerabilityIDs[i], f.FeatureName, f.AffectedVersion, f.FixedInVersion).Scan(&affectedID) + if err != nil { + return nil, handleError("insertVulnerabilityAffected", err) } - vulnerability.FixedIn = append(vulnerability.FixedIn, featureVersion) + affectedFeatures[affectedID] = f } + vulnFeature[vulnerabilityIDs[i]] = affectedFeatureRows{rows: affectedFeatures} } - if err := rows.Err(); err != nil { - return vulnerability, handleError("searchVulnerabilityFixedIn.Rows()", err) - } - - return vulnerability, nil + return vulnFeature, nil } -// FixedIn.Namespace are not necessary, they are overwritten by the vuln. -// By setting the fixed version to minVersion, we can say that the vuln does'nt affect anymore. -func (pgSQL *pgSQL) InsertVulnerabilities(vulnerabilities []database.Vulnerability, generateNotifications bool) error { - for _, vulnerability := range vulnerabilities { - err := pgSQL.insertVulnerability(vulnerability, false, generateNotifications) +// insertVulnerabilities inserts a set of unique vulnerabilities into database, +// under the assumption that all vulnerabilities are valid. +func (tx *pgSession) insertVulnerabilities(vulnerabilities []database.VulnerabilityWithAffected) ([]int64, error) { + var ( + vulnID int64 + vulnIDs = make([]int64, 0, len(vulnerabilities)) + vulnMap = map[database.VulnerabilityID]struct{}{} + ) + + for _, v := range vulnerabilities { + key := database.VulnerabilityID{ + Name: v.Name, + Namespace: v.Namespace.Name, + } + + // Ensure uniqueness of vulnerability IDs + if _, ok := vulnMap[key]; ok { + return nil, errors.New("inserting duplicated vulnerabilities is not allowed") + } + vulnMap[key] = struct{}{} + } + + //TODO(Sida): Change to bulk insert. + stmt, err := tx.Prepare(insertVulnerability) + if err != nil { + return nil, handleError("insertVulnerability", err) + } + + defer stmt.Close() + for _, vuln := range vulnerabilities { + err := stmt.QueryRow(vuln.Name, vuln.Description, + vuln.Link, &vuln.Severity, &vuln.Metadata, + vuln.Namespace.Name, vuln.Namespace.VersionFormat).Scan(&vulnID) if err != nil { - return err - } - } - return nil -} - -func (pgSQL *pgSQL) insertVulnerability(vulnerability database.Vulnerability, onlyFixedIn, generateNotification bool) error { - tf := time.Now() - - // Verify parameters - if vulnerability.Name == "" || vulnerability.Namespace.Name == "" { - return commonerr.NewBadRequestError("insertVulnerability needs at least the Name and the Namespace") - } - - for i := 0; i < len(vulnerability.FixedIn); i++ { - fifv := &vulnerability.FixedIn[i] - - if fifv.Feature.Namespace.Name == "" { - // As there is no Namespace on that FixedIn FeatureVersion, set it to the Vulnerability's - // Namespace. - fifv.Feature.Namespace = vulnerability.Namespace - } else if fifv.Feature.Namespace.Name != vulnerability.Namespace.Name { - msg := "could not insert an invalid vulnerability that contains FixedIn FeatureVersion that are not in the same namespace as the Vulnerability" - log.Warning(msg) - return commonerr.NewBadRequestError(msg) - } - } - - // We do `defer observeQueryTime` here because we don't want to observe invalid vulnerabilities. - defer observeQueryTime("insertVulnerability", "all", tf) - - // Begin transaction. - tx, err := pgSQL.Begin() - if err != nil { - tx.Rollback() - return handleError("insertVulnerability.Begin()", err) - } - - // Find existing vulnerability and its Vulnerability_FixedIn_Features (for update). - existingVulnerability, err := findVulnerability(tx, vulnerability.Namespace.Name, vulnerability.Name, true) - if err != nil && err != commonerr.ErrNotFound { - tx.Rollback() - return err - } - - if onlyFixedIn { - // Because this call tries to update FixedIn FeatureVersion, import all other data from the - // existing one. - if existingVulnerability.ID == 0 { - return commonerr.ErrNotFound + return nil, handleError("insertVulnerability", err) } - fixedIn := vulnerability.FixedIn - vulnerability = existingVulnerability - vulnerability.FixedIn = fixedIn + vulnIDs = append(vulnIDs, vulnID) } - if existingVulnerability.ID != 0 { - updateMetadata := vulnerability.Description != existingVulnerability.Description || - vulnerability.Link != existingVulnerability.Link || - vulnerability.Severity != existingVulnerability.Severity || - !reflect.DeepEqual(castMetadata(vulnerability.Metadata), existingVulnerability.Metadata) - - // Construct the entire list of FixedIn FeatureVersion, by using the - // the FixedIn list of the old vulnerability. - // - // TODO(Quentin-M): We could use !updateFixedIn to just copy FixedIn/Affects rows from the - // existing vulnerability in order to make metadata updates much faster. - var updateFixedIn bool - vulnerability.FixedIn, updateFixedIn = applyFixedInDiff(existingVulnerability.FixedIn, vulnerability.FixedIn) - - if !updateMetadata && !updateFixedIn { - tx.Commit() - return nil - } - - // Mark the old vulnerability as non latest. - _, err = tx.Exec(removeVulnerability, vulnerability.Namespace.Name, vulnerability.Name) - if err != nil { - tx.Rollback() - return handleError("removeVulnerability", err) - } - } else { - // The vulnerability is new, we don't want to have any - // versionfmt.MinVersion as they are only used for diffing existing - // vulnerabilities. - var fixedIn []database.FeatureVersion - for _, fv := range vulnerability.FixedIn { - if fv.Version != versionfmt.MinVersion { - fixedIn = append(fixedIn, fv) - } - } - vulnerability.FixedIn = fixedIn - } - - // Find or insert Vulnerability's Namespace. - namespaceID, err := pgSQL.insertNamespace(vulnerability.Namespace) - if err != nil { - return err - } - - // Insert vulnerability. - err = tx.QueryRow( - insertVulnerability, - namespaceID, - vulnerability.Name, - vulnerability.Description, - vulnerability.Link, - &vulnerability.Severity, - &vulnerability.Metadata, - ).Scan(&vulnerability.ID) - - if err != nil { - tx.Rollback() - return handleError("insertVulnerability", err) - } - - // Update Vulnerability_FixedIn_Feature and Vulnerability_Affects_FeatureVersion now. - err = pgSQL.insertVulnerabilityFixedInFeatureVersions(tx, vulnerability.ID, vulnerability.FixedIn) - if err != nil { - tx.Rollback() - return err - } - - // Create a notification. - if generateNotification { - err = createNotification(tx, existingVulnerability.ID, vulnerability.ID) - if err != nil { - return err - } - } - - // Commit transaction. - err = tx.Commit() - if err != nil { - tx.Rollback() - return handleError("insertVulnerability.Commit()", err) - } - - return nil + return vulnIDs, nil } // castMetadata marshals the given database.MetadataMap and unmarshals it again to make sure that @@ -376,241 +231,208 @@ func castMetadata(m database.MetadataMap) database.MetadataMap { return c } -// applyFixedInDiff applies a FeatureVersion diff on a FeatureVersion list and returns the result. -func applyFixedInDiff(currentList, diff []database.FeatureVersion) ([]database.FeatureVersion, bool) { - currentMap, currentNames := createFeatureVersionNameMap(currentList) - diffMap, diffNames := createFeatureVersionNameMap(diff) - - addedNames := compareStringLists(diffNames, currentNames) - inBothNames := compareStringListsInBoth(diffNames, currentNames) - - different := false - - for _, name := range addedNames { - if diffMap[name].Version == versionfmt.MinVersion { - // MinVersion only makes sense when a Feature is already fixed in some version, - // in which case we would be in the "inBothNames". - continue - } - - currentMap[name] = diffMap[name] - different = true - } - - for _, name := range inBothNames { - fv := diffMap[name] - - if fv.Version == versionfmt.MinVersion { - // MinVersion means that the Feature doesn't affect the Vulnerability anymore. - delete(currentMap, name) - different = true - } else if fv.Version != currentMap[name].Version { - // The version got updated. - currentMap[name] = diffMap[name] - different = true - } - } - - // Convert currentMap to a slice and return it. - var newList []database.FeatureVersion - for _, fv := range currentMap { - newList = append(newList, fv) - } - - return newList, different -} - -func createFeatureVersionNameMap(features []database.FeatureVersion) (map[string]database.FeatureVersion, []string) { - m := make(map[string]database.FeatureVersion, 0) - s := make([]string, 0, len(features)) - - for i := 0; i < len(features); i++ { - featureVersion := features[i] - m[featureVersion.Feature.Name] = featureVersion - s = append(s, featureVersion.Feature.Name) - } - - return m, s -} - -// insertVulnerabilityFixedInFeatureVersions populates Vulnerability_FixedIn_Feature for the given -// vulnerability with the specified database.FeatureVersion list and uses -// linkVulnerabilityToFeatureVersions to propagate the changes on Vulnerability_FixedIn_Feature to -// Vulnerability_Affects_FeatureVersion. -func (pgSQL *pgSQL) insertVulnerabilityFixedInFeatureVersions(tx *sql.Tx, vulnerabilityID int, fixedIn []database.FeatureVersion) error { - defer observeQueryTime("insertVulnerabilityFixedInFeatureVersions", "all", time.Now()) - - // Insert or find the Features. - // TODO(Quentin-M): Batch me. - var err error - var features []*database.Feature - for i := 0; i < len(fixedIn); i++ { - features = append(features, &fixedIn[i].Feature) - } - for _, feature := range features { - if feature.ID == 0 { - if feature.ID, err = pgSQL.insertFeature(*feature); err != nil { - return err - } - } - } - - // Lock Vulnerability_Affects_FeatureVersion exclusively. - // We want to prevent InsertFeatureVersion to modify it. - promConcurrentLockVAFV.Inc() - defer promConcurrentLockVAFV.Dec() - t := time.Now() - _, err = tx.Exec(lockVulnerabilityAffects) - observeQueryTime("insertVulnerability", "lock", t) - +func (tx *pgSession) lockFeatureVulnerabilityCache() error { + _, err := tx.Exec(lockVulnerabilityAffects) if err != nil { - tx.Rollback() - return handleError("insertVulnerability.lockVulnerabilityAffects", err) + return handleError("lockVulnerabilityAffects", err) } - - for _, fv := range fixedIn { - var fixedInID int - var created bool - - // Find or create entry in Vulnerability_FixedIn_Feature. - err = tx.QueryRow( - soiVulnerabilityFixedInFeature, - vulnerabilityID, fv.Feature.ID, - &fv.Version, - ).Scan(&created, &fixedInID) - - if err != nil { - return handleError("insertVulnerabilityFixedInFeature", err) - } - - if !created { - // The relationship between the feature and the vulnerability already - // existed, no need to update Vulnerability_Affects_FeatureVersion. - continue - } - - // Insert Vulnerability_Affects_FeatureVersion. - err = linkVulnerabilityToFeatureVersions(tx, fixedInID, vulnerabilityID, fv.Feature.ID, fv.Feature.Namespace.VersionFormat, fv.Version) - if err != nil { - return err - } - } - return nil } -func linkVulnerabilityToFeatureVersions(tx *sql.Tx, fixedInID, vulnerabilityID, featureID int, versionFormat, fixedInVersion string) error { - // Find every FeatureVersions of the Feature that the vulnerability affects. - // TODO(Quentin-M): LIMIT - rows, err := tx.Query(searchFeatureVersionByFeature, featureID) - if err != nil { - return handleError("searchFeatureVersionByFeature", err) - } - defer rows.Close() - - var affecteds []database.FeatureVersion - for rows.Next() { - var affected database.FeatureVersion - - err := rows.Scan(&affected.ID, &affected.Version) - if err != nil { - return handleError("searchFeatureVersionByFeature.Scan()", err) - } - - cmp, err := versionfmt.Compare(versionFormat, affected.Version, fixedInVersion) - if err != nil { - return err - } - if cmp < 0 { - // The version of the FeatureVersion is lower than the fixed version of this vulnerability, - // thus, this FeatureVersion is affected by it. - affecteds = append(affecteds, affected) - } - } - if err = rows.Err(); err != nil { - return handleError("searchFeatureVersionByFeature.Rows()", err) - } - rows.Close() - - // Insert into Vulnerability_Affects_FeatureVersion. - for _, affected := range affecteds { - // TODO(Quentin-M): Batch me. - _, err := tx.Exec(insertVulnerabilityAffectsFeatureVersion, vulnerabilityID, affected.ID, fixedInID) - if err != nil { - return handleError("insertVulnerabilityAffectsFeatureVersion", err) - } - } - - return nil -} - -func (pgSQL *pgSQL) InsertVulnerabilityFixes(vulnerabilityNamespace, vulnerabilityName string, fixes []database.FeatureVersion) error { - defer observeQueryTime("InsertVulnerabilityFixes", "all", time.Now()) - - v := database.Vulnerability{ - Name: vulnerabilityName, - Namespace: database.Namespace{ - Name: vulnerabilityNamespace, - }, - FixedIn: fixes, - } - - return pgSQL.insertVulnerability(v, true, true) -} - -func (pgSQL *pgSQL) DeleteVulnerabilityFix(vulnerabilityNamespace, vulnerabilityName, featureName string) error { - defer observeQueryTime("DeleteVulnerabilityFix", "all", time.Now()) - - v := database.Vulnerability{ - Name: vulnerabilityName, - Namespace: database.Namespace{ - Name: vulnerabilityNamespace, - }, - FixedIn: []database.FeatureVersion{ - { - Feature: database.Feature{ - Name: featureName, - Namespace: database.Namespace{ - Name: vulnerabilityNamespace, - }, - }, - Version: versionfmt.MinVersion, - }, - }, - } - - return pgSQL.insertVulnerability(v, true, true) -} - -func (pgSQL *pgSQL) DeleteVulnerability(namespaceName, name string) error { - defer observeQueryTime("DeleteVulnerability", "all", time.Now()) - - // Begin transaction. - tx, err := pgSQL.Begin() - if err != nil { - tx.Rollback() - return handleError("DeleteVulnerability.Begin()", err) - } - - var vulnerabilityID int - err = tx.QueryRow(removeVulnerability, namespaceName, name).Scan(&vulnerabilityID) - if err != nil { - tx.Rollback() - return handleError("removeVulnerability", err) - } - - // Create a notification. - err = createNotification(tx, vulnerabilityID, 0) +// cacheVulnerabiltyAffectedNamespacedFeature takes in a map of vulnerability ID +// to affected feature rows and caches them. +func (tx *pgSession) cacheVulnerabiltyAffectedNamespacedFeature(affected map[int64]affectedFeatureRows) error { + // Prevent InsertNamespacedFeatures to modify it. + err := tx.lockFeatureVulnerabilityCache() if err != nil { return err } - // Commit transaction. - err = tx.Commit() + vulnIDs := []int64{} + for id := range affected { + vulnIDs = append(vulnIDs, id) + } + + rows, err := tx.Query(searchVulnerabilityPotentialAffected, pq.Array(vulnIDs)) if err != nil { - tx.Rollback() - return handleError("DeleteVulnerability.Commit()", err) + return handleError("searchVulnerabilityPotentialAffected", err) + } + + defer rows.Close() + + relation := []affectRelation{} + for rows.Next() { + var ( + vulnID int64 + nsfID int64 + fVersion string + addedBy int64 + ) + + err := rows.Scan(&vulnID, &nsfID, &fVersion, &addedBy) + if err != nil { + return handleError("searchVulnerabilityPotentialAffected", err) + } + + candidate, ok := affected[vulnID].rows[addedBy] + + if !ok { + return errors.New("vulnerability affected feature not found") + } + + if in, err := versionfmt.InRange(candidate.Namespace.VersionFormat, + fVersion, + candidate.AffectedVersion); err == nil { + if in { + relation = append(relation, + affectRelation{ + vulnerabilityID: vulnID, + namespacedFeatureID: nsfID, + addedBy: addedBy, + }) + } + } else { + return err + } + } + + //TODO(Sida): Change to bulk insert. + for _, r := range relation { + result, err := tx.Exec(insertVulnerabilityAffectedNamespacedFeature, r.vulnerabilityID, r.namespacedFeatureID, r.addedBy) + if err != nil { + return handleError("insertVulnerabilityAffectedNamespacedFeature", err) + } + + if num, err := result.RowsAffected(); err == nil { + if num <= 0 { + return errors.New("Nothing cached in database") + } + } else { + return err + } + } + + log.Debugf("Cached %d features in vulnerability_affected_namespaced_feature", len(relation)) + return nil +} + +func (tx *pgSession) DeleteVulnerabilities(vulnerabilities []database.VulnerabilityID) error { + defer observeQueryTime("DeleteVulnerability", "all", time.Now()) + + vulnIDs, err := tx.markVulnerabilitiesAsDeleted(vulnerabilities) + if err != nil { + return err + } + + if err := tx.invalidateVulnerabilityCache(vulnIDs); err != nil { + return err + } + return nil +} + +func (tx *pgSession) invalidateVulnerabilityCache(vulnerabilityIDs []int64) error { + if len(vulnerabilityIDs) == 0 { + return nil + } + + // Prevent InsertNamespacedFeatures to modify it. + err := tx.lockFeatureVulnerabilityCache() + if err != nil { + return err + } + + //TODO(Sida): Make a nicer interface for bulk inserting. + keys := make([]interface{}, len(vulnerabilityIDs)) + for i, id := range vulnerabilityIDs { + keys[i] = id + } + + _, err = tx.Exec(queryInvalidateVulnerabilityCache(len(vulnerabilityIDs)), keys...) + if err != nil { + return handleError("removeVulnerabilityAffectedFeature", err) } return nil } + +func (tx *pgSession) markVulnerabilitiesAsDeleted(vulnerabilities []database.VulnerabilityID) ([]int64, error) { + var ( + vulnID sql.NullInt64 + vulnIDs []int64 + ) + + // mark vulnerabilities deleted + stmt, err := tx.Prepare(removeVulnerability) + if err != nil { + return nil, handleError("removeVulnerability", err) + } + + defer stmt.Close() + for _, vuln := range vulnerabilities { + err := stmt.QueryRow(vuln.Namespace, vuln.Name).Scan(&vulnID) + if err != nil { + return nil, handleError("removeVulnerability", err) + } + if !vulnID.Valid { + return nil, handleError("removeVulnerability", errors.New("Vulnerability to be removed is not in database")) + } + vulnIDs = append(vulnIDs, vulnID.Int64) + } + return vulnIDs, nil +} + +// findLatestDeletedVulnerabilityIDs requires all elements in vulnIDs are in +// database and the order of output array is not guaranteed. +func (tx *pgSession) findLatestDeletedVulnerabilityIDs(vulnIDs []database.VulnerabilityID) ([]sql.NullInt64, error) { + return tx.findVulnerabilityIDs(vulnIDs, true) +} + +func (tx *pgSession) findNotDeletedVulnerabilityIDs(vulnIDs []database.VulnerabilityID) ([]sql.NullInt64, error) { + return tx.findVulnerabilityIDs(vulnIDs, false) +} + +func (tx *pgSession) findVulnerabilityIDs(vulnIDs []database.VulnerabilityID, withLatestDeleted bool) ([]sql.NullInt64, error) { + if len(vulnIDs) == 0 { + return nil, nil + } + + vulnIDMap := map[database.VulnerabilityID]sql.NullInt64{} + keys := make([]interface{}, len(vulnIDs)*2) + for i, vulnID := range vulnIDs { + keys[i*2] = vulnID.Name + keys[i*2+1] = vulnID.Namespace + vulnIDMap[vulnID] = sql.NullInt64{} + } + + query := "" + if withLatestDeleted { + query = querySearchLastDeletedVulnerabilityID(len(vulnIDs)) + } else { + query = querySearchNotDeletedVulnerabilityID(len(vulnIDs)) + } + + rows, err := tx.Query(query, keys...) + if err != nil { + return nil, handleError("querySearchVulnerabilityID.LatestDeleted.Query", err) + } + + defer rows.Close() + var ( + id sql.NullInt64 + vulnID database.VulnerabilityID + ) + for rows.Next() { + err := rows.Scan(&id, &vulnID.Name, &vulnID.Namespace) + if err != nil { + return nil, handleError("querySearchVulnerabilityID.LatestDeleted.Scan", err) + } + vulnIDMap[vulnID] = id + } + + ids := make([]sql.NullInt64, len(vulnIDs)) + for i, v := range vulnIDs { + ids[i] = vulnIDMap[v] + } + + return ids, nil +} diff --git a/database/pgsql/vulnerability_test.go b/database/pgsql/vulnerability_test.go index 61d835bb..9fe2c23b 100644 --- a/database/pgsql/vulnerability_test.go +++ b/database/pgsql/vulnerability_test.go @@ -15,282 +15,329 @@ package pgsql import ( - "reflect" "testing" "github.com/stretchr/testify/assert" "github.com/coreos/clair/database" - "github.com/coreos/clair/ext/versionfmt" "github.com/coreos/clair/ext/versionfmt/dpkg" - "github.com/coreos/clair/pkg/commonerr" ) -func TestFindVulnerability(t *testing.T) { - datastore, err := openDatabaseForTest("FindVulnerability", true) - if err != nil { - t.Error(err) - return +func TestInsertVulnerabilities(t *testing.T) { + store, tx := openSessionForTest(t, "InsertVulnerabilities", true) + + ns1 := database.Namespace{ + Name: "name", + VersionFormat: "random stuff", } - defer datastore.Close() - // Find a vulnerability that does not exist. - _, err = datastore.FindVulnerability("", "") - assert.Equal(t, commonerr.ErrNotFound, err) + ns2 := database.Namespace{ + Name: "debian:7", + VersionFormat: "dpkg", + } - // Find a normal vulnerability. + // invalid vulnerability v1 := database.Vulnerability{ - Name: "CVE-OPENSSL-1-DEB7", - Description: "A vulnerability affecting OpenSSL < 2.0 on Debian 7.0", - Link: "http://google.com/#q=CVE-OPENSSL-1-DEB7", - Severity: database.HighSeverity, - Namespace: database.Namespace{ - Name: "debian:7", - VersionFormat: dpkg.ParserName, - }, - FixedIn: []database.FeatureVersion{ - { - Feature: database.Feature{Name: "openssl"}, - Version: "2.0", - }, - { - Feature: database.Feature{Name: "libssl"}, - Version: "1.9-abc", - }, - }, + Name: "invalid", + Namespace: ns1, } - v1f, err := datastore.FindVulnerability("debian:7", "CVE-OPENSSL-1-DEB7") - if assert.Nil(t, err) { - equalsVuln(t, &v1, &v1f) + vwa1 := database.VulnerabilityWithAffected{ + Vulnerability: v1, } - - // Find a vulnerability that has no link, no severity and no FixedIn. + // valid vulnerability v2 := database.Vulnerability{ - Name: "CVE-NOPE", - Description: "A vulnerability affecting nothing", - Namespace: database.Namespace{ - Name: "debian:7", + Name: "valid", + Namespace: ns2, + Severity: database.UnknownSeverity, + } + + vwa2 := database.VulnerabilityWithAffected{ + Vulnerability: v2, + } + + // empty + err := tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{}) + assert.Nil(t, err) + + // invalid content: vwa1 is invalid + err = tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{vwa1, vwa2}) + assert.NotNil(t, err) + + tx = restartSession(t, store, tx, false) + // invalid content: duplicated input + err = tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{vwa2, vwa2}) + assert.NotNil(t, err) + + tx = restartSession(t, store, tx, false) + // valid content + err = tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{vwa2}) + assert.Nil(t, err) + + tx = restartSession(t, store, tx, true) + // ensure the content is in database + vulns, err := tx.FindVulnerabilities([]database.VulnerabilityID{{Name: "valid", Namespace: "debian:7"}}) + if assert.Nil(t, err) && assert.Len(t, vulns, 1) { + assert.True(t, vulns[0].Valid) + } + + tx = restartSession(t, store, tx, false) + // valid content: vwa2 removed and inserted + err = tx.DeleteVulnerabilities([]database.VulnerabilityID{{Name: vwa2.Name, Namespace: vwa2.Namespace.Name}}) + assert.Nil(t, err) + + err = tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{vwa2}) + assert.Nil(t, err) + + closeTest(t, store, tx) +} + +func TestCachingVulnerable(t *testing.T) { + datastore, tx := openSessionForTest(t, "CachingVulnerable", true) + defer closeTest(t, datastore, tx) + + ns := database.Namespace{ + Name: "debian:8", + VersionFormat: dpkg.ParserName, + } + + f := database.NamespacedFeature{ + Feature: database.Feature{ + Name: "openssl", + Version: "1.0", VersionFormat: dpkg.ParserName, }, - Severity: database.UnknownSeverity, + Namespace: ns, } - v2f, err := datastore.FindVulnerability("debian:7", "CVE-NOPE") - if assert.Nil(t, err) { - equalsVuln(t, &v2, &v2f) - } -} - -func TestDeleteVulnerability(t *testing.T) { - datastore, err := openDatabaseForTest("InsertVulnerability", true) - if err != nil { - t.Error(err) - return - } - defer datastore.Close() - - // Delete non-existing Vulnerability. - err = datastore.DeleteVulnerability("TestDeleteVulnerabilityNamespace1", "CVE-OPENSSL-1-DEB7") - assert.Equal(t, commonerr.ErrNotFound, err) - err = datastore.DeleteVulnerability("debian:7", "TestDeleteVulnerabilityVulnerability1") - assert.Equal(t, commonerr.ErrNotFound, err) - - // Delete Vulnerability. - err = datastore.DeleteVulnerability("debian:7", "CVE-OPENSSL-1-DEB7") - if assert.Nil(t, err) { - _, err := datastore.FindVulnerability("debian:7", "CVE-OPENSSL-1-DEB7") - assert.Equal(t, commonerr.ErrNotFound, err) - } -} - -func TestInsertVulnerability(t *testing.T) { - datastore, err := openDatabaseForTest("InsertVulnerability", false) - if err != nil { - t.Error(err) - return - } - defer datastore.Close() - - // Create some data. - n1 := database.Namespace{ - Name: "TestInsertVulnerabilityNamespace1", - VersionFormat: dpkg.ParserName, - } - n2 := database.Namespace{ - Name: "TestInsertVulnerabilityNamespace2", - VersionFormat: dpkg.ParserName, - } - - f1 := database.FeatureVersion{ - Feature: database.Feature{ - Name: "TestInsertVulnerabilityFeatureVersion1", - Namespace: n1, + vuln := database.VulnerabilityWithAffected{ + Vulnerability: database.Vulnerability{ + Name: "CVE-YAY", + Namespace: ns, + Severity: database.HighSeverity, }, - Version: "1.0", - } - f2 := database.FeatureVersion{ - Feature: database.Feature{ - Name: "TestInsertVulnerabilityFeatureVersion1", - Namespace: n2, + Affected: []database.AffectedFeature{ + { + Namespace: ns, + FeatureName: "openssl", + AffectedVersion: "2.0", + FixedInVersion: "2.1", + }, }, - Version: "1.0", - } - f3 := database.FeatureVersion{ - Feature: database.Feature{ - Name: "TestInsertVulnerabilityFeatureVersion2", - }, - Version: versionfmt.MaxVersion, - } - f4 := database.FeatureVersion{ - Feature: database.Feature{ - Name: "TestInsertVulnerabilityFeatureVersion2", - }, - Version: "1.4", - } - f5 := database.FeatureVersion{ - Feature: database.Feature{ - Name: "TestInsertVulnerabilityFeatureVersion3", - }, - Version: "1.5", - } - f6 := database.FeatureVersion{ - Feature: database.Feature{ - Name: "TestInsertVulnerabilityFeatureVersion4", - }, - Version: "0.1", - } - f7 := database.FeatureVersion{ - Feature: database.Feature{ - Name: "TestInsertVulnerabilityFeatureVersion5", - }, - Version: versionfmt.MaxVersion, - } - f8 := database.FeatureVersion{ - Feature: database.Feature{ - Name: "TestInsertVulnerabilityFeatureVersion5", - }, - Version: versionfmt.MinVersion, } - // Insert invalid vulnerabilities. - for _, vulnerability := range []database.Vulnerability{ - { - Name: "", - Namespace: n1, - FixedIn: []database.FeatureVersion{f1}, - Severity: database.UnknownSeverity, + vuln2 := database.VulnerabilityWithAffected{ + Vulnerability: database.Vulnerability{ + Name: "CVE-YAY2", + Namespace: ns, + Severity: database.HighSeverity, }, - { - Name: "TestInsertVulnerability0", - Namespace: database.Namespace{}, - FixedIn: []database.FeatureVersion{f1}, - Severity: database.UnknownSeverity, + Affected: []database.AffectedFeature{ + { + Namespace: ns, + FeatureName: "openssl", + AffectedVersion: "2.1", + FixedInVersion: "2.2", + }, }, - { - Name: "TestInsertVulnerability0-", - Namespace: database.Namespace{}, - FixedIn: []database.FeatureVersion{f1}, + } + + vulnFixed1 := database.VulnerabilityWithFixedIn{ + Vulnerability: database.Vulnerability{ + Name: "CVE-YAY", + Namespace: ns, + Severity: database.HighSeverity, }, - { - Name: "TestInsertVulnerability0", - Namespace: n1, - FixedIn: []database.FeatureVersion{f2}, - Severity: database.UnknownSeverity, + FixedInVersion: "2.1", + } + + vulnFixed2 := database.VulnerabilityWithFixedIn{ + Vulnerability: database.Vulnerability{ + Name: "CVE-YAY2", + Namespace: ns, + Severity: database.HighSeverity, }, - } { - err := datastore.InsertVulnerabilities([]database.Vulnerability{vulnerability}, true) - assert.Error(t, err) + FixedInVersion: "2.2", } - // Insert a simple vulnerability and find it. - v1meta := make(map[string]interface{}) - v1meta["TestInsertVulnerabilityMetadata1"] = "TestInsertVulnerabilityMetadataValue1" - v1meta["TestInsertVulnerabilityMetadata2"] = struct { - Test string - }{ - Test: "TestInsertVulnerabilityMetadataValue1", + if !assert.Nil(t, tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{vuln, vuln2})) { + t.FailNow() } - v1 := database.Vulnerability{ - Name: "TestInsertVulnerability1", - Namespace: n1, - FixedIn: []database.FeatureVersion{f1, f3, f6, f7}, - Severity: database.LowSeverity, - Description: "TestInsertVulnerabilityDescription1", - Link: "TestInsertVulnerabilityLink1", - Metadata: v1meta, - } - err = datastore.InsertVulnerabilities([]database.Vulnerability{v1}, true) - if assert.Nil(t, err) { - v1f, err := datastore.FindVulnerability(n1.Name, v1.Name) - if assert.Nil(t, err) { - equalsVuln(t, &v1, &v1f) - } - } - - // Update vulnerability. - v1.Description = "TestInsertVulnerabilityLink2" - v1.Link = "TestInsertVulnerabilityLink2" - v1.Severity = database.HighSeverity - // Update f3 in f4, add fixed in f5, add fixed in f6 which already exists, - // removes fixed in f7 by adding f8 which is f7 but with MinVersion, and - // add fixed by f5 a second time (duplicated). - v1.FixedIn = []database.FeatureVersion{f4, f5, f6, f8, f5} - - err = datastore.InsertVulnerabilities([]database.Vulnerability{v1}, true) - if assert.Nil(t, err) { - v1f, err := datastore.FindVulnerability(n1.Name, v1.Name) - if assert.Nil(t, err) { - // Remove f8 from the struct for comparison as it was just here to cancel f7. - // Remove one of the f5 too as it was twice in the struct but the database - // implementation should have dedup'd it. - v1.FixedIn = v1.FixedIn[:len(v1.FixedIn)-2] - - // We already had f1 before the update. - // Add it to the struct for comparison. - v1.FixedIn = append(v1.FixedIn, f1) - - equalsVuln(t, &v1, &v1f) - } - } -} - -func equalsVuln(t *testing.T, expected, actual *database.Vulnerability) { - assert.Equal(t, expected.Name, actual.Name) - assert.Equal(t, expected.Namespace.Name, actual.Namespace.Name) - assert.Equal(t, expected.Description, actual.Description) - assert.Equal(t, expected.Link, actual.Link) - assert.Equal(t, expected.Severity, actual.Severity) - assert.True(t, reflect.DeepEqual(castMetadata(expected.Metadata), actual.Metadata), "Got metadata %#v, expected %#v", actual.Metadata, castMetadata(expected.Metadata)) - - if assert.Len(t, actual.FixedIn, len(expected.FixedIn)) { - for _, actualFeatureVersion := range actual.FixedIn { - found := false - for _, expectedFeatureVersion := range expected.FixedIn { - if expectedFeatureVersion.Feature.Name == actualFeatureVersion.Feature.Name { - found = true - - assert.Equal(t, expected.Namespace.Name, actualFeatureVersion.Feature.Namespace.Name) - assert.Equal(t, expectedFeatureVersion.Version, actualFeatureVersion.Version) + r, err := tx.FindAffectedNamespacedFeatures([]database.NamespacedFeature{f}) + assert.Nil(t, err) + assert.Len(t, r, 1) + for _, anf := range r { + if assert.True(t, anf.Valid) && assert.Len(t, anf.AffectedBy, 2) { + for _, a := range anf.AffectedBy { + if a.Name == "CVE-YAY" { + assert.Equal(t, vulnFixed1, a) + } else if a.Name == "CVE-YAY2" { + assert.Equal(t, vulnFixed2, a) + } else { + t.FailNow() } } - if !found { - t.Errorf("unexpected package %s in %s", actualFeatureVersion.Feature.Name, expected.Name) + } + } +} + +func TestFindVulnerabilities(t *testing.T) { + datastore, tx := openSessionForTest(t, "FindVulnerabilities", true) + defer closeTest(t, datastore, tx) + + vuln, err := tx.FindVulnerabilities([]database.VulnerabilityID{ + {Name: "CVE-OPENSSL-1-DEB7", Namespace: "debian:7"}, + {Name: "CVE-NOPE", Namespace: "debian:7"}, + {Name: "CVE-NOT HERE"}, + }) + + ns := database.Namespace{ + Name: "debian:7", + VersionFormat: "dpkg", + } + + expectedExisting := []database.VulnerabilityWithAffected{ + { + Vulnerability: database.Vulnerability{ + Namespace: ns, + Name: "CVE-OPENSSL-1-DEB7", + Description: "A vulnerability affecting OpenSSL < 2.0 on Debian 7.0", + Link: "http://google.com/#q=CVE-OPENSSL-1-DEB7", + Severity: database.HighSeverity, + }, + Affected: []database.AffectedFeature{ + { + FeatureName: "openssl", + AffectedVersion: "2.0", + FixedInVersion: "2.0", + Namespace: ns, + }, + { + FeatureName: "libssl", + AffectedVersion: "1.9-abc", + FixedInVersion: "1.9-abc", + Namespace: ns, + }, + }, + }, + { + Vulnerability: database.Vulnerability{ + Namespace: ns, + Name: "CVE-NOPE", + Description: "A vulnerability affecting nothing", + Severity: database.UnknownSeverity, + }, + }, + } + + expectedExistingMap := map[database.VulnerabilityID]database.VulnerabilityWithAffected{} + for _, v := range expectedExisting { + expectedExistingMap[database.VulnerabilityID{Name: v.Name, Namespace: v.Namespace.Name}] = v + } + + nonexisting := database.VulnerabilityWithAffected{ + Vulnerability: database.Vulnerability{Name: "CVE-NOT HERE"}, + } + + if assert.Nil(t, err) { + for _, v := range vuln { + if v.Valid { + key := database.VulnerabilityID{ + Name: v.Name, + Namespace: v.Namespace.Name, + } + + expected, ok := expectedExistingMap[key] + if assert.True(t, ok, "vulnerability not found: "+key.Name+":"+key.Namespace) { + assertVulnerabilityWithAffectedEqual(t, expected, v.VulnerabilityWithAffected) + } + } else if !assert.Equal(t, nonexisting, v.VulnerabilityWithAffected) { + t.FailNow() + } + } + } + + // same vulnerability + r, err := tx.FindVulnerabilities([]database.VulnerabilityID{ + {Name: "CVE-OPENSSL-1-DEB7", Namespace: "debian:7"}, + {Name: "CVE-OPENSSL-1-DEB7", Namespace: "debian:7"}, + }) + + if assert.Nil(t, err) { + for _, vuln := range r { + if assert.True(t, vuln.Valid) { + expected, _ := expectedExistingMap[database.VulnerabilityID{Name: "CVE-OPENSSL-1-DEB7", Namespace: "debian:7"}] + assertVulnerabilityWithAffectedEqual(t, expected, vuln.VulnerabilityWithAffected) } } } } -func TestStringComparison(t *testing.T) { - cmp := compareStringLists([]string{"a", "b", "b", "a"}, []string{"a", "c"}) - assert.Len(t, cmp, 1) - assert.NotContains(t, cmp, "a") - assert.Contains(t, cmp, "b") +func TestDeleteVulnerabilities(t *testing.T) { + datastore, tx := openSessionForTest(t, "DeleteVulnerabilities", true) + defer closeTest(t, datastore, tx) - cmp = compareStringListsInBoth([]string{"a", "a", "b", "c"}, []string{"a", "c", "c"}) - assert.Len(t, cmp, 2) - assert.NotContains(t, cmp, "b") - assert.Contains(t, cmp, "a") - assert.Contains(t, cmp, "c") + remove := []database.VulnerabilityID{} + // empty case + assert.Nil(t, tx.DeleteVulnerabilities(remove)) + // invalid case + remove = append(remove, database.VulnerabilityID{}) + assert.NotNil(t, tx.DeleteVulnerabilities(remove)) + + // valid case + validRemove := []database.VulnerabilityID{ + {Name: "CVE-OPENSSL-1-DEB7", Namespace: "debian:7"}, + {Name: "CVE-NOPE", Namespace: "debian:7"}, + } + + assert.Nil(t, tx.DeleteVulnerabilities(validRemove)) + vuln, err := tx.FindVulnerabilities(validRemove) + if assert.Nil(t, err) { + for _, v := range vuln { + assert.False(t, v.Valid) + } + } +} + +func TestFindVulnerabilityIDs(t *testing.T) { + store, tx := openSessionForTest(t, "FindVulnerabilityIDs", true) + defer closeTest(t, store, tx) + + ids, err := tx.findLatestDeletedVulnerabilityIDs([]database.VulnerabilityID{{Name: "CVE-DELETED", Namespace: "debian:7"}}) + if assert.Nil(t, err) { + if !(assert.Len(t, ids, 1) && assert.True(t, ids[0].Valid) && assert.Equal(t, 3, ids[0].Int64)) { + assert.Fail(t, "") + } + } + + ids, err = tx.findNotDeletedVulnerabilityIDs([]database.VulnerabilityID{{Name: "CVE-NOPE", Namespace: "debian:7"}}) + if assert.Nil(t, err) { + if !(assert.Len(t, ids, 1) && assert.True(t, ids[0].Valid) && assert.Equal(t, 2, ids[0].Int64)) { + assert.Fail(t, "") + } + } +} + +func assertVulnerabilityWithAffectedEqual(t *testing.T, expected database.VulnerabilityWithAffected, actual database.VulnerabilityWithAffected) bool { + return assert.Equal(t, expected.Vulnerability, actual.Vulnerability) && assertAffectedFeaturesEqual(t, expected.Affected, actual.Affected) +} + +func assertAffectedFeaturesEqual(t *testing.T, expected []database.AffectedFeature, actual []database.AffectedFeature) bool { + if assert.Len(t, actual, len(expected)) { + has := map[database.AffectedFeature]bool{} + for _, i := range expected { + has[i] = false + } + for _, i := range actual { + if visited, ok := has[i]; !ok { + return false + } else if visited { + return false + } + has[i] = true + } + return true + } + return false } diff --git a/database/severity.go b/database/severity.go index 58084d64..840f6afb 100644 --- a/database/severity.go +++ b/database/severity.go @@ -36,7 +36,7 @@ const ( // NegligibleSeverity is technically a security problem, but is only // theoretical in nature, requires a very special situation, has almost no // install base, or does no real damage. These tend not to get backport from - // upstreams, and will likely not be included in security updates unless + // upstream, and will likely not be included in security updates unless // there is an easy fix and some other issue causes an update. NegligibleSeverity Severity = "Negligible" @@ -93,7 +93,7 @@ func NewSeverity(s string) (Severity, error) { // Compare determines the equality of two severities. // // If the severities are equal, returns 0. -// If the receiever is less, returns -1. +// If the receiver is less, returns -1. // If the receiver is greater, returns 1. func (s Severity) Compare(s2 Severity) int { var i1, i2 int @@ -132,3 +132,13 @@ func (s *Severity) Scan(value interface{}) error { func (s Severity) Value() (driver.Value, error) { return string(s), nil } + +// Valid checks if the severity is valid or not. +func (s Severity) Valid() bool { + for _, v := range Severities { + if s == v { + return true + } + } + return false +} diff --git a/ext/featurefmt/apk/apk.go b/ext/featurefmt/apk/apk.go index ff63880d..195c8920 100644 --- a/ext/featurefmt/apk/apk.go +++ b/ext/featurefmt/apk/apk.go @@ -34,17 +34,17 @@ func init() { type lister struct{} -func (l lister) ListFeatures(files tarutil.FilesMap) ([]database.FeatureVersion, error) { +func (l lister) ListFeatures(files tarutil.FilesMap) ([]database.Feature, error) { file, exists := files["lib/apk/db/installed"] if !exists { - return []database.FeatureVersion{}, nil + return []database.Feature{}, nil } // Iterate over each line in the "installed" file attempting to parse each // package into a feature that will be stored in a set to guarantee // uniqueness. - pkgSet := make(map[string]database.FeatureVersion) - ipkg := database.FeatureVersion{} + pkgSet := make(map[string]database.Feature) + ipkg := database.Feature{} scanner := bufio.NewScanner(bytes.NewBuffer(file)) for scanner.Scan() { line := scanner.Text() @@ -55,7 +55,7 @@ func (l lister) ListFeatures(files tarutil.FilesMap) ([]database.FeatureVersion, // Parse the package name or version. switch { case line[:2] == "P:": - ipkg.Feature.Name = line[2:] + ipkg.Name = line[2:] case line[:2] == "V:": version := string(line[2:]) err := versionfmt.Valid(dpkg.ParserName, version) @@ -67,20 +67,21 @@ func (l lister) ListFeatures(files tarutil.FilesMap) ([]database.FeatureVersion, case line == "": // Restart if the parser reaches another package definition before // creating a valid package. - ipkg = database.FeatureVersion{} + ipkg = database.Feature{} } // If we have a whole feature, store it in the set and try to parse a new // one. - if ipkg.Feature.Name != "" && ipkg.Version != "" { - pkgSet[ipkg.Feature.Name+"#"+ipkg.Version] = ipkg - ipkg = database.FeatureVersion{} + if ipkg.Name != "" && ipkg.Version != "" { + pkgSet[ipkg.Name+"#"+ipkg.Version] = ipkg + ipkg = database.Feature{} } } - // Convert the map into a slice. - pkgs := make([]database.FeatureVersion, 0, len(pkgSet)) + // Convert the map into a slice and attach the version format + pkgs := make([]database.Feature, 0, len(pkgSet)) for _, pkg := range pkgSet { + pkg.VersionFormat = dpkg.ParserName pkgs = append(pkgs, pkg) } diff --git a/ext/featurefmt/apk/apk_test.go b/ext/featurefmt/apk/apk_test.go index 6dbde3e6..d8dc0d88 100644 --- a/ext/featurefmt/apk/apk_test.go +++ b/ext/featurefmt/apk/apk_test.go @@ -19,58 +19,32 @@ import ( "github.com/coreos/clair/database" "github.com/coreos/clair/ext/featurefmt" + "github.com/coreos/clair/ext/versionfmt/dpkg" "github.com/coreos/clair/pkg/tarutil" ) func TestAPKFeatureDetection(t *testing.T) { + testFeatures := []database.Feature{ + {Name: "musl", Version: "1.1.14-r10"}, + {Name: "busybox", Version: "1.24.2-r9"}, + {Name: "alpine-baselayout", Version: "3.0.3-r0"}, + {Name: "alpine-keys", Version: "1.1-r0"}, + {Name: "zlib", Version: "1.2.8-r2"}, + {Name: "libcrypto1.0", Version: "1.0.2h-r1"}, + {Name: "libssl1.0", Version: "1.0.2h-r1"}, + {Name: "apk-tools", Version: "2.6.7-r0"}, + {Name: "scanelf", Version: "1.1.6-r0"}, + {Name: "musl-utils", Version: "1.1.14-r10"}, + {Name: "libc-utils", Version: "0.7-r0"}, + } + + for i := range testFeatures { + testFeatures[i].VersionFormat = dpkg.ParserName + } + testData := []featurefmt.TestData{ { - FeatureVersions: []database.FeatureVersion{ - { - Feature: database.Feature{Name: "musl"}, - Version: "1.1.14-r10", - }, - { - Feature: database.Feature{Name: "busybox"}, - Version: "1.24.2-r9", - }, - { - Feature: database.Feature{Name: "alpine-baselayout"}, - Version: "3.0.3-r0", - }, - { - Feature: database.Feature{Name: "alpine-keys"}, - Version: "1.1-r0", - }, - { - Feature: database.Feature{Name: "zlib"}, - Version: "1.2.8-r2", - }, - { - Feature: database.Feature{Name: "libcrypto1.0"}, - Version: "1.0.2h-r1", - }, - { - Feature: database.Feature{Name: "libssl1.0"}, - Version: "1.0.2h-r1", - }, - { - Feature: database.Feature{Name: "apk-tools"}, - Version: "2.6.7-r0", - }, - { - Feature: database.Feature{Name: "scanelf"}, - Version: "1.1.6-r0", - }, - { - Feature: database.Feature{Name: "musl-utils"}, - Version: "1.1.14-r10", - }, - { - Feature: database.Feature{Name: "libc-utils"}, - Version: "0.7-r0", - }, - }, + Features: testFeatures, Files: tarutil.FilesMap{ "lib/apk/db/installed": featurefmt.LoadFileForTest("apk/testdata/installed"), }, diff --git a/ext/featurefmt/dpkg/dpkg.go b/ext/featurefmt/dpkg/dpkg.go index a0653580..6b987cf3 100644 --- a/ext/featurefmt/dpkg/dpkg.go +++ b/ext/featurefmt/dpkg/dpkg.go @@ -40,16 +40,16 @@ func init() { featurefmt.RegisterLister("dpkg", dpkg.ParserName, &lister{}) } -func (l lister) ListFeatures(files tarutil.FilesMap) ([]database.FeatureVersion, error) { +func (l lister) ListFeatures(files tarutil.FilesMap) ([]database.Feature, error) { f, hasFile := files["var/lib/dpkg/status"] if !hasFile { - return []database.FeatureVersion{}, nil + return []database.Feature{}, nil } // Create a map to store packages and ensure their uniqueness - packagesMap := make(map[string]database.FeatureVersion) + packagesMap := make(map[string]database.Feature) - var pkg database.FeatureVersion + var pkg database.Feature var err error scanner := bufio.NewScanner(strings.NewReader(string(f))) for scanner.Scan() { @@ -59,7 +59,7 @@ func (l lister) ListFeatures(files tarutil.FilesMap) ([]database.FeatureVersion, // Package line // Defines the name of the package - pkg.Feature.Name = strings.TrimSpace(strings.TrimPrefix(line, "Package: ")) + pkg.Name = strings.TrimSpace(strings.TrimPrefix(line, "Package: ")) pkg.Version = "" } else if strings.HasPrefix(line, "Source: ") { // Source line (Optionnal) @@ -72,7 +72,7 @@ func (l lister) ListFeatures(files tarutil.FilesMap) ([]database.FeatureVersion, md[dpkgSrcCaptureRegexpNames[i]] = strings.TrimSpace(n) } - pkg.Feature.Name = md["name"] + pkg.Name = md["name"] if md["version"] != "" { version := md["version"] err = versionfmt.Valid(dpkg.ParserName, version) @@ -96,21 +96,22 @@ func (l lister) ListFeatures(files tarutil.FilesMap) ([]database.FeatureVersion, pkg.Version = version } } else if line == "" { - pkg.Feature.Name = "" + pkg.Name = "" pkg.Version = "" } // Add the package to the result array if we have all the informations - if pkg.Feature.Name != "" && pkg.Version != "" { - packagesMap[pkg.Feature.Name+"#"+pkg.Version] = pkg - pkg.Feature.Name = "" + if pkg.Name != "" && pkg.Version != "" { + packagesMap[pkg.Name+"#"+pkg.Version] = pkg + pkg.Name = "" pkg.Version = "" } } - // Convert the map to a slice - packages := make([]database.FeatureVersion, 0, len(packagesMap)) + // Convert the map to a slice and add version format. + packages := make([]database.Feature, 0, len(packagesMap)) for _, pkg := range packagesMap { + pkg.VersionFormat = dpkg.ParserName packages = append(packages, pkg) } diff --git a/ext/featurefmt/dpkg/dpkg_test.go b/ext/featurefmt/dpkg/dpkg_test.go index a9c3a8cf..1561be4f 100644 --- a/ext/featurefmt/dpkg/dpkg_test.go +++ b/ext/featurefmt/dpkg/dpkg_test.go @@ -19,28 +19,35 @@ import ( "github.com/coreos/clair/database" "github.com/coreos/clair/ext/featurefmt" + "github.com/coreos/clair/ext/versionfmt/dpkg" "github.com/coreos/clair/pkg/tarutil" ) func TestDpkgFeatureDetection(t *testing.T) { + testFeatures := []database.Feature{ + // Two packages from this source are installed, it should only appear one time + { + Name: "pam", + Version: "1.1.8-3.1ubuntu3", + }, + { + Name: "makedev", // The source name and the package name are equals + Version: "2.3.1-93ubuntu1", // The version comes from the "Version:" line + }, + { + Name: "gcc-5", + Version: "5.1.1-12ubuntu1", // The version comes from the "Source:" line + }, + } + + for i := range testFeatures { + testFeatures[i].VersionFormat = dpkg.ParserName + } + testData := []featurefmt.TestData{ // Test an Ubuntu dpkg status file { - FeatureVersions: []database.FeatureVersion{ - // Two packages from this source are installed, it should only appear one time - { - Feature: database.Feature{Name: "pam"}, - Version: "1.1.8-3.1ubuntu3", - }, - { - Feature: database.Feature{Name: "makedev"}, // The source name and the package name are equals - Version: "2.3.1-93ubuntu1", // The version comes from the "Version:" line - }, - { - Feature: database.Feature{Name: "gcc-5"}, - Version: "5.1.1-12ubuntu1", // The version comes from the "Source:" line - }, - }, + Features: testFeatures, Files: tarutil.FilesMap{ "var/lib/dpkg/status": featurefmt.LoadFileForTest("dpkg/testdata/status"), }, diff --git a/ext/featurefmt/driver.go b/ext/featurefmt/driver.go index 8e8d593d..0f48b0e7 100644 --- a/ext/featurefmt/driver.go +++ b/ext/featurefmt/driver.go @@ -38,8 +38,8 @@ var ( // Lister represents an ability to list the features present in an image layer. type Lister interface { - // ListFeatures produces a list of FeatureVersions present in an image layer. - ListFeatures(tarutil.FilesMap) ([]database.FeatureVersion, error) + // ListFeatures produces a list of Features present in an image layer. + ListFeatures(tarutil.FilesMap) ([]database.Feature, error) // RequiredFilenames returns the list of files required to be in the FilesMap // provided to the ListFeatures method. @@ -71,34 +71,24 @@ func RegisterLister(name string, versionfmt string, l Lister) { versionfmtListerName[versionfmt] = append(versionfmtListerName[versionfmt], name) } -// ListFeatures produces the list of FeatureVersions in an image layer using +// ListFeatures produces the list of Features in an image layer using // every registered Lister. -func ListFeatures(files tarutil.FilesMap, namespace *database.Namespace) ([]database.FeatureVersion, error) { +func ListFeatures(files tarutil.FilesMap, listerNames []string) ([]database.Feature, error) { listersM.RLock() defer listersM.RUnlock() - var ( - totalFeatures []database.FeatureVersion - listersName []string - found bool - ) + var totalFeatures []database.Feature - if namespace == nil { - log.Debug("Can't detect features without namespace") - return totalFeatures, nil - } - - if listersName, found = versionfmtListerName[namespace.VersionFormat]; !found { - log.WithFields(log.Fields{"namespace": namespace.Name, "version format": namespace.VersionFormat}).Debug("Unsupported Namespace") - return totalFeatures, nil - } - - for _, listerName := range listersName { - features, err := listers[listerName].ListFeatures(files) - if err != nil { - return totalFeatures, err + for _, name := range listerNames { + if lister, ok := listers[name]; ok { + features, err := lister.ListFeatures(files) + if err != nil { + return []database.Feature{}, err + } + totalFeatures = append(totalFeatures, features...) + } else { + log.WithField("Name", name).Warn("Unknown Lister") } - totalFeatures = append(totalFeatures, features...) } return totalFeatures, nil @@ -106,7 +96,7 @@ func ListFeatures(files tarutil.FilesMap, namespace *database.Namespace) ([]data // RequiredFilenames returns the total list of files required for all // registered Listers. -func RequiredFilenames() (files []string) { +func RequiredFilenames(listerNames []string) (files []string) { listersM.RLock() defer listersM.RUnlock() @@ -117,10 +107,19 @@ func RequiredFilenames() (files []string) { return } +// ListListers returns the names of all the registered feature listers. +func ListListers() []string { + r := []string{} + for name := range listers { + r = append(r, name) + } + return r +} + // TestData represents the data used to test an implementation of Lister. type TestData struct { - Files tarutil.FilesMap - FeatureVersions []database.FeatureVersion + Files tarutil.FilesMap + Features []database.Feature } // LoadFileForTest can be used in order to obtain the []byte contents of a file @@ -136,9 +135,9 @@ func LoadFileForTest(name string) []byte { func TestLister(t *testing.T, l Lister, testData []TestData) { for _, td := range testData { featureVersions, err := l.ListFeatures(td.Files) - if assert.Nil(t, err) && assert.Len(t, featureVersions, len(td.FeatureVersions)) { - for _, expectedFeatureVersion := range td.FeatureVersions { - assert.Contains(t, featureVersions, expectedFeatureVersion) + if assert.Nil(t, err) && assert.Len(t, featureVersions, len(td.Features)) { + for _, expectedFeature := range td.Features { + assert.Contains(t, featureVersions, expectedFeature) } } } diff --git a/ext/featurefmt/rpm/rpm.go b/ext/featurefmt/rpm/rpm.go index 9e62f0fc..5a0e1fa1 100644 --- a/ext/featurefmt/rpm/rpm.go +++ b/ext/featurefmt/rpm/rpm.go @@ -38,27 +38,27 @@ func init() { featurefmt.RegisterLister("rpm", rpm.ParserName, &lister{}) } -func (l lister) ListFeatures(files tarutil.FilesMap) ([]database.FeatureVersion, error) { +func (l lister) ListFeatures(files tarutil.FilesMap) ([]database.Feature, error) { f, hasFile := files["var/lib/rpm/Packages"] if !hasFile { - return []database.FeatureVersion{}, nil + return []database.Feature{}, nil } // Create a map to store packages and ensure their uniqueness - packagesMap := make(map[string]database.FeatureVersion) + packagesMap := make(map[string]database.Feature) // Write the required "Packages" file to disk tmpDir, err := ioutil.TempDir(os.TempDir(), "rpm") defer os.RemoveAll(tmpDir) if err != nil { log.WithError(err).Error("could not create temporary folder for RPM detection") - return []database.FeatureVersion{}, commonerr.ErrFilesystem + return []database.Feature{}, commonerr.ErrFilesystem } err = ioutil.WriteFile(tmpDir+"/Packages", f, 0700) if err != nil { log.WithError(err).Error("could not create temporary file for RPM detection") - return []database.FeatureVersion{}, commonerr.ErrFilesystem + return []database.Feature{}, commonerr.ErrFilesystem } // Extract binary package names because RHSA refers to binary package names. @@ -67,7 +67,7 @@ func (l lister) ListFeatures(files tarutil.FilesMap) ([]database.FeatureVersion, log.WithError(err).WithField("output", string(out)).Error("could not query RPM") // Do not bubble up because we probably won't be able to fix it, // the database must be corrupted - return []database.FeatureVersion{}, nil + return []database.Feature{}, nil } scanner := bufio.NewScanner(strings.NewReader(string(out))) @@ -93,18 +93,17 @@ func (l lister) ListFeatures(files tarutil.FilesMap) ([]database.FeatureVersion, } // Add package - pkg := database.FeatureVersion{ - Feature: database.Feature{ - Name: line[0], - }, + pkg := database.Feature{ + Name: line[0], Version: version, } - packagesMap[pkg.Feature.Name+"#"+pkg.Version] = pkg + packagesMap[pkg.Name+"#"+pkg.Version] = pkg } // Convert the map to a slice - packages := make([]database.FeatureVersion, 0, len(packagesMap)) + packages := make([]database.Feature, 0, len(packagesMap)) for _, pkg := range packagesMap { + pkg.VersionFormat = rpm.ParserName packages = append(packages, pkg) } diff --git a/ext/featurefmt/rpm/rpm_test.go b/ext/featurefmt/rpm/rpm_test.go index 1b6f531c..0b674523 100644 --- a/ext/featurefmt/rpm/rpm_test.go +++ b/ext/featurefmt/rpm/rpm_test.go @@ -19,6 +19,7 @@ import ( "github.com/coreos/clair/database" "github.com/coreos/clair/ext/featurefmt" + "github.com/coreos/clair/ext/versionfmt/rpm" "github.com/coreos/clair/pkg/tarutil" ) @@ -27,16 +28,18 @@ func TestRpmFeatureDetection(t *testing.T) { // Test a CentOS 7 RPM database // Memo: Use the following command on a RPM-based system to shrink a database: rpm -qa --qf "%{NAME}\n" |tail -n +3| xargs rpm -e --justdb { - FeatureVersions: []database.FeatureVersion{ + Features: []database.Feature{ // Two packages from this source are installed, it should only appear once { - Feature: database.Feature{Name: "centos-release"}, - Version: "7-1.1503.el7.centos.2.8", + Name: "centos-release", + Version: "7-1.1503.el7.centos.2.8", + VersionFormat: rpm.ParserName, }, // Two packages from this source are installed, it should only appear once { - Feature: database.Feature{Name: "filesystem"}, - Version: "3.2-18.el7", + Name: "filesystem", + Version: "3.2-18.el7", + VersionFormat: rpm.ParserName, }, }, Files: tarutil.FilesMap{ diff --git a/ext/featurens/driver.go b/ext/featurens/driver.go index 754ed8c5..b7e0ad37 100644 --- a/ext/featurens/driver.go +++ b/ext/featurens/driver.go @@ -69,20 +69,24 @@ func RegisterDetector(name string, d Detector) { } // Detect iterators through all registered Detectors and returns all non-nil detected namespaces -func Detect(files tarutil.FilesMap) ([]database.Namespace, error) { +func Detect(files tarutil.FilesMap, detectorNames []string) ([]database.Namespace, error) { detectorsM.RLock() defer detectorsM.RUnlock() namespaces := map[string]*database.Namespace{} - for name, detector := range detectors { - namespace, err := detector.Detect(files) - if err != nil { - log.WithError(err).WithField("name", name).Warning("failed while attempting to detect namespace") - return []database.Namespace{}, err - } + for _, name := range detectorNames { + if detector, ok := detectors[name]; ok { + namespace, err := detector.Detect(files) + if err != nil { + log.WithError(err).WithField("name", name).Warning("failed while attempting to detect namespace") + return nil, err + } - if namespace != nil { - log.WithFields(log.Fields{"name": name, "namespace": namespace.Name}).Debug("detected namespace") - namespaces[namespace.Name] = namespace + if namespace != nil { + log.WithFields(log.Fields{"name": name, "namespace": namespace.Name}).Debug("detected namespace") + namespaces[namespace.Name] = namespace + } + } else { + log.WithField("Name", name).Warn("Unknown namespace detector") } } @@ -95,7 +99,7 @@ func Detect(files tarutil.FilesMap) ([]database.Namespace, error) { // RequiredFilenames returns the total list of files required for all // registered Detectors. -func RequiredFilenames() (files []string) { +func RequiredFilenames(detectorNames []string) (files []string) { detectorsM.RLock() defer detectorsM.RUnlock() @@ -106,6 +110,15 @@ func RequiredFilenames() (files []string) { return } +// ListDetectors returns the names of all registered namespace detectors. +func ListDetectors() []string { + r := []string{} + for name := range detectors { + r = append(r, name) + } + return r +} + // TestData represents the data used to test an implementation of Detector. type TestData struct { Files tarutil.FilesMap diff --git a/ext/featurens/driver_test.go b/ext/featurens/driver_test.go index e1a47ef6..8493c0cc 100644 --- a/ext/featurens/driver_test.go +++ b/ext/featurens/driver_test.go @@ -8,7 +8,7 @@ import ( "github.com/coreos/clair/database" "github.com/coreos/clair/ext/featurens" "github.com/coreos/clair/pkg/tarutil" - + _ "github.com/coreos/clair/ext/featurens/alpinerelease" _ "github.com/coreos/clair/ext/featurens/aptsources" _ "github.com/coreos/clair/ext/featurens/lsbrelease" @@ -35,7 +35,7 @@ func assertnsNameEqual(t *testing.T, nslist_expected, nslist []database.Namespac func testMultipleNamespace(t *testing.T, testData []MultipleNamespaceTestData) { for _, td := range testData { - nslist, err := featurens.Detect(td.Files) + nslist, err := featurens.Detect(td.Files, featurens.ListDetectors()) assert.Nil(t, err) assertnsNameEqual(t, td.ExpectedNamespaces, nslist) } diff --git a/ext/imagefmt/driver.go b/ext/imagefmt/driver.go index 6997e93b..178de53b 100644 --- a/ext/imagefmt/driver.go +++ b/ext/imagefmt/driver.go @@ -38,7 +38,7 @@ import ( var ( // ErrCouldNotFindLayer is returned when we could not download or open the layer file. - ErrCouldNotFindLayer = commonerr.NewBadRequestError("could not find layer") + ErrCouldNotFindLayer = commonerr.NewBadRequestError("could not find layer from given path") // insecureTLS controls whether TLS server's certificate chain and hostname are verified // when pulling layers, verified in default. diff --git a/ext/notification/driver.go b/ext/notification/driver.go index 8b961ae8..2768b7e3 100644 --- a/ext/notification/driver.go +++ b/ext/notification/driver.go @@ -23,8 +23,6 @@ package notification import ( "sync" "time" - - "github.com/coreos/clair/database" ) var ( @@ -47,7 +45,7 @@ type Sender interface { Configure(*Config) (bool, error) // Send informs the existence of the specified notification. - Send(notification database.VulnerabilityNotification) error + Send(notificationName string) error } // RegisterSender makes a Sender available by the provided name. diff --git a/ext/notification/webhook/webhook.go b/ext/notification/webhook/webhook.go index d54b588b..14ef48b2 100644 --- a/ext/notification/webhook/webhook.go +++ b/ext/notification/webhook/webhook.go @@ -29,7 +29,6 @@ import ( "gopkg.in/yaml.v2" - "github.com/coreos/clair/database" "github.com/coreos/clair/ext/notification" ) @@ -112,9 +111,9 @@ type notificationEnvelope struct { } } -func (s *sender) Send(notification database.VulnerabilityNotification) error { +func (s *sender) Send(notificationName string) error { // Marshal notification. - jsonNotification, err := json.Marshal(notificationEnvelope{struct{ Name string }{notification.Name}}) + jsonNotification, err := json.Marshal(notificationEnvelope{struct{ Name string }{notificationName}}) if err != nil { return fmt.Errorf("could not marshal: %s", err) } diff --git a/ext/versionfmt/dpkg/parser.go b/ext/versionfmt/dpkg/parser.go index 2d6eefbc..a2c82ec6 100644 --- a/ext/versionfmt/dpkg/parser.go +++ b/ext/versionfmt/dpkg/parser.go @@ -120,6 +120,18 @@ func (p parser) Valid(str string) bool { return err == nil } +func (p parser) InRange(versionA, rangeB string) (bool, error) { + cmp, err := p.Compare(versionA, rangeB) + if err != nil { + return false, err + } + return cmp < 0, nil +} + +func (p parser) GetFixedIn(fixedIn string) (string, error) { + return fixedIn, nil +} + // Compare function compares two Debian-like package version // // The implementation is based on http://man.he.net/man5/deb-version diff --git a/ext/versionfmt/driver.go b/ext/versionfmt/driver.go index 42f6c5b8..03179cd1 100644 --- a/ext/versionfmt/driver.go +++ b/ext/versionfmt/driver.go @@ -19,6 +19,8 @@ package versionfmt import ( "errors" "sync" + + log "github.com/sirupsen/logrus" ) const ( @@ -50,6 +52,18 @@ type Parser interface { // Compare parses two different version strings. // Returns 0 when equal, -1 when a < b, 1 when b < a. Compare(a, b string) (int, error) + + // InRange computes if a is in range of b + // + // NOTE(Sida): For legacy version formats, rangeB is a version and + // always use if versionA < rangeB as threshold. + InRange(versionA, rangeB string) (bool, error) + + // GetFixedIn computes a fixed in version for a certain version range. + // + // NOTE(Sida): For legacy version formats, rangeA is a version and + // be returned directly becuase it was considered fixed in version. + GetFixedIn(rangeA string) (string, error) } // RegisterParser provides a way to dynamically register an implementation of a @@ -110,3 +124,28 @@ func Compare(format, versionA, versionB string) (int, error) { return versionParser.Compare(versionA, versionB) } + +// InRange is a helper function that checks if `versionA` is in `rangeB` +func InRange(format, version, versionRange string) (bool, error) { + versionParser, exists := GetParser(format) + if !exists { + return false, ErrUnknownVersionFormat + } + + in, err := versionParser.InRange(version, versionRange) + if err != nil { + log.WithFields(log.Fields{"Format": format, "Version": version, "Range": versionRange}).Error(err) + } + return in, err +} + +// GetFixedIn is a helper function that computes the next fixed in version given +// a affected version range `rangeA`. +func GetFixedIn(format, rangeA string) (string, error) { + versionParser, exists := GetParser(format) + if !exists { + return "", ErrUnknownVersionFormat + } + + return versionParser.GetFixedIn(rangeA) +} diff --git a/ext/versionfmt/rpm/parser.go b/ext/versionfmt/rpm/parser.go index 34fbb9b9..55266ca5 100644 --- a/ext/versionfmt/rpm/parser.go +++ b/ext/versionfmt/rpm/parser.go @@ -121,6 +121,20 @@ func (p parser) Valid(str string) bool { return err == nil } +func (p parser) InRange(versionA, rangeB string) (bool, error) { + cmp, err := p.Compare(versionA, rangeB) + if err != nil { + return false, err + } + return cmp < 0, nil +} + +func (p parser) GetFixedIn(fixedIn string) (string, error) { + // In the old version format parser design, the string to determine fixed in + // version is the fixed in version. + return fixedIn, nil +} + func (p parser) Compare(a, b string) (int, error) { v1, err := newVersion(a) if err != nil { diff --git a/ext/vulnsrc/alpine/alpine.go b/ext/vulnsrc/alpine/alpine.go index 271b6553..5b6f46e1 100644 --- a/ext/vulnsrc/alpine/alpine.go +++ b/ext/vulnsrc/alpine/alpine.go @@ -60,10 +60,20 @@ func (u *updater) Update(db database.Datastore) (resp vulnsrc.UpdateResponse, er // Ask the database for the latest commit we successfully applied. var dbCommit string - dbCommit, err = db.GetKeyValue(updaterFlag) + tx, err := db.Begin() if err != nil { return } + defer tx.Rollback() + + dbCommit, ok, err := tx.FindKeyValue(updaterFlag) + if err != nil { + return + } + + if !ok { + dbCommit = "" + } // Set the updaterFlag to equal the commit processed. resp.FlagName = updaterFlag @@ -84,7 +94,7 @@ func (u *updater) Update(db database.Datastore) (resp vulnsrc.UpdateResponse, er // Append any changed vulnerabilities to the response. for _, namespace := range namespaces { - var vulns []database.Vulnerability + var vulns []database.VulnerabilityWithAffected var note string vulns, note, err = parseVulnsFromNamespace(u.repositoryLocalPath, namespace) if err != nil { @@ -144,7 +154,7 @@ func ls(path string, filter lsFilter) ([]string, error) { return files, nil } -func parseVulnsFromNamespace(repositoryPath, namespace string) (vulns []database.Vulnerability, note string, err error) { +func parseVulnsFromNamespace(repositoryPath, namespace string) (vulns []database.VulnerabilityWithAffected, note string, err error) { nsDir := filepath.Join(repositoryPath, namespace) var dbFilenames []string dbFilenames, err = ls(nsDir, filesOnly) @@ -159,7 +169,7 @@ func parseVulnsFromNamespace(repositoryPath, namespace string) (vulns []database return } - var fileVulns []database.Vulnerability + var fileVulns []database.VulnerabilityWithAffected fileVulns, err = parseYAML(file) if err != nil { return @@ -216,7 +226,7 @@ type secDBFile struct { } `yaml:"packages"` } -func parseYAML(r io.Reader) (vulns []database.Vulnerability, err error) { +func parseYAML(r io.Reader) (vulns []database.VulnerabilityWithAffected, err error) { var rBytes []byte rBytes, err = ioutil.ReadAll(r) if err != nil { @@ -239,20 +249,24 @@ func parseYAML(r io.Reader) (vulns []database.Vulnerability, err error) { } for _, vulnStr := range vulnStrs { - var vuln database.Vulnerability + var vuln database.VulnerabilityWithAffected vuln.Severity = database.UnknownSeverity vuln.Name = vulnStr vuln.Link = nvdURLPrefix + vulnStr - vuln.FixedIn = []database.FeatureVersion{ + + var fixedInVersion string + if version != versionfmt.MaxVersion { + fixedInVersion = version + } + vuln.Affected = []database.AffectedFeature{ { - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "alpine:" + file.Distro, - VersionFormat: dpkg.ParserName, - }, - Name: pkg.Name, + FeatureName: pkg.Name, + AffectedVersion: version, + FixedInVersion: fixedInVersion, + Namespace: database.Namespace{ + Name: "alpine:" + file.Distro, + VersionFormat: dpkg.ParserName, }, - Version: version, }, } vulns = append(vulns, vuln) diff --git a/ext/vulnsrc/alpine/alpine_test.go b/ext/vulnsrc/alpine/alpine_test.go index ac95f5c5..eddcc759 100644 --- a/ext/vulnsrc/alpine/alpine_test.go +++ b/ext/vulnsrc/alpine/alpine_test.go @@ -36,7 +36,7 @@ func TestYAMLParsing(t *testing.T) { } assert.Equal(t, 105, len(vulns)) assert.Equal(t, "CVE-2016-5387", vulns[0].Name) - assert.Equal(t, "alpine:v3.4", vulns[0].FixedIn[0].Feature.Namespace.Name) - assert.Equal(t, "apache2", vulns[0].FixedIn[0].Feature.Name) + assert.Equal(t, "alpine:v3.4", vulns[0].Affected[0].Namespace.Name) + assert.Equal(t, "apache2", vulns[0].Affected[0].FeatureName) assert.Equal(t, "https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2016-5387", vulns[0].Link) } diff --git a/ext/vulnsrc/debian/debian.go b/ext/vulnsrc/debian/debian.go index 3288e46b..c0efc37e 100644 --- a/ext/vulnsrc/debian/debian.go +++ b/ext/vulnsrc/debian/debian.go @@ -62,6 +62,27 @@ func init() { func (u *updater) Update(datastore database.Datastore) (resp vulnsrc.UpdateResponse, err error) { log.WithField("package", "Debian").Info("Start fetching vulnerabilities") + tx, err := datastore.Begin() + if err != nil { + return resp, err + } + + // Get the SHA-1 of the latest update's JSON data + latestHash, ok, err := tx.FindKeyValue(updaterFlag) + if err != nil { + return resp, err + } + + // NOTE(sida): The transaction won't mutate the database and I want the + // transaction to be short. + if err := tx.Rollback(); err != nil { + return resp, err + } + + if !ok { + latestHash = "" + } + // Download JSON. r, err := http.Get(url) if err != nil { @@ -69,12 +90,6 @@ func (u *updater) Update(datastore database.Datastore) (resp vulnsrc.UpdateRespo return resp, commonerr.ErrCouldNotDownload } - // Get the SHA-1 of the latest update's JSON data - latestHash, err := datastore.GetKeyValue(updaterFlag) - if err != nil { - return resp, err - } - // Parse the JSON. resp, err = buildResponse(r.Body, latestHash) if err != nil { @@ -131,8 +146,8 @@ func buildResponse(jsonReader io.Reader, latestKnownHash string) (resp vulnsrc.U return resp, nil } -func parseDebianJSON(data *jsonData) (vulnerabilities []database.Vulnerability, unknownReleases map[string]struct{}) { - mvulnerabilities := make(map[string]*database.Vulnerability) +func parseDebianJSON(data *jsonData) (vulnerabilities []database.VulnerabilityWithAffected, unknownReleases map[string]struct{}) { + mvulnerabilities := make(map[string]*database.VulnerabilityWithAffected) unknownReleases = make(map[string]struct{}) for pkgName, pkgNode := range *data { @@ -145,6 +160,7 @@ func parseDebianJSON(data *jsonData) (vulnerabilities []database.Vulnerability, } // Skip if the status is not determined or the vulnerability is a temporary one. + // TODO: maybe add "undetermined" as Unknown severity. if !strings.HasPrefix(vulnName, "CVE-") || releaseNode.Status == "undetermined" { continue } @@ -152,11 +168,13 @@ func parseDebianJSON(data *jsonData) (vulnerabilities []database.Vulnerability, // Get or create the vulnerability. vulnerability, vulnerabilityAlreadyExists := mvulnerabilities[vulnName] if !vulnerabilityAlreadyExists { - vulnerability = &database.Vulnerability{ - Name: vulnName, - Link: strings.Join([]string{cveURLPrefix, "/", vulnName}, ""), - Severity: database.UnknownSeverity, - Description: vulnNode.Description, + vulnerability = &database.VulnerabilityWithAffected{ + Vulnerability: database.Vulnerability{ + Name: vulnName, + Link: strings.Join([]string{cveURLPrefix, "/", vulnName}, ""), + Severity: database.UnknownSeverity, + Description: vulnNode.Description, + }, } } @@ -171,10 +189,7 @@ func parseDebianJSON(data *jsonData) (vulnerabilities []database.Vulnerability, // Determine the version of the package the vulnerability affects. var version string var err error - if releaseNode.FixedVersion == "0" { - // This means that the package is not affected by this vulnerability. - version = versionfmt.MinVersion - } else if releaseNode.Status == "open" { + if releaseNode.Status == "open" { // Open means that the package is currently vulnerable in the latest // version of this Debian release. version = versionfmt.MaxVersion @@ -186,21 +201,34 @@ func parseDebianJSON(data *jsonData) (vulnerabilities []database.Vulnerability, log.WithError(err).WithField("version", version).Warning("could not parse package version. skipping") continue } - version = releaseNode.FixedVersion + + // FixedVersion = "0" means that the vulnerability affecting + // current feature is not that important + if releaseNode.FixedVersion != "0" { + version = releaseNode.FixedVersion + } + } + + if version == "" { + continue + } + + var fixedInVersion string + if version != versionfmt.MaxVersion { + fixedInVersion = version } // Create and add the feature version. - pkg := database.FeatureVersion{ - Feature: database.Feature{ - Name: pkgName, - Namespace: database.Namespace{ - Name: "debian:" + database.DebianReleasesMapping[releaseName], - VersionFormat: dpkg.ParserName, - }, + pkg := database.AffectedFeature{ + FeatureName: pkgName, + AffectedVersion: version, + FixedInVersion: fixedInVersion, + Namespace: database.Namespace{ + Name: "debian:" + database.DebianReleasesMapping[releaseName], + VersionFormat: dpkg.ParserName, }, - Version: version, } - vulnerability.FixedIn = append(vulnerability.FixedIn, pkg) + vulnerability.Affected = append(vulnerability.Affected, pkg) // Store the vulnerability. mvulnerabilities[vulnName] = vulnerability @@ -223,30 +251,16 @@ func SeverityFromUrgency(urgency string) database.Severity { case "not yet assigned": return database.UnknownSeverity - case "end-of-life": - fallthrough - case "unimportant": + case "end-of-life", "unimportant": return database.NegligibleSeverity - case "low": - fallthrough - case "low*": - fallthrough - case "low**": + case "low", "low*", "low**": return database.LowSeverity - case "medium": - fallthrough - case "medium*": - fallthrough - case "medium**": + case "medium", "medium*", "medium**": return database.MediumSeverity - case "high": - fallthrough - case "high*": - fallthrough - case "high**": + case "high", "high*", "high**": return database.HighSeverity default: diff --git a/ext/vulnsrc/debian/debian_test.go b/ext/vulnsrc/debian/debian_test.go index 1c62500c..3a6f9ace 100644 --- a/ext/vulnsrc/debian/debian_test.go +++ b/ext/vulnsrc/debian/debian_test.go @@ -32,103 +32,76 @@ func TestDebianParser(t *testing.T) { // Test parsing testdata/fetcher_debian_test.json testFile, _ := os.Open(filepath.Join(filepath.Dir(filename)) + "/testdata/fetcher_debian_test.json") response, err := buildResponse(testFile, "") - if assert.Nil(t, err) && assert.Len(t, response.Vulnerabilities, 3) { + if assert.Nil(t, err) && assert.Len(t, response.Vulnerabilities, 2) { for _, vulnerability := range response.Vulnerabilities { if vulnerability.Name == "CVE-2015-1323" { assert.Equal(t, "https://security-tracker.debian.org/tracker/CVE-2015-1323", vulnerability.Link) assert.Equal(t, database.LowSeverity, vulnerability.Severity) assert.Equal(t, "This vulnerability is not very dangerous.", vulnerability.Description) - expectedFeatureVersions := []database.FeatureVersion{ + expectedFeatures := []database.AffectedFeature{ { - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "debian:8", - VersionFormat: dpkg.ParserName, - }, - Name: "aptdaemon", + Namespace: database.Namespace{ + Name: "debian:8", + VersionFormat: dpkg.ParserName, }, - Version: versionfmt.MaxVersion, + FeatureName: "aptdaemon", + AffectedVersion: versionfmt.MaxVersion, }, { - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "debian:unstable", - VersionFormat: dpkg.ParserName, - }, - Name: "aptdaemon", + Namespace: database.Namespace{ + Name: "debian:unstable", + VersionFormat: dpkg.ParserName, }, - Version: "1.1.1+bzr982-1", + FeatureName: "aptdaemon", + AffectedVersion: "1.1.1+bzr982-1", + FixedInVersion: "1.1.1+bzr982-1", }, } - for _, expectedFeatureVersion := range expectedFeatureVersions { - assert.Contains(t, vulnerability.FixedIn, expectedFeatureVersion) + for _, expectedFeature := range expectedFeatures { + assert.Contains(t, vulnerability.Affected, expectedFeature) } } else if vulnerability.Name == "CVE-2003-0779" { assert.Equal(t, "https://security-tracker.debian.org/tracker/CVE-2003-0779", vulnerability.Link) assert.Equal(t, database.HighSeverity, vulnerability.Severity) assert.Equal(t, "But this one is very dangerous.", vulnerability.Description) - expectedFeatureVersions := []database.FeatureVersion{ + expectedFeatures := []database.AffectedFeature{ { - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "debian:8", - VersionFormat: dpkg.ParserName, - }, - Name: "aptdaemon", + Namespace: database.Namespace{ + Name: "debian:8", + VersionFormat: dpkg.ParserName, }, - Version: "0.7.0", + FeatureName: "aptdaemon", + FixedInVersion: "0.7.0", + AffectedVersion: "0.7.0", }, { - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "debian:unstable", - VersionFormat: dpkg.ParserName, - }, - Name: "aptdaemon", + Namespace: database.Namespace{ + Name: "debian:unstable", + VersionFormat: dpkg.ParserName, }, - Version: "0.7.0", + FeatureName: "aptdaemon", + FixedInVersion: "0.7.0", + AffectedVersion: "0.7.0", }, { - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "debian:8", - VersionFormat: dpkg.ParserName, - }, - Name: "asterisk", + Namespace: database.Namespace{ + Name: "debian:8", + VersionFormat: dpkg.ParserName, }, - Version: "0.5.56", + FeatureName: "asterisk", + FixedInVersion: "0.5.56", + AffectedVersion: "0.5.56", }, } - for _, expectedFeatureVersion := range expectedFeatureVersions { - assert.Contains(t, vulnerability.FixedIn, expectedFeatureVersion) - } - } else if vulnerability.Name == "CVE-2013-2685" { - assert.Equal(t, "https://security-tracker.debian.org/tracker/CVE-2013-2685", vulnerability.Link) - assert.Equal(t, database.NegligibleSeverity, vulnerability.Severity) - assert.Equal(t, "Un-affected packages.", vulnerability.Description) - - expectedFeatureVersions := []database.FeatureVersion{ - { - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "debian:8", - VersionFormat: dpkg.ParserName, - }, - Name: "asterisk", - }, - Version: versionfmt.MinVersion, - }, - } - - for _, expectedFeatureVersion := range expectedFeatureVersions { - assert.Contains(t, vulnerability.FixedIn, expectedFeatureVersion) + for _, expectedFeature := range expectedFeatures { + assert.Contains(t, vulnerability.Affected, expectedFeature) } } else { - assert.Fail(t, "Wrong vulnerability name: ", vulnerability.ID) + assert.Fail(t, "Wrong vulnerability name: ", vulnerability.Namespace.Name+":"+vulnerability.Name) } } } diff --git a/ext/vulnsrc/driver.go b/ext/vulnsrc/driver.go index fd442416..91b28831 100644 --- a/ext/vulnsrc/driver.go +++ b/ext/vulnsrc/driver.go @@ -39,11 +39,10 @@ type UpdateResponse struct { FlagName string FlagValue string Notes []string - Vulnerabilities []database.Vulnerability + Vulnerabilities []database.VulnerabilityWithAffected } -// Updater represents anything that can fetch vulnerabilities and insert them -// into a Clair datastore. +// Updater represents anything that can fetch vulnerabilities. type Updater interface { // Update gets vulnerability updates. Update(database.Datastore) (UpdateResponse, error) @@ -88,3 +87,12 @@ func Updaters() map[string]Updater { return ret } + +// ListUpdaters returns the names of registered vulnerability updaters. +func ListUpdaters() []string { + r := []string{} + for u := range updaters { + r = append(r, u) + } + return r +} diff --git a/ext/vulnsrc/oracle/oracle.go b/ext/vulnsrc/oracle/oracle.go index ee6a8343..40dcd669 100644 --- a/ext/vulnsrc/oracle/oracle.go +++ b/ext/vulnsrc/oracle/oracle.go @@ -118,10 +118,20 @@ func compareELSA(left, right int) int { func (u *updater) Update(datastore database.Datastore) (resp vulnsrc.UpdateResponse, err error) { log.WithField("package", "Oracle Linux").Info("Start fetching vulnerabilities") // Get the first ELSA we have to manage. - flagValue, err := datastore.GetKeyValue(updaterFlag) + tx, err := datastore.Begin() if err != nil { return resp, err } + defer tx.Rollback() + + flagValue, ok, err := tx.FindKeyValue(updaterFlag) + if err != nil { + return resp, err + } + + if !ok { + flagValue = "" + } firstELSA, err := strconv.Atoi(flagValue) if firstELSA == 0 || err != nil { @@ -192,7 +202,7 @@ func largest(list []int) (largest int) { func (u *updater) Clean() {} -func parseELSA(ovalReader io.Reader) (vulnerabilities []database.Vulnerability, err error) { +func parseELSA(ovalReader io.Reader) (vulnerabilities []database.VulnerabilityWithAffected, err error) { // Decode the XML. var ov oval err = xml.NewDecoder(ovalReader).Decode(&ov) @@ -205,16 +215,18 @@ func parseELSA(ovalReader io.Reader) (vulnerabilities []database.Vulnerability, // Iterate over the definitions and collect any vulnerabilities that affect // at least one package. for _, definition := range ov.Definitions { - pkgs := toFeatureVersions(definition.Criteria) + pkgs := toFeatures(definition.Criteria) if len(pkgs) > 0 { - vulnerability := database.Vulnerability{ - Name: name(definition), - Link: link(definition), - Severity: severity(definition), - Description: description(definition), + vulnerability := database.VulnerabilityWithAffected{ + Vulnerability: database.Vulnerability{ + Name: name(definition), + Link: link(definition), + Severity: severity(definition), + Description: description(definition), + }, } for _, p := range pkgs { - vulnerability.FixedIn = append(vulnerability.FixedIn, p) + vulnerability.Affected = append(vulnerability.Affected, p) } vulnerabilities = append(vulnerabilities, vulnerability) } @@ -298,15 +310,15 @@ func getPossibilities(node criteria) [][]criterion { return possibilities } -func toFeatureVersions(criteria criteria) []database.FeatureVersion { +func toFeatures(criteria criteria) []database.AffectedFeature { // There are duplicates in Oracle .xml files. // This map is for deduplication. - featureVersionParameters := make(map[string]database.FeatureVersion) + featureVersionParameters := make(map[string]database.AffectedFeature) possibilities := getPossibilities(criteria) for _, criterions := range possibilities { var ( - featureVersion database.FeatureVersion + featureVersion database.AffectedFeature osVersion int err error ) @@ -321,29 +333,32 @@ func toFeatureVersions(criteria criteria) []database.FeatureVersion { } } else if strings.Contains(c.Comment, " is earlier than ") { const prefixLen = len(" is earlier than ") - featureVersion.Feature.Name = strings.TrimSpace(c.Comment[:strings.Index(c.Comment, " is earlier than ")]) + featureVersion.FeatureName = strings.TrimSpace(c.Comment[:strings.Index(c.Comment, " is earlier than ")]) version := c.Comment[strings.Index(c.Comment, " is earlier than ")+prefixLen:] err := versionfmt.Valid(rpm.ParserName, version) if err != nil { log.WithError(err).WithField("version", version).Warning("could not parse package version. skipping") } else { - featureVersion.Version = version + featureVersion.AffectedVersion = version + if version != versionfmt.MaxVersion { + featureVersion.FixedInVersion = version + } } } } - featureVersion.Feature.Namespace.Name = "oracle" + ":" + strconv.Itoa(osVersion) - featureVersion.Feature.Namespace.VersionFormat = rpm.ParserName + featureVersion.Namespace.Name = "oracle" + ":" + strconv.Itoa(osVersion) + featureVersion.Namespace.VersionFormat = rpm.ParserName - if featureVersion.Feature.Namespace.Name != "" && featureVersion.Feature.Name != "" && featureVersion.Version != "" { - featureVersionParameters[featureVersion.Feature.Namespace.Name+":"+featureVersion.Feature.Name] = featureVersion + if featureVersion.Namespace.Name != "" && featureVersion.FeatureName != "" && featureVersion.AffectedVersion != "" && featureVersion.FixedInVersion != "" { + featureVersionParameters[featureVersion.Namespace.Name+":"+featureVersion.FeatureName] = featureVersion } else { log.WithField("criterions", fmt.Sprintf("%v", criterions)).Warning("could not determine a valid package from criterions") } } // Convert the map to slice. - var featureVersionParametersArray []database.FeatureVersion + var featureVersionParametersArray []database.AffectedFeature for _, fv := range featureVersionParameters { featureVersionParametersArray = append(featureVersionParametersArray, fv) } diff --git a/ext/vulnsrc/oracle/oracle_test.go b/ext/vulnsrc/oracle/oracle_test.go index bab98bcc..a9348d48 100644 --- a/ext/vulnsrc/oracle/oracle_test.go +++ b/ext/vulnsrc/oracle/oracle_test.go @@ -40,41 +40,38 @@ func TestOracleParser(t *testing.T) { assert.Equal(t, database.MediumSeverity, vulnerabilities[0].Severity) assert.Equal(t, ` [3.1.1-7] Resolves: rhbz#1217104 CVE-2015-0252 `, vulnerabilities[0].Description) - expectedFeatureVersions := []database.FeatureVersion{ + expectedFeatures := []database.AffectedFeature{ { - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "oracle:7", - VersionFormat: rpm.ParserName, - }, - Name: "xerces-c", + Namespace: database.Namespace{ + Name: "oracle:7", + VersionFormat: rpm.ParserName, }, - Version: "0:3.1.1-7.el7_1", + FeatureName: "xerces-c", + FixedInVersion: "0:3.1.1-7.el7_1", + AffectedVersion: "0:3.1.1-7.el7_1", }, { - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "oracle:7", - VersionFormat: rpm.ParserName, - }, - Name: "xerces-c-devel", + Namespace: database.Namespace{ + Name: "oracle:7", + VersionFormat: rpm.ParserName, }, - Version: "0:3.1.1-7.el7_1", + FeatureName: "xerces-c-devel", + FixedInVersion: "0:3.1.1-7.el7_1", + AffectedVersion: "0:3.1.1-7.el7_1", }, { - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "oracle:7", - VersionFormat: rpm.ParserName, - }, - Name: "xerces-c-doc", + Namespace: database.Namespace{ + Name: "oracle:7", + VersionFormat: rpm.ParserName, }, - Version: "0:3.1.1-7.el7_1", + FeatureName: "xerces-c-doc", + FixedInVersion: "0:3.1.1-7.el7_1", + AffectedVersion: "0:3.1.1-7.el7_1", }, } - for _, expectedFeatureVersion := range expectedFeatureVersions { - assert.Contains(t, vulnerabilities[0].FixedIn, expectedFeatureVersion) + for _, expectedFeature := range expectedFeatures { + assert.Contains(t, vulnerabilities[0].Affected, expectedFeature) } } @@ -87,31 +84,29 @@ func TestOracleParser(t *testing.T) { assert.Equal(t, "http://linux.oracle.com/errata/ELSA-2015-1207.html", vulnerabilities[0].Link) assert.Equal(t, database.CriticalSeverity, vulnerabilities[0].Severity) assert.Equal(t, ` [38.1.0-1.0.1.el7_1] - Add firefox-oracle-default-prefs.js and remove the corresponding Red Hat file [38.1.0-1] - Update to 38.1.0 ESR [38.0.1-2] - Fixed rhbz#1222807 by removing preun section `, vulnerabilities[0].Description) - expectedFeatureVersions := []database.FeatureVersion{ + expectedFeatures := []database.AffectedFeature{ { - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "oracle:6", - VersionFormat: rpm.ParserName, - }, - Name: "firefox", + Namespace: database.Namespace{ + Name: "oracle:6", + VersionFormat: rpm.ParserName, }, - Version: "0:38.1.0-1.0.1.el6_6", + FeatureName: "firefox", + FixedInVersion: "0:38.1.0-1.0.1.el6_6", + AffectedVersion: "0:38.1.0-1.0.1.el6_6", }, { - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "oracle:7", - VersionFormat: rpm.ParserName, - }, - Name: "firefox", + Namespace: database.Namespace{ + Name: "oracle:7", + VersionFormat: rpm.ParserName, }, - Version: "0:38.1.0-1.0.1.el7_1", + FeatureName: "firefox", + FixedInVersion: "0:38.1.0-1.0.1.el7_1", + AffectedVersion: "0:38.1.0-1.0.1.el7_1", }, } - for _, expectedFeatureVersion := range expectedFeatureVersions { - assert.Contains(t, vulnerabilities[0].FixedIn, expectedFeatureVersion) + for _, expectedFeature := range expectedFeatures { + assert.Contains(t, vulnerabilities[0].Affected, expectedFeature) } } } diff --git a/ext/vulnsrc/rhel/rhel.go b/ext/vulnsrc/rhel/rhel.go index bbd48c15..f4cbce8f 100644 --- a/ext/vulnsrc/rhel/rhel.go +++ b/ext/vulnsrc/rhel/rhel.go @@ -90,11 +90,26 @@ func init() { func (u *updater) Update(datastore database.Datastore) (resp vulnsrc.UpdateResponse, err error) { log.WithField("package", "RHEL").Info("Start fetching vulnerabilities") - // Get the first RHSA we have to manage. - flagValue, err := datastore.GetKeyValue(updaterFlag) + + tx, err := datastore.Begin() if err != nil { return resp, err } + + // Get the first RHSA we have to manage. + flagValue, ok, err := tx.FindKeyValue(updaterFlag) + if err != nil { + return resp, err + } + + if err := tx.Rollback(); err != nil { + return resp, err + } + + if !ok { + flagValue = "" + } + firstRHSA, err := strconv.Atoi(flagValue) if firstRHSA == 0 || err != nil { firstRHSA = firstRHEL5RHSA @@ -154,7 +169,7 @@ func (u *updater) Update(datastore database.Datastore) (resp vulnsrc.UpdateRespo func (u *updater) Clean() {} -func parseRHSA(ovalReader io.Reader) (vulnerabilities []database.Vulnerability, err error) { +func parseRHSA(ovalReader io.Reader) (vulnerabilities []database.VulnerabilityWithAffected, err error) { // Decode the XML. var ov oval err = xml.NewDecoder(ovalReader).Decode(&ov) @@ -167,16 +182,18 @@ func parseRHSA(ovalReader io.Reader) (vulnerabilities []database.Vulnerability, // Iterate over the definitions and collect any vulnerabilities that affect // at least one package. for _, definition := range ov.Definitions { - pkgs := toFeatureVersions(definition.Criteria) + pkgs := toFeatures(definition.Criteria) if len(pkgs) > 0 { - vulnerability := database.Vulnerability{ - Name: name(definition), - Link: link(definition), - Severity: severity(definition), - Description: description(definition), + vulnerability := database.VulnerabilityWithAffected{ + Vulnerability: database.Vulnerability{ + Name: name(definition), + Link: link(definition), + Severity: severity(definition), + Description: description(definition), + }, } for _, p := range pkgs { - vulnerability.FixedIn = append(vulnerability.FixedIn, p) + vulnerability.Affected = append(vulnerability.Affected, p) } vulnerabilities = append(vulnerabilities, vulnerability) } @@ -260,15 +277,15 @@ func getPossibilities(node criteria) [][]criterion { return possibilities } -func toFeatureVersions(criteria criteria) []database.FeatureVersion { +func toFeatures(criteria criteria) []database.AffectedFeature { // There are duplicates in Red Hat .xml files. // This map is for deduplication. - featureVersionParameters := make(map[string]database.FeatureVersion) + featureVersionParameters := make(map[string]database.AffectedFeature) possibilities := getPossibilities(criteria) for _, criterions := range possibilities { var ( - featureVersion database.FeatureVersion + featureVersion database.AffectedFeature osVersion int err error ) @@ -283,34 +300,37 @@ func toFeatureVersions(criteria criteria) []database.FeatureVersion { } } else if strings.Contains(c.Comment, " is earlier than ") { const prefixLen = len(" is earlier than ") - featureVersion.Feature.Name = strings.TrimSpace(c.Comment[:strings.Index(c.Comment, " is earlier than ")]) + featureVersion.FeatureName = strings.TrimSpace(c.Comment[:strings.Index(c.Comment, " is earlier than ")]) version := c.Comment[strings.Index(c.Comment, " is earlier than ")+prefixLen:] err := versionfmt.Valid(rpm.ParserName, version) if err != nil { log.WithError(err).WithField("version", version).Warning("could not parse package version. skipping") } else { - featureVersion.Version = version - featureVersion.Feature.Namespace.VersionFormat = rpm.ParserName + featureVersion.AffectedVersion = version + if version != versionfmt.MaxVersion { + featureVersion.FixedInVersion = version + } + featureVersion.Namespace.VersionFormat = rpm.ParserName } } } if osVersion >= firstConsideredRHEL { // TODO(vbatts) this is where features need multiple labels ('centos' and 'rhel') - featureVersion.Feature.Namespace.Name = "centos" + ":" + strconv.Itoa(osVersion) + featureVersion.Namespace.Name = "centos" + ":" + strconv.Itoa(osVersion) } else { continue } - if featureVersion.Feature.Namespace.Name != "" && featureVersion.Feature.Name != "" && featureVersion.Version != "" { - featureVersionParameters[featureVersion.Feature.Namespace.Name+":"+featureVersion.Feature.Name] = featureVersion + if featureVersion.Namespace.Name != "" && featureVersion.FeatureName != "" && featureVersion.AffectedVersion != "" && featureVersion.FixedInVersion != "" { + featureVersionParameters[featureVersion.Namespace.Name+":"+featureVersion.FeatureName] = featureVersion } else { log.WithField("criterions", fmt.Sprintf("%v", criterions)).Warning("could not determine a valid package from criterions") } } // Convert the map to slice. - var featureVersionParametersArray []database.FeatureVersion + var featureVersionParametersArray []database.AffectedFeature for _, fv := range featureVersionParameters { featureVersionParametersArray = append(featureVersionParametersArray, fv) } diff --git a/ext/vulnsrc/rhel/rhel_test.go b/ext/vulnsrc/rhel/rhel_test.go index db762610..e91ec502 100644 --- a/ext/vulnsrc/rhel/rhel_test.go +++ b/ext/vulnsrc/rhel/rhel_test.go @@ -38,41 +38,38 @@ func TestRHELParser(t *testing.T) { assert.Equal(t, database.MediumSeverity, vulnerabilities[0].Severity) assert.Equal(t, `Xerces-C is a validating XML parser written in a portable subset of C++. A flaw was found in the way the Xerces-C XML parser processed certain XML documents. A remote attacker could provide specially crafted XML input that, when parsed by an application using Xerces-C, would cause that application to crash.`, vulnerabilities[0].Description) - expectedFeatureVersions := []database.FeatureVersion{ + expectedFeatures := []database.AffectedFeature{ { - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "centos:7", - VersionFormat: rpm.ParserName, - }, - Name: "xerces-c", + Namespace: database.Namespace{ + Name: "centos:7", + VersionFormat: rpm.ParserName, }, - Version: "0:3.1.1-7.el7_1", + FeatureName: "xerces-c", + AffectedVersion: "0:3.1.1-7.el7_1", + FixedInVersion: "0:3.1.1-7.el7_1", }, { - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "centos:7", - VersionFormat: rpm.ParserName, - }, - Name: "xerces-c-devel", + Namespace: database.Namespace{ + Name: "centos:7", + VersionFormat: rpm.ParserName, }, - Version: "0:3.1.1-7.el7_1", + FeatureName: "xerces-c-devel", + AffectedVersion: "0:3.1.1-7.el7_1", + FixedInVersion: "0:3.1.1-7.el7_1", }, { - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "centos:7", - VersionFormat: rpm.ParserName, - }, - Name: "xerces-c-doc", + Namespace: database.Namespace{ + Name: "centos:7", + VersionFormat: rpm.ParserName, }, - Version: "0:3.1.1-7.el7_1", + FeatureName: "xerces-c-doc", + AffectedVersion: "0:3.1.1-7.el7_1", + FixedInVersion: "0:3.1.1-7.el7_1", }, } - for _, expectedFeatureVersion := range expectedFeatureVersions { - assert.Contains(t, vulnerabilities[0].FixedIn, expectedFeatureVersion) + for _, expectedFeature := range expectedFeatures { + assert.Contains(t, vulnerabilities[0].Affected, expectedFeature) } } @@ -85,31 +82,29 @@ func TestRHELParser(t *testing.T) { assert.Equal(t, database.CriticalSeverity, vulnerabilities[0].Severity) assert.Equal(t, `Mozilla Firefox is an open source web browser. XULRunner provides the XUL Runtime environment for Mozilla Firefox. Several flaws were found in the processing of malformed web content. A web page containing malicious content could cause Firefox to crash or, potentially, execute arbitrary code with the privileges of the user running Firefox.`, vulnerabilities[0].Description) - expectedFeatureVersions := []database.FeatureVersion{ + expectedFeatures := []database.AffectedFeature{ { - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "centos:6", - VersionFormat: rpm.ParserName, - }, - Name: "firefox", + Namespace: database.Namespace{ + Name: "centos:6", + VersionFormat: rpm.ParserName, }, - Version: "0:38.1.0-1.el6_6", + FeatureName: "firefox", + FixedInVersion: "0:38.1.0-1.el6_6", + AffectedVersion: "0:38.1.0-1.el6_6", }, { - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "centos:7", - VersionFormat: rpm.ParserName, - }, - Name: "firefox", + Namespace: database.Namespace{ + Name: "centos:7", + VersionFormat: rpm.ParserName, }, - Version: "0:38.1.0-1.el7_1", + FeatureName: "firefox", + FixedInVersion: "0:38.1.0-1.el7_1", + AffectedVersion: "0:38.1.0-1.el7_1", }, } - for _, expectedFeatureVersion := range expectedFeatureVersions { - assert.Contains(t, vulnerabilities[0].FixedIn, expectedFeatureVersion) + for _, expectedFeature := range expectedFeatures { + assert.Contains(t, vulnerabilities[0].Affected, expectedFeature) } } } diff --git a/ext/vulnsrc/ubuntu/ubuntu.go b/ext/vulnsrc/ubuntu/ubuntu.go index 28803c76..6af0c14c 100644 --- a/ext/vulnsrc/ubuntu/ubuntu.go +++ b/ext/vulnsrc/ubuntu/ubuntu.go @@ -98,12 +98,25 @@ func (u *updater) Update(datastore database.Datastore) (resp vulnsrc.UpdateRespo return resp, err } - // Get the latest revision number we successfully applied in the database. - dbRevisionNumber, err := datastore.GetKeyValue("ubuntuUpdater") + tx, err := datastore.Begin() if err != nil { return resp, err } + // Get the latest revision number we successfully applied in the database. + dbRevisionNumber, ok, err := tx.FindKeyValue("ubuntuUpdater") + if err != nil { + return resp, err + } + + if err := tx.Rollback(); err != nil { + return resp, err + } + + if !ok { + dbRevisionNumber = "" + } + // Get the list of vulnerabilities that we have to update. modifiedCVE, err := collectModifiedVulnerabilities(revisionNumber, dbRevisionNumber, u.repositoryLocalPath) if err != nil { @@ -278,11 +291,15 @@ func collectModifiedVulnerabilities(revision int, dbRevision, repositoryLocalPat return modifiedCVE, nil } -func parseUbuntuCVE(fileContent io.Reader) (vulnerability database.Vulnerability, unknownReleases map[string]struct{}, err error) { +func parseUbuntuCVE(fileContent io.Reader) (vulnerability database.VulnerabilityWithAffected, unknownReleases map[string]struct{}, err error) { unknownReleases = make(map[string]struct{}) readingDescription := false scanner := bufio.NewScanner(fileContent) + // only unique major releases will be considered. All sub releases' (e.g. + // precise/esm) features are considered belong to major releases. + uniqueRelease := map[string]struct{}{} + for scanner.Scan() { line := strings.TrimSpace(scanner.Text()) @@ -344,7 +361,7 @@ func parseUbuntuCVE(fileContent io.Reader) (vulnerability database.Vulnerability // Only consider the package if its status is needed, active, deferred, not-affected or // released. Ignore DNE (package does not exist), needs-triage, ignored, pending. if md["status"] == "needed" || md["status"] == "active" || md["status"] == "deferred" || md["status"] == "released" || md["status"] == "not-affected" { - md["release"] = strings.Split(md["release"], "/")[0] + md["release"] = strings.Split(md["release"], "/")[0] if _, isReleaseIgnored := ubuntuIgnoredReleases[md["release"]]; isReleaseIgnored { continue } @@ -363,8 +380,6 @@ func parseUbuntuCVE(fileContent io.Reader) (vulnerability database.Vulnerability } version = md["note"] } - } else if md["status"] == "not-affected" { - version = versionfmt.MinVersion } else { version = versionfmt.MaxVersion } @@ -372,18 +387,30 @@ func parseUbuntuCVE(fileContent io.Reader) (vulnerability database.Vulnerability continue } - // Create and add the new package. - featureVersion := database.FeatureVersion{ - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "ubuntu:" + database.UbuntuReleasesMapping[md["release"]], - VersionFormat: dpkg.ParserName, - }, - Name: md["package"], - }, - Version: version, + releaseName := "ubuntu:" + database.UbuntuReleasesMapping[md["release"]] + if _, ok := uniqueRelease[releaseName+"_:_"+md["package"]]; ok { + continue } - vulnerability.FixedIn = append(vulnerability.FixedIn, featureVersion) + + uniqueRelease[releaseName+"_:_"+md["package"]] = struct{}{} + var fixedinVersion string + if version == versionfmt.MaxVersion { + fixedinVersion = "" + } else { + fixedinVersion = version + } + + // Create and add the new package. + featureVersion := database.AffectedFeature{ + Namespace: database.Namespace{ + Name: releaseName, + VersionFormat: dpkg.ParserName, + }, + FeatureName: md["package"], + AffectedVersion: version, + FixedInVersion: fixedinVersion, + } + vulnerability.Affected = append(vulnerability.Affected, featureVersion) } } } diff --git a/ext/vulnsrc/ubuntu/ubuntu_test.go b/ext/vulnsrc/ubuntu/ubuntu_test.go index 5cdbd9a4..a4bd8afd 100644 --- a/ext/vulnsrc/ubuntu/ubuntu_test.go +++ b/ext/vulnsrc/ubuntu/ubuntu_test.go @@ -44,41 +44,37 @@ func TestUbuntuParser(t *testing.T) { _, hasUnkownRelease := unknownReleases["unknown"] assert.True(t, hasUnkownRelease) - expectedFeatureVersions := []database.FeatureVersion{ + expectedFeatures := []database.AffectedFeature{ { - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "ubuntu:14.04", - VersionFormat: dpkg.ParserName, - }, - Name: "libmspack", + Namespace: database.Namespace{ + Name: "ubuntu:14.04", + VersionFormat: dpkg.ParserName, }, - Version: versionfmt.MaxVersion, + FeatureName: "libmspack", + AffectedVersion: versionfmt.MaxVersion, }, { - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "ubuntu:15.04", - VersionFormat: dpkg.ParserName, - }, - Name: "libmspack", + Namespace: database.Namespace{ + Name: "ubuntu:15.04", + VersionFormat: dpkg.ParserName, }, - Version: "0.4-3", + FeatureName: "libmspack", + FixedInVersion: "0.4-3", + AffectedVersion: "0.4-3", }, { - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "ubuntu:15.10", - VersionFormat: dpkg.ParserName, - }, - Name: "libmspack-anotherpkg", + Namespace: database.Namespace{ + Name: "ubuntu:15.10", + VersionFormat: dpkg.ParserName, }, - Version: "0.1", + FeatureName: "libmspack-anotherpkg", + FixedInVersion: "0.1", + AffectedVersion: "0.1", }, } - for _, expectedFeatureVersion := range expectedFeatureVersions { - assert.Contains(t, vulnerability.FixedIn, expectedFeatureVersion) + for _, expectedFeature := range expectedFeatures { + assert.Contains(t, vulnerability.Affected, expectedFeature) } } } diff --git a/notifier.go b/notifier.go index ad3e947c..3b4d5f49 100644 --- a/notifier.go +++ b/notifier.go @@ -24,7 +24,6 @@ import ( "github.com/coreos/clair/database" "github.com/coreos/clair/ext/notification" - "github.com/coreos/clair/pkg/commonerr" "github.com/coreos/clair/pkg/stopper" ) @@ -94,14 +93,16 @@ func RunNotifier(config *notification.Config, datastore database.Datastore, stop go func() { success, interrupted := handleTask(*notification, stopper, config.Attempts) if success { - datastore.SetNotificationNotified(notification.Name) - + err := markNotificationNotified(datastore, notification.Name) + if err != nil { + log.WithError(err).Error("Failed to mark notification notified") + } promNotifierLatencyMilliseconds.Observe(float64(time.Since(notification.Created).Nanoseconds()) / float64(time.Millisecond)) } if interrupted { running = false } - datastore.Unlock(notification.Name, whoAmI) + unlock(datastore, notification.Name, whoAmI) done <- true }() @@ -112,7 +113,10 @@ func RunNotifier(config *notification.Config, datastore database.Datastore, stop case <-done: break outer case <-time.After(notifierLockRefreshDuration): - datastore.Lock(notification.Name, whoAmI, notifierLockDuration, true) + lock(datastore, notification.Name, whoAmI, notifierLockDuration, true) + case <-stopper.Chan(): + running = false + break } } } @@ -120,13 +124,11 @@ func RunNotifier(config *notification.Config, datastore database.Datastore, stop log.Info("notifier service stopped") } -func findTask(datastore database.Datastore, renotifyInterval time.Duration, whoAmI string, stopper *stopper.Stopper) *database.VulnerabilityNotification { +func findTask(datastore database.Datastore, renotifyInterval time.Duration, whoAmI string, stopper *stopper.Stopper) *database.NotificationHook { for { - // Find a notification to send. - notification, err := datastore.GetAvailableNotification(renotifyInterval) - if err != nil { - // There is no notification or an error occurred. - if err != commonerr.ErrNotFound { + notification, ok, err := findNewNotification(datastore, renotifyInterval) + if err != nil || !ok { + if !ok { log.WithError(err).Warning("could not get notification to send") } @@ -139,14 +141,14 @@ func findTask(datastore database.Datastore, renotifyInterval time.Duration, whoA } // Lock the notification. - if hasLock, _ := datastore.Lock(notification.Name, whoAmI, notifierLockDuration, false); hasLock { + if hasLock, _ := lock(datastore, notification.Name, whoAmI, notifierLockDuration, false); hasLock { log.WithField(logNotiName, notification.Name).Info("found and locked a notification") return ¬ification } } } -func handleTask(n database.VulnerabilityNotification, st *stopper.Stopper, maxAttempts int) (bool, bool) { +func handleTask(n database.NotificationHook, st *stopper.Stopper, maxAttempts int) (bool, bool) { // Send notification. for senderName, sender := range notification.Senders() { var attempts int @@ -167,7 +169,7 @@ func handleTask(n database.VulnerabilityNotification, st *stopper.Stopper, maxAt } // Send using the current notifier. - if err := sender.Send(n); err != nil { + if err := sender.Send(n.Name); err != nil { // Send failed; increase attempts/backoff and retry. promNotifierBackendErrorsTotal.WithLabelValues(senderName).Inc() log.WithError(err).WithFields(log.Fields{logSenderName: senderName, logNotiName: n.Name}).Error("could not send notification via notifier") @@ -184,3 +186,66 @@ func handleTask(n database.VulnerabilityNotification, st *stopper.Stopper, maxAt log.WithField(logNotiName, n.Name).Info("successfully sent notification") return true, false } + +func findNewNotification(datastore database.Datastore, renotifyInterval time.Duration) (database.NotificationHook, bool, error) { + tx, err := datastore.Begin() + if err != nil { + return database.NotificationHook{}, false, err + } + defer tx.Rollback() + return tx.FindNewNotification(time.Now().Add(-renotifyInterval)) +} + +func markNotificationNotified(datastore database.Datastore, name string) error { + tx, err := datastore.Begin() + if err != nil { + log.WithError(err).Error("an error happens when beginning database transaction") + } + defer tx.Rollback() + + if err := tx.MarkNotificationNotified(name); err != nil { + return err + } + return tx.Commit() +} + +// unlock removes a lock with provided name, owner. Internally, it handles +// database transaction and catches error. +func unlock(datastore database.Datastore, name, owner string) { + tx, err := datastore.Begin() + if err != nil { + return + } + + defer tx.Rollback() + + if err := tx.Unlock(name, owner); err != nil { + return + } + if err := tx.Commit(); err != nil { + return + } +} + +func lock(datastore database.Datastore, name string, owner string, duration time.Duration, renew bool) (bool, time.Time) { + // any error will cause the function to catch the error and return false. + tx, err := datastore.Begin() + if err != nil { + return false, time.Time{} + } + + defer tx.Rollback() + + locked, t, err := tx.Lock(name, owner, duration, renew) + if err != nil { + return false, time.Time{} + } + + if locked { + if err := tx.Commit(); err != nil { + return false, time.Time{} + } + } + + return locked, t +} diff --git a/pkg/commonerr/errors.go b/pkg/commonerr/errors.go index 1e690eea..6b268d74 100644 --- a/pkg/commonerr/errors.go +++ b/pkg/commonerr/errors.go @@ -16,7 +16,11 @@ // codebase. package commonerr -import "errors" +import ( + "errors" + "fmt" + "strings" +) var ( // ErrFilesystem occurs when a filesystem interaction fails. @@ -45,3 +49,19 @@ func NewBadRequestError(message string) error { func (e *ErrBadRequest) Error() string { return e.s } + +// CombineErrors merges a slice of errors into one separated by ";". If all +// errors are nil, return nil. +func CombineErrors(errs ...error) error { + errStr := []string{} + for i, err := range errs { + if err != nil { + errStr = append(errStr, fmt.Sprintf("[%d] %s", i, err.Error())) + } + } + + if len(errStr) != 0 { + return errors.New(strings.Join(errStr, ";")) + } + return nil +} diff --git a/pkg/strutil/strutil.go b/pkg/strutil/strutil.go new file mode 100644 index 00000000..a8d04f21 --- /dev/null +++ b/pkg/strutil/strutil.go @@ -0,0 +1,55 @@ +// Copyright 2017 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package strutil + +// CompareStringLists returns the strings that are present in X but not in Y. +func CompareStringLists(X, Y []string) []string { + m := make(map[string]bool) + + for _, y := range Y { + m[y] = true + } + + diff := []string{} + for _, x := range X { + if m[x] { + continue + } + + diff = append(diff, x) + m[x] = true + } + + return diff +} + +// CompareStringListsInBoth returns the strings that are present in both X and Y. +func CompareStringListsInBoth(X, Y []string) []string { + m := make(map[string]struct{}) + + for _, y := range Y { + m[y] = struct{}{} + } + + diff := []string{} + for _, x := range X { + if _, e := m[x]; e { + diff = append(diff, x) + delete(m, x) + } + } + + return diff +} diff --git a/database/pgsql/migrations/00007_expand_column_width.go b/pkg/strutil/strutil_test.go similarity index 53% rename from database/pgsql/migrations/00007_expand_column_width.go rename to pkg/strutil/strutil_test.go index 8bfdaaab..4cbf1e90 100644 --- a/database/pgsql/migrations/00007_expand_column_width.go +++ b/pkg/strutil/strutil_test.go @@ -12,20 +12,23 @@ // See the License for the specific language governing permissions and // limitations under the License. -package migrations +package strutil -import "github.com/remind101/migrate" +import ( + "testing" -func init() { - RegisterMigration(migrate.Migration{ - ID: 7, - Up: migrate.Queries([]string{ - `ALTER TABLE Namespace ALTER COLUMN version_format SET DATA TYPE varchar(256);`, - `ALTER TABLE Layer ALTER COLUMN name SET DATA TYPE varchar(256);`, - }), - Down: migrate.Queries([]string{ - `ALTER TABLE Namespace ALTER COLUMN version_format SET DATA TYPE varchar(128);`, - `ALTER TABLE Layer ALTER COLUMN name SET DATA TYPE varchar(128);`, - }), - }) + "github.com/stretchr/testify/assert" +) + +func TestStringComparison(t *testing.T) { + cmp := CompareStringLists([]string{"a", "b", "b", "a"}, []string{"a", "c"}) + assert.Len(t, cmp, 1) + assert.NotContains(t, cmp, "a") + assert.Contains(t, cmp, "b") + + cmp = CompareStringListsInBoth([]string{"a", "a", "b", "c"}, []string{"a", "c", "c"}) + assert.Len(t, cmp, 2) + assert.NotContains(t, cmp, "b") + assert.Contains(t, cmp, "a") + assert.Contains(t, cmp, "c") } diff --git a/updater.go b/updater.go index 2e3aa216..792e068b 100644 --- a/updater.go +++ b/updater.go @@ -15,6 +15,7 @@ package clair import ( + "fmt" "math/rand" "strconv" "sync" @@ -53,6 +54,9 @@ var ( Name: "clair_updater_notes_total", Help: "Number of notes that the vulnerability fetchers generated.", }) + + // EnabledUpdaters contains all updaters to be used for update. + EnabledUpdaters []string ) func init() { @@ -63,7 +67,13 @@ func init() { // UpdaterConfig is the configuration for the Updater service. type UpdaterConfig struct { - Interval time.Duration + EnabledUpdaters []string + Interval time.Duration +} + +type vulnerabilityChange struct { + old *database.VulnerabilityWithAffected + new *database.VulnerabilityWithAffected } // RunUpdater begins a process that updates the vulnerability database at @@ -72,7 +82,7 @@ func RunUpdater(config *UpdaterConfig, datastore database.Datastore, st *stopper defer st.End() // Do not run the updater if there is no config or if the interval is 0. - if config == nil || config.Interval == 0 { + if config == nil || config.Interval == 0 || len(config.EnabledUpdaters) == 0 { log.Info("updater service is disabled.") return } @@ -86,11 +96,11 @@ func RunUpdater(config *UpdaterConfig, datastore database.Datastore, st *stopper // Determine if this is the first update and define the next update time. // The next update time is (last update time + interval) or now if this is the first update. nextUpdate := time.Now().UTC() - lastUpdate, firstUpdate, err := getLastUpdate(datastore) + lastUpdate, firstUpdate, err := GetLastUpdateTime(datastore) if err != nil { - log.WithError(err).Error("an error occured while getting the last update time") + log.WithError(err).Error("an error occurred while getting the last update time") nextUpdate = nextUpdate.Add(config.Interval) - } else if firstUpdate == false { + } else if !firstUpdate { nextUpdate = lastUpdate.Add(config.Interval) } @@ -98,7 +108,7 @@ func RunUpdater(config *UpdaterConfig, datastore database.Datastore, st *stopper if nextUpdate.Before(time.Now().UTC()) { // Attempt to get a lock on the the update. log.Debug("attempting to obtain update lock") - hasLock, hasLockUntil := datastore.Lock(updaterLockName, whoAmI, updaterLockDuration, false) + hasLock, hasLockUntil := lock(datastore, updaterLockName, whoAmI, updaterLockDuration, false) if hasLock { // Launch update in a new go routine. doneC := make(chan bool, 1) @@ -113,14 +123,14 @@ func RunUpdater(config *UpdaterConfig, datastore database.Datastore, st *stopper done = true case <-time.After(updaterLockRefreshDuration): // Refresh the lock until the update is done. - datastore.Lock(updaterLockName, whoAmI, updaterLockDuration, true) + lock(datastore, updaterLockName, whoAmI, updaterLockDuration, true) case <-st.Chan(): stop = true } } - // Unlock the update. - datastore.Unlock(updaterLockName, whoAmI) + // Unlock the updater. + unlock(datastore, updaterLockName, whoAmI) if stop { break @@ -132,10 +142,9 @@ func RunUpdater(config *UpdaterConfig, datastore database.Datastore, st *stopper break } continue - } else { - lockOwner, lockExpiration, err := datastore.FindLock(updaterLockName) - if err != nil { + lockOwner, lockExpiration, ok, err := findLock(datastore, updaterLockName) + if !ok || err != nil { log.Debug("update lock is already taken") nextUpdate = hasLockUntil } else { @@ -174,40 +183,74 @@ func sleepUpdater(approxWakeup time.Time, st *stopper.Stopper) (stopped bool) { return false } -// update fetches all the vulnerabilities from the registered fetchers, upserts -// them into the database and then sends notifications. +// update fetches all the vulnerabilities from the registered fetchers, updates +// vulnerabilities, and updater flags, and logs notes from updaters. func update(datastore database.Datastore, firstUpdate bool) { defer setUpdaterDuration(time.Now()) log.Info("updating vulnerabilities") // Fetch updates. - status, vulnerabilities, flags, notes := fetch(datastore) + success, vulnerabilities, flags, notes := fetch(datastore) - // Insert vulnerabilities. - log.WithField("count", len(vulnerabilities)).Debug("inserting vulnerabilities for update") - err := datastore.InsertVulnerabilities(vulnerabilities, !firstUpdate) - if err != nil { - promUpdaterErrorsTotal.Inc() - log.WithError(err).Error("an error occured when inserting vulnerabilities for update") + // do vulnerability namespacing again to merge potentially duplicated + // vulnerabilities from each updater. + vulnerabilities = doVulnerabilitiesNamespacing(vulnerabilities) + + // deduplicate fetched namespaces and store them into database. + nsMap := map[database.Namespace]struct{}{} + for _, vuln := range vulnerabilities { + nsMap[vuln.Namespace] = struct{}{} + } + + namespaces := make([]database.Namespace, 0, len(nsMap)) + for ns := range nsMap { + namespaces = append(namespaces, ns) + } + + if err := persistNamespaces(datastore, namespaces); err != nil { + log.WithError(err).Error("Unable to insert namespaces") return } - vulnerabilities = nil - // Update flags. - for flagName, flagValue := range flags { - datastore.InsertKeyValue(flagName, flagValue) + changes, err := updateVulnerabilities(datastore, vulnerabilities) + + defer func() { + if err != nil { + promUpdaterErrorsTotal.Inc() + } + }() + + if err != nil { + log.WithError(err).Error("Unable to update vulnerabilities") + return + } + + if !firstUpdate { + err = createVulnerabilityNotifications(datastore, changes) + if err != nil { + log.WithError(err).Error("Unable to create notifications") + return + } + } + + err = updateUpdaterFlags(datastore, flags) + if err != nil { + log.WithError(err).Error("Unable to update updater flags") + return } - // Log notes. for _, note := range notes { log.WithField("note", note).Warning("fetcher note") } promUpdaterNotesTotal.Set(float64(len(notes))) - // Update last successful update if every fetchers worked properly. - if status { - datastore.InsertKeyValue(updaterLastFlagName, strconv.FormatInt(time.Now().UTC().Unix(), 10)) + if success { + err = setLastUpdateTime(datastore) + if err != nil { + log.WithError(err).Error("Unable to set last update time") + return + } } log.Info("update finished") @@ -218,8 +261,8 @@ func setUpdaterDuration(start time.Time) { } // fetch get data from the registered fetchers, in parallel. -func fetch(datastore database.Datastore) (bool, []database.Vulnerability, map[string]string, []string) { - var vulnerabilities []database.Vulnerability +func fetch(datastore database.Datastore) (bool, []database.VulnerabilityWithAffected, map[string]string, []string) { + var vulnerabilities []database.VulnerabilityWithAffected var notes []string status := true flags := make(map[string]string) @@ -227,12 +270,17 @@ func fetch(datastore database.Datastore) (bool, []database.Vulnerability, map[st // Fetch updates in parallel. log.Info("fetching vulnerability updates") var responseC = make(chan *vulnsrc.UpdateResponse, 0) + numUpdaters := 0 for n, u := range vulnsrc.Updaters() { + if !updaterEnabled(n) { + continue + } + numUpdaters++ go func(name string, u vulnsrc.Updater) { response, err := u.Update(datastore) if err != nil { promUpdaterErrorsTotal.Inc() - log.WithError(err).WithField("updater name", name).Error("an error occured when fetching update") + log.WithError(err).WithField("updater name", name).Error("an error occurred when fetching update") status = false responseC <- nil return @@ -244,7 +292,7 @@ func fetch(datastore database.Datastore) (bool, []database.Vulnerability, map[st } // Collect results of updates. - for i := 0; i < len(vulnsrc.Updaters()); i++ { + for i := 0; i < numUpdaters; i++ { resp := <-responseC if resp != nil { vulnerabilities = append(vulnerabilities, doVulnerabilitiesNamespacing(resp.Vulnerabilities)...) @@ -259,9 +307,10 @@ func fetch(datastore database.Datastore) (bool, []database.Vulnerability, map[st return status, addMetadata(datastore, vulnerabilities), flags, notes } -// Add metadata to the specified vulnerabilities using the registered MetadataFetchers, in parallel. -func addMetadata(datastore database.Datastore, vulnerabilities []database.Vulnerability) []database.Vulnerability { - if len(vulnmdsrc.Appenders()) == 0 { +// Add metadata to the specified vulnerabilities using the registered +// MetadataFetchers, in parallel. +func addMetadata(datastore database.Datastore, vulnerabilities []database.VulnerabilityWithAffected) []database.VulnerabilityWithAffected { + if len(vulnmdsrc.Appenders()) == 0 || len(vulnerabilities) == 0 { return vulnerabilities } @@ -272,7 +321,7 @@ func addMetadata(datastore database.Datastore, vulnerabilities []database.Vulner lockableVulnerabilities := make([]*lockableVulnerability, 0, len(vulnerabilities)) for i := 0; i < len(vulnerabilities); i++ { lockableVulnerabilities = append(lockableVulnerabilities, &lockableVulnerability{ - Vulnerability: &vulnerabilities[i], + VulnerabilityWithAffected: &vulnerabilities[i], }) } @@ -286,7 +335,7 @@ func addMetadata(datastore database.Datastore, vulnerabilities []database.Vulner // Build up a metadata cache. if err := appender.BuildCache(datastore); err != nil { promUpdaterErrorsTotal.Inc() - log.WithError(err).WithField("appender name", name).Error("an error occured when loading metadata fetcher") + log.WithError(err).WithField("appender name", name).Error("an error occurred when loading metadata fetcher") return } @@ -305,13 +354,21 @@ func addMetadata(datastore database.Datastore, vulnerabilities []database.Vulner return vulnerabilities } -func getLastUpdate(datastore database.Datastore) (time.Time, bool, error) { - lastUpdateTSS, err := datastore.GetKeyValue(updaterLastFlagName) +// GetLastUpdateTime retrieves the latest successful time of update and whether +// or not it's the first update. +func GetLastUpdateTime(datastore database.Datastore) (time.Time, bool, error) { + tx, err := datastore.Begin() + if err != nil { + return time.Time{}, false, err + } + defer tx.Rollback() + + lastUpdateTSS, ok, err := tx.FindKeyValue(updaterLastFlagName) if err != nil { return time.Time{}, false, err } - if lastUpdateTSS == "" { + if !ok { // This is the first update. return time.Time{}, true, nil } @@ -325,7 +382,7 @@ func getLastUpdate(datastore database.Datastore) (time.Time, bool, error) { } type lockableVulnerability struct { - *database.Vulnerability + *database.VulnerabilityWithAffected sync.Mutex } @@ -349,39 +406,293 @@ func (lv *lockableVulnerability) appendFunc(metadataKey string, metadata interfa // doVulnerabilitiesNamespacing takes Vulnerabilities that don't have a // Namespace and split them into multiple vulnerabilities that have a Namespace -// and only contains the FixedIn FeatureVersions corresponding to their +// and only contains the Affected Features corresponding to their // Namespace. // // It helps simplifying the fetchers that share the same metadata about a // Vulnerability regardless of their actual namespace (ie. same vulnerability // information for every version of a distro). -func doVulnerabilitiesNamespacing(vulnerabilities []database.Vulnerability) []database.Vulnerability { - vulnerabilitiesMap := make(map[string]*database.Vulnerability) +// +// It also validates the vulnerabilities fetched from updaters. If any +// vulnerability is mal-formated, the updater process will continue but will log +// warning. +func doVulnerabilitiesNamespacing(vulnerabilities []database.VulnerabilityWithAffected) []database.VulnerabilityWithAffected { + vulnerabilitiesMap := make(map[string]*database.VulnerabilityWithAffected) for _, v := range vulnerabilities { - featureVersions := v.FixedIn - v.FixedIn = []database.FeatureVersion{} + namespacedFeatures := v.Affected + v.Affected = []database.AffectedFeature{} - for _, fv := range featureVersions { - index := fv.Feature.Namespace.Name + ":" + v.Name + for _, fv := range namespacedFeatures { + // validate vulnerabilities, throw out the invalid vulnerabilities + if fv.AffectedVersion == "" || fv.FeatureName == "" || fv.Namespace.Name == "" || fv.Namespace.VersionFormat == "" { + log.WithFields(log.Fields{ + "Name": fv.FeatureName, + "Affected Version": fv.AffectedVersion, + "Namespace": fv.Namespace.Name + ":" + fv.Namespace.VersionFormat, + }).Warn("Mal-formated affected feature (skipped)") + continue + } + index := fv.Namespace.Name + ":" + v.Name if vulnerability, ok := vulnerabilitiesMap[index]; !ok { newVulnerability := v - newVulnerability.Namespace = fv.Feature.Namespace - newVulnerability.FixedIn = []database.FeatureVersion{fv} + newVulnerability.Namespace = fv.Namespace + newVulnerability.Affected = []database.AffectedFeature{fv} vulnerabilitiesMap[index] = &newVulnerability } else { - vulnerability.FixedIn = append(vulnerability.FixedIn, fv) + vulnerability.Affected = append(vulnerability.Affected, fv) } } } // Convert map into a slice. - var response []database.Vulnerability - for _, vulnerability := range vulnerabilitiesMap { - response = append(response, *vulnerability) + var response []database.VulnerabilityWithAffected + for _, v := range vulnerabilitiesMap { + // throw out invalid vulnerabilities. + if v.Name == "" || !v.Severity.Valid() || v.Namespace.Name == "" || v.Namespace.VersionFormat == "" { + log.WithFields(log.Fields{ + "Name": v.Name, + "Severity": v.Severity, + "Namespace": v.Namespace.Name + ":" + v.Namespace.VersionFormat, + }).Warning("Vulnerability is mal-formatted") + continue + } + response = append(response, *v) } return response } + +func findLock(datastore database.Datastore, updaterLockName string) (string, time.Time, bool, error) { + tx, err := datastore.Begin() + if err != nil { + log.WithError(err).Error() + } + defer tx.Rollback() + return tx.FindLock(updaterLockName) +} + +// updateUpdaterFlags updates the flags specified by updaters, every transaction +// is independent of each other. +func updateUpdaterFlags(datastore database.Datastore, flags map[string]string) error { + for key, value := range flags { + tx, err := datastore.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + err = tx.UpdateKeyValue(key, value) + if err != nil { + return err + } + if err = tx.Commit(); err != nil { + return err + } + } + return nil +} + +// setLastUpdateTime records the last successful date time in database. +func setLastUpdateTime(datastore database.Datastore) error { + tx, err := datastore.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + err = tx.UpdateKeyValue(updaterLastFlagName, strconv.FormatInt(time.Now().UTC().Unix(), 10)) + if err != nil { + return err + } + return tx.Commit() +} + +// isVulnerabilityChange compares two vulnerabilities by their severity and +// affected features, and return true if they are different. +func isVulnerabilityChanged(a *database.VulnerabilityWithAffected, b *database.VulnerabilityWithAffected) bool { + if a == b { + return false + } else if a != nil && b != nil && a.Severity == b.Severity && len(a.Affected) == len(b.Affected) { + checked := map[string]bool{} + for _, affected := range a.Affected { + checked[affected.Namespace.Name+":"+affected.FeatureName] = false + } + + for _, affected := range b.Affected { + key := affected.Namespace.Name + ":" + affected.FeatureName + if visited, ok := checked[key]; !ok || visited { + return true + } + checked[key] = true + } + return false + } + return true +} + +// findVulnerabilityChanges finds vulnerability changes from old +// vulnerabilities to new vulnerabilities. +// old and new vulnerabilities should be unique. +func findVulnerabilityChanges(old []database.VulnerabilityWithAffected, new []database.VulnerabilityWithAffected) ([]vulnerabilityChange, error) { + changes := map[database.VulnerabilityID]vulnerabilityChange{} + for i, vuln := range old { + key := database.VulnerabilityID{ + Name: vuln.Name, + Namespace: vuln.Namespace.Name, + } + + if _, ok := changes[key]; ok { + return nil, fmt.Errorf("duplicated old vulnerability") + } + changes[key] = vulnerabilityChange{old: &old[i]} + } + + for i, vuln := range new { + key := database.VulnerabilityID{ + Name: vuln.Name, + Namespace: vuln.Namespace.Name, + } + + if change, ok := changes[key]; ok { + if isVulnerabilityChanged(change.old, &vuln) { + change.new = &new[i] + changes[key] = change + } else { + delete(changes, key) + } + } else { + changes[key] = vulnerabilityChange{new: &new[i]} + } + } + + vulnChange := make([]vulnerabilityChange, 0, len(changes)) + for _, change := range changes { + vulnChange = append(vulnChange, change) + } + return vulnChange, nil +} + +// createVulnerabilityNotifications makes notifications out of vulnerability +// changes and insert them into database. +func createVulnerabilityNotifications(datastore database.Datastore, changes []vulnerabilityChange) error { + log.WithField("count", len(changes)).Debug("creating vulnerability notifications") + if len(changes) == 0 { + return nil + } + + tx, err := datastore.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + notifications := make([]database.VulnerabilityNotification, 0, len(changes)) + for _, change := range changes { + var oldVuln, newVuln *database.Vulnerability + if change.old != nil { + oldVuln = &change.old.Vulnerability + } + + if change.new != nil { + newVuln = &change.new.Vulnerability + } + + notifications = append(notifications, database.VulnerabilityNotification{ + NotificationHook: database.NotificationHook{ + Name: uuid.New(), + Created: time.Now(), + }, + Old: oldVuln, + New: newVuln, + }) + } + + if err := tx.InsertVulnerabilityNotifications(notifications); err != nil { + return err + } + + return tx.Commit() +} + +// updateVulnerabilities upserts unique vulnerabilities into the database and +// computes vulnerability changes. +func updateVulnerabilities(datastore database.Datastore, vulnerabilities []database.VulnerabilityWithAffected) ([]vulnerabilityChange, error) { + log.WithField("count", len(vulnerabilities)).Debug("updating vulnerabilities") + if len(vulnerabilities) == 0 { + return nil, nil + } + + ids := make([]database.VulnerabilityID, 0, len(vulnerabilities)) + for _, vuln := range vulnerabilities { + ids = append(ids, database.VulnerabilityID{ + Name: vuln.Name, + Namespace: vuln.Namespace.Name, + }) + } + + tx, err := datastore.Begin() + if err != nil { + return nil, err + } + + defer tx.Rollback() + oldVulnNullable, err := tx.FindVulnerabilities(ids) + if err != nil { + return nil, err + } + + oldVuln := []database.VulnerabilityWithAffected{} + for _, vuln := range oldVulnNullable { + if vuln.Valid { + oldVuln = append(oldVuln, vuln.VulnerabilityWithAffected) + } + } + + changes, err := findVulnerabilityChanges(oldVuln, vulnerabilities) + if err != nil { + return nil, err + } + + toRemove := []database.VulnerabilityID{} + toAdd := []database.VulnerabilityWithAffected{} + for _, change := range changes { + if change.old != nil { + toRemove = append(toRemove, database.VulnerabilityID{ + Name: change.old.Name, + Namespace: change.old.Namespace.Name, + }) + } + + if change.new != nil { + toAdd = append(toAdd, *change.new) + } + } + + log.WithField("count", len(toRemove)).Debug("marking vulnerabilities as outdated") + if err := tx.DeleteVulnerabilities(toRemove); err != nil { + return nil, err + } + + log.WithField("count", len(toAdd)).Debug("inserting new vulnerabilities") + if err := tx.InsertVulnerabilities(toAdd); err != nil { + return nil, err + } + + if err := tx.Commit(); err != nil { + return nil, err + } + + return changes, nil +} + +func updaterEnabled(updaterName string) bool { + for _, u := range EnabledUpdaters { + if u == updaterName { + return true + } + } + return false +} diff --git a/updater_test.go b/updater_test.go index 380ff277..bb2a8e60 100644 --- a/updater_test.go +++ b/updater_test.go @@ -15,6 +15,7 @@ package clair import ( + "errors" "fmt" "testing" @@ -23,49 +24,301 @@ import ( "github.com/coreos/clair/database" ) +type mockUpdaterDatastore struct { + database.MockDatastore + + namespaces map[string]database.Namespace + vulnerabilities map[database.VulnerabilityID]database.VulnerabilityWithAffected + vulnNotification map[string]database.VulnerabilityNotification + keyValues map[string]string +} + +type mockUpdaterSession struct { + database.MockSession + + store *mockUpdaterDatastore + copy mockUpdaterDatastore + terminated bool +} + +func copyUpdaterDatastore(md *mockUpdaterDatastore) mockUpdaterDatastore { + namespaces := map[string]database.Namespace{} + for k, n := range md.namespaces { + namespaces[k] = n + } + + vulnerabilities := map[database.VulnerabilityID]database.VulnerabilityWithAffected{} + for key, v := range md.vulnerabilities { + newV := v + affected := []database.AffectedFeature{} + for _, f := range v.Affected { + affected = append(affected, f) + } + newV.Affected = affected + vulnerabilities[key] = newV + } + + vulnNoti := map[string]database.VulnerabilityNotification{} + for key, v := range md.vulnNotification { + vulnNoti[key] = v + } + + kv := map[string]string{} + for key, value := range md.keyValues { + kv[key] = value + } + + return mockUpdaterDatastore{ + namespaces: namespaces, + vulnerabilities: vulnerabilities, + vulnNotification: vulnNoti, + keyValues: kv, + } +} + +func newmockUpdaterDatastore() *mockUpdaterDatastore { + errSessionDone := errors.New("Session Done") + md := &mockUpdaterDatastore{ + namespaces: make(map[string]database.Namespace), + vulnerabilities: make(map[database.VulnerabilityID]database.VulnerabilityWithAffected), + vulnNotification: make(map[string]database.VulnerabilityNotification), + keyValues: make(map[string]string), + } + + md.FctBegin = func() (database.Session, error) { + session := &mockUpdaterSession{ + store: md, + copy: copyUpdaterDatastore(md), + terminated: false, + } + + session.FctCommit = func() error { + if session.terminated { + return errSessionDone + } + session.store.namespaces = session.copy.namespaces + session.store.vulnerabilities = session.copy.vulnerabilities + session.store.vulnNotification = session.copy.vulnNotification + session.store.keyValues = session.copy.keyValues + session.terminated = true + return nil + } + + session.FctRollback = func() error { + if session.terminated { + return errSessionDone + } + session.terminated = true + session.copy = mockUpdaterDatastore{} + return nil + } + + session.FctPersistNamespaces = func(ns []database.Namespace) error { + if session.terminated { + return errSessionDone + } + for _, n := range ns { + _, ok := session.copy.namespaces[n.Name] + if !ok { + session.copy.namespaces[n.Name] = n + } + } + return nil + } + + session.FctFindVulnerabilities = func(ids []database.VulnerabilityID) ([]database.NullableVulnerability, error) { + r := []database.NullableVulnerability{} + for _, id := range ids { + vuln, ok := session.copy.vulnerabilities[id] + r = append(r, database.NullableVulnerability{ + VulnerabilityWithAffected: vuln, + Valid: ok, + }) + } + return r, nil + } + + session.FctDeleteVulnerabilities = func(ids []database.VulnerabilityID) error { + for _, id := range ids { + delete(session.copy.vulnerabilities, id) + } + return nil + } + + session.FctInsertVulnerabilities = func(vulnerabilities []database.VulnerabilityWithAffected) error { + for _, vuln := range vulnerabilities { + id := database.VulnerabilityID{ + Name: vuln.Name, + Namespace: vuln.Namespace.Name, + } + if _, ok := session.copy.vulnerabilities[id]; ok { + return errors.New("Vulnerability already exists") + } + session.copy.vulnerabilities[id] = vuln + } + return nil + } + + session.FctUpdateKeyValue = func(key, value string) error { + session.copy.keyValues[key] = value + return nil + } + + session.FctFindKeyValue = func(key string) (string, bool, error) { + s, b := session.copy.keyValues[key] + return s, b, nil + } + + session.FctInsertVulnerabilityNotifications = func(notifications []database.VulnerabilityNotification) error { + for _, noti := range notifications { + session.copy.vulnNotification[noti.Name] = noti + } + return nil + } + + return session, nil + } + return md +} + func TestDoVulnerabilitiesNamespacing(t *testing.T) { - fv1 := database.FeatureVersion{ - Feature: database.Feature{ - Namespace: database.Namespace{Name: "Namespace1"}, - Name: "Feature1", + fv1 := database.AffectedFeature{ + Namespace: database.Namespace{Name: "Namespace1"}, + FeatureName: "Feature1", + FixedInVersion: "0.1", + AffectedVersion: "0.1", + } + + fv2 := database.AffectedFeature{ + Namespace: database.Namespace{Name: "Namespace2"}, + FeatureName: "Feature1", + FixedInVersion: "0.2", + AffectedVersion: "0.2", + } + + fv3 := database.AffectedFeature{ + + Namespace: database.Namespace{Name: "Namespace2"}, + FeatureName: "Feature2", + FixedInVersion: "0.3", + AffectedVersion: "0.3", + } + + vulnerability := database.VulnerabilityWithAffected{ + Vulnerability: database.Vulnerability{ + Name: "DoVulnerabilityNamespacing", }, - Version: "0.1", + Affected: []database.AffectedFeature{fv1, fv2, fv3}, } - fv2 := database.FeatureVersion{ - Feature: database.Feature{ - Namespace: database.Namespace{Name: "Namespace2"}, - Name: "Feature1", - }, - Version: "0.2", - } - - fv3 := database.FeatureVersion{ - Feature: database.Feature{ - Namespace: database.Namespace{Name: "Namespace2"}, - Name: "Feature2", - }, - Version: "0.3", - } - - vulnerability := database.Vulnerability{ - Name: "DoVulnerabilityNamespacing", - FixedIn: []database.FeatureVersion{fv1, fv2, fv3}, - } - - vulnerabilities := doVulnerabilitiesNamespacing([]database.Vulnerability{vulnerability}) + vulnerabilities := doVulnerabilitiesNamespacing([]database.VulnerabilityWithAffected{vulnerability}) for _, vulnerability := range vulnerabilities { switch vulnerability.Namespace.Name { - case fv1.Feature.Namespace.Name: - assert.Len(t, vulnerability.FixedIn, 1) - assert.Contains(t, vulnerability.FixedIn, fv1) - case fv2.Feature.Namespace.Name: - assert.Len(t, vulnerability.FixedIn, 2) - assert.Contains(t, vulnerability.FixedIn, fv2) - assert.Contains(t, vulnerability.FixedIn, fv3) + case fv1.Namespace.Name: + assert.Len(t, vulnerability.Affected, 1) + assert.Contains(t, vulnerability.Affected, fv1) + case fv2.Namespace.Name: + assert.Len(t, vulnerability.Affected, 2) + assert.Contains(t, vulnerability.Affected, fv2) + assert.Contains(t, vulnerability.Affected, fv3) default: t.Errorf("Should not have a Vulnerability with '%s' as its Namespace.", vulnerability.Namespace.Name) fmt.Printf("%#v\n", vulnerability) } } } + +func TestCreatVulnerabilityNotification(t *testing.T) { + vf1 := "VersionFormat1" + ns1 := database.Namespace{ + Name: "namespace 1", + VersionFormat: vf1, + } + af1 := database.AffectedFeature{ + Namespace: ns1, + FeatureName: "feature 1", + } + + v1 := database.VulnerabilityWithAffected{ + Vulnerability: database.Vulnerability{ + Name: "vulnerability 1", + Namespace: ns1, + Severity: database.UnknownSeverity, + }, + } + + // severity change + v2 := database.VulnerabilityWithAffected{ + Vulnerability: database.Vulnerability{ + Name: "vulnerability 1", + Namespace: ns1, + Severity: database.LowSeverity, + }, + } + + // affected versions change + v3 := database.VulnerabilityWithAffected{ + Vulnerability: database.Vulnerability{ + Name: "vulnerability 1", + Namespace: ns1, + Severity: database.UnknownSeverity, + }, + Affected: []database.AffectedFeature{af1}, + } + + datastore := newmockUpdaterDatastore() + change, err := updateVulnerabilities(datastore, []database.VulnerabilityWithAffected{}) + assert.Nil(t, err) + assert.Len(t, change, 0) + + change, err = updateVulnerabilities(datastore, []database.VulnerabilityWithAffected{v1}) + assert.Nil(t, err) + assert.Len(t, change, 1) + assert.Nil(t, change[0].old) + assertVulnerability(t, *change[0].new, v1) + + change, err = updateVulnerabilities(datastore, []database.VulnerabilityWithAffected{v1}) + assert.Nil(t, err) + assert.Len(t, change, 0) + + change, err = updateVulnerabilities(datastore, []database.VulnerabilityWithAffected{v2}) + assert.Nil(t, err) + assert.Len(t, change, 1) + assertVulnerability(t, *change[0].new, v2) + assertVulnerability(t, *change[0].old, v1) + + change, err = updateVulnerabilities(datastore, []database.VulnerabilityWithAffected{v3}) + assert.Nil(t, err) + assert.Len(t, change, 1) + assertVulnerability(t, *change[0].new, v3) + assertVulnerability(t, *change[0].old, v2) + + err = createVulnerabilityNotifications(datastore, change) + assert.Nil(t, err) + assert.Len(t, datastore.vulnNotification, 1) + for _, noti := range datastore.vulnNotification { + assert.Equal(t, *noti.New, v3.Vulnerability) + assert.Equal(t, *noti.Old, v2.Vulnerability) + } +} + +func assertVulnerability(t *testing.T, expected database.VulnerabilityWithAffected, actual database.VulnerabilityWithAffected) bool { + expectedAF := expected.Affected + actualAF := actual.Affected + expected.Affected, actual.Affected = nil, nil + + assert.Equal(t, expected, actual) + assert.Len(t, actualAF, len(expectedAF)) + + mapAF := map[database.AffectedFeature]bool{} + for _, af := range expectedAF { + mapAF[af] = false + } + + for _, af := range actualAF { + if visited, ok := mapAF[af]; !ok || visited { + return false + } + } + return true +} diff --git a/worker.go b/worker.go index d407253f..a9c82762 100644 --- a/worker.go +++ b/worker.go @@ -15,7 +15,9 @@ package clair import ( + "errors" "regexp" + "sync" log "github.com/sirupsen/logrus" @@ -24,13 +26,10 @@ import ( "github.com/coreos/clair/ext/featurens" "github.com/coreos/clair/ext/imagefmt" "github.com/coreos/clair/pkg/commonerr" - "github.com/coreos/clair/pkg/tarutil" + "github.com/coreos/clair/pkg/strutil" ) const ( - // Version (integer) represents the worker version. - // Increased each time the engine changes. - Version = 3 logLayerName = "layer" ) @@ -44,177 +43,525 @@ var ( ErrParentUnknown = commonerr.NewBadRequestError("worker: parent layer is unknown, it must be processed first") urlParametersRegexp = regexp.MustCompile(`(\?|\&)([^=]+)\=([^ &]+)`) + + // Processors contain the names of namespace detectors and feature listers + // enabled in this instance of Clair. + // + // Processors are initialized during booting and configured in the + // configuration file. + Processors database.Processors ) +type WorkerConfig struct { + EnabledDetectors []string `yaml:"namespace_detectors"` + EnabledListers []string `yaml:"feature_listers"` +} + +// LayerRequest represents all information necessary to download and process a +// layer. +type LayerRequest struct { + Hash string + Path string + Headers map[string]string +} + +// partialLayer stores layer's content detected by `processedBy` processors. +type partialLayer struct { + hash string + processedBy database.Processors + namespaces []database.Namespace + features []database.Feature + + err error +} + +// processRequest stores parameters used for processing layers. +type processRequest struct { + request LayerRequest + // notProcessedBy represents a set of processors used to process the + // request. + notProcessedBy database.Processors +} + // cleanURL removes all parameters from an URL. func cleanURL(str string) string { return urlParametersRegexp.ReplaceAllString(str, "") } -// ProcessLayer detects the Namespace of a layer, the features it adds/removes, -// and then stores everything in the database. -// -// TODO(Quentin-M): We could have a goroutine that looks for layers that have -// been analyzed with an older engine version and that processes them. -func ProcessLayer(datastore database.Datastore, imageFormat, name, parentName, path string, headers map[string]string) error { - // Verify parameters. - if name == "" { - return commonerr.NewBadRequestError("could not process a layer which does not have a name") +// processLayers in parallel processes a set of requests for unique set of layers +// and returns sets of unique namespaces, features and layers to be inserted +// into the database. +func processRequests(imageFormat string, toDetect []processRequest) ([]database.Namespace, []database.Feature, map[string]partialLayer, error) { + wg := &sync.WaitGroup{} + wg.Add(len(toDetect)) + results := make([]partialLayer, len(toDetect)) + for i := range toDetect { + go func(req *processRequest, res *partialLayer) { + res.hash = req.request.Hash + res.processedBy = req.notProcessedBy + res.namespaces, res.features, res.err = detectContent(imageFormat, req.request.Hash, req.request.Path, req.request.Headers, req.notProcessedBy) + wg.Done() + }(&toDetect[i], &results[i]) + } + wg.Wait() + distinctNS := map[database.Namespace]struct{}{} + distinctF := map[database.Feature]struct{}{} + + errs := []error{} + for _, r := range results { + errs = append(errs, r.err) } - if path == "" { - return commonerr.NewBadRequestError("could not process a layer which does not have a path") + if err := commonerr.CombineErrors(errs...); err != nil { + return nil, nil, nil, err + } + + updates := map[string]partialLayer{} + for _, r := range results { + for _, ns := range r.namespaces { + distinctNS[ns] = struct{}{} + } + + for _, f := range r.features { + distinctF[f] = struct{}{} + } + + if _, ok := updates[r.hash]; !ok { + updates[r.hash] = r + } else { + return nil, nil, nil, errors.New("Duplicated updates is not allowed") + } + } + + namespaces := make([]database.Namespace, 0, len(distinctNS)) + features := make([]database.Feature, 0, len(distinctF)) + + for ns := range distinctNS { + namespaces = append(namespaces, ns) + } + + for f := range distinctF { + features = append(features, f) + } + return namespaces, features, updates, nil +} + +func getLayer(datastore database.Datastore, req LayerRequest) (layer database.LayerWithContent, preq *processRequest, err error) { + var ok bool + tx, err := datastore.Begin() + if err != nil { + return + } + defer tx.Rollback() + + layer, ok, err = tx.FindLayerWithContent(req.Hash) + if err != nil { + return + } + + if !ok { + l := database.Layer{Hash: req.Hash} + err = tx.PersistLayer(l) + if err != nil { + return + } + + if err = tx.Commit(); err != nil { + return + } + + layer = database.LayerWithContent{Layer: l} + preq = &processRequest{ + request: req, + notProcessedBy: Processors, + } + } else { + notProcessed := getNotProcessedBy(layer.ProcessedBy) + if !(len(notProcessed.Detectors) == 0 && len(notProcessed.Listers) == 0 && ok) { + preq = &processRequest{ + request: req, + notProcessedBy: notProcessed, + } + } + } + return +} + +// processLayers processes a set of post layer requests, stores layers and +// returns an ordered list of processed layers with detected features and +// namespaces. +func processLayers(datastore database.Datastore, imageFormat string, requests []LayerRequest) ([]database.LayerWithContent, error) { + toDetect := []processRequest{} + layers := map[string]database.LayerWithContent{} + for _, req := range requests { + if _, ok := layers[req.Hash]; ok { + continue + } + layer, preq, err := getLayer(datastore, req) + if err != nil { + return nil, err + } + layers[req.Hash] = layer + if preq != nil { + toDetect = append(toDetect, *preq) + } + } + + namespaces, features, partialRes, err := processRequests(imageFormat, toDetect) + if err != nil { + return nil, err + } + + // Store partial results. + if err := persistNamespaces(datastore, namespaces); err != nil { + return nil, err + } + + if err := persistFeatures(datastore, features); err != nil { + return nil, err + } + + for _, res := range partialRes { + if err := persistPartialLayer(datastore, res); err != nil { + return nil, err + } + } + + // NOTE(Sida): The full layers are computed using partially + // processed layers in current database session. If any other instances of + // Clair are changing some layers in this set of layers, it might generate + // different results especially when the other Clair is with different + // processors. + completeLayers := []database.LayerWithContent{} + for _, req := range requests { + if partialLayer, ok := partialRes[req.Hash]; ok { + completeLayers = append(completeLayers, combineLayers(layers[req.Hash], partialLayer)) + } else { + completeLayers = append(completeLayers, layers[req.Hash]) + } + } + + return completeLayers, nil +} + +func persistPartialLayer(datastore database.Datastore, layer partialLayer) error { + tx, err := datastore.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + if err := tx.PersistLayerContent(layer.hash, layer.namespaces, layer.features, layer.processedBy); err != nil { + return err + } + return tx.Commit() +} + +func persistFeatures(datastore database.Datastore, features []database.Feature) error { + tx, err := datastore.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + if err := tx.PersistFeatures(features); err != nil { + return err + } + return tx.Commit() +} + +func persistNamespaces(datastore database.Datastore, namespaces []database.Namespace) error { + tx, err := datastore.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + if err := tx.PersistNamespaces(namespaces); err != nil { + return err + } + + return tx.Commit() +} + +// combineLayers merges `layer` and `partial` without duplicated content. +func combineLayers(layer database.LayerWithContent, partial partialLayer) database.LayerWithContent { + mapF := map[database.Feature]struct{}{} + mapNS := map[database.Namespace]struct{}{} + for _, f := range layer.Features { + mapF[f] = struct{}{} + } + for _, ns := range layer.Namespaces { + mapNS[ns] = struct{}{} + } + for _, f := range partial.features { + mapF[f] = struct{}{} + } + for _, ns := range partial.namespaces { + mapNS[ns] = struct{}{} + } + features := make([]database.Feature, 0, len(mapF)) + namespaces := make([]database.Namespace, 0, len(mapNS)) + for f := range mapF { + features = append(features, f) + } + for ns := range mapNS { + namespaces = append(namespaces, ns) + } + + layer.ProcessedBy.Detectors = append(layer.ProcessedBy.Detectors, strutil.CompareStringLists(partial.processedBy.Detectors, layer.ProcessedBy.Detectors)...) + layer.ProcessedBy.Listers = append(layer.ProcessedBy.Listers, strutil.CompareStringLists(partial.processedBy.Listers, layer.ProcessedBy.Listers)...) + return database.LayerWithContent{ + Layer: database.Layer{ + Hash: layer.Hash, + }, + ProcessedBy: layer.ProcessedBy, + Features: features, + Namespaces: namespaces, + } +} + +func isAncestryProcessed(datastore database.Datastore, name string) (bool, error) { + tx, err := datastore.Begin() + if err != nil { + return false, err + } + defer tx.Rollback() + _, processed, ok, err := tx.FindAncestry(name) + if err != nil { + return false, err + } + if !ok { + return false, nil + } + + notProcessed := getNotProcessedBy(processed) + return len(notProcessed.Detectors) == 0 && len(notProcessed.Listers) == 0, nil +} + +// ProcessAncestry downloads and scans an ancestry if it's not scanned by all +// enabled processors in this instance of Clair. +func ProcessAncestry(datastore database.Datastore, imageFormat, name string, layerRequest []LayerRequest) error { + var err error + if name == "" { + return commonerr.NewBadRequestError("could not process a layer which does not have a name") } if imageFormat == "" { return commonerr.NewBadRequestError("could not process a layer which does not have a format") } - log.WithFields(log.Fields{logLayerName: name, "path": cleanURL(path), "engine version": Version, "parent layer": parentName, "format": imageFormat}).Debug("processing layer") - - // Check to see if the layer is already in the database. - layer, err := datastore.FindLayer(name, false, false) - if err != nil && err != commonerr.ErrNotFound { + if ok, err := isAncestryProcessed(datastore, name); ok && err == nil { + log.WithField("ancestry", name).Debug("Ancestry is processed") + return nil + } else if err != nil { return err } - if err == commonerr.ErrNotFound { - // New layer case. - layer = database.Layer{Name: name, EngineVersion: Version} - - // Retrieve the parent if it has one. - // We need to get it with its Features in order to diff them. - if parentName != "" { - parent, err := datastore.FindLayer(parentName, true, false) - if err != nil && err != commonerr.ErrNotFound { - return err - } - if err == commonerr.ErrNotFound { - log.WithFields(log.Fields{logLayerName: name, "parent layer": parentName}).Warning("the parent layer is unknown. it must be processed first") - return ErrParentUnknown - } - layer.Parent = &parent - } - } else { - // The layer is already in the database, check if we need to update it. - if layer.EngineVersion >= Version { - log.WithFields(log.Fields{logLayerName: name, "past engine version": layer.EngineVersion, "current engine version": Version}).Debug("layer content has already been processed in the past with older engine. skipping analysis") - return nil - } - log.WithFields(log.Fields{logLayerName: name, "past engine version": layer.EngineVersion, "current engine version": Version}).Debug("layer content has already been processed in the past with older engine. analyzing again") - } - - // Analyze the content. - layer.Namespaces, layer.Features, err = detectContent(imageFormat, name, path, headers, layer.Parent) + layers, err := processLayers(datastore, imageFormat, layerRequest) if err != nil { return err } - return datastore.InsertLayer(layer) + if !validateProcessors(layers) { + // This error might be triggered because of multiple workers are + // processing the same instance with different processors. + return errors.New("ancestry layers are scanned with different listers and detectors") + } + + return processAncestry(datastore, name, layers) } -// detectContent downloads a layer's archive and extracts its Namespace and -// Features. -func detectContent(imageFormat, name, path string, headers map[string]string, parent *database.Layer) (namespaces []database.Namespace, featureVersions []database.FeatureVersion, err error) { - totalRequiredFiles := append(featurefmt.RequiredFilenames(), featurens.RequiredFilenames()...) +func processAncestry(datastore database.Datastore, name string, layers []database.LayerWithContent) error { + ancestryFeatures, err := computeAncestryFeatures(layers) + if err != nil { + return err + } + + ancestryLayers := make([]database.Layer, 0, len(layers)) + for _, layer := range layers { + ancestryLayers = append(ancestryLayers, layer.Layer) + } + + log.WithFields(log.Fields{ + "ancestry": name, + "number of features": len(ancestryFeatures), + "processed by": Processors, + "number of layers": len(ancestryLayers), + }).Debug("compute ancestry features") + + if err := persistNamespacedFeatures(datastore, ancestryFeatures); err != nil { + return err + } + + tx, err := datastore.Begin() + if err != nil { + return err + } + + err = tx.UpsertAncestry(database.Ancestry{Name: name, Layers: ancestryLayers}, ancestryFeatures, Processors) + if err != nil { + tx.Rollback() + return err + } + + err = tx.Commit() + if err != nil { + return err + } + return nil +} + +func persistNamespacedFeatures(datastore database.Datastore, features []database.NamespacedFeature) error { + tx, err := datastore.Begin() + if err != nil { + return err + } + + if err := tx.PersistNamespacedFeatures(features); err != nil { + tx.Rollback() + return err + } + + if err := tx.Commit(); err != nil { + return err + } + + tx, err = datastore.Begin() + if err != nil { + return err + } + + if err := tx.CacheAffectedNamespacedFeatures(features); err != nil { + tx.Rollback() + return err + } + + return tx.Commit() +} + +// validateProcessors checks if the layers processed by same set of processors. +func validateProcessors(layers []database.LayerWithContent) bool { + if len(layers) == 0 { + return true + } + detectors := layers[0].ProcessedBy.Detectors + listers := layers[0].ProcessedBy.Listers + + for _, l := range layers[1:] { + if len(strutil.CompareStringLists(detectors, l.ProcessedBy.Detectors)) != 0 || + len(strutil.CompareStringLists(listers, l.ProcessedBy.Listers)) != 0 { + return false + } + } + return true +} + +// computeAncestryFeatures computes the features in an ancestry based on all +// layers. +func computeAncestryFeatures(ancestryLayers []database.LayerWithContent) ([]database.NamespacedFeature, error) { + // version format -> namespace + namespaces := map[string]database.Namespace{} + // version format -> feature ID -> feature + features := map[string]map[string]database.NamespacedFeature{} + for _, layer := range ancestryLayers { + // At start of the loop, namespaces and features always contain the + // previous layer's result. + for _, ns := range layer.Namespaces { + namespaces[ns.VersionFormat] = ns + } + + // version format -> feature ID -> feature + currentFeatures := map[string]map[string]database.NamespacedFeature{} + for _, f := range layer.Features { + if ns, ok := namespaces[f.VersionFormat]; ok { + var currentMap map[string]database.NamespacedFeature + if currentMap, ok = currentFeatures[f.VersionFormat]; !ok { + currentFeatures[f.VersionFormat] = make(map[string]database.NamespacedFeature) + currentMap = currentFeatures[f.VersionFormat] + } + + inherited := false + if mapF, ok := features[f.VersionFormat]; ok { + if parentFeature, ok := mapF[f.Name+":"+f.Version]; ok { + currentMap[f.Name+":"+f.Version] = parentFeature + inherited = true + } + } + + if !inherited { + currentMap[f.Name+":"+f.Version] = database.NamespacedFeature{ + Feature: f, + Namespace: ns, + } + } + + } else { + return nil, errors.New("No corresponding version format") + } + } + + // NOTE(Sida): we update the feature map in some version format + // only if there's at least one feature with that version format. This + // approach won't differentiate feature file removed vs all detectable + // features removed from that file vs feature file not changed. + // + // One way to differentiate (feature file removed or not changed) vs + // all detectable features removed is to pass in the file status. + for vf, mapF := range currentFeatures { + features[vf] = mapF + } + } + + ancestryFeatures := []database.NamespacedFeature{} + for _, featureMap := range features { + for _, feature := range featureMap { + ancestryFeatures = append(ancestryFeatures, feature) + } + } + return ancestryFeatures, nil +} + +// getNotProcessedBy returns a processors, which contains the detectors and +// listers not in `processedBy` but implemented in the current clair instance. +func getNotProcessedBy(processedBy database.Processors) database.Processors { + notProcessedLister := strutil.CompareStringLists(Processors.Listers, processedBy.Listers) + notProcessedDetector := strutil.CompareStringLists(Processors.Detectors, processedBy.Detectors) + return database.Processors{ + Listers: notProcessedLister, + Detectors: notProcessedDetector, + } +} + +// detectContent downloads a layer and detects all features and namespaces. +func detectContent(imageFormat, name, path string, headers map[string]string, toProcess database.Processors) (namespaces []database.Namespace, featureVersions []database.Feature, err error) { + log.WithFields(log.Fields{"Hash": name}).Debug("Process Layer") + totalRequiredFiles := append(featurefmt.RequiredFilenames(toProcess.Listers), featurens.RequiredFilenames(toProcess.Detectors)...) files, err := imagefmt.Extract(imageFormat, path, headers, totalRequiredFiles) if err != nil { - log.WithError(err).WithFields(log.Fields{logLayerName: name, "path": cleanURL(path)}).Error("failed to extract data from path") + log.WithError(err).WithFields(log.Fields{ + logLayerName: name, + "path": cleanURL(path), + }).Error("failed to extract data from path") return } - namespaces, err = detectNamespaces(name, files, parent) - if err != nil { - return - } - - featureVersions, err = detectFeatureVersions(name, files, namespaces, parent) + namespaces, err = featurens.Detect(files, toProcess.Detectors) if err != nil { return } if len(featureVersions) > 0 { - log.WithFields(log.Fields{logLayerName: name, "feature count": len(featureVersions)}).Debug("detected features") + log.WithFields(log.Fields{logLayerName: name, "count": len(namespaces)}).Debug("detected layer namespaces") } - return -} - -// detectNamespaces returns a list of unique namespaces detected in a layer and its ancestry. -func detectNamespaces(name string, files tarutil.FilesMap, parent *database.Layer) (namespaces []database.Namespace, err error) { - nsSet := map[string]*database.Namespace{} - nsCurrent, err := featurens.Detect(files) + featureVersions, err = featurefmt.ListFeatures(files, toProcess.Listers) if err != nil { return } - if parent != nil { - for _, ns := range parent.Namespaces { - // Under assumption that one version format corresponds to one type - // of namespace. - nsSet[ns.VersionFormat] = &ns - log.WithFields(log.Fields{logLayerName: name, "detected namespace": ns.Name, "version format": ns.VersionFormat}).Debug("detected namespace (from parent)") - } - } - - for _, ns := range nsCurrent { - nsSet[ns.VersionFormat] = &ns - log.WithFields(log.Fields{logLayerName: name, "detected namespace": ns.Name, "version format": ns.VersionFormat}).Debug("detected namespace") - } - - for _, ns := range nsSet { - namespaces = append(namespaces, *ns) - } - return -} - -func detectFeatureVersions(name string, files tarutil.FilesMap, namespaces []database.Namespace, parent *database.Layer) (features []database.FeatureVersion, err error) { - // Build a map of the namespaces for each FeatureVersion in our parent layer. - parentFeatureNamespaces := make(map[string]database.Namespace) - if parent != nil { - for _, parentFeature := range parent.Features { - parentFeatureNamespaces[parentFeature.Feature.Name+":"+parentFeature.Version] = parentFeature.Feature.Namespace - } - } - - for _, ns := range namespaces { - // TODO(Quentin-M): We need to pass the parent image to DetectFeatures because it's possible that - // some detectors would need it in order to produce the entire feature list (if they can only - // detect a diff). Also, we should probably pass the detected namespace so detectors could - // make their own decision. - detectedFeatures, err := featurefmt.ListFeatures(files, &ns) - if err != nil { - return features, err - } - - // Ensure that each FeatureVersion has an associated Namespace. - for i, feature := range detectedFeatures { - if feature.Feature.Namespace.Name != "" { - // There is a Namespace associated. - continue - } - - if parentFeatureNamespace, ok := parentFeatureNamespaces[feature.Feature.Name+":"+feature.Version]; ok { - // The FeatureVersion is present in the parent layer; associate - // with their Namespace. - // This might cause problem because a package with same feature - // name and version could be different in parent layer's - // namespace and current layer's namespace - detectedFeatures[i].Feature.Namespace = parentFeatureNamespace - continue - } - - detectedFeatures[i].Feature.Namespace = ns - } - features = append(features, detectedFeatures...) - } - - // If there are no FeatureVersions, use parent's FeatureVersions if possible. - // TODO(Quentin-M): We eventually want to give the choice to each detectors to use none/some of - // their parent's FeatureVersions. It would be useful for detectors that can't find their entire - // result using one Layer. - if len(features) == 0 && parent != nil { - features = parent.Features + if len(featureVersions) > 0 { + log.WithFields(log.Fields{logLayerName: name, "count": len(featureVersions)}).Debug("detected layer features") } return diff --git a/worker_test.go b/worker_test.go index 950c7689..5f6e0ff4 100644 --- a/worker_test.go +++ b/worker_test.go @@ -15,18 +15,23 @@ package clair import ( + "errors" "path/filepath" "runtime" + "strings" "testing" "github.com/stretchr/testify/assert" "github.com/coreos/clair/database" + "github.com/coreos/clair/ext/featurefmt" + "github.com/coreos/clair/ext/featurens" "github.com/coreos/clair/ext/versionfmt/dpkg" - "github.com/coreos/clair/pkg/commonerr" + "github.com/coreos/clair/pkg/strutil" // Register the required detectors. _ "github.com/coreos/clair/ext/featurefmt/dpkg" + _ "github.com/coreos/clair/ext/featurefmt/rpm" _ "github.com/coreos/clair/ext/featurens/aptsources" _ "github.com/coreos/clair/ext/featurens/osrelease" _ "github.com/coreos/clair/ext/imagefmt/docker" @@ -34,42 +39,306 @@ import ( type mockDatastore struct { database.MockDatastore - layers map[string]database.Layer + + layers map[string]database.LayerWithContent + ancestry map[string]database.AncestryWithFeatures + namespaces map[string]database.Namespace + features map[string]database.Feature + namespacedFeatures map[string]database.NamespacedFeature +} + +type mockSession struct { + database.MockSession + + store *mockDatastore + copy mockDatastore + terminated bool +} + +func copyDatastore(md *mockDatastore) mockDatastore { + layers := map[string]database.LayerWithContent{} + for k, l := range md.layers { + features := append([]database.Feature(nil), l.Features...) + namespaces := append([]database.Namespace(nil), l.Namespaces...) + listers := append([]string(nil), l.ProcessedBy.Listers...) + detectors := append([]string(nil), l.ProcessedBy.Detectors...) + layers[k] = database.LayerWithContent{ + Layer: database.Layer{ + Hash: l.Hash, + }, + ProcessedBy: database.Processors{ + Listers: listers, + Detectors: detectors, + }, + Features: features, + Namespaces: namespaces, + } + } + + ancestry := map[string]database.AncestryWithFeatures{} + for k, a := range md.ancestry { + nf := append([]database.NamespacedFeature(nil), a.Features...) + l := append([]database.Layer(nil), a.Layers...) + listers := append([]string(nil), a.ProcessedBy.Listers...) + detectors := append([]string(nil), a.ProcessedBy.Detectors...) + ancestry[k] = database.AncestryWithFeatures{ + Ancestry: database.Ancestry{ + Name: a.Name, + Layers: l, + }, + ProcessedBy: database.Processors{ + Detectors: detectors, + Listers: listers, + }, + Features: nf, + } + } + + namespaces := map[string]database.Namespace{} + for k, n := range md.namespaces { + namespaces[k] = n + } + + features := map[string]database.Feature{} + for k, f := range md.features { + features[k] = f + } + + namespacedFeatures := map[string]database.NamespacedFeature{} + for k, f := range md.namespacedFeatures { + namespacedFeatures[k] = f + } + return mockDatastore{ + layers: layers, + ancestry: ancestry, + namespaces: namespaces, + namespacedFeatures: namespacedFeatures, + features: features, + } } func newMockDatastore() *mockDatastore { - return &mockDatastore{ - layers: make(map[string]database.Layer), + errSessionDone := errors.New("Session Done") + md := &mockDatastore{ + layers: make(map[string]database.LayerWithContent), + ancestry: make(map[string]database.AncestryWithFeatures), + namespaces: make(map[string]database.Namespace), + features: make(map[string]database.Feature), + namespacedFeatures: make(map[string]database.NamespacedFeature), } + + md.FctBegin = func() (database.Session, error) { + session := &mockSession{ + store: md, + copy: copyDatastore(md), + terminated: false, + } + + session.FctCommit = func() error { + if session.terminated { + return nil + } + session.store.layers = session.copy.layers + session.store.ancestry = session.copy.ancestry + session.store.namespaces = session.copy.namespaces + session.store.features = session.copy.features + session.store.namespacedFeatures = session.copy.namespacedFeatures + session.terminated = true + return nil + } + + session.FctRollback = func() error { + if session.terminated { + return nil + } + session.terminated = true + session.copy = mockDatastore{} + return nil + } + + session.FctFindAncestry = func(name string) (database.Ancestry, database.Processors, bool, error) { + processors := database.Processors{} + if session.terminated { + return database.Ancestry{}, processors, false, errSessionDone + } + ancestry, ok := session.copy.ancestry[name] + return ancestry.Ancestry, ancestry.ProcessedBy, ok, nil + } + + session.FctFindLayer = func(name string) (database.Layer, database.Processors, bool, error) { + processors := database.Processors{} + if session.terminated { + return database.Layer{}, processors, false, errSessionDone + } + layer, ok := session.copy.layers[name] + return layer.Layer, layer.ProcessedBy, ok, nil + } + + session.FctFindLayerWithContent = func(name string) (database.LayerWithContent, bool, error) { + if session.terminated { + return database.LayerWithContent{}, false, errSessionDone + } + layer, ok := session.copy.layers[name] + return layer, ok, nil + } + + session.FctPersistLayer = func(layer database.Layer) error { + if session.terminated { + return errSessionDone + } + if _, ok := session.copy.layers[layer.Hash]; !ok { + session.copy.layers[layer.Hash] = database.LayerWithContent{Layer: layer} + } + return nil + } + + session.FctPersistNamespaces = func(ns []database.Namespace) error { + if session.terminated { + return errSessionDone + } + for _, n := range ns { + _, ok := session.copy.namespaces[n.Name] + if !ok { + session.copy.namespaces[n.Name] = n + } + } + return nil + } + + session.FctPersistFeatures = func(fs []database.Feature) error { + if session.terminated { + return errSessionDone + } + for _, f := range fs { + key := FeatureKey(&f) + _, ok := session.copy.features[key] + if !ok { + session.copy.features[key] = f + } + } + return nil + } + + session.FctPersistLayerContent = func(hash string, namespaces []database.Namespace, features []database.Feature, processedBy database.Processors) error { + if session.terminated { + return errSessionDone + } + + // update the layer + layer, ok := session.copy.layers[hash] + if !ok { + return errors.New("layer not found") + } + + layerFeatures := map[string]database.Feature{} + layerNamespaces := map[string]database.Namespace{} + for _, f := range layer.Features { + layerFeatures[FeatureKey(&f)] = f + } + for _, n := range layer.Namespaces { + layerNamespaces[n.Name] = n + } + + // ensure that all the namespaces, features are in the database + for _, ns := range namespaces { + if _, ok := session.copy.namespaces[ns.Name]; !ok { + return errors.New("Namespaces should be in the database") + } + if _, ok := layerNamespaces[ns.Name]; !ok { + layer.Namespaces = append(layer.Namespaces, ns) + layerNamespaces[ns.Name] = ns + } + } + + for _, f := range features { + if _, ok := session.copy.features[FeatureKey(&f)]; !ok { + return errors.New("Namespaces should be in the database") + } + if _, ok := layerFeatures[FeatureKey(&f)]; !ok { + layer.Features = append(layer.Features, f) + layerFeatures[FeatureKey(&f)] = f + } + } + + layer.ProcessedBy.Detectors = append(layer.ProcessedBy.Detectors, strutil.CompareStringLists(processedBy.Detectors, layer.ProcessedBy.Detectors)...) + layer.ProcessedBy.Listers = append(layer.ProcessedBy.Listers, strutil.CompareStringLists(processedBy.Listers, layer.ProcessedBy.Listers)...) + + session.copy.layers[hash] = layer + return nil + } + + session.FctUpsertAncestry = func(ancestry database.Ancestry, features []database.NamespacedFeature, processors database.Processors) error { + if session.terminated { + return errSessionDone + } + + // ensure features are in the database + for _, f := range features { + if _, ok := session.copy.namespacedFeatures[NamespacedFeatureKey(&f)]; !ok { + return errors.New("namepsaced feature not in db") + } + } + + ancestryWFeature := database.AncestryWithFeatures{ + Ancestry: ancestry, + Features: features, + ProcessedBy: processors, + } + + session.copy.ancestry[ancestry.Name] = ancestryWFeature + return nil + } + + session.FctPersistNamespacedFeatures = func(namespacedFeatures []database.NamespacedFeature) error { + for i, f := range namespacedFeatures { + session.copy.namespacedFeatures[NamespacedFeatureKey(&f)] = namespacedFeatures[i] + } + return nil + } + + session.FctCacheAffectedNamespacedFeatures = func(namespacedFeatures []database.NamespacedFeature) error { + // The function does nothing because we don't care about the vulnerability cache in worker_test. + return nil + } + + return session, nil + } + return md } -func TestProcessWithDistUpgrade(t *testing.T) { - _, f, _, _ := runtime.Caller(0) - testDataPath := filepath.Join(filepath.Dir(f)) + "/testdata/DistUpgrade/" - - // Create a mock datastore. - datastore := newMockDatastore() - datastore.FctInsertLayer = func(layer database.Layer) error { - datastore.layers[layer.Name] = layer - return nil +func TestMain(m *testing.M) { + Processors = database.Processors{ + Listers: featurefmt.ListListers(), + Detectors: featurens.ListDetectors(), } - datastore.FctFindLayer = func(name string, withFeatures, withVulnerabilities bool) (database.Layer, error) { - if layer, exists := datastore.layers[name]; exists { - return layer, nil - } - return database.Layer{}, commonerr.ErrNotFound + m.Run() +} + +func FeatureKey(f *database.Feature) string { + return strings.Join([]string{f.Name, f.VersionFormat, f.Version}, "__") +} + +func NamespacedFeatureKey(f *database.NamespacedFeature) string { + return strings.Join([]string{f.Name, f.Namespace.Name}, "__") +} + +func TestProcessAncestryWithDistUpgrade(t *testing.T) { + // Create the list of Features that should not been upgraded from one layer to another. + nonUpgradedFeatures := []database.Feature{ + {Name: "libtext-wrapi18n-perl", Version: "0.06-7"}, + {Name: "libtext-charwidth-perl", Version: "0.04-7"}, + {Name: "libtext-iconv-perl", Version: "1.7-5"}, + {Name: "mawk", Version: "1.3.3-17"}, + {Name: "insserv", Version: "1.14.0-5"}, + {Name: "db", Version: "5.1.29-5"}, + {Name: "ustr", Version: "1.0.4-3"}, + {Name: "xz-utils", Version: "5.1.1alpha+20120614-2"}, } - // Create the list of FeatureVersions that should not been upgraded from one layer to another. - nonUpgradedFeatureVersions := []database.FeatureVersion{ - {Feature: database.Feature{Name: "libtext-wrapi18n-perl"}, Version: "0.06-7"}, - {Feature: database.Feature{Name: "libtext-charwidth-perl"}, Version: "0.04-7"}, - {Feature: database.Feature{Name: "libtext-iconv-perl"}, Version: "1.7-5"}, - {Feature: database.Feature{Name: "mawk"}, Version: "1.3.3-17"}, - {Feature: database.Feature{Name: "insserv"}, Version: "1.14.0-5"}, - {Feature: database.Feature{Name: "db"}, Version: "5.1.29-5"}, - {Feature: database.Feature{Name: "ustr"}, Version: "1.0.4-3"}, - {Feature: database.Feature{Name: "xz-utils"}, Version: "5.1.1alpha+20120614-2"}, + nonUpgradedMap := map[database.Feature]struct{}{} + for _, f := range nonUpgradedFeatures { + f.VersionFormat = "dpkg" + nonUpgradedMap[f] = struct{}{} } // Process test layers. @@ -78,42 +347,294 @@ func TestProcessWithDistUpgrade(t *testing.T) { // wheezy.tar: FROM debian:wheezy // jessie.tar: RUN sed -i "s/precise/trusty/" /etc/apt/sources.list && apt-get update && // apt-get -y dist-upgrade - assert.Nil(t, ProcessLayer(datastore, "Docker", "blank", "", testDataPath+"blank.tar.gz", nil)) - assert.Nil(t, ProcessLayer(datastore, "Docker", "wheezy", "blank", testDataPath+"wheezy.tar.gz", nil)) - assert.Nil(t, ProcessLayer(datastore, "Docker", "jessie", "wheezy", testDataPath+"jessie.tar.gz", nil)) + _, f, _, _ := runtime.Caller(0) + testDataPath := filepath.Join(filepath.Dir(f)) + "/testdata/DistUpgrade/" - // Ensure that the 'wheezy' layer has the expected namespace and features. - wheezy, ok := datastore.layers["wheezy"] - if assert.True(t, ok, "layer 'wheezy' not processed") { - if !assert.Len(t, wheezy.Namespaces, 1) { - return - } - assert.Equal(t, "debian:7", wheezy.Namespaces[0].Name) - assert.Len(t, wheezy.Features, 52) + datastore := newMockDatastore() - for _, nufv := range nonUpgradedFeatureVersions { - nufv.Feature.Namespace.Name = "debian:7" - nufv.Feature.Namespace.VersionFormat = dpkg.ParserName - assert.Contains(t, wheezy.Features, nufv) + layers := []LayerRequest{ + {Hash: "blank", Path: testDataPath + "blank.tar.gz"}, + {Hash: "wheezy", Path: testDataPath + "wheezy.tar.gz"}, + {Hash: "jessie", Path: testDataPath + "jessie.tar.gz"}, + } + + assert.Nil(t, ProcessAncestry(datastore, "Docker", "Mock", layers)) + // check the ancestry features + assert.Len(t, datastore.ancestry["Mock"].Features, 74) + for _, f := range datastore.ancestry["Mock"].Features { + if _, ok := nonUpgradedMap[f.Feature]; ok { + assert.Equal(t, "debian:7", f.Namespace.Name) + } else { + assert.Equal(t, "debian:8", f.Namespace.Name) } } - // Ensure that the 'wheezy' layer has the expected namespace and non-upgraded features. - jessie, ok := datastore.layers["jessie"] - if assert.True(t, ok, "layer 'jessie' not processed") { - assert.Len(t, jessie.Namespaces, 1) - assert.Equal(t, "debian:8", jessie.Namespaces[0].Name) - assert.Len(t, jessie.Features, 74) + assert.Equal(t, []database.Layer{ + {Hash: "blank"}, + {Hash: "wheezy"}, + {Hash: "jessie"}, + }, datastore.ancestry["Mock"].Layers) +} - for _, nufv := range nonUpgradedFeatureVersions { - nufv.Feature.Namespace.Name = "debian:7" - nufv.Feature.Namespace.VersionFormat = dpkg.ParserName - assert.Contains(t, jessie.Features, nufv) +func TestProcessLayers(t *testing.T) { + _, f, _, _ := runtime.Caller(0) + testDataPath := filepath.Join(filepath.Dir(f)) + "/testdata/DistUpgrade/" + + datastore := newMockDatastore() + + layers := []LayerRequest{ + {Hash: "blank", Path: testDataPath + "blank.tar.gz"}, + {Hash: "wheezy", Path: testDataPath + "wheezy.tar.gz"}, + {Hash: "jessie", Path: testDataPath + "jessie.tar.gz"}, + } + + processedLayers, err := processLayers(datastore, "Docker", layers) + assert.Nil(t, err) + assert.Len(t, processedLayers, 3) + // ensure resubmit won't break the stuff + processedLayers, err = processLayers(datastore, "Docker", layers) + assert.Nil(t, err) + assert.Len(t, processedLayers, 3) + // Ensure each processed layer is correct + assert.Len(t, processedLayers[0].Namespaces, 0) + assert.Len(t, processedLayers[1].Namespaces, 1) + assert.Len(t, processedLayers[2].Namespaces, 1) + assert.Len(t, processedLayers[0].Features, 0) + assert.Len(t, processedLayers[1].Features, 52) + assert.Len(t, processedLayers[2].Features, 74) + + // Ensure each layer has expected namespaces and features detected + if blank, ok := datastore.layers["blank"]; ok { + assert.Equal(t, blank.ProcessedBy.Detectors, Processors.Detectors) + assert.Equal(t, blank.ProcessedBy.Listers, Processors.Listers) + assert.Len(t, blank.Namespaces, 0) + assert.Len(t, blank.Features, 0) + } else { + assert.Fail(t, "blank is not stored") + return + } + + if wheezy, ok := datastore.layers["wheezy"]; ok { + assert.Equal(t, wheezy.ProcessedBy.Detectors, Processors.Detectors) + assert.Equal(t, wheezy.ProcessedBy.Listers, Processors.Listers) + assert.Equal(t, wheezy.Namespaces, []database.Namespace{{Name: "debian:7", VersionFormat: dpkg.ParserName}}) + assert.Len(t, wheezy.Features, 52) + } else { + assert.Fail(t, "wheezy is not stored") + return + } + + if jessie, ok := datastore.layers["jessie"]; ok { + assert.Equal(t, jessie.ProcessedBy.Detectors, Processors.Detectors) + assert.Equal(t, jessie.ProcessedBy.Listers, Processors.Listers) + assert.Equal(t, jessie.Namespaces, []database.Namespace{{Name: "debian:8", VersionFormat: dpkg.ParserName}}) + assert.Len(t, jessie.Features, 74) + } else { + assert.Fail(t, "jessie is not stored") + return + } +} + +// TestUpgradeClair checks if a clair is upgraded and certain ancestry's +// features should not change. We assume that Clair should only upgrade +func TestClairUpgrade(t *testing.T) { + _, f, _, _ := runtime.Caller(0) + testDataPath := filepath.Join(filepath.Dir(f)) + "/testdata/DistUpgrade/" + + datastore := newMockDatastore() + + // suppose there are two ancestries. + layers := []LayerRequest{ + {Hash: "blank", Path: testDataPath + "blank.tar.gz"}, + {Hash: "wheezy", Path: testDataPath + "wheezy.tar.gz"}, + {Hash: "jessie", Path: testDataPath + "jessie.tar.gz"}, + } + + layers2 := []LayerRequest{ + {Hash: "blank", Path: testDataPath + "blank.tar.gz"}, + {Hash: "wheezy", Path: testDataPath + "wheezy.tar.gz"}, + } + + // Suppose user scan an ancestry with an old instance of Clair. + Processors = database.Processors{ + Detectors: []string{"os-release"}, + Listers: []string{"rpm"}, + } + + assert.Nil(t, ProcessAncestry(datastore, "Docker", "Mock", layers)) + assert.Len(t, datastore.ancestry["Mock"].Features, 0) + + assert.Nil(t, ProcessAncestry(datastore, "Docker", "Mock2", layers2)) + assert.Len(t, datastore.ancestry["Mock2"].Features, 0) + + // Clair is upgraded to use a new namespace detector. The expected + // behavior is that all layers will be rescanned with "apt-sources" and + // the ancestry's features are recalculated. + Processors = database.Processors{ + Detectors: []string{"os-release", "apt-sources"}, + Listers: []string{"rpm"}, + } + + // Even though Clair processors are upgraded, the ancestry's features should + // not be upgraded without posting the ancestry to Clair again. + assert.Nil(t, ProcessAncestry(datastore, "Docker", "Mock", layers)) + assert.Len(t, datastore.ancestry["Mock"].Features, 0) + + // Clair is upgraded to use a new feature lister. The expected behavior is + // that all layers will be rescanned with "dpkg" and the ancestry's features + // are invalidated and recalculated. + Processors = database.Processors{ + Detectors: []string{"os-release", "apt-sources"}, + Listers: []string{"rpm", "dpkg"}, + } + + assert.Nil(t, ProcessAncestry(datastore, "Docker", "Mock", layers)) + assert.Len(t, datastore.ancestry["Mock"].Features, 74) + assert.Nil(t, ProcessAncestry(datastore, "Docker", "Mock2", layers2)) + assert.Len(t, datastore.ancestry["Mock2"].Features, 52) + + // check the namespaces are correct + for _, f := range datastore.ancestry["Mock"].Features { + if !assert.NotEqual(t, database.Namespace{}, f.Namespace) { + assert.Fail(t, "Every feature should have a namespace attached") } - for _, nufv := range nonUpgradedFeatureVersions { - nufv.Feature.Namespace.Name = "debian:8" - nufv.Feature.Namespace.VersionFormat = dpkg.ParserName - assert.NotContains(t, jessie.Features, nufv) + } + + for _, f := range datastore.ancestry["Mock2"].Features { + if !assert.NotEqual(t, database.Namespace{}, f.Namespace) { + assert.Fail(t, "Every feature should have a namespace attached") } } } + +// TestMultipleNamespaces tests computing ancestry features +func TestComputeAncestryFeatures(t *testing.T) { + vf1 := "format 1" + vf2 := "format 2" + + ns1a := database.Namespace{ + Name: "namespace 1:a", + VersionFormat: vf1, + } + + ns1b := database.Namespace{ + Name: "namespace 1:b", + VersionFormat: vf1, + } + + ns2a := database.Namespace{ + Name: "namespace 2:a", + VersionFormat: vf2, + } + + ns2b := database.Namespace{ + Name: "namespace 2:b", + VersionFormat: vf2, + } + + f1 := database.Feature{ + Name: "feature 1", + Version: "0.1", + VersionFormat: vf1, + } + + f2 := database.Feature{ + Name: "feature 2", + Version: "0.2", + VersionFormat: vf1, + } + + f3 := database.Feature{ + Name: "feature 1", + Version: "0.3", + VersionFormat: vf2, + } + + f4 := database.Feature{ + Name: "feature 2", + Version: "0.3", + VersionFormat: vf2, + } + + // Suppose Clair is watching two files for namespaces one containing ns1 + // changes e.g. os-release and the other one containing ns2 changes e.g. + // node. + blank := database.LayerWithContent{Layer: database.Layer{Hash: "blank"}} + initNS1a := database.LayerWithContent{ + Layer: database.Layer{Hash: "init ns1a"}, + Namespaces: []database.Namespace{ns1a}, + Features: []database.Feature{f1, f2}, + } + + upgradeNS2b := database.LayerWithContent{ + Layer: database.Layer{Hash: "upgrade ns2b"}, + Namespaces: []database.Namespace{ns2b}, + } + + upgradeNS1b := database.LayerWithContent{ + Layer: database.Layer{Hash: "upgrade ns1b"}, + Namespaces: []database.Namespace{ns1b}, + Features: []database.Feature{f1, f2}, + } + + initNS2a := database.LayerWithContent{ + Layer: database.Layer{Hash: "init ns2a"}, + Namespaces: []database.Namespace{ns2a}, + Features: []database.Feature{f3, f4}, + } + + removeF2 := database.LayerWithContent{ + Layer: database.Layer{Hash: "remove f2"}, + Features: []database.Feature{f1}, + } + + // blank -> ns1:a, f1 f2 (init) + // -> f1 (feature change) + // -> ns2:a, f3, f4 (init ns2a) + // -> ns2:b (ns2 upgrade without changing features) + // -> blank (empty) + // -> ns1:b, f1 f2 (ns1 upgrade and add f2) + // -> f1 (remove f2) + // -> blank (empty) + + layers := []database.LayerWithContent{ + blank, + initNS1a, + removeF2, + initNS2a, + upgradeNS2b, + blank, + upgradeNS1b, + removeF2, + blank, + } + + expected := map[database.NamespacedFeature]bool{ + { + Feature: f1, + Namespace: ns1a, + }: false, + { + Feature: f3, + Namespace: ns2a, + }: false, + { + Feature: f4, + Namespace: ns2a, + }: false, + } + + features, err := computeAncestryFeatures(layers) + assert.Nil(t, err) + for _, f := range features { + if assert.Contains(t, expected, f) { + if assert.False(t, expected[f]) { + expected[f] = true + } + } + } + + for f, visited := range expected { + assert.True(t, visited, "expected feature is missing : "+f.Namespace.Name+":"+f.Name) + } +}