// Copyright 2017 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package pgsql

import (
	"testing"

	"github.com/stretchr/testify/assert"

	"github.com/coreos/clair/database"
	"github.com/coreos/clair/ext/versionfmt/dpkg"
)

func TestInsertVulnerabilities(t *testing.T) {
	store, tx := openSessionForTest(t, "InsertVulnerabilities", true)

	ns1 := database.Namespace{
		Name:          "name",
		VersionFormat: "random stuff",
	}

	ns2 := database.Namespace{
		Name:          "debian:7",
		VersionFormat: "dpkg",
	}

	// invalid vulnerability
	v1 := database.Vulnerability{
		Name:      "invalid",
		Namespace: ns1,
	}

	vwa1 := database.VulnerabilityWithAffected{
		Vulnerability: v1,
	}
	// valid vulnerability
	v2 := database.Vulnerability{
		Name:      "valid",
		Namespace: ns2,
		Severity:  database.UnknownSeverity,
	}

	vwa2 := database.VulnerabilityWithAffected{
		Vulnerability: v2,
	}

	// empty
	err := tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{})
	assert.Nil(t, err)

	// invalid content: vwa1 is invalid
	err = tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{vwa1, vwa2})
	assert.NotNil(t, err)

	tx = restartSession(t, store, tx, false)
	// invalid content: duplicated input
	err = tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{vwa2, vwa2})
	assert.NotNil(t, err)

	tx = restartSession(t, store, tx, false)
	// valid content
	err = tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{vwa2})
	assert.Nil(t, err)

	tx = restartSession(t, store, tx, true)
	// ensure the content is in database
	vulns, err := tx.FindVulnerabilities([]database.VulnerabilityID{{Name: "valid", Namespace: "debian:7"}})
	if assert.Nil(t, err) && assert.Len(t, vulns, 1) {
		assert.True(t, vulns[0].Valid)
	}

	tx = restartSession(t, store, tx, false)
	// valid content: vwa2 removed and inserted
	err = tx.DeleteVulnerabilities([]database.VulnerabilityID{{Name: vwa2.Name, Namespace: vwa2.Namespace.Name}})
	assert.Nil(t, err)

	err = tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{vwa2})
	assert.Nil(t, err)

	closeTest(t, store, tx)
}

func TestCachingVulnerable(t *testing.T) {
	datastore, tx := openSessionForTest(t, "CachingVulnerable", true)
	defer closeTest(t, datastore, tx)

	ns := database.Namespace{
		Name:          "debian:8",
		VersionFormat: dpkg.ParserName,
	}

	f := database.NamespacedFeature{
		Feature: database.Feature{
			Name:          "openssl",
			Version:       "1.0",
			VersionFormat: dpkg.ParserName,
		},
		Namespace: ns,
	}

	vuln := database.VulnerabilityWithAffected{
		Vulnerability: database.Vulnerability{
			Name:      "CVE-YAY",
			Namespace: ns,
			Severity:  database.HighSeverity,
		},
		Affected: []database.AffectedFeature{
			{
				Namespace:       ns,
				FeatureName:     "openssl",
				AffectedVersion: "2.0",
				FixedInVersion:  "2.1",
			},
		},
	}

	vuln2 := database.VulnerabilityWithAffected{
		Vulnerability: database.Vulnerability{
			Name:      "CVE-YAY2",
			Namespace: ns,
			Severity:  database.HighSeverity,
		},
		Affected: []database.AffectedFeature{
			{
				Namespace:       ns,
				FeatureName:     "openssl",
				AffectedVersion: "2.1",
				FixedInVersion:  "2.2",
			},
		},
	}

	vulnFixed1 := database.VulnerabilityWithFixedIn{
		Vulnerability: database.Vulnerability{
			Name:      "CVE-YAY",
			Namespace: ns,
			Severity:  database.HighSeverity,
		},
		FixedInVersion: "2.1",
	}

	vulnFixed2 := database.VulnerabilityWithFixedIn{
		Vulnerability: database.Vulnerability{
			Name:      "CVE-YAY2",
			Namespace: ns,
			Severity:  database.HighSeverity,
		},
		FixedInVersion: "2.2",
	}

	if !assert.Nil(t, tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{vuln, vuln2})) {
		t.FailNow()
	}

	r, err := tx.FindAffectedNamespacedFeatures([]database.NamespacedFeature{f})
	assert.Nil(t, err)
	assert.Len(t, r, 1)
	for _, anf := range r {
		if assert.True(t, anf.Valid) && assert.Len(t, anf.AffectedBy, 2) {
			for _, a := range anf.AffectedBy {
				if a.Name == "CVE-YAY" {
					assert.Equal(t, vulnFixed1, a)
				} else if a.Name == "CVE-YAY2" {
					assert.Equal(t, vulnFixed2, a)
				} else {
					t.FailNow()
				}
			}
		}
	}
}

func TestFindVulnerabilities(t *testing.T) {
	datastore, tx := openSessionForTest(t, "FindVulnerabilities", true)
	defer closeTest(t, datastore, tx)

	vuln, err := tx.FindVulnerabilities([]database.VulnerabilityID{
		{Name: "CVE-OPENSSL-1-DEB7", Namespace: "debian:7"},
		{Name: "CVE-NOPE", Namespace: "debian:7"},
		{Name: "CVE-NOT HERE"},
	})

	ns := database.Namespace{
		Name:          "debian:7",
		VersionFormat: "dpkg",
	}

	expectedExisting := []database.VulnerabilityWithAffected{
		{
			Vulnerability: database.Vulnerability{
				Namespace:   ns,
				Name:        "CVE-OPENSSL-1-DEB7",
				Description: "A vulnerability affecting OpenSSL < 2.0 on Debian 7.0",
				Link:        "http://google.com/#q=CVE-OPENSSL-1-DEB7",
				Severity:    database.HighSeverity,
			},
			Affected: []database.AffectedFeature{
				{
					FeatureName:     "openssl",
					AffectedVersion: "2.0",
					FixedInVersion:  "2.0",
					Namespace:       ns,
				},
				{
					FeatureName:     "libssl",
					AffectedVersion: "1.9-abc",
					FixedInVersion:  "1.9-abc",
					Namespace:       ns,
				},
			},
		},
		{
			Vulnerability: database.Vulnerability{
				Namespace:   ns,
				Name:        "CVE-NOPE",
				Description: "A vulnerability affecting nothing",
				Severity:    database.UnknownSeverity,
			},
		},
	}

	expectedExistingMap := map[database.VulnerabilityID]database.VulnerabilityWithAffected{}
	for _, v := range expectedExisting {
		expectedExistingMap[database.VulnerabilityID{Name: v.Name, Namespace: v.Namespace.Name}] = v
	}

	nonexisting := database.VulnerabilityWithAffected{
		Vulnerability: database.Vulnerability{Name: "CVE-NOT HERE"},
	}

	if assert.Nil(t, err) {
		for _, v := range vuln {
			if v.Valid {
				key := database.VulnerabilityID{
					Name:      v.Name,
					Namespace: v.Namespace.Name,
				}

				expected, ok := expectedExistingMap[key]
				if assert.True(t, ok, "vulnerability not found: "+key.Name+":"+key.Namespace) {
					assertVulnerabilityWithAffectedEqual(t, expected, v.VulnerabilityWithAffected)
				}
			} else if !assert.Equal(t, nonexisting, v.VulnerabilityWithAffected) {
				t.FailNow()
			}
		}
	}

	// same vulnerability
	r, err := tx.FindVulnerabilities([]database.VulnerabilityID{
		{Name: "CVE-OPENSSL-1-DEB7", Namespace: "debian:7"},
		{Name: "CVE-OPENSSL-1-DEB7", Namespace: "debian:7"},
	})

	if assert.Nil(t, err) {
		for _, vuln := range r {
			if assert.True(t, vuln.Valid) {
				expected, _ := expectedExistingMap[database.VulnerabilityID{Name: "CVE-OPENSSL-1-DEB7", Namespace: "debian:7"}]
				assertVulnerabilityWithAffectedEqual(t, expected, vuln.VulnerabilityWithAffected)
			}
		}
	}
}

func TestDeleteVulnerabilities(t *testing.T) {
	datastore, tx := openSessionForTest(t, "DeleteVulnerabilities", true)
	defer closeTest(t, datastore, tx)

	remove := []database.VulnerabilityID{}
	// empty case
	assert.Nil(t, tx.DeleteVulnerabilities(remove))
	// invalid case
	remove = append(remove, database.VulnerabilityID{})
	assert.NotNil(t, tx.DeleteVulnerabilities(remove))

	// valid case
	validRemove := []database.VulnerabilityID{
		{Name: "CVE-OPENSSL-1-DEB7", Namespace: "debian:7"},
		{Name: "CVE-NOPE", Namespace: "debian:7"},
	}

	assert.Nil(t, tx.DeleteVulnerabilities(validRemove))
	vuln, err := tx.FindVulnerabilities(validRemove)
	if assert.Nil(t, err) {
		for _, v := range vuln {
			assert.False(t, v.Valid)
		}
	}
}

func TestFindVulnerabilityIDs(t *testing.T) {
	store, tx := openSessionForTest(t, "FindVulnerabilityIDs", true)
	defer closeTest(t, store, tx)

	ids, err := tx.findLatestDeletedVulnerabilityIDs([]database.VulnerabilityID{{Name: "CVE-DELETED", Namespace: "debian:7"}})
	if assert.Nil(t, err) {
		if !(assert.Len(t, ids, 1) && assert.True(t, ids[0].Valid) && assert.Equal(t, 3, int(ids[0].Int64))) {
			assert.Fail(t, "")
		}
	}

	ids, err = tx.findNotDeletedVulnerabilityIDs([]database.VulnerabilityID{{Name: "CVE-NOPE", Namespace: "debian:7"}})
	if assert.Nil(t, err) {
		if !(assert.Len(t, ids, 1) && assert.True(t, ids[0].Valid) && assert.Equal(t, 2, int(ids[0].Int64))) {
			assert.Fail(t, "")
		}
	}
}

func assertVulnerabilityWithAffectedEqual(t *testing.T, expected database.VulnerabilityWithAffected, actual database.VulnerabilityWithAffected) bool {
	return assert.Equal(t, expected.Vulnerability, actual.Vulnerability) && assertAffectedFeaturesEqual(t, expected.Affected, actual.Affected)
}

func assertAffectedFeaturesEqual(t *testing.T, expected []database.AffectedFeature, actual []database.AffectedFeature) bool {
	if assert.Len(t, actual, len(expected)) {
		has := map[database.AffectedFeature]bool{}
		for _, i := range expected {
			has[i] = false
		}
		for _, i := range actual {
			if visited, ok := has[i]; !ok {
				return false
			} else if visited {
				return false
			}
			has[i] = true
		}
		return true
	}
	return false
}