diff --git a/cmd/util.go b/cmd/util.go index 87b13f0..0bc40be 100644 --- a/cmd/util.go +++ b/cmd/util.go @@ -10,6 +10,7 @@ import ( "github.com/aquasecurity/kube-bench/check" "github.com/fatih/color" "github.com/golang/glog" + "github.com/spf13/viper" ) var ( @@ -83,6 +84,37 @@ func ps(proc string) string { return string(out) } +// getBinaries finds which of the set of candidate executables are running +func getBinaries(v *viper.Viper) map[string]string { + binmap := make(map[string]string) + + for _, exeType := range v.AllKeys() { + bin, err := findExecutable(v.GetStringSlice(exeType)) + if err != nil { + exitWithError(fmt.Errorf("looking for %s executable but none of the candidates are running", exeType)) + } + + binmap[exeType] = bin + } + return binmap +} + +// getConfigFiles finds which of the set of candidate config files exist +func getConfigFiles(v *viper.Viper) map[string]string { + confmap := make(map[string]string) + + for _, confType := range v.AllKeys() { + conf := findConfigFile(v.GetStringSlice(confType)) + if conf == "" { + printlnWarn(fmt.Sprintf("Missing kubernetes config file for %s", confType)) + } else { + confmap[confType] = conf + } + } + + return confmap +} + // verifyBin checks that the binary specified is running func verifyBin(bin string) bool { @@ -110,6 +142,21 @@ func verifyBin(bin string) bool { return false } +// fundConfigFile looks through a list of possible config files and finds the first one that exists +func findConfigFile(candidates []string) string { + for _, c := range candidates { + _, err := statFunc(c) + if err == nil { + return c + } + if !os.IsNotExist(err) { + exitWithError(fmt.Errorf("error looking for file %s: %v", c, err)) + } + } + + return "" +} + // findExecutable looks through a list of possible executable names and finds the first one that's running func findExecutable(candidates []string) (string, error) { for _, c := range candidates { diff --git a/cmd/util_test.go b/cmd/util_test.go index 659dd2e..d5192a6 100644 --- a/cmd/util_test.go +++ b/cmd/util_test.go @@ -15,9 +15,13 @@ package cmd import ( + "os" + "reflect" "regexp" "strconv" "testing" + + "github.com/spf13/viper" ) func TestCheckVersion(t *testing.T) { @@ -78,10 +82,19 @@ func TestVersionMatch(t *testing.T) { } var g string +var e []error +var eIndex int func fakeps(proc string) string { return g } + +func fakestat(file string) (os.FileInfo, error) { + err := e[eIndex] + eIndex++ + return nil, err +} + func TestVerifyBin(t *testing.T) { cases := []struct { proc string @@ -145,6 +158,41 @@ func TestFindExecutable(t *testing.T) { } } +func TestGetBinaries(t *testing.T) { + cases := []struct { + config map[string]interface{} + psOut string + exp map[string]string + }{ + { + config: map[string]interface{}{"apiserver": []string{"apiserver", "kube-apiserver"}}, + psOut: "kube-apiserver", + exp: map[string]string{"apiserver": "kube-apiserver"}, + }, + { + config: map[string]interface{}{"apiserver": []string{"apiserver", "kube-apiserver"}, "thing": []string{"something else", "thing"}}, + psOut: "kube-apiserver thing", + exp: map[string]string{"apiserver": "kube-apiserver", "thing": "thing"}, + }, + } + + v := viper.New() + psFunc = fakeps + + for id, c := range cases { + t.Run(strconv.Itoa(id), func(t *testing.T) { + g = c.psOut + for k, val := range c.config { + v.Set(k, val) + } + m := getBinaries(v) + if !reflect.DeepEqual(m, c.exp) { + t.Fatalf("Got %v\nExpected %v", m, c.exp) + } + }) + } +} + func TestMultiWordReplace(t *testing.T) { cases := []struct { input string @@ -166,3 +214,64 @@ func TestMultiWordReplace(t *testing.T) { }) } } + +func TestFindConfigFile(t *testing.T) { + cases := []struct { + input []string + statResults []error + exp string + }{ + {input: []string{"myfile"}, statResults: []error{nil}, exp: "myfile"}, + {input: []string{"thisfile", "thatfile"}, statResults: []error{os.ErrNotExist, nil}, exp: "thatfile"}, + {input: []string{"thisfile", "thatfile"}, statResults: []error{os.ErrNotExist, os.ErrNotExist}, exp: ""}, + } + + statFunc = fakestat + for id, c := range cases { + t.Run(strconv.Itoa(id), func(t *testing.T) { + e = c.statResults + eIndex = 0 + conf := findConfigFile(c.input) + if conf != c.exp { + t.Fatalf("Got %s expected %s", conf, c.exp) + } + }) + } +} + +func TestGetConfigFiles(t *testing.T) { + cases := []struct { + config map[string]interface{} + exp map[string]string + statResults []error + }{ + { + config: map[string]interface{}{"apiserver": []string{"apiserver", "kube-apiserver"}}, + statResults: []error{os.ErrNotExist, nil}, + exp: map[string]string{"apiserver": "kube-apiserver"}, + }, + { + config: map[string]interface{}{"apiserver": []string{"apiserver", "kube-apiserver"}, "thing": []string{"/my/file/thing"}}, + statResults: []error{os.ErrNotExist, nil, nil}, + exp: map[string]string{"apiserver": "kube-apiserver", "thing": "/my/file/thing"}, + }, + } + + v := viper.New() + statFunc = fakestat + + for id, c := range cases { + t.Run(strconv.Itoa(id), func(t *testing.T) { + for k, val := range c.config { + v.Set(k, val) + } + e = c.statResults + eIndex = 0 + + m := getConfigFiles(v) + if !reflect.DeepEqual(m, c.exp) { + t.Fatalf("Got %v\nExpected %v", m, c.exp) + } + }) + } +}