diff --git a/api/api.go b/api/api.go index deb18d67..c69825d1 100644 --- a/api/api.go +++ b/api/api.go @@ -17,18 +17,16 @@ package api import ( - "io/ioutil" "net" "net/http" "strconv" "time" - "crypto/tls" - "crypto/x509" - "github.com/coreos/pkg/capnslog" - "github.com/coreos/clair/utils" "github.com/tylerb/graceful" + + "github.com/coreos/clair/utils" + httputils "github.com/coreos/clair/utils/http" ) var log = capnslog.NewPackageLogger("github.com/coreos/clair", "api") @@ -49,12 +47,20 @@ func RunMain(conf *Config, st *utils.Stopper) { st.End() }() + tlsConfig, err := httputils.LoadTLSClientConfigForServer(conf.CAFile) + if err != nil { + log.Fatalf("could not initialize client cert authentification: %s\n", err) + } + if tlsConfig != nil { + log.Info("api configured with client certificate authentification") + } + srv := &graceful.Server{ Timeout: 0, // Already handled by our TimeOut middleware NoSignalHandling: true, // We want to use our own Stopper Server: &http.Server{ Addr: ":" + strconv.Itoa(conf.Port), - TLSConfig: setupClientCert(conf.CAFile), + TLSConfig: tlsConfig, Handler: NewVersionRouter(conf.TimeOut), }, } @@ -102,25 +108,3 @@ func listenAndServeWithStopper(srv *graceful.Server, st *utils.Stopper, certFile log.Fatal(err) } } - -// setupClientCert creates a tls.Config instance using a CA file path -// (if provided) and and calls log.Fatal if it does not exist. -func setupClientCert(caFile string) *tls.Config { - if len(caFile) > 0 { - log.Info("API: Client Certificate Authentification Enabled") - caCert, err := ioutil.ReadFile(caFile) - if err != nil { - log.Fatal(err) - } - caCertPool := x509.NewCertPool() - caCertPool.AppendCertsFromPEM(caCert) - return &tls.Config{ - ClientCAs: caCertPool, - ClientAuth: tls.RequireAndVerifyClientCert, - } - } - - return &tls.Config{ - ClientAuth: tls.NoClientCert, - } -} diff --git a/api/jsonhttp/json.go b/api/jsonhttp/json.go deleted file mode 100644 index b0e39d3f..00000000 --- a/api/jsonhttp/json.go +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright 2015 clair authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package jsonhttp provides helper functions to write JSON responses to -// http.ResponseWriter and read JSON bodies from http.Request. -package jsonhttp - -import ( - "encoding/json" - "io" - "net/http" - - "github.com/coreos/clair/database" - cerrors "github.com/coreos/clair/utils/errors" - "github.com/coreos/clair/worker" -) - -// MaxPostSize is the maximum number of bytes that ParseBody reads from an -// http.Request.Body. -var MaxPostSize int64 = 1048576 - -// Render writes a JSON-encoded object to a http.ResponseWriter, as well as -// a HTTP status code. -func Render(w http.ResponseWriter, httpStatus int, v interface{}) { - w.WriteHeader(httpStatus) - if v != nil { - w.Header().Set("Content-Type", "application/json; charset=utf-8") - result, _ := json.Marshal(v) - w.Write(result) - } -} - -// RenderError writes an error, wrapped in the Message field of a JSON-encoded -// object to a http.ResponseWriter, as well as a HTTP status code. -// If the status code is 0, RenderError tries to guess the proper HTTP status -// code from the error type. -func RenderError(w http.ResponseWriter, httpStatus int, err error) { - if httpStatus == 0 { - httpStatus = http.StatusInternalServerError - // Try to guess the http status code from the error type - if _, isBadRequestError := err.(*cerrors.ErrBadRequest); isBadRequestError { - httpStatus = http.StatusBadRequest - } else { - switch err { - case cerrors.ErrNotFound: - httpStatus = http.StatusNotFound - case database.ErrTransaction, database.ErrBackendException: - httpStatus = http.StatusServiceUnavailable - case worker.ErrParentUnknown, worker.ErrUnsupported: - httpStatus = http.StatusBadRequest - } - } - } - - Render(w, httpStatus, struct{ Message string }{Message: err.Error()}) -} - -// ParseBody reads a JSON-encoded body from a http.Request and unmarshals it -// into the provided object. -func ParseBody(r *http.Request, v interface{}) (int, error) { - defer r.Body.Close() - err := json.NewDecoder(io.LimitReader(r.Body, MaxPostSize)).Decode(v) - if err != nil { - return http.StatusUnsupportedMediaType, err - } - return 0, nil -} diff --git a/api/logic/general.go b/api/logic/general.go index 3a040c92..50fa6e7e 100644 --- a/api/logic/general.go +++ b/api/logic/general.go @@ -20,10 +20,11 @@ import ( "net/http" "strconv" - "github.com/coreos/clair/api/jsonhttp" - "github.com/coreos/clair/health" - "github.com/coreos/clair/worker" "github.com/julienschmidt/httprouter" + + "github.com/coreos/clair/health" + httputils "github.com/coreos/clair/utils/http" + "github.com/coreos/clair/worker" ) // Version is an integer representing the API version. @@ -31,7 +32,7 @@ const Version = 1 // GETVersions returns API and Engine versions. func GETVersions(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { - jsonhttp.Render(w, http.StatusOK, struct { + httputils.WriteHTTP(w, http.StatusOK, struct { APIVersion string EngineVersion string }{ @@ -49,6 +50,6 @@ func GETHealth(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { httpStatus = http.StatusServiceUnavailable } - jsonhttp.Render(w, httpStatus, statuses) + httputils.WriteHTTP(w, httpStatus, statuses) return } diff --git a/api/logic/layers.go b/api/logic/layers.go index 47c1f9bc..5116d3a1 100644 --- a/api/logic/layers.go +++ b/api/logic/layers.go @@ -19,12 +19,13 @@ import ( "net/http" "strconv" - "github.com/coreos/clair/api/jsonhttp" + "github.com/julienschmidt/httprouter" + "github.com/coreos/clair/database" cerrors "github.com/coreos/clair/utils/errors" + httputils "github.com/coreos/clair/utils/http" "github.com/coreos/clair/utils/types" "github.com/coreos/clair/worker" - "github.com/julienschmidt/httprouter" ) // POSTLayersParameters represents the expected parameters for POSTLayers. @@ -36,19 +37,19 @@ type POSTLayersParameters struct { // for the analysis. func POSTLayers(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { var parameters POSTLayersParameters - if s, err := jsonhttp.ParseBody(r, ¶meters); err != nil { - jsonhttp.RenderError(w, s, err) + if s, err := httputils.ParseHTTPBody(r, ¶meters); err != nil { + httputils.WriteHTTPError(w, s, err) return } // Process data. if err := worker.Process(parameters.ID, parameters.ParentID, parameters.Path); err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } // Get engine version and return. - jsonhttp.Render(w, http.StatusCreated, struct{ Version string }{Version: strconv.Itoa(worker.Version)}) + httputils.WriteHTTP(w, http.StatusCreated, struct{ Version string }{Version: strconv.Itoa(worker.Version)}) } // DeleteLayer deletes the specified layer and any child layers that are @@ -56,11 +57,11 @@ func POSTLayers(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { func DELETELayers(w http.ResponseWriter, r *http.Request, p httprouter.Params) { err := database.DeleteLayer(p.ByName("id")) if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } - jsonhttp.Render(w, http.StatusNoContent, nil) + httputils.WriteHTTP(w, http.StatusNoContent, nil) } // GETLayersOS returns the operating system of a layer if it exists. @@ -70,18 +71,18 @@ func GETLayersOS(w http.ResponseWriter, r *http.Request, p httprouter.Params) { // Find layer. layer, err := database.FindOneLayerByID(p.ByName("id"), []string{database.FieldLayerParent, database.FieldLayerOS}) if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } // Get OS. os, err := layer.OperatingSystem() if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } - jsonhttp.Render(w, http.StatusOK, struct{ OS string }{OS: os}) + httputils.WriteHTTP(w, http.StatusOK, struct{ OS string }{OS: os}) } // GETLayersParent returns the parent ID of a layer if it exists. @@ -90,14 +91,14 @@ func GETLayersParent(w http.ResponseWriter, r *http.Request, p httprouter.Params // Find layer layer, err := database.FindOneLayerByID(p.ByName("id"), []string{database.FieldLayerParent}) if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } // Get layer's parent. parent, err := layer.Parent([]string{database.FieldLayerID}) if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } @@ -105,7 +106,7 @@ func GETLayersParent(w http.ResponseWriter, r *http.Request, p httprouter.Params if parent != nil { ID = parent.ID } - jsonhttp.Render(w, http.StatusOK, struct{ ID string }{ID: ID}) + httputils.WriteHTTP(w, http.StatusOK, struct{ ID string }{ID: ID}) } // GETLayersPackages returns the complete list of packages that a layer has @@ -114,14 +115,14 @@ func GETLayersPackages(w http.ResponseWriter, r *http.Request, p httprouter.Para // Find layer layer, err := database.FindOneLayerByID(p.ByName("id"), []string{database.FieldLayerParent, database.FieldLayerPackages}) if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } // Find layer's packages. packagesNodes, err := layer.AllPackages() if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } @@ -129,12 +130,12 @@ func GETLayersPackages(w http.ResponseWriter, r *http.Request, p httprouter.Para if len(packagesNodes) > 0 { packages, err = database.FindAllPackagesByNodes(packagesNodes, []string{database.FieldPackageOS, database.FieldPackageName, database.FieldPackageVersion}) if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } } - jsonhttp.Render(w, http.StatusOK, struct{ Packages []*database.Package }{Packages: packages}) + httputils.WriteHTTP(w, http.StatusOK, struct{ Packages []*database.Package }{Packages: packages}) } // GETLayersPackagesDiff returns the list of packages that a layer installs and @@ -143,7 +144,7 @@ func GETLayersPackagesDiff(w http.ResponseWriter, r *http.Request, p httprouter. // Find layer. layer, err := database.FindOneLayerByID(p.ByName("id"), []string{database.FieldLayerPackages}) if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } @@ -152,19 +153,19 @@ func GETLayersPackagesDiff(w http.ResponseWriter, r *http.Request, p httprouter. if len(layer.InstalledPackagesNodes) > 0 { installedPackages, err = database.FindAllPackagesByNodes(layer.InstalledPackagesNodes, []string{database.FieldPackageOS, database.FieldPackageName, database.FieldPackageVersion}) if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } } if len(layer.RemovedPackagesNodes) > 0 { removedPackages, err = database.FindAllPackagesByNodes(layer.RemovedPackagesNodes, []string{database.FieldPackageOS, database.FieldPackageName, database.FieldPackageVersion}) if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } } - jsonhttp.Render(w, http.StatusOK, struct{ InstalledPackages, RemovedPackages []*database.Package }{InstalledPackages: installedPackages, RemovedPackages: removedPackages}) + httputils.WriteHTTP(w, http.StatusOK, struct{ InstalledPackages, RemovedPackages []*database.Package }{InstalledPackages: installedPackages, RemovedPackages: removedPackages}) } // GETLayersVulnerabilities returns the complete list of vulnerabilities that @@ -175,32 +176,32 @@ func GETLayersVulnerabilities(w http.ResponseWriter, r *http.Request, p httprout if minimumPriority == "" { minimumPriority = "High" // Set default priority to High } else if !minimumPriority.IsValid() { - jsonhttp.RenderError(w, 0, cerrors.NewBadRequestError("invalid priority")) + httputils.WriteHTTPError(w, 0, cerrors.NewBadRequestError("invalid priority")) return } // Find layer layer, err := database.FindOneLayerByID(p.ByName("id"), []string{database.FieldLayerParent, database.FieldLayerPackages}) if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } // Find layer's packages. packagesNodes, err := layer.AllPackages() if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } // Find vulnerabilities. vulnerabilities, err := getVulnerabilitiesFromLayerPackagesNodes(packagesNodes, minimumPriority, []string{database.FieldVulnerabilityID, database.FieldVulnerabilityLink, database.FieldVulnerabilityPriority, database.FieldVulnerabilityDescription, database.FieldVulnerabilityCausedByPackage}) if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } - jsonhttp.Render(w, http.StatusOK, struct{ Vulnerabilities []*database.Vulnerability }{Vulnerabilities: vulnerabilities}) + httputils.WriteHTTP(w, http.StatusOK, struct{ Vulnerabilities []*database.Vulnerability }{Vulnerabilities: vulnerabilities}) } // GETLayersVulnerabilitiesDiff returns the list of vulnerabilities that a layer @@ -211,14 +212,14 @@ func GETLayersVulnerabilitiesDiff(w http.ResponseWriter, r *http.Request, p http if minimumPriority == "" { minimumPriority = "High" // Set default priority to High } else if !minimumPriority.IsValid() { - jsonhttp.RenderError(w, 0, cerrors.NewBadRequestError("invalid priority")) + httputils.WriteHTTPError(w, 0, cerrors.NewBadRequestError("invalid priority")) return } // Find layer. layer, err := database.FindOneLayerByID(p.ByName("id"), []string{database.FieldLayerPackages}) if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } @@ -228,14 +229,14 @@ func GETLayersVulnerabilitiesDiff(w http.ResponseWriter, r *http.Request, p http // Find vulnerabilities for installed packages. addedVulnerabilities, err := getVulnerabilitiesFromLayerPackagesNodes(layer.InstalledPackagesNodes, minimumPriority, selectedFields) if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } // Find vulnerabilities for removed packages. removedVulnerabilities, err := getVulnerabilitiesFromLayerPackagesNodes(layer.RemovedPackagesNodes, minimumPriority, selectedFields) if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } @@ -249,7 +250,7 @@ func GETLayersVulnerabilitiesDiff(w http.ResponseWriter, r *http.Request, p http } } - jsonhttp.Render(w, http.StatusOK, struct{ Adds, Removes []*database.Vulnerability }{Adds: addedVulnerabilities, Removes: removedVulnerabilities}) + httputils.WriteHTTP(w, http.StatusOK, struct{ Adds, Removes []*database.Vulnerability }{Adds: addedVulnerabilities, Removes: removedVulnerabilities}) } // POSTBatchLayersVulnerabilitiesParameters represents the expected parameters @@ -263,12 +264,12 @@ type POSTBatchLayersVulnerabilitiesParameters struct { func POSTBatchLayersVulnerabilities(w http.ResponseWriter, r *http.Request, p httprouter.Params) { // Parse body var parameters POSTBatchLayersVulnerabilitiesParameters - if s, err := jsonhttp.ParseBody(r, ¶meters); err != nil { - jsonhttp.RenderError(w, s, err) + if s, err := httputils.ParseHTTPBody(r, ¶meters); err != nil { + httputils.WriteHTTPError(w, s, err) return } if len(parameters.LayersIDs) == 0 { - jsonhttp.RenderError(w, http.StatusBadRequest, errors.New("at least one LayerID query parameter must be provided")) + httputils.WriteHTTPError(w, http.StatusBadRequest, errors.New("at least one LayerID query parameter must be provided")) return } @@ -277,7 +278,7 @@ func POSTBatchLayersVulnerabilities(w http.ResponseWriter, r *http.Request, p ht if minimumPriority == "" { minimumPriority = "High" // Set default priority to High } else if !minimumPriority.IsValid() { - jsonhttp.RenderError(w, 0, cerrors.NewBadRequestError("invalid priority")) + httputils.WriteHTTPError(w, 0, cerrors.NewBadRequestError("invalid priority")) return } @@ -287,28 +288,28 @@ func POSTBatchLayersVulnerabilities(w http.ResponseWriter, r *http.Request, p ht // Find layer layer, err := database.FindOneLayerByID(layerID, []string{database.FieldLayerParent, database.FieldLayerPackages}) if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } // Find layer's packages. packagesNodes, err := layer.AllPackages() if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } // Find vulnerabilities. vulnerabilities, err := getVulnerabilitiesFromLayerPackagesNodes(packagesNodes, minimumPriority, []string{database.FieldVulnerabilityID, database.FieldVulnerabilityLink, database.FieldVulnerabilityPriority, database.FieldVulnerabilityDescription, database.FieldVulnerabilityCausedByPackage}) if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } response[layerID] = struct{ Vulnerabilities []*database.Vulnerability }{Vulnerabilities: vulnerabilities} } - jsonhttp.Render(w, http.StatusOK, response) + httputils.WriteHTTP(w, http.StatusOK, response) } // getSuccessorsFromPackagesNodes returns the node list of packages that have diff --git a/api/logic/vulnerabilities.go b/api/logic/vulnerabilities.go index 695c7359..2f089112 100644 --- a/api/logic/vulnerabilities.go +++ b/api/logic/vulnerabilities.go @@ -18,10 +18,11 @@ import ( "errors" "net/http" - "github.com/coreos/clair/api/jsonhttp" + "github.com/julienschmidt/httprouter" + "github.com/coreos/clair/database" cerrors "github.com/coreos/clair/utils/errors" - "github.com/julienschmidt/httprouter" + httputils "github.com/coreos/clair/utils/http" ) // GETVulnerabilities returns a vulnerability identified by an ID if it exists. @@ -29,36 +30,36 @@ func GETVulnerabilities(w http.ResponseWriter, r *http.Request, p httprouter.Par // Find vulnerability. vulnerability, err := database.FindOneVulnerability(p.ByName("id"), []string{database.FieldVulnerabilityID, database.FieldVulnerabilityLink, database.FieldVulnerabilityPriority, database.FieldVulnerabilityDescription, database.FieldVulnerabilityFixedIn}) if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } abstractVulnerability, err := vulnerability.ToAbstractVulnerability() if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } - jsonhttp.Render(w, http.StatusOK, abstractVulnerability) + httputils.WriteHTTP(w, http.StatusOK, abstractVulnerability) } // POSTVulnerabilities manually inserts a vulnerability into the database if it // does not exist yet. func POSTVulnerabilities(w http.ResponseWriter, r *http.Request, p httprouter.Params) { var parameters *database.AbstractVulnerability - if s, err := jsonhttp.ParseBody(r, ¶meters); err != nil { - jsonhttp.RenderError(w, s, err) + if s, err := httputils.ParseHTTPBody(r, ¶meters); err != nil { + httputils.WriteHTTPError(w, s, err) return } // Ensure that the vulnerability does not exist. vulnerability, err := database.FindOneVulnerability(parameters.ID, []string{}) if err != nil && err != cerrors.ErrNotFound { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } if vulnerability != nil { - jsonhttp.RenderError(w, 0, cerrors.NewBadRequestError("vulnerability already exists")) + httputils.WriteHTTPError(w, 0, cerrors.NewBadRequestError("vulnerability already exists")) return } @@ -66,7 +67,7 @@ func POSTVulnerabilities(w http.ResponseWriter, r *http.Request, p httprouter.Pa packages := database.AbstractPackagesToPackages(parameters.AffectedPackages) err = database.InsertPackages(packages) if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } var pkgNodes []string @@ -77,25 +78,25 @@ func POSTVulnerabilities(w http.ResponseWriter, r *http.Request, p httprouter.Pa // Insert vulnerability. notifications, err := database.InsertVulnerabilities([]*database.Vulnerability{parameters.ToVulnerability(pkgNodes)}) if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } // Insert notifications. err = database.InsertNotifications(notifications, database.GetDefaultNotificationWrapper()) if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } - jsonhttp.Render(w, http.StatusCreated, nil) + httputils.WriteHTTP(w, http.StatusCreated, nil) } // PUTVulnerabilities updates a vulnerability if it exists. func PUTVulnerabilities(w http.ResponseWriter, r *http.Request, p httprouter.Params) { var parameters *database.AbstractVulnerability - if s, err := jsonhttp.ParseBody(r, ¶meters); err != nil { - jsonhttp.RenderError(w, s, err) + if s, err := httputils.ParseHTTPBody(r, ¶meters); err != nil { + httputils.WriteHTTPError(w, s, err) return } parameters.ID = p.ByName("id") @@ -103,7 +104,7 @@ func PUTVulnerabilities(w http.ResponseWriter, r *http.Request, p httprouter.Par // Ensure that the vulnerability exists. _, err := database.FindOneVulnerability(parameters.ID, []string{}) if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } @@ -111,7 +112,7 @@ func PUTVulnerabilities(w http.ResponseWriter, r *http.Request, p httprouter.Par packages := database.AbstractPackagesToPackages(parameters.AffectedPackages) err = database.InsertPackages(packages) if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } var pkgNodes []string @@ -122,29 +123,29 @@ func PUTVulnerabilities(w http.ResponseWriter, r *http.Request, p httprouter.Par // Insert vulnerability. notifications, err := database.InsertVulnerabilities([]*database.Vulnerability{parameters.ToVulnerability(pkgNodes)}) if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } // Insert notifications. err = database.InsertNotifications(notifications, database.GetDefaultNotificationWrapper()) if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } - jsonhttp.Render(w, http.StatusCreated, nil) + httputils.WriteHTTP(w, http.StatusCreated, nil) } // DELVulnerabilities deletes a vulnerability if it exists. func DELVulnerabilities(w http.ResponseWriter, r *http.Request, p httprouter.Params) { err := database.DeleteVulnerability(p.ByName("id")) if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } - jsonhttp.Render(w, http.StatusNoContent, nil) + httputils.WriteHTTP(w, http.StatusNoContent, nil) } // GETVulnerabilitiesIntroducingLayers returns the list of layers that @@ -155,13 +156,13 @@ func GETVulnerabilitiesIntroducingLayers(w http.ResponseWriter, r *http.Request, // Find vulnerability to verify that it exists. _, err := database.FindOneVulnerability(p.ByName("id"), []string{}) if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } layers, err := database.FindAllLayersIntroducingVulnerability(p.ByName("id"), []string{database.FieldLayerID}) if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } @@ -170,7 +171,7 @@ func GETVulnerabilitiesIntroducingLayers(w http.ResponseWriter, r *http.Request, layersIDs = append(layersIDs, l.ID) } - jsonhttp.Render(w, http.StatusOK, struct{ IntroducingLayersIDs []string }{IntroducingLayersIDs: layersIDs}) + httputils.WriteHTTP(w, http.StatusOK, struct{ IntroducingLayersIDs []string }{IntroducingLayersIDs: layersIDs}) } // POSTVulnerabilitiesAffectedLayersParameters represents the expected @@ -184,19 +185,19 @@ type POSTVulnerabilitiesAffectedLayersParameters struct { func POSTVulnerabilitiesAffectedLayers(w http.ResponseWriter, r *http.Request, p httprouter.Params) { // Parse body. var parameters POSTBatchLayersVulnerabilitiesParameters - if s, err := jsonhttp.ParseBody(r, ¶meters); err != nil { - jsonhttp.RenderError(w, s, err) + if s, err := httputils.ParseHTTPBody(r, ¶meters); err != nil { + httputils.WriteHTTPError(w, s, err) return } if len(parameters.LayersIDs) == 0 { - jsonhttp.RenderError(w, http.StatusBadRequest, errors.New("getting the entire list of affected layers is not supported yet: at least one LayerID query parameter must be provided")) + httputils.WriteHTTPError(w, http.StatusBadRequest, errors.New("getting the entire list of affected layers is not supported yet: at least one LayerID query parameter must be provided")) return } // Find vulnerability. vulnerability, err := database.FindOneVulnerability(p.ByName("id"), []string{database.FieldVulnerabilityFixedIn}) if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } @@ -212,21 +213,21 @@ func POSTVulnerabilitiesAffectedLayers(w http.ResponseWriter, r *http.Request, p // Find layer layer, err := database.FindOneLayerByID(layerID, []string{database.FieldLayerParent, database.FieldLayerPackages, database.FieldLayerPackages}) if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } // Find layer's packages. packagesNodes, err := layer.AllPackages() if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } // Get successors packages of layer' packages. successors, err := getSuccessorsFromPackagesNodes(packagesNodes) if err != nil { - jsonhttp.RenderError(w, 0, err) + httputils.WriteHTTPError(w, 0, err) return } @@ -243,5 +244,5 @@ func POSTVulnerabilitiesAffectedLayers(w http.ResponseWriter, r *http.Request, p response[layerID] = struct{ Vulnerable bool }{Vulnerable: vulnerable} } - jsonhttp.Render(w, http.StatusOK, response) + httputils.WriteHTTP(w, http.StatusOK, response) } diff --git a/api/wrappers/timeout.go b/api/wrappers/timeout.go index dac70de6..d33f3be1 100644 --- a/api/wrappers/timeout.go +++ b/api/wrappers/timeout.go @@ -19,13 +19,13 @@ package wrappers import ( "errors" - "fmt" "net/http" "sync" "time" - "github.com/coreos/clair/api/jsonhttp" "github.com/julienschmidt/httprouter" + + httputils "github.com/coreos/clair/utils/http" ) // ErrHandlerTimeout is returned on ResponseWriter Write calls @@ -77,7 +77,6 @@ func (tw *timeoutWriter) WriteHeader(status int) { // If the duration is 0, the wrapper does nothing. func TimeOut(d time.Duration, fn httprouter.Handle) httprouter.Handle { if d == 0 { - fmt.Println("nope timeout") return fn } @@ -97,7 +96,7 @@ func TimeOut(d time.Duration, fn httprouter.Handle) httprouter.Handle { tw.mu.Lock() defer tw.mu.Unlock() if !tw.wroteHeader { - jsonhttp.RenderError(tw.ResponseWriter, http.StatusServiceUnavailable, ErrHandlerTimeout) + httputils.WriteHTTPError(tw.ResponseWriter, http.StatusServiceUnavailable, ErrHandlerTimeout) } tw.timedOut = true } diff --git a/notifier/notifier.go b/notifier/notifier.go index dc4995ea..aabe8138 100644 --- a/notifier/notifier.go +++ b/notifier/notifier.go @@ -55,7 +55,7 @@ type Notifier struct { } // Config represents the configuration of a Notifier. -// The certificates are optionnals and enables client certificate authentification. +// The certificates are optionnal and enable client certificate authentification. type Config struct { Endpoint string CertFile, KeyFile, CAFile string diff --git a/utils/http/http.go b/utils/http/http.go index 6250d733..1b1c0f8d 100644 --- a/utils/http/http.go +++ b/utils/http/http.go @@ -18,9 +18,19 @@ package http import ( "crypto/tls" "crypto/x509" + "encoding/json" + "io" "io/ioutil" + "net/http" + + "github.com/coreos/clair/database" + cerrors "github.com/coreos/clair/utils/errors" + "github.com/coreos/clair/worker" ) +// MaxPostSize is the maximum number of bytes that ParseHTTPBody reads from an http.Request.Body. +const MaxBodySize int64 = 1048576 + // LoadTLSClientConfig initializes a *tls.Config using the given certificates and private key, that // can be used to communicate with a server using client certificate authentificate. // @@ -53,3 +63,75 @@ func LoadTLSClientConfig(certFile, keyFile, caFile string) (*tls.Config, error) return tlsConfig, nil } + +// LoadTLSClientConfigForServer initializes a *tls.Config using the given CA, that can be used to +// configure http server to do client certificate authentification. +// +// If no CA is given, a nil *tls.Config is returned: no client certificate will be required and +// verified. In other words, authentification will be disabled. +func LoadTLSClientConfigForServer(caFile string) (*tls.Config, error) { + if len(caFile) == 0 { + return nil, nil + } + + caCert, err := ioutil.ReadFile(caFile) + if err != nil { + return nil, err + } + caCertPool := x509.NewCertPool() + caCertPool.AppendCertsFromPEM(caCert) + + tlsConfig := &tls.Config{ + ClientCAs: caCertPool, + ClientAuth: tls.RequireAndVerifyClientCert, + } + + return tlsConfig, nil +} + +// WriteHTTP writes a JSON-encoded object to a http.ResponseWriter, as well as +// a HTTP status code. +func WriteHTTP(w http.ResponseWriter, httpStatus int, v interface{}) { + w.WriteHeader(httpStatus) + if v != nil { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + result, _ := json.Marshal(v) + w.Write(result) + } +} + +// WriteHTTPError writes an error, wrapped in the Message field of a JSON-encoded +// object to a http.ResponseWriter, as well as a HTTP status code. +// If the status code is 0, handleError tries to guess the proper HTTP status +// code from the error type. +func WriteHTTPError(w http.ResponseWriter, httpStatus int, err error) { + if httpStatus == 0 { + httpStatus = http.StatusInternalServerError + // Try to guess the http status code from the error type + if _, isBadRequestError := err.(*cerrors.ErrBadRequest); isBadRequestError { + httpStatus = http.StatusBadRequest + } else { + switch err { + case cerrors.ErrNotFound: + httpStatus = http.StatusNotFound + case database.ErrTransaction, database.ErrBackendException: + httpStatus = http.StatusServiceUnavailable + case worker.ErrParentUnknown, worker.ErrUnsupported: + httpStatus = http.StatusBadRequest + } + } + } + + WriteHTTP(w, httpStatus, struct{ Message string }{Message: err.Error()}) +} + +// ParseHTTPBody reads a JSON-encoded body from a http.Request and unmarshals it +// into the provided object. +func ParseHTTPBody(r *http.Request, v interface{}) (int, error) { + defer r.Body.Close() + err := json.NewDecoder(io.LimitReader(r.Body, MaxBodySize)).Decode(v) + if err != nil { + return http.StatusUnsupportedMediaType, err + } + return 0, nil +}