Refactoring (minor)

This commit is contained in:
Eric Sim 2018-12-17 14:36:04 -08:00
parent 8e98ee878a
commit 684ae2be1d
2 changed files with 75 additions and 56 deletions

View File

@ -26,6 +26,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"fmt"
"github.com/coreos/clair/database" "github.com/coreos/clair/database"
"github.com/coreos/clair/ext/versionfmt" "github.com/coreos/clair/ext/versionfmt"
"github.com/coreos/clair/ext/versionfmt/rpm" "github.com/coreos/clair/ext/versionfmt/rpm"
@ -35,50 +36,55 @@ import (
) )
const ( const (
amazonLinux1Name = "Amazon Linux 2018.03"
amazonLinux1Namespace = "amzn:2018.03"
amazonLinux1UpdaterFlag = "amazonLinux1Updater" amazonLinux1UpdaterFlag = "amazonLinux1Updater"
amazonLinux1MirrorListURI = "http://repo.us-west-2.amazonaws.com/2018.03/updates/x86_64/mirror.list" amazonLinux1MirrorListURI = "http://repo.us-west-2.amazonaws.com/2018.03/updates/x86_64/mirror.list"
amazonLinux2Name = "Amazon Linux 2" amazonLinux1Name = "Amazon Linux 2018.03"
amazonLinux2Namespace = "amzn:2" amazonLinux1Namespace = "amzn:2018.03"
amazonLinux1LinkFormat = "https://alas.aws.amazon.com/%s.html"
amazonLinux2UpdaterFlag = "amazonLinux2Updater" amazonLinux2UpdaterFlag = "amazonLinux2Updater"
amazonLinux2MirrorListURI = "https://cdn.amazonlinux.com/2/core/latest/x86_64/mirror.list" amazonLinux2MirrorListURI = "https://cdn.amazonlinux.com/2/core/latest/x86_64/mirror.list"
amazonLinux2Name = "Amazon Linux 2"
amazonLinux2Namespace = "amzn:2"
amazonLinux2LinkFormat = "https://alas.aws.amazon.com/AL2/%s.html"
) )
type updater struct { type updater struct {
Name string
Namespace string
UpdaterFlag string UpdaterFlag string
MirrorListURI string MirrorListURI string
Name string
Namespace string
LinkFormat string
} }
func init() { func init() {
// Register updater for Amazon Linux 2018.03. // Register updater for Amazon Linux 2018.03.
amazonLinux1Updater := updater{ amazonLinux1Updater := updater{
Name: amazonLinux1Name,
Namespace: amazonLinux1Namespace,
UpdaterFlag: amazonLinux1UpdaterFlag, UpdaterFlag: amazonLinux1UpdaterFlag,
MirrorListURI: amazonLinux1MirrorListURI, MirrorListURI: amazonLinux1MirrorListURI,
Name: amazonLinux1Name,
Namespace: amazonLinux1Namespace,
LinkFormat: amazonLinux1LinkFormat,
} }
vulnsrc.RegisterUpdater("amzn", &amazonLinux1Updater) vulnsrc.RegisterUpdater("amzn1", &amazonLinux1Updater)
// Register updater for Amazon Linux 2. // Register updater for Amazon Linux 2.
amazonLinux2Updater := updater{ amazonLinux2Updater := updater{
Name: amazonLinux2Name,
Namespace: amazonLinux2Namespace,
UpdaterFlag: amazonLinux2UpdaterFlag, UpdaterFlag: amazonLinux2UpdaterFlag,
MirrorListURI: amazonLinux2MirrorListURI, MirrorListURI: amazonLinux2MirrorListURI,
Name: amazonLinux2Name,
Namespace: amazonLinux2Namespace,
LinkFormat: amazonLinux2LinkFormat,
} }
vulnsrc.RegisterUpdater("amzn2", &amazonLinux2Updater) vulnsrc.RegisterUpdater("amzn2", &amazonLinux2Updater)
} }
func (u *updater) Update(datastore database.Datastore) (response vulnsrc.UpdateResponse, err error) { func (u *updater) Update(datastore database.Datastore) (vulnsrc.UpdateResponse, error) {
log.WithField("package", u.Name).Info("Start fetching vulnerabilities") log.WithField("package", u.Name).Info("Start fetching vulnerabilities")
// Get the flag value (the timestamp of the latest ALAS of the previous update). // Get the flag value (the timestamp of the latest ALAS of the previous update).
flagValue, found, err := database.FindKeyValueAndRollback(datastore, u.UpdaterFlag) flagValue, found, err := database.FindKeyValueAndRollback(datastore, u.UpdaterFlag)
if err != nil { if err != nil {
return response, err return vulnsrc.UpdateResponse{}, err
} }
if !found { if !found {
@ -90,7 +96,7 @@ func (u *updater) Update(datastore database.Datastore) (response vulnsrc.UpdateR
// Get the ALASs from updateinfo.xml.gz from the repos. // Get the ALASs from updateinfo.xml.gz from the repos.
updateInfo, err := u.getUpdateInfo() updateInfo, err := u.getUpdateInfo()
if err != nil { if err != nil {
return response, err return vulnsrc.UpdateResponse{}, err
} }
// Get the ALASs which were issued/updated since the previous update. // Get the ALASs which were issued/updated since the previous update.
@ -106,9 +112,10 @@ func (u *updater) Update(datastore database.Datastore) (response vulnsrc.UpdateR
} }
// Get the vulnerabilities. // Get the vulnerabilities.
response.Vulnerabilities, err = u.alasListToVulnerabilities(alasList) vulnerabilities := u.alasListToVulnerabilities(alasList)
if err != nil {
return response, err response := vulnsrc.UpdateResponse{
Vulnerabilities: vulnerabilities,
} }
// Set the flag value. // Set the flag value.
@ -126,56 +133,56 @@ func (u *updater) Clean() {
} }
func (u *updater) getUpdateInfo() (updateInfo UpdateInfo, err error) { func (u *updater) getUpdateInfo() (UpdateInfo, error) {
// Get the URI of updateinfo.xml.gz. // Get the URI of updateinfo.xml.gz.
updateInfoURI, err := u.getUpdateInfoURI() updateInfoURI, err := u.getUpdateInfoURI()
if err != nil { if err != nil {
return updateInfo, err return UpdateInfo{}, err
} }
// Download updateinfo.xml.gz. // Download updateinfo.xml.gz.
updateInfoResponse, err := httputil.GetWithUserAgent(updateInfoURI) updateInfoResponse, err := httputil.GetWithUserAgent(updateInfoURI)
if err != nil { if err != nil {
log.WithError(err).Error("could not download updateinfo.xml.gz") log.WithError(err).Error("could not download updateinfo.xml.gz")
return updateInfo, commonerr.ErrCouldNotDownload return UpdateInfo{}, commonerr.ErrCouldNotDownload
} }
defer updateInfoResponse.Body.Close() defer updateInfoResponse.Body.Close()
if !httputil.Status2xx(updateInfoResponse) { if !httputil.Status2xx(updateInfoResponse) {
log.WithField("StatusCode", updateInfoResponse.StatusCode).Error("could not download updateinfo.xml.gz") log.WithField("StatusCode", updateInfoResponse.StatusCode).Error("could not download updateinfo.xml.gz")
return updateInfo, commonerr.ErrCouldNotDownload return UpdateInfo{}, commonerr.ErrCouldNotDownload
} }
// Decompress updateinfo.xml.gz. // Decompress updateinfo.xml.gz.
updateInfoXml, err := gzip.NewReader(updateInfoResponse.Body) updateInfoXml, err := gzip.NewReader(updateInfoResponse.Body)
if err != nil { if err != nil {
log.WithError(err).Error("could not decompress updateinfo.xml.gz") log.WithError(err).Error("could not decompress updateinfo.xml.gz")
return updateInfo, commonerr.ErrCouldNotDownload return UpdateInfo{}, commonerr.ErrCouldNotParse
} }
defer updateInfoXml.Close() defer updateInfoXml.Close()
// Decode updateinfo.xml. // Decode updateinfo.xml.
updateInfo, err = decodeUpdateInfo(updateInfoXml) updateInfo, err := decodeUpdateInfo(updateInfoXml)
if err != nil { if err != nil {
log.WithError(err).Error("could not decode updateinfo.xml") log.WithError(err).Error("could not decode updateinfo.xml")
return updateInfo, err return UpdateInfo{}, commonerr.ErrCouldNotParse
} }
return return updateInfo, nil
} }
func (u *updater) getUpdateInfoURI() (updateInfoURI string, err error) { func (u *updater) getUpdateInfoURI() (string, error) {
// Download mirror.list // Download mirror.list
mirrorListResponse, err := httputil.GetWithUserAgent(u.MirrorListURI) mirrorListResponse, err := httputil.GetWithUserAgent(u.MirrorListURI)
if err != nil { if err != nil {
log.WithError(err).Error("could not download mirror list") log.WithError(err).Error("could not download mirror list")
return updateInfoURI, commonerr.ErrCouldNotDownload return "", commonerr.ErrCouldNotDownload
} }
defer mirrorListResponse.Body.Close() defer mirrorListResponse.Body.Close()
if !httputil.Status2xx(mirrorListResponse) { if !httputil.Status2xx(mirrorListResponse) {
log.WithField("StatusCode", mirrorListResponse.StatusCode).Error("could not download mirror list") log.WithField("StatusCode", mirrorListResponse.StatusCode).Error("could not download mirror list")
return updateInfoURI, commonerr.ErrCouldNotDownload return "", commonerr.ErrCouldNotDownload
} }
// Parse the URI of the first mirror. // Parse the URI of the first mirror.
@ -191,13 +198,13 @@ func (u *updater) getUpdateInfoURI() (updateInfoURI string, err error) {
repoMdResponse, err := httputil.GetWithUserAgent(repoMdURI) repoMdResponse, err := httputil.GetWithUserAgent(repoMdURI)
if err != nil { if err != nil {
log.WithError(err).Error("could not download repomd.xml") log.WithError(err).Error("could not download repomd.xml")
return updateInfoURI, commonerr.ErrCouldNotDownload return "", commonerr.ErrCouldNotDownload
} }
defer repoMdResponse.Body.Close() defer repoMdResponse.Body.Close()
if !httputil.Status2xx(repoMdResponse) { if !httputil.Status2xx(repoMdResponse) {
log.WithField("StatusCode", repoMdResponse.StatusCode).Error("could not download repomd.xml") log.WithField("StatusCode", repoMdResponse.StatusCode).Error("could not download repomd.xml")
return updateInfoURI, commonerr.ErrCouldNotDownload return "", commonerr.ErrCouldNotDownload
} }
// Decode repomd.xml. // Decode repomd.xml.
@ -205,10 +212,11 @@ func (u *updater) getUpdateInfoURI() (updateInfoURI string, err error) {
err = xml.NewDecoder(repoMdResponse.Body).Decode(&repoMd) err = xml.NewDecoder(repoMdResponse.Body).Decode(&repoMd)
if err != nil { if err != nil {
log.WithError(err).Error("could not decode repomd.xml") log.WithError(err).Error("could not decode repomd.xml")
return updateInfoURI, commonerr.ErrCouldNotDownload return "", commonerr.ErrCouldNotDownload
} }
// Parse the URI of updateinfo.xml.gz. // Parse the URI of updateinfo.xml.gz.
var updateInfoURI string
for _, repo := range repoMd.RepoList { for _, repo := range repoMd.RepoList {
if repo.Type == "updateinfo" { if repo.Type == "updateinfo" {
updateInfoURI = mirrorURI + "/" + repo.Location.Href updateInfoURI = mirrorURI + "/" + repo.Location.Href
@ -217,22 +225,24 @@ func (u *updater) getUpdateInfoURI() (updateInfoURI string, err error) {
} }
if updateInfoURI == "" { if updateInfoURI == "" {
log.Error("could not find updateinfo in repomd.xml") log.Error("could not find updateinfo in repomd.xml")
return updateInfoURI, commonerr.ErrCouldNotDownload return "", commonerr.ErrCouldNotDownload
} }
return return updateInfoURI, nil
} }
func decodeUpdateInfo(updateInfoReader io.Reader) (updateInfo UpdateInfo, err error) { func decodeUpdateInfo(updateInfoReader io.Reader) (UpdateInfo, error) {
err = xml.NewDecoder(updateInfoReader).Decode(&updateInfo) var updateInfo UpdateInfo
err := xml.NewDecoder(updateInfoReader).Decode(&updateInfo)
if err != nil { if err != nil {
return updateInfo, err return updateInfo, err
} }
return return updateInfo, nil
} }
func (u *updater) alasListToVulnerabilities(alasList []ALAS) (vulnerabilities []database.VulnerabilityWithAffected, err error) { func (u *updater) alasListToVulnerabilities(alasList []ALAS) []database.VulnerabilityWithAffected {
var vulnerabilities []database.VulnerabilityWithAffected
for _, alas := range alasList { for _, alas := range alasList {
featureVersions := u.alasToFeatureVersions(alas) featureVersions := u.alasToFeatureVersions(alas)
if len(featureVersions) > 0 { if len(featureVersions) > 0 {
@ -249,7 +259,7 @@ func (u *updater) alasListToVulnerabilities(alasList []ALAS) (vulnerabilities []
} }
} }
return return vulnerabilities
} }
func (u *updater) alasToName(alas ALAS) string { func (u *updater) alasToName(alas ALAS) string {
@ -258,12 +268,16 @@ func (u *updater) alasToName(alas ALAS) string {
func (u *updater) alasToLink(alas ALAS) string { func (u *updater) alasToLink(alas ALAS) string {
if u.Name == amazonLinux1Name { if u.Name == amazonLinux1Name {
return "https://alas.aws.amazon.com/" + alas.Id + ".html" return fmt.Sprintf(u.LinkFormat, alas.Id)
} }
if u.Name == amazonLinux2Name {
// "ALAS2-2018-1097" becomes "https://alas.aws.amazon.com/AL2/ALAS-2018-1097.html". // "ALAS2-2018-1097" becomes "https://alas.aws.amazon.com/AL2/ALAS-2018-1097.html".
re := regexp.MustCompile(`^ALAS2-(.+)$`) re := regexp.MustCompile(`^ALAS2-(.+)$`)
return "https://alas.aws.amazon.com/AL2/ALAS-" + re.FindStringSubmatch(alas.Id)[1] + ".html" return fmt.Sprintf(u.LinkFormat, "ALAS-"+re.FindStringSubmatch(alas.Id)[1])
}
return ""
} }
func (u *updater) alasToSeverity(alas ALAS) database.Severity { func (u *updater) alasToSeverity(alas ALAS) database.Severity {
@ -287,7 +301,8 @@ func (u *updater) alasToDescription(alas ALAS) string {
return re.ReplaceAllString(strings.TrimSpace(alas.Description), " ") return re.ReplaceAllString(strings.TrimSpace(alas.Description), " ")
} }
func (u *updater) alasToFeatureVersions(alas ALAS) (featureVersions []database.AffectedFeature) { func (u *updater) alasToFeatureVersions(alas ALAS) []database.AffectedFeature {
var featureVersions []database.AffectedFeature
for _, p := range alas.Packages { for _, p := range alas.Packages {
var version string var version string
if p.Epoch == "0" { if p.Epoch == "0" {
@ -301,20 +316,24 @@ func (u *updater) alasToFeatureVersions(alas ALAS) (featureVersions []database.A
continue continue
} }
var featureVersion database.AffectedFeature featureVersion := database.AffectedFeature{
featureVersion.Namespace.Name = u.Namespace Namespace: database.Namespace{
featureVersion.Namespace.VersionFormat = rpm.ParserName Name: u.Namespace,
featureVersion.FeatureName = p.Name VersionFormat: rpm.ParserName,
featureVersion.AffectedVersion = version },
FeatureName: p.Name,
AffectedVersion: version,
AffectedType: database.AffectBinaryPackage,
}
if version != versionfmt.MaxVersion { if version != versionfmt.MaxVersion {
featureVersion.FixedInVersion = version featureVersion.FixedInVersion = version
} }
featureVersion.AffectedType = database.AffectBinaryPackage
featureVersions = append(featureVersions, featureVersion) featureVersions = append(featureVersions, featureVersion)
} }
return return featureVersions
} }
func compareTimestamp(date0 string, date1 string) int { func compareTimestamp(date0 string, date1 string) int {

View File

@ -28,10 +28,11 @@ import (
func TestAmazonLinux1(t *testing.T) { func TestAmazonLinux1(t *testing.T) {
amazonLinux1Updater := updater{ amazonLinux1Updater := updater{
MirrorListURI: "http://repo.us-west-2.amazonaws.com/2018.03/updates/x86_64/mirror.list",
Name: "Amazon Linux 2018.03", Name: "Amazon Linux 2018.03",
Namespace: "amzn:2018.03", Namespace: "amzn:2018.03",
UpdaterFlag: "amazonLinux1Updater", UpdaterFlag: "amazonLinux1Updater",
MirrorListURI: "http://repo.us-west-2.amazonaws.com/2018.03/updates/x86_64/mirror.list", LinkFormat: "https://alas.aws.amazon.com/%s.html",
} }
_, filename, _, _ := runtime.Caller(0) _, filename, _, _ := runtime.Caller(0)
@ -49,8 +50,7 @@ func TestAmazonLinux1(t *testing.T) {
updateInfo, err := decodeUpdateInfo(updateInfoXml) updateInfo, err := decodeUpdateInfo(updateInfoXml)
assert.Nil(t, err) assert.Nil(t, err)
vulnerabilities, err := amazonLinux1Updater.alasListToVulnerabilities(updateInfo.ALASList) vulnerabilities := amazonLinux1Updater.alasListToVulnerabilities(updateInfo.ALASList)
assert.Nil(t, err)
assert.Equal(t, "ALAS-2011-1", vulnerabilities[0].Name) assert.Equal(t, "ALAS-2011-1", vulnerabilities[0].Name)
assert.Equal(t, "https://alas.aws.amazon.com/ALAS-2011-1.html", vulnerabilities[0].Link) assert.Equal(t, "https://alas.aws.amazon.com/ALAS-2011-1.html", vulnerabilities[0].Link)
@ -121,10 +121,11 @@ func TestAmazonLinux1(t *testing.T) {
func TestAmazonLinux2(t *testing.T) { func TestAmazonLinux2(t *testing.T) {
amazonLinux2Updater := updater{ amazonLinux2Updater := updater{
MirrorListURI: "https://cdn.amazonlinux.com/2/core/latest/x86_64/mirror.list",
Name: "Amazon Linux 2", Name: "Amazon Linux 2",
Namespace: "amzn:2", Namespace: "amzn:2",
UpdaterFlag: "amazonLinux2Updater", UpdaterFlag: "amazonLinux2Updater",
MirrorListURI: "https://cdn.amazonlinux.com/2/core/latest/x86_64/mirror.list", LinkFormat: "https://alas.aws.amazon.com/AL2/%s.html",
} }
_, filename, _, _ := runtime.Caller(0) _, filename, _, _ := runtime.Caller(0)
@ -142,8 +143,7 @@ func TestAmazonLinux2(t *testing.T) {
updateInfo, err := decodeUpdateInfo(updateInfoXml) updateInfo, err := decodeUpdateInfo(updateInfoXml)
assert.Nil(t, err) assert.Nil(t, err)
vulnerabilities, err := amazonLinux2Updater.alasListToVulnerabilities(updateInfo.ALASList) vulnerabilities := amazonLinux2Updater.alasListToVulnerabilities(updateInfo.ALASList)
assert.Nil(t, err)
assert.Equal(t, "ALAS2-2018-939", vulnerabilities[0].Name) assert.Equal(t, "ALAS2-2018-939", vulnerabilities[0].Name)
assert.Equal(t, "https://alas.aws.amazon.com/AL2/ALAS-2018-939.html", vulnerabilities[0].Link) assert.Equal(t, "https://alas.aws.amazon.com/AL2/ALAS-2018-939.html", vulnerabilities[0].Link)