From 20a126c84ae8fad5f9a9ee3bc042866407d48308 Mon Sep 17 00:00:00 2001 From: Quentin Machu Date: Mon, 23 Nov 2015 23:43:33 -0500 Subject: [PATCH 1/2] notifier: Refactor and add client certificate authentification support. Fixes #23 --- main.go | 29 +++-- notifier/notifier.go | 252 +++++++++++++++++++++++-------------------- utils/http/http.go | 55 ++++++++++ 3 files changed, 205 insertions(+), 131 deletions(-) create mode 100644 utils/http/http.go diff --git a/main.go b/main.go index cfaccbd3..c6ace87d 100644 --- a/main.go +++ b/main.go @@ -45,8 +45,10 @@ var ( cfgDbPath = kingpin.Flag("db-path", "Path to the database to use").String() // Notifier configuration - cfgNotifierType = kingpin.Flag("notifier-type", "Type of the notifier to use").Default("none").Enum("none", "http") - cfgNotifierHTTPURL = kingpin.Flag("notifier-http-url", "URL that will receive POST notifications").String() + cfgNotifierEndpoint = kingpin.Flag("notifier-endpoint", "URL that will receive POST notifications").String() + cfgNotifierCertFile = kingpin.Flag("notifier-cert-file", "Path to TLS Cert file").ExistingFile() + cfgNotifierKeyFile = kingpin.Flag("notifier-key-file", "Path to TLS Key file").ExistingFile() + cfgNotifierCAFile = kingpin.Flag("notifier-ca-file", "Path to CA for verifying TLS client certs").ExistingFile() // Updater configuration cfgUpdateInterval = kingpin.Flag("update-interval", "Frequency at which the vulnerability updater will run. Use 0 to disable the updater entirely.").Default("1h").Duration() @@ -75,10 +77,6 @@ func main() { kingpin.Errorf("required flag --db-path not provided, try --help") os.Exit(1) } - if *cfgNotifierType == "http" && *cfgNotifierHTTPURL == "" { - kingpin.Errorf("required flag --notifier-http-url not provided, try --help") - os.Exit(1) - } // Initialize error/logging system logLevel, err := capnslog.ParseLevel(strings.ToUpper(*cfgLogLevel)) @@ -110,17 +108,16 @@ func main() { defer database.Close() // Start notifier - var notifierService notifier.Notifier - switch *cfgNotifierType { - case "http": - notifierService, err = notifier.NewHTTPNotifier(*cfgNotifierHTTPURL) - if err != nil { - log.Fatalf("could not initialize HTTP notifier: %s", err) - } - } - if notifierService != nil { + if len(*cfgNotifierEndpoint) > 0 { + notifier := notifier.New(notifier.Config{ + Endpoint: *cfgNotifierEndpoint, + CertFile: *cfgNotifierCertFile, + KeyFile: *cfgNotifierKeyFile, + CAFile: *cfgNotifierCAFile, + }) + st.Begin() - go notifierService.Run(st) + go notifier.Serve(st) } // Start Main API and Health API diff --git a/notifier/notifier.go b/notifier/notifier.go index 6793098d..dc4995ea 100644 --- a/notifier/notifier.go +++ b/notifier/notifier.go @@ -24,150 +24,172 @@ import ( "time" "github.com/coreos/pkg/capnslog" - "github.com/coreos/pkg/timeutil" + "github.com/pborman/uuid" + "github.com/coreos/clair/database" - cerrors "github.com/coreos/clair/utils/errors" "github.com/coreos/clair/health" "github.com/coreos/clair/utils" - "github.com/pborman/uuid" + httputils "github.com/coreos/clair/utils/http" ) -// A Notifier dispatches notifications -type Notifier interface { - Run(*utils.Stopper) -} - var log = capnslog.NewPackageLogger("github.com/coreos/clair", "notifier") const ( - maxBackOff = 5 * time.Minute - checkInterval = 5 * time.Second + checkInterval = 5 * time.Minute - refreshLockAnticipation = time.Minute * 2 - lockDuration = time.Minute*8 + refreshLockAnticipation + refreshLockDuration = time.Minute * 2 + lockDuration = time.Minute*8 + refreshLockDuration ) -// A HTTPNotifier dispatches notifications to an HTTP endpoint with POST requests -type HTTPNotifier struct { - url string +// A Notification represents the structure of the notifications that are sent by a Notifier. +type Notification struct { + Name, Type string + Content interface{} } -// NewHTTPNotifier initializes a new HTTPNotifier -func NewHTTPNotifier(URL string) (*HTTPNotifier, error) { - if _, err := url.Parse(URL); err != nil { - return nil, cerrors.NewBadRequestError("could not create a notifier with an invalid URL") +// A Notifier dispatches notifications to an HTTP endpoint. +type Notifier struct { + lockIdentifier string + endpoint string + client *http.Client +} + +// Config represents the configuration of a Notifier. +// The certificates are optionnals and enables client certificate authentification. +type Config struct { + Endpoint string + CertFile, KeyFile, CAFile string +} + +// New initializes a new Notifier from the specified configuration. +func New(cfg Config) *Notifier { + if _, err := url.Parse(cfg.Endpoint); err != nil { + log.Fatal("could not create a notifier with an invalid endpoint URL") } - notifier := &HTTPNotifier{url: URL} - health.RegisterHealthchecker("notifier", notifier.Healthcheck) + // Initialize TLS + tlsConfig, err := httputils.LoadTLSClientConfig(cfg.CertFile, cfg.KeyFile, cfg.CAFile) + if err != nil { + log.Fatalf("could not initialize client cert authentification: %s\n", err) + } + if tlsConfig != nil { + log.Info("notifier configured with client certificate authentification") + } - return notifier, nil + httpClient := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: tlsConfig, + }, + } + + return &Notifier{ + lockIdentifier: uuid.New(), + endpoint: cfg.Endpoint, + client: httpClient, + } } -// Run pops notifications from the database, lock them, send them, mark them as -// send and unlock them -// -// It uses an exponential backoff when POST requests fail -func (notifier *HTTPNotifier) Run(st *utils.Stopper) { - defer st.End() +// Serve starts the Notifier. +func (n *Notifier) Serve(stopper *utils.Stopper) { + defer stopper.End() + health.RegisterHealthchecker("notifier", n.Healthcheck) - whoAmI := uuid.New() - log.Infof("HTTP notifier started. URL: %s. Lock Identifier: %s", notifier.url, whoAmI) + log.Infof("notifier service started. endpoint: %s. lock identifier: %s\n", n.endpoint, n.lockIdentifier) for { - node, notification, err := database.FindOneNotificationToSend(database.GetDefaultNotificationWrapper()) - if notification == nil || err != nil { - if err != nil { - log.Warningf("could not get notification to send: %s.", err) - } - - if !st.Sleep(checkInterval) { - break - } - - continue - } - - // Try to lock the notification - hasLock, hasLockUntil := database.Lock(node, lockDuration, whoAmI) - if !hasLock { - continue - } - - for backOff := time.Duration(0); ; backOff = timeutil.ExpBackoff(backOff, maxBackOff) { - // Backoff, it happens when an error occurs during the communication - // with the notification endpoint - if backOff > 0 { - // Renew lock before going to sleep if necessary - if time.Now().Add(backOff).After(hasLockUntil.Add(-refreshLockAnticipation)) { - hasLock, hasLockUntil = database.Lock(node, lockDuration, whoAmI) - if !hasLock { - log.Warning("lost lock ownership, aborting") - break - } - } - - // Sleep - if !st.Sleep(backOff) { - return - } - } - - // Get notification content - content, err := notification.GetContent() - if err != nil { - log.Warningf("could not get content of notification '%s': %s", notification.GetName(), err.Error()) - break - } - - // Marshal the notification content - jsonContent, err := json.Marshal(struct { - Name, Type string - Content interface{} - }{ - Name: notification.GetName(), - Type: notification.GetType(), - Content: content, - }) - if err != nil { - log.Errorf("could not marshal content of notification '%s': %s", notification.GetName(), err.Error()) - break - } - - // Post notification - req, _ := http.NewRequest("POST", notifier.url, bytes.NewBuffer(jsonContent)) - req.Header.Set("Content-Type", "application/json") - - client := &http.Client{} - res, err := client.Do(req) - if err != nil { - log.Warningf("could not post notification '%s': %s", notification.GetName(), err.Error()) - continue - } - res.Body.Close() - - if res.StatusCode != 200 && res.StatusCode != 201 { - log.Warningf("could not post notification '%s': got status code %d", notification.GetName(), res.StatusCode) - continue - } - - // Mark the notification as sent - database.MarkNotificationAsSent(node) - - log.Infof("sent notification '%s' successfully", notification.GetName()) + // Find task. + // TODO(Quentin-M): Combine node and notification. + node, notification := n.findTask(stopper) + if node == "" && notification == nil { break } - if hasLock { - database.Unlock(node, whoAmI) + // Handle task. + done := make(chan bool, 1) + go func() { + if n.handleTask(node, notification) { + database.MarkNotificationAsSent(node) + } + database.Unlock(node, n.lockIdentifier) + done <- true + }() + + // Refresh task lock until done. + outer: + for { + select { + case <-done: + break outer + case <-time.After(refreshLockDuration): + database.Lock(node, lockDuration, n.lockIdentifier) + } } } - log.Info("HTTP notifier stopped") + log.Info("notifier service stopped") } -// Healthcheck returns the health of the notifier service -func (notifier *HTTPNotifier) Healthcheck() health.Status { +func (n *Notifier) findTask(stopper *utils.Stopper) (string, database.Notification) { + for { + // Find a notification to send. + node, notification, err := database.FindOneNotificationToSend(database.GetDefaultNotificationWrapper()) + if err != nil { + log.Warningf("could not get notification to send: %s", err) + } + + // No notification or error: wait. + if notification == nil || err != nil { + if !stopper.Sleep(checkInterval) { + return "", nil + } + continue + } + + // Lock the notification. + if hasLock, _ := database.Lock(node, lockDuration, n.lockIdentifier); hasLock { + log.Infof("found and locked a notification: %s", notification.GetName()) + return node, notification + } + } +} + +func (n *Notifier) handleTask(node string, notification database.Notification) bool { + // Get notification content. + // TODO(Quentin-M): Split big notifications. + notificationContent, err := notification.GetContent() + if err != nil { + log.Warningf("could not get content of notification '%s': %s", notification.GetName(), err) + return false + } + + // Create notification. + payload := Notification{ + Name: notification.GetName(), + Type: notification.GetType(), + Content: notificationContent, + } + + // Marshal notification. + jsonPayload, err := json.Marshal(payload) + if err != nil { + log.Errorf("could not marshal content of notification '%s': %s", notification.GetName(), err) + return false + } + + // Send notification. + resp, err := n.client.Post(n.endpoint, "application/json", bytes.NewBuffer(jsonPayload)) + defer resp.Body.Close() + if err != nil || (resp.StatusCode != 200 && resp.StatusCode != 201) { + log.Errorf("could not send notification '%s': (%d) %s", notification.GetName(), resp.StatusCode, err) + return false + } + + log.Infof("successfully sent notification '%s'\n", notification.GetName()) + return true +} + +// Healthcheck returns the health of the notifier service. +func (n *Notifier) Healthcheck() health.Status { queueSize, err := database.CountNotificationsToSend() return health.Status{IsEssential: false, IsHealthy: err == nil, Details: struct{ QueueSize int }{QueueSize: queueSize}} } diff --git a/utils/http/http.go b/utils/http/http.go new file mode 100644 index 00000000..6250d733 --- /dev/null +++ b/utils/http/http.go @@ -0,0 +1,55 @@ +// Copyright 2015 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package http provides utility functions for HTTP servers and clients. +package http + +import ( + "crypto/tls" + "crypto/x509" + "io/ioutil" +) + +// LoadTLSClientConfig initializes a *tls.Config using the given certificates and private key, that +// can be used to communicate with a server using client certificate authentificate. +// +// If no certificates are given, a nil *tls.Config is returned. +// The CA certificate is optionnal, the system defaults are used if not provided. +func LoadTLSClientConfig(certFile, keyFile, caFile string) (*tls.Config, error) { + if len(certFile) == 0 || len(keyFile) == 0 { + return nil, nil + } + + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return nil, err + } + + var caCertPool *x509.CertPool + if len(caFile) > 0 { + caCert, err := ioutil.ReadFile(caFile) + if err != nil { + return nil, err + } + caCertPool = x509.NewCertPool() + caCertPool.AppendCertsFromPEM(caCert) + } + + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{cert}, + RootCAs: caCertPool, + } + + return tlsConfig, nil +} From 9946382223431179b1133786bc6debfa1e288fee Mon Sep 17 00:00:00 2001 From: Quentin Machu Date: Tue, 24 Nov 2015 00:18:11 -0500 Subject: [PATCH 2/2] api: Extracted client cert & HTTP JSON Render to utils. --- api/api.go | 40 ++++++------------ api/jsonhttp/json.go | 78 ---------------------------------- api/logic/general.go | 11 ++--- api/logic/layers.go | 81 +++++++++++++++++------------------ api/logic/vulnerabilities.go | 67 ++++++++++++++--------------- api/wrappers/timeout.go | 7 ++- notifier/notifier.go | 2 +- utils/http/http.go | 82 ++++++++++++++++++++++++++++++++++++ 8 files changed, 179 insertions(+), 189 deletions(-) delete mode 100644 api/jsonhttp/json.go diff --git a/api/api.go b/api/api.go index deb18d67..c69825d1 100644 --- a/api/api.go +++ b/api/api.go @@ -17,18 +17,16 @@ package api import ( - "io/ioutil" "net" "net/http" "strconv" "time" - "crypto/tls" - "crypto/x509" - "github.com/coreos/pkg/capnslog" - "github.com/coreos/clair/utils" "github.com/tylerb/graceful" + + "github.com/coreos/clair/utils" + httputils "github.com/coreos/clair/utils/http" ) var log = capnslog.NewPackageLogger("github.com/coreos/clair", "api") @@ -49,12 +47,20 @@ func RunMain(conf *Config, st *utils.Stopper) { st.End() }() + tlsConfig, err := httputils.LoadTLSClientConfigForServer(conf.CAFile) + if err != nil { + log.Fatalf("could not initialize client cert authentification: %s\n", err) + } + if tlsConfig != nil { + log.Info("api configured with client certificate authentification") + } + srv := &graceful.Server{ Timeout: 0, // Already handled by our TimeOut middleware NoSignalHandling: true, // We want to use our own Stopper Server: &http.Server{ Addr: ":" + strconv.Itoa(conf.Port), - TLSConfig: setupClientCert(conf.CAFile), + TLSConfig: tlsConfig, Handler: NewVersionRouter(conf.TimeOut), }, } @@ -102,25 +108,3 @@ func listenAndServeWithStopper(srv *graceful.Server, st *utils.Stopper, certFile log.Fatal(err) } } - -// setupClientCert creates a tls.Config instance using a CA file path -// (if provided) and and calls log.Fatal if it does not exist. -func setupClientCert(caFile string) *tls.Config { - if len(caFile) > 0 { - log.Info("API: Client Certificate Authentification Enabled") - caCert, err := ioutil.ReadFile(caFile) - if err != nil { - log.Fatal(err) - } - caCertPool := x509.NewCertPool() - caCertPool.AppendCertsFromPEM(caCert) - return &tls.Config{ - ClientCAs: caCertPool, - ClientAuth: tls.RequireAndVerifyClientCert, - } - } - - return &tls.Config{ - ClientAuth: tls.NoClientCert, - } -} diff --git a/api/jsonhttp/json.go b/api/jsonhttp/json.go deleted file mode 100644 index b0e39d3f..00000000 --- a/api/jsonhttp/json.go +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright 2015 clair authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package jsonhttp provides helper functions to write JSON responses to -// http.ResponseWriter and read JSON bodies from http.Request. -package jsonhttp - -import ( - "encoding/json" - "io" - "net/http" - - "github.com/coreos/clair/database" - cerrors "github.com/coreos/clair/utils/errors" - "github.com/coreos/clair/worker" -) - -// MaxPostSize is the maximum number of bytes that ParseBody reads from an -// http.Request.Body. -var MaxPostSize int64 = 1048576 - -// Render writes a JSON-encoded object to a http.ResponseWriter, as well as -// a HTTP status code. -func Render(w http.ResponseWriter, httpStatus int, v interface{}) { - w.WriteHeader(httpStatus) - if v != nil { - w.Header().Set("Content-Type", "application/json; charset=utf-8") - result, _ := json.Marshal(v) - w.Write(result) - } -} - -// RenderError writes an error, wrapped in the Message field of a JSON-encoded -// object to a http.ResponseWriter, as well as a HTTP status code. -// If the status code is 0, RenderError tries to guess the proper HTTP status -// code from the error type. -func RenderError(w http.ResponseWriter, httpStatus int, err error) { - if httpStatus == 0 { - httpStatus = http.StatusInternalServerError - // Try to guess the http status code from the error type - if _, isBadRequestError := err.(*cerrors.ErrBadRequest); isBadRequestError { - httpStatus = http.StatusBadRequest - } else { - switch err { - case cerrors.ErrNotFound: - httpStatus = http.StatusNotFound - case database.ErrTransaction, database.ErrBackendException: - httpStatus = http.StatusServiceUnavailable - case worker.ErrParentUnknown, worker.ErrUnsupported: - httpStatus = http.StatusBadRequest - } - } - } - - Render(w, httpStatus, struct{ Message string }{Message: err.Error()}) -} - -// ParseBody reads a JSON-encoded body from a http.Request and unmarshals it -// into the provided object. -func ParseBody(r *http.Request, v interface{}) (int, error) { - defer r.Body.Close() - err := json.NewDecoder(io.LimitReader(r.Body, MaxPostSize)).Decode(v) - if err != nil { - return http.StatusUnsupportedMediaType, err - } - return 0, nil -} diff --git a/api/logic/general.go b/api/logic/general.go index 3a040c92..50fa6e7e 100644 --- a/api/logic/general.go +++ b/api/logic/general.go @@ -20,10 +20,11 @@ import ( "net/http" "strconv" - "github.com/coreos/clair/api/jsonhttp" - "github.com/coreos/clair/health" - "github.com/coreos/clair/worker" "github.com/julienschmidt/httprouter" + + "github.com/coreos/clair/health" + httputils "github.com/coreos/clair/utils/http" + "github.com/coreos/clair/worker" ) // Version is an integer representing the API version. @@ -31,7 +32,7 @@ const Version = 1 // GETVersions returns API and Engine versions. func GETVersions(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { - jsonhttp.Render(w, http.StatusOK, struct { + httputils.WriteHTTP(w, http.StatusOK, struct { APIVersion string EngineVersion string }{ @@ -49,6 +50,6 @@ func GETHealth(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { httpStatus = http.StatusServiceUnavailable } - jsonhttp.Render(w, httpStatus, statuses) + httputils.WriteHTTP(w, httpStatus, statuses) return } diff --git a/api/logic/layers.go b/api/logic/layers.go index 47c1f9bc..5116d3a1 100644 --- a/api/logic/layers.go +++ b/api/logic/layers.go @@ -19,12 +19,13 @@ import ( "net/http" "strconv" - "github.com/coreos/clair/api/jsonhttp" + "github.com/julienschmidt/httprouter" + "github.com/coreos/clair/database" cerrors "github.com/coreos/clair/utils/errors" + httputils "github.com/coreos/clair/utils/http" "github.com/coreos/clair/utils/types" "github.com/coreos/clair/worker" - "github.com/julienschmidt/httprouter" ) // POSTLayersParameters represents the expected parameters for POSTLayers. @@ -36,19 +37,19 @@ type POSTLayersParameters struct { // for the analysis. func POSTLayers(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { var parameters POSTLayersParameters - if s, err := jsonhttp.ParseBody(r, ¶meters); err != nil { - jsonhttp.RenderError(w, s, err) + if s, err := httputils.ParseHTTPBody(r, ¶meters); err != nil { + httputils.WriteHTTPError(w, s, err) return } // Process data. if err := worker.Process(parameters.ID, parameters.ParentID, parameters.Path); err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } // Get engine version and return. - jsonhttp.Render(w, http.StatusCreated, struct{ Version string }{Version: strconv.Itoa(worker.Version)}) + httputils.WriteHTTP(w, http.StatusCreated, struct{ Version string }{Version: strconv.Itoa(worker.Version)}) } // DeleteLayer deletes the specified layer and any child layers that are @@ -56,11 +57,11 @@ func POSTLayers(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { func DELETELayers(w http.ResponseWriter, r *http.Request, p httprouter.Params) { err := database.DeleteLayer(p.ByName("id")) if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } - jsonhttp.Render(w, http.StatusNoContent, nil) + httputils.WriteHTTP(w, http.StatusNoContent, nil) } // GETLayersOS returns the operating system of a layer if it exists. @@ -70,18 +71,18 @@ func GETLayersOS(w http.ResponseWriter, r *http.Request, p httprouter.Params) { // Find layer. layer, err := database.FindOneLayerByID(p.ByName("id"), []string{database.FieldLayerParent, database.FieldLayerOS}) if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } // Get OS. os, err := layer.OperatingSystem() if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } - jsonhttp.Render(w, http.StatusOK, struct{ OS string }{OS: os}) + httputils.WriteHTTP(w, http.StatusOK, struct{ OS string }{OS: os}) } // GETLayersParent returns the parent ID of a layer if it exists. @@ -90,14 +91,14 @@ func GETLayersParent(w http.ResponseWriter, r *http.Request, p httprouter.Params // Find layer layer, err := database.FindOneLayerByID(p.ByName("id"), []string{database.FieldLayerParent}) if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } // Get layer's parent. parent, err := layer.Parent([]string{database.FieldLayerID}) if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } @@ -105,7 +106,7 @@ func GETLayersParent(w http.ResponseWriter, r *http.Request, p httprouter.Params if parent != nil { ID = parent.ID } - jsonhttp.Render(w, http.StatusOK, struct{ ID string }{ID: ID}) + httputils.WriteHTTP(w, http.StatusOK, struct{ ID string }{ID: ID}) } // GETLayersPackages returns the complete list of packages that a layer has @@ -114,14 +115,14 @@ func GETLayersPackages(w http.ResponseWriter, r *http.Request, p httprouter.Para // Find layer layer, err := database.FindOneLayerByID(p.ByName("id"), []string{database.FieldLayerParent, database.FieldLayerPackages}) if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } // Find layer's packages. packagesNodes, err := layer.AllPackages() if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } @@ -129,12 +130,12 @@ func GETLayersPackages(w http.ResponseWriter, r *http.Request, p httprouter.Para if len(packagesNodes) > 0 { packages, err = database.FindAllPackagesByNodes(packagesNodes, []string{database.FieldPackageOS, database.FieldPackageName, database.FieldPackageVersion}) if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } } - jsonhttp.Render(w, http.StatusOK, struct{ Packages []*database.Package }{Packages: packages}) + httputils.WriteHTTP(w, http.StatusOK, struct{ Packages []*database.Package }{Packages: packages}) } // GETLayersPackagesDiff returns the list of packages that a layer installs and @@ -143,7 +144,7 @@ func GETLayersPackagesDiff(w http.ResponseWriter, r *http.Request, p httprouter. // Find layer. layer, err := database.FindOneLayerByID(p.ByName("id"), []string{database.FieldLayerPackages}) if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } @@ -152,19 +153,19 @@ func GETLayersPackagesDiff(w http.ResponseWriter, r *http.Request, p httprouter. if len(layer.InstalledPackagesNodes) > 0 { installedPackages, err = database.FindAllPackagesByNodes(layer.InstalledPackagesNodes, []string{database.FieldPackageOS, database.FieldPackageName, database.FieldPackageVersion}) if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } } if len(layer.RemovedPackagesNodes) > 0 { removedPackages, err = database.FindAllPackagesByNodes(layer.RemovedPackagesNodes, []string{database.FieldPackageOS, database.FieldPackageName, database.FieldPackageVersion}) if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } } - jsonhttp.Render(w, http.StatusOK, struct{ InstalledPackages, RemovedPackages []*database.Package }{InstalledPackages: installedPackages, RemovedPackages: removedPackages}) + httputils.WriteHTTP(w, http.StatusOK, struct{ InstalledPackages, RemovedPackages []*database.Package }{InstalledPackages: installedPackages, RemovedPackages: removedPackages}) } // GETLayersVulnerabilities returns the complete list of vulnerabilities that @@ -175,32 +176,32 @@ func GETLayersVulnerabilities(w http.ResponseWriter, r *http.Request, p httprout if minimumPriority == "" { minimumPriority = "High" // Set default priority to High } else if !minimumPriority.IsValid() { - jsonhttp.RenderError(w, 0, cerrors.NewBadRequestError("invalid priority")) + httputils.WriteHTTPError(w, 0, cerrors.NewBadRequestError("invalid priority")) return } // Find layer layer, err := database.FindOneLayerByID(p.ByName("id"), []string{database.FieldLayerParent, database.FieldLayerPackages}) if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } // Find layer's packages. packagesNodes, err := layer.AllPackages() if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } // Find vulnerabilities. vulnerabilities, err := getVulnerabilitiesFromLayerPackagesNodes(packagesNodes, minimumPriority, []string{database.FieldVulnerabilityID, database.FieldVulnerabilityLink, database.FieldVulnerabilityPriority, database.FieldVulnerabilityDescription, database.FieldVulnerabilityCausedByPackage}) if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } - jsonhttp.Render(w, http.StatusOK, struct{ Vulnerabilities []*database.Vulnerability }{Vulnerabilities: vulnerabilities}) + httputils.WriteHTTP(w, http.StatusOK, struct{ Vulnerabilities []*database.Vulnerability }{Vulnerabilities: vulnerabilities}) } // GETLayersVulnerabilitiesDiff returns the list of vulnerabilities that a layer @@ -211,14 +212,14 @@ func GETLayersVulnerabilitiesDiff(w http.ResponseWriter, r *http.Request, p http if minimumPriority == "" { minimumPriority = "High" // Set default priority to High } else if !minimumPriority.IsValid() { - jsonhttp.RenderError(w, 0, cerrors.NewBadRequestError("invalid priority")) + httputils.WriteHTTPError(w, 0, cerrors.NewBadRequestError("invalid priority")) return } // Find layer. layer, err := database.FindOneLayerByID(p.ByName("id"), []string{database.FieldLayerPackages}) if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } @@ -228,14 +229,14 @@ func GETLayersVulnerabilitiesDiff(w http.ResponseWriter, r *http.Request, p http // Find vulnerabilities for installed packages. addedVulnerabilities, err := getVulnerabilitiesFromLayerPackagesNodes(layer.InstalledPackagesNodes, minimumPriority, selectedFields) if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } // Find vulnerabilities for removed packages. removedVulnerabilities, err := getVulnerabilitiesFromLayerPackagesNodes(layer.RemovedPackagesNodes, minimumPriority, selectedFields) if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } @@ -249,7 +250,7 @@ func GETLayersVulnerabilitiesDiff(w http.ResponseWriter, r *http.Request, p http } } - jsonhttp.Render(w, http.StatusOK, struct{ Adds, Removes []*database.Vulnerability }{Adds: addedVulnerabilities, Removes: removedVulnerabilities}) + httputils.WriteHTTP(w, http.StatusOK, struct{ Adds, Removes []*database.Vulnerability }{Adds: addedVulnerabilities, Removes: removedVulnerabilities}) } // POSTBatchLayersVulnerabilitiesParameters represents the expected parameters @@ -263,12 +264,12 @@ type POSTBatchLayersVulnerabilitiesParameters struct { func POSTBatchLayersVulnerabilities(w http.ResponseWriter, r *http.Request, p httprouter.Params) { // Parse body var parameters POSTBatchLayersVulnerabilitiesParameters - if s, err := jsonhttp.ParseBody(r, ¶meters); err != nil { - jsonhttp.RenderError(w, s, err) + if s, err := httputils.ParseHTTPBody(r, ¶meters); err != nil { + httputils.WriteHTTPError(w, s, err) return } if len(parameters.LayersIDs) == 0 { - jsonhttp.RenderError(w, http.StatusBadRequest, errors.New("at least one LayerID query parameter must be provided")) + httputils.WriteHTTPError(w, http.StatusBadRequest, errors.New("at least one LayerID query parameter must be provided")) return } @@ -277,7 +278,7 @@ func POSTBatchLayersVulnerabilities(w http.ResponseWriter, r *http.Request, p ht if minimumPriority == "" { minimumPriority = "High" // Set default priority to High } else if !minimumPriority.IsValid() { - jsonhttp.RenderError(w, 0, cerrors.NewBadRequestError("invalid priority")) + httputils.WriteHTTPError(w, 0, cerrors.NewBadRequestError("invalid priority")) return } @@ -287,28 +288,28 @@ func POSTBatchLayersVulnerabilities(w http.ResponseWriter, r *http.Request, p ht // Find layer layer, err := database.FindOneLayerByID(layerID, []string{database.FieldLayerParent, database.FieldLayerPackages}) if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } // Find layer's packages. packagesNodes, err := layer.AllPackages() if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } // Find vulnerabilities. vulnerabilities, err := getVulnerabilitiesFromLayerPackagesNodes(packagesNodes, minimumPriority, []string{database.FieldVulnerabilityID, database.FieldVulnerabilityLink, database.FieldVulnerabilityPriority, database.FieldVulnerabilityDescription, database.FieldVulnerabilityCausedByPackage}) if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } response[layerID] = struct{ Vulnerabilities []*database.Vulnerability }{Vulnerabilities: vulnerabilities} } - jsonhttp.Render(w, http.StatusOK, response) + httputils.WriteHTTP(w, http.StatusOK, response) } // getSuccessorsFromPackagesNodes returns the node list of packages that have diff --git a/api/logic/vulnerabilities.go b/api/logic/vulnerabilities.go index 695c7359..2f089112 100644 --- a/api/logic/vulnerabilities.go +++ b/api/logic/vulnerabilities.go @@ -18,10 +18,11 @@ import ( "errors" "net/http" - "github.com/coreos/clair/api/jsonhttp" + "github.com/julienschmidt/httprouter" + "github.com/coreos/clair/database" cerrors "github.com/coreos/clair/utils/errors" - "github.com/julienschmidt/httprouter" + httputils "github.com/coreos/clair/utils/http" ) // GETVulnerabilities returns a vulnerability identified by an ID if it exists. @@ -29,36 +30,36 @@ func GETVulnerabilities(w http.ResponseWriter, r *http.Request, p httprouter.Par // Find vulnerability. vulnerability, err := database.FindOneVulnerability(p.ByName("id"), []string{database.FieldVulnerabilityID, database.FieldVulnerabilityLink, database.FieldVulnerabilityPriority, database.FieldVulnerabilityDescription, database.FieldVulnerabilityFixedIn}) if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } abstractVulnerability, err := vulnerability.ToAbstractVulnerability() if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } - jsonhttp.Render(w, http.StatusOK, abstractVulnerability) + httputils.WriteHTTP(w, http.StatusOK, abstractVulnerability) } // POSTVulnerabilities manually inserts a vulnerability into the database if it // does not exist yet. func POSTVulnerabilities(w http.ResponseWriter, r *http.Request, p httprouter.Params) { var parameters *database.AbstractVulnerability - if s, err := jsonhttp.ParseBody(r, ¶meters); err != nil { - jsonhttp.RenderError(w, s, err) + if s, err := httputils.ParseHTTPBody(r, ¶meters); err != nil { + httputils.WriteHTTPError(w, s, err) return } // Ensure that the vulnerability does not exist. vulnerability, err := database.FindOneVulnerability(parameters.ID, []string{}) if err != nil && err != cerrors.ErrNotFound { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } if vulnerability != nil { - jsonhttp.RenderError(w, 0, cerrors.NewBadRequestError("vulnerability already exists")) + httputils.WriteHTTPError(w, 0, cerrors.NewBadRequestError("vulnerability already exists")) return } @@ -66,7 +67,7 @@ func POSTVulnerabilities(w http.ResponseWriter, r *http.Request, p httprouter.Pa packages := database.AbstractPackagesToPackages(parameters.AffectedPackages) err = database.InsertPackages(packages) if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } var pkgNodes []string @@ -77,25 +78,25 @@ func POSTVulnerabilities(w http.ResponseWriter, r *http.Request, p httprouter.Pa // Insert vulnerability. notifications, err := database.InsertVulnerabilities([]*database.Vulnerability{parameters.ToVulnerability(pkgNodes)}) if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } // Insert notifications. err = database.InsertNotifications(notifications, database.GetDefaultNotificationWrapper()) if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } - jsonhttp.Render(w, http.StatusCreated, nil) + httputils.WriteHTTP(w, http.StatusCreated, nil) } // PUTVulnerabilities updates a vulnerability if it exists. func PUTVulnerabilities(w http.ResponseWriter, r *http.Request, p httprouter.Params) { var parameters *database.AbstractVulnerability - if s, err := jsonhttp.ParseBody(r, ¶meters); err != nil { - jsonhttp.RenderError(w, s, err) + if s, err := httputils.ParseHTTPBody(r, ¶meters); err != nil { + httputils.WriteHTTPError(w, s, err) return } parameters.ID = p.ByName("id") @@ -103,7 +104,7 @@ func PUTVulnerabilities(w http.ResponseWriter, r *http.Request, p httprouter.Par // Ensure that the vulnerability exists. _, err := database.FindOneVulnerability(parameters.ID, []string{}) if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } @@ -111,7 +112,7 @@ func PUTVulnerabilities(w http.ResponseWriter, r *http.Request, p httprouter.Par packages := database.AbstractPackagesToPackages(parameters.AffectedPackages) err = database.InsertPackages(packages) if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } var pkgNodes []string @@ -122,29 +123,29 @@ func PUTVulnerabilities(w http.ResponseWriter, r *http.Request, p httprouter.Par // Insert vulnerability. notifications, err := database.InsertVulnerabilities([]*database.Vulnerability{parameters.ToVulnerability(pkgNodes)}) if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } // Insert notifications. err = database.InsertNotifications(notifications, database.GetDefaultNotificationWrapper()) if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } - jsonhttp.Render(w, http.StatusCreated, nil) + httputils.WriteHTTP(w, http.StatusCreated, nil) } // DELVulnerabilities deletes a vulnerability if it exists. func DELVulnerabilities(w http.ResponseWriter, r *http.Request, p httprouter.Params) { err := database.DeleteVulnerability(p.ByName("id")) if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } - jsonhttp.Render(w, http.StatusNoContent, nil) + httputils.WriteHTTP(w, http.StatusNoContent, nil) } // GETVulnerabilitiesIntroducingLayers returns the list of layers that @@ -155,13 +156,13 @@ func GETVulnerabilitiesIntroducingLayers(w http.ResponseWriter, r *http.Request, // Find vulnerability to verify that it exists. _, err := database.FindOneVulnerability(p.ByName("id"), []string{}) if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } layers, err := database.FindAllLayersIntroducingVulnerability(p.ByName("id"), []string{database.FieldLayerID}) if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } @@ -170,7 +171,7 @@ func GETVulnerabilitiesIntroducingLayers(w http.ResponseWriter, r *http.Request, layersIDs = append(layersIDs, l.ID) } - jsonhttp.Render(w, http.StatusOK, struct{ IntroducingLayersIDs []string }{IntroducingLayersIDs: layersIDs}) + httputils.WriteHTTP(w, http.StatusOK, struct{ IntroducingLayersIDs []string }{IntroducingLayersIDs: layersIDs}) } // POSTVulnerabilitiesAffectedLayersParameters represents the expected @@ -184,19 +185,19 @@ type POSTVulnerabilitiesAffectedLayersParameters struct { func POSTVulnerabilitiesAffectedLayers(w http.ResponseWriter, r *http.Request, p httprouter.Params) { // Parse body. var parameters POSTBatchLayersVulnerabilitiesParameters - if s, err := jsonhttp.ParseBody(r, ¶meters); err != nil { - jsonhttp.RenderError(w, s, err) + if s, err := httputils.ParseHTTPBody(r, ¶meters); err != nil { + httputils.WriteHTTPError(w, s, err) return } if len(parameters.LayersIDs) == 0 { - jsonhttp.RenderError(w, http.StatusBadRequest, errors.New("getting the entire list of affected layers is not supported yet: at least one LayerID query parameter must be provided")) + httputils.WriteHTTPError(w, http.StatusBadRequest, errors.New("getting the entire list of affected layers is not supported yet: at least one LayerID query parameter must be provided")) return } // Find vulnerability. vulnerability, err := database.FindOneVulnerability(p.ByName("id"), []string{database.FieldVulnerabilityFixedIn}) if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } @@ -212,21 +213,21 @@ func POSTVulnerabilitiesAffectedLayers(w http.ResponseWriter, r *http.Request, p // Find layer layer, err := database.FindOneLayerByID(layerID, []string{database.FieldLayerParent, database.FieldLayerPackages, database.FieldLayerPackages}) if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } // Find layer's packages. packagesNodes, err := layer.AllPackages() if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } // Get successors packages of layer' packages. successors, err := getSuccessorsFromPackagesNodes(packagesNodes) if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } @@ -243,5 +244,5 @@ func POSTVulnerabilitiesAffectedLayers(w http.ResponseWriter, r *http.Request, p response[layerID] = struct{ Vulnerable bool }{Vulnerable: vulnerable} } - jsonhttp.Render(w, http.StatusOK, response) + httputils.WriteHTTP(w, http.StatusOK, response) } diff --git a/api/wrappers/timeout.go b/api/wrappers/timeout.go index dac70de6..d33f3be1 100644 --- a/api/wrappers/timeout.go +++ b/api/wrappers/timeout.go @@ -19,13 +19,13 @@ package wrappers import ( "errors" - "fmt" "net/http" "sync" "time" - "github.com/coreos/clair/api/jsonhttp" "github.com/julienschmidt/httprouter" + + httputils "github.com/coreos/clair/utils/http" ) // ErrHandlerTimeout is returned on ResponseWriter Write calls @@ -77,7 +77,6 @@ func (tw *timeoutWriter) WriteHeader(status int) { // If the duration is 0, the wrapper does nothing. func TimeOut(d time.Duration, fn httprouter.Handle) httprouter.Handle { if d == 0 { - fmt.Println("nope timeout") return fn } @@ -97,7 +96,7 @@ func TimeOut(d time.Duration, fn httprouter.Handle) httprouter.Handle { tw.mu.Lock() defer tw.mu.Unlock() if !tw.wroteHeader { - jsonhttp.RenderError(tw.ResponseWriter, http.StatusServiceUnavailable, ErrHandlerTimeout) + httputils.WriteHTTPError(tw.ResponseWriter, http.StatusServiceUnavailable, ErrHandlerTimeout) } tw.timedOut = true } diff --git a/notifier/notifier.go b/notifier/notifier.go index dc4995ea..aabe8138 100644 --- a/notifier/notifier.go +++ b/notifier/notifier.go @@ -55,7 +55,7 @@ type Notifier struct { } // Config represents the configuration of a Notifier. -// The certificates are optionnals and enables client certificate authentification. +// The certificates are optionnal and enable client certificate authentification. type Config struct { Endpoint string CertFile, KeyFile, CAFile string diff --git a/utils/http/http.go b/utils/http/http.go index 6250d733..1b1c0f8d 100644 --- a/utils/http/http.go +++ b/utils/http/http.go @@ -18,9 +18,19 @@ package http import ( "crypto/tls" "crypto/x509" + "encoding/json" + "io" "io/ioutil" + "net/http" + + "github.com/coreos/clair/database" + cerrors "github.com/coreos/clair/utils/errors" + "github.com/coreos/clair/worker" ) +// MaxPostSize is the maximum number of bytes that ParseHTTPBody reads from an http.Request.Body. +const MaxBodySize int64 = 1048576 + // LoadTLSClientConfig initializes a *tls.Config using the given certificates and private key, that // can be used to communicate with a server using client certificate authentificate. // @@ -53,3 +63,75 @@ func LoadTLSClientConfig(certFile, keyFile, caFile string) (*tls.Config, error) return tlsConfig, nil } + +// LoadTLSClientConfigForServer initializes a *tls.Config using the given CA, that can be used to +// configure http server to do client certificate authentification. +// +// If no CA is given, a nil *tls.Config is returned: no client certificate will be required and +// verified. In other words, authentification will be disabled. +func LoadTLSClientConfigForServer(caFile string) (*tls.Config, error) { + if len(caFile) == 0 { + return nil, nil + } + + caCert, err := ioutil.ReadFile(caFile) + if err != nil { + return nil, err + } + caCertPool := x509.NewCertPool() + caCertPool.AppendCertsFromPEM(caCert) + + tlsConfig := &tls.Config{ + ClientCAs: caCertPool, + ClientAuth: tls.RequireAndVerifyClientCert, + } + + return tlsConfig, nil +} + +// WriteHTTP writes a JSON-encoded object to a http.ResponseWriter, as well as +// a HTTP status code. +func WriteHTTP(w http.ResponseWriter, httpStatus int, v interface{}) { + w.WriteHeader(httpStatus) + if v != nil { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + result, _ := json.Marshal(v) + w.Write(result) + } +} + +// WriteHTTPError writes an error, wrapped in the Message field of a JSON-encoded +// object to a http.ResponseWriter, as well as a HTTP status code. +// If the status code is 0, handleError tries to guess the proper HTTP status +// code from the error type. +func WriteHTTPError(w http.ResponseWriter, httpStatus int, err error) { + if httpStatus == 0 { + httpStatus = http.StatusInternalServerError + // Try to guess the http status code from the error type + if _, isBadRequestError := err.(*cerrors.ErrBadRequest); isBadRequestError { + httpStatus = http.StatusBadRequest + } else { + switch err { + case cerrors.ErrNotFound: + httpStatus = http.StatusNotFound + case database.ErrTransaction, database.ErrBackendException: + httpStatus = http.StatusServiceUnavailable + case worker.ErrParentUnknown, worker.ErrUnsupported: + httpStatus = http.StatusBadRequest + } + } + } + + WriteHTTP(w, httpStatus, struct{ Message string }{Message: err.Error()}) +} + +// ParseHTTPBody reads a JSON-encoded body from a http.Request and unmarshals it +// into the provided object. +func ParseHTTPBody(r *http.Request, v interface{}) (int, error) { + defer r.Body.Close() + err := json.NewDecoder(io.LimitReader(r.Body, MaxBodySize)).Decode(v) + if err != nil { + return http.StatusUnsupportedMediaType, err + } + return 0, nil +}