// Copyright 2019 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package clair

import (
	"context"
	"errors"
	"fmt"
	"testing"

	"github.com/coreos/clair/database"
	"github.com/stretchr/testify/assert"
)

type mockUpdaterDatastore struct {
	database.MockDatastore

	namespaces       map[string]database.Namespace
	vulnerabilities  map[database.VulnerabilityID]database.VulnerabilityWithAffected
	vulnNotification map[string]database.VulnerabilityNotification
	keyValues        map[string]string
}

type mockUpdaterSession struct {
	database.MockSession

	store      *mockUpdaterDatastore
	copy       mockUpdaterDatastore
	terminated bool
}

func copyUpdaterDatastore(md *mockUpdaterDatastore) mockUpdaterDatastore {
	namespaces := map[string]database.Namespace{}
	for k, n := range md.namespaces {
		namespaces[k] = n
	}

	vulnerabilities := map[database.VulnerabilityID]database.VulnerabilityWithAffected{}
	for key, v := range md.vulnerabilities {
		newV := v
		affected := []database.AffectedFeature{}
		for _, f := range v.Affected {
			affected = append(affected, f)
		}
		newV.Affected = affected
		vulnerabilities[key] = newV
	}

	vulnNoti := map[string]database.VulnerabilityNotification{}
	for key, v := range md.vulnNotification {
		vulnNoti[key] = v
	}

	kv := map[string]string{}
	for key, value := range md.keyValues {
		kv[key] = value
	}

	return mockUpdaterDatastore{
		namespaces:       namespaces,
		vulnerabilities:  vulnerabilities,
		vulnNotification: vulnNoti,
		keyValues:        kv,
	}
}

func newmockUpdaterDatastore() *mockUpdaterDatastore {
	errSessionDone := errors.New("Session Done")
	md := &mockUpdaterDatastore{
		namespaces:       make(map[string]database.Namespace),
		vulnerabilities:  make(map[database.VulnerabilityID]database.VulnerabilityWithAffected),
		vulnNotification: make(map[string]database.VulnerabilityNotification),
		keyValues:        make(map[string]string),
	}

	md.FctBegin = func() (database.Session, error) {
		session := &mockUpdaterSession{
			store:      md,
			copy:       copyUpdaterDatastore(md),
			terminated: false,
		}

		session.FctCommit = func() error {
			if session.terminated {
				return errSessionDone
			}
			session.store.namespaces = session.copy.namespaces
			session.store.vulnerabilities = session.copy.vulnerabilities
			session.store.vulnNotification = session.copy.vulnNotification
			session.store.keyValues = session.copy.keyValues
			session.terminated = true
			return nil
		}

		session.FctRollback = func() error {
			if session.terminated {
				return errSessionDone
			}
			session.terminated = true
			session.copy = mockUpdaterDatastore{}
			return nil
		}

		session.FctPersistNamespaces = func(ns []database.Namespace) error {
			if session.terminated {
				return errSessionDone
			}
			for _, n := range ns {
				_, ok := session.copy.namespaces[n.Name]
				if !ok {
					session.copy.namespaces[n.Name] = n
				}
			}
			return nil
		}

		session.FctFindVulnerabilities = func(ids []database.VulnerabilityID) ([]database.NullableVulnerability, error) {
			r := []database.NullableVulnerability{}
			for _, id := range ids {
				vuln, ok := session.copy.vulnerabilities[id]
				r = append(r, database.NullableVulnerability{
					VulnerabilityWithAffected: vuln,
					Valid:                     ok,
				})
			}
			return r, nil
		}

		session.FctDeleteVulnerabilities = func(ids []database.VulnerabilityID) error {
			for _, id := range ids {
				delete(session.copy.vulnerabilities, id)
			}
			return nil
		}

		session.FctInsertVulnerabilities = func(vulnerabilities []database.VulnerabilityWithAffected) error {
			for _, vuln := range vulnerabilities {
				id := database.VulnerabilityID{
					Name:      vuln.Name,
					Namespace: vuln.Namespace.Name,
				}
				if _, ok := session.copy.vulnerabilities[id]; ok {
					return errors.New("Vulnerability already exists")
				}
				session.copy.vulnerabilities[id] = vuln
			}
			return nil
		}

		session.FctUpdateKeyValue = func(key, value string) error {
			session.copy.keyValues[key] = value
			return nil
		}

		session.FctFindKeyValue = func(key string) (string, bool, error) {
			s, b := session.copy.keyValues[key]
			return s, b, nil
		}

		session.FctInsertVulnerabilityNotifications = func(notifications []database.VulnerabilityNotification) error {
			for _, noti := range notifications {
				session.copy.vulnNotification[noti.Name] = noti
			}
			return nil
		}

		return session, nil
	}
	return md
}

func TestDoVulnerabilitiesNamespacing(t *testing.T) {
	fv1 := database.AffectedFeature{
		FeatureType:     database.SourcePackage,
		Namespace:       database.Namespace{Name: "Namespace1"},
		FeatureName:     "Feature1",
		FixedInVersion:  "0.1",
		AffectedVersion: "0.1",
	}

	fv2 := database.AffectedFeature{
		FeatureType:     database.SourcePackage,
		Namespace:       database.Namespace{Name: "Namespace2"},
		FeatureName:     "Feature1",
		FixedInVersion:  "0.2",
		AffectedVersion: "0.2",
	}

	fv3 := database.AffectedFeature{
		FeatureType:     database.SourcePackage,
		Namespace:       database.Namespace{Name: "Namespace2"},
		FeatureName:     "Feature2",
		FixedInVersion:  "0.3",
		AffectedVersion: "0.3",
	}

	vulnerability := database.VulnerabilityWithAffected{
		Vulnerability: database.Vulnerability{
			Name: "DoVulnerabilityNamespacing",
		},
		Affected: []database.AffectedFeature{fv1, fv2, fv3},
	}

	vulnerabilities := doVulnerabilitiesNamespacing([]database.VulnerabilityWithAffected{vulnerability})
	for _, vulnerability := range vulnerabilities {
		switch vulnerability.Namespace.Name {
		case fv1.Namespace.Name:
			assert.Len(t, vulnerability.Affected, 1)
			assert.Contains(t, vulnerability.Affected, fv1)
		case fv2.Namespace.Name:
			assert.Len(t, vulnerability.Affected, 2)
			assert.Contains(t, vulnerability.Affected, fv2)
			assert.Contains(t, vulnerability.Affected, fv3)
		default:
			t.Errorf("Should not have a Vulnerability with '%s' as its Namespace.", vulnerability.Namespace.Name)
			fmt.Printf("%#v\n", vulnerability)
		}
	}
}

func TestCreatVulnerabilityNotification(t *testing.T) {
	vf1 := "VersionFormat1"
	ns1 := database.Namespace{
		Name:          "namespace 1",
		VersionFormat: vf1,
	}
	af1 := database.AffectedFeature{
		FeatureType: database.SourcePackage,
		Namespace:   ns1,
		FeatureName: "feature 1",
	}

	v1 := database.VulnerabilityWithAffected{
		Vulnerability: database.Vulnerability{
			Name:      "vulnerability 1",
			Namespace: ns1,
			Severity:  database.UnknownSeverity,
		},
	}

	// severity change
	v2 := database.VulnerabilityWithAffected{
		Vulnerability: database.Vulnerability{
			Name:      "vulnerability 1",
			Namespace: ns1,
			Severity:  database.LowSeverity,
		},
	}

	// affected versions change
	v3 := database.VulnerabilityWithAffected{
		Vulnerability: database.Vulnerability{
			Name:      "vulnerability 1",
			Namespace: ns1,
			Severity:  database.UnknownSeverity,
		},
		Affected: []database.AffectedFeature{af1},
	}

	datastore := newmockUpdaterDatastore()
	change, err := updateVulnerabilities(context.TODO(), datastore, []database.VulnerabilityWithAffected{})
	assert.Nil(t, err)
	assert.Len(t, change, 0)

	change, err = updateVulnerabilities(context.TODO(), datastore, []database.VulnerabilityWithAffected{v1})
	assert.Nil(t, err)
	assert.Len(t, change, 1)
	assert.Nil(t, change[0].old)
	assertVulnerability(t, *change[0].new, v1)

	change, err = updateVulnerabilities(context.TODO(), datastore, []database.VulnerabilityWithAffected{v1})
	assert.Nil(t, err)
	assert.Len(t, change, 0)

	change, err = updateVulnerabilities(context.TODO(), datastore, []database.VulnerabilityWithAffected{v2})
	assert.Nil(t, err)
	assert.Len(t, change, 1)
	assertVulnerability(t, *change[0].new, v2)
	assertVulnerability(t, *change[0].old, v1)

	change, err = updateVulnerabilities(context.TODO(), datastore, []database.VulnerabilityWithAffected{v3})
	assert.Nil(t, err)
	assert.Len(t, change, 1)
	assertVulnerability(t, *change[0].new, v3)
	assertVulnerability(t, *change[0].old, v2)

	err = createVulnerabilityNotifications(datastore, change)
	assert.Nil(t, err)
	assert.Len(t, datastore.vulnNotification, 1)
	for _, noti := range datastore.vulnNotification {
		assert.Equal(t, *noti.New, v3.Vulnerability)
		assert.Equal(t, *noti.Old, v2.Vulnerability)
	}
}

func assertVulnerability(t *testing.T, expected database.VulnerabilityWithAffected, actual database.VulnerabilityWithAffected) bool {
	expectedAF := expected.Affected
	actualAF := actual.Affected
	expected.Affected, actual.Affected = nil, nil

	assert.Equal(t, expected, actual)
	assert.Len(t, actualAF, len(expectedAF))

	mapAF := map[database.AffectedFeature]bool{}
	for _, af := range expectedAF {
		mapAF[af] = false
	}

	for _, af := range actualAF {
		if visited, ok := mapAF[af]; !ok || visited {
			return false
		}
	}
	return true
}