// Copyright © 2017-2020 Aqua Security Software Ltd. <info@aquasec.com>
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package check

import (
	"strings"
	"testing"
)

func TestCheck_Run(t *testing.T) {
	type TestCase struct {
		name     string
		check    Check
		Expected State
	}

	testCases := []TestCase{
		{name: "Manual check should WARN", check: Check{Type: MANUAL}, Expected: WARN},
		{name: "Skip check should INFO", check: Check{Type: "skip"}, Expected: INFO},
		{name: "Unscored check (with no type) should WARN on failure", check: Check{Scored: false}, Expected: WARN},
		{
			name: "Unscored check that pass should PASS",
			check: Check{
				Scored: false,
				Audit:  "echo hello",
				Tests: &tests{TestItems: []*testItem{{
					Flag: "hello",
					Set:  true,
				}}},
			},
			Expected: PASS,
		},

		{name: "Check with no tests should WARN", check: Check{Scored: true}, Expected: WARN},
		{name: "Scored check with empty tests should FAIL", check: Check{Scored: true, Tests: &tests{}}, Expected: FAIL},
		{
			name: "Scored check that doesn't pass should FAIL",
			check: Check{
				Scored: true,
				Audit:  "echo hello",
				Tests: &tests{TestItems: []*testItem{{
					Flag: "hello",
					Set:  false,
				}}},
			},
			Expected: FAIL,
		},
		{
			name: "Scored checks that pass should PASS",
			check: Check{
				Scored: true,
				Audit:  "echo hello",
				Tests: &tests{TestItems: []*testItem{{
					Flag: "hello",
					Set:  true,
				}}},
			},
			Expected: PASS,
		},
	}

	for _, testCase := range testCases {
		t.Run(testCase.name, func(t *testing.T) {
			testCase.check.run()
			if testCase.check.State != testCase.Expected {
				t.Errorf("expected %s, actual %s", testCase.Expected, testCase.check.State)
			}
		})
	}
}

func TestCheckAuditEnv(t *testing.T) {
	passingCases := []*Check{
		controls.Groups[2].Checks[0],
		controls.Groups[2].Checks[2],
		controls.Groups[2].Checks[3],
		controls.Groups[2].Checks[4],
	}

	failingCases := []*Check{
		controls.Groups[2].Checks[1],
		controls.Groups[2].Checks[5],
		controls.Groups[2].Checks[6],
	}

	for _, c := range passingCases {
		t.Run(c.Text, func(t *testing.T) {
			c.run()
			if c.State != "PASS" {
				t.Errorf("Should PASS, got: %v", c.State)
			}
		})
	}

	for _, c := range failingCases {
		t.Run(c.Text, func(t *testing.T) {
			c.run()
			if c.State != "FAIL" {
				t.Errorf("Should FAIL, got: %v", c.State)
			}
		})
	}
}

func TestCheckAuditConfig(t *testing.T) {

	passingCases := []*Check{
		controls.Groups[1].Checks[0],
		controls.Groups[1].Checks[3],
		controls.Groups[1].Checks[5],
		controls.Groups[1].Checks[7],
		controls.Groups[1].Checks[9],
		controls.Groups[1].Checks[15],
	}

	failingCases := []*Check{
		controls.Groups[1].Checks[1],
		controls.Groups[1].Checks[2],
		controls.Groups[1].Checks[4],
		controls.Groups[1].Checks[6],
		controls.Groups[1].Checks[8],
		controls.Groups[1].Checks[10],
		controls.Groups[1].Checks[11],
		controls.Groups[1].Checks[12],
		controls.Groups[1].Checks[13],
		controls.Groups[1].Checks[14],
		controls.Groups[1].Checks[16],
	}

	for _, c := range passingCases {
		t.Run(c.Text, func(t *testing.T) {
			c.run()
			if c.State != "PASS" {
				t.Errorf("Should PASS, got: %v", c.State)
			}
		})
	}

	for _, c := range failingCases {
		t.Run(c.Text, func(t *testing.T) {
			c.run()
			if c.State != "FAIL" {
				t.Errorf("Should FAIL, got: %v", c.State)
			}
		})
	}
}

func Test_runAudit(t *testing.T) {
	type args struct {
		audit  string
		output string
	}
	tests := []struct {
		name   string
		args   args
		errMsg string
		output string
	}{
		{
			name: "run success",
			args: args{
				audit: "echo 'hello world'",
			},
			errMsg: "",
			output: "hello world\n",
		},
		{
			name: "run multiple lines script",
			args: args{
				audit: `
hello() {
  echo "hello world"
}

hello
`,
			},
			errMsg: "",
			output: "hello world\n",
		},
		{
			name: "run failed",
			args: args{
				audit: "unknown_command",
			},
			errMsg: "failed to run: \"unknown_command\", output: \"/bin/sh: ",
			output: "not found\n",
		},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			var errMsg string
			output, err := runAudit(tt.args.audit)
			if err != nil {
				errMsg = err.Error()
			}
			if errMsg != "" && !strings.Contains(errMsg, tt.errMsg) {
				t.Errorf("name %s errMsg = %q, want %q", tt.name, errMsg, tt.errMsg)
			}
			if errMsg == "" && output != tt.output {
				t.Errorf("name %s output = %q, want %q", tt.name, output, tt.output)
			}
			if errMsg != "" && !strings.Contains(output, tt.output) {
				t.Errorf("name %s output = %q, want %q", tt.name, output, tt.output)
			}
		})
	}
}