diff --git a/.gitignore b/.gitignore index 4bfdd88..1ce4a32 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,5 @@ vendor dist .vscode/ hack/kind.test.yaml + +.idea/ \ No newline at end of file diff --git a/Gopkg.lock b/Gopkg.lock index e74b52e..601bce2 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -1,6 +1,14 @@ # This file is autogenerated, do not edit; changes may be undone by the next 'dep ensure'. +[[projects]] + digest = "1:ffe9824d294da03b391f44e1ae8281281b4afc1bdaa9588c9097785e3af10cec" + name = "github.com/davecgh/go-spew" + packages = ["spew"] + pruneopts = "UT" + revision = "8991bc29aa16c548c550c7ff78260e27b9ab7c73" + version = "v1.1.1" + [[projects]] digest = "1:938a2672d6ebbb7f7bc63eee3e4b9464c16ffcf77ec8913d3edbf32b4e3984dd" name = "github.com/fatih/color" @@ -113,6 +121,14 @@ pruneopts = "UT" revision = "0131db6d737cfbbfb678f8b7d92e55e27ce46224" +[[projects]] + digest = "1:0028cb19b2e4c3112225cd871870f2d9cf49b9b4276531f03438a88e94be86fe" + name = "github.com/pmezard/go-difflib" + packages = ["difflib"] + pruneopts = "UT" + revision = "792786c7400a136282c1664665ae0a8db921c6c2" + version = "v1.0.0" + [[projects]] digest = "1:1fccaaeae58b2a2f1af4dbf7eee92ff14f222e161d143bfd20082ef664f91216" name = "github.com/spf13/afero" @@ -161,6 +177,25 @@ revision = "25b30aa063fc18e48662b86996252eabdcf2f0c7" version = "v1.0.0" +[[projects]] + digest = "1:ac83cf90d08b63ad5f7e020ef480d319ae890c208f8524622a2f3136e2686b02" + name = "github.com/stretchr/objx" + packages = ["."] + pruneopts = "UT" + revision = "477a77ecc69700c7cdeb1fa9e129548e1c1c393c" + version = "v0.1.1" + +[[projects]] + digest = "1:0bcc464dabcfad5393daf87c3f8142911d0f6c52569b837e91a1c15e890265f3" + name = "github.com/stretchr/testify" + packages = [ + "assert", + "mock", + ] + pruneopts = "UT" + revision = "ffdc059bfe9ce6a4e144ba849dbedead332c6053" + version = "v1.3.0" + [[projects]] digest = "1:c9c0ba9ea00233c41b91e441cfd490f34b129bbfebcb1858979623bd8de07f72" name = "golang.org/x/sys" @@ -210,7 +245,10 @@ "github.com/jinzhu/gorm/dialects/postgres", "github.com/spf13/cobra", "github.com/spf13/viper", + "github.com/stretchr/testify/assert", + "github.com/stretchr/testify/mock", "gopkg.in/yaml.v2", + "k8s.io/client-go/util/jsonpath", ] solver-name = "gps-cdcl" solver-version = 1 diff --git a/Gopkg.toml b/Gopkg.toml index c397575..4c062d7 100644 --- a/Gopkg.toml +++ b/Gopkg.toml @@ -18,6 +18,10 @@ name = "github.com/spf13/viper" version = "1.0.0" +[[constraint]] + name = "github.com/stretchr/testify" + version = "1.3.0" + [prune] go-tests = true unused-packages = true diff --git a/check/check.go b/check/check.go index 4ace74b..6ebd253 100644 --- a/check/check.go +++ b/check/check.go @@ -36,11 +36,11 @@ const ( // PASS check passed. PASS State = "PASS" // FAIL check failed. - FAIL = "FAIL" + FAIL State = "FAIL" // WARN could not carry out check. - WARN = "WARN" + WARN State = "WARN" // INFO informational message - INFO = "INFO" + INFO State = "INFO" // MASTER a master node MASTER NodeType = "master" @@ -74,20 +74,37 @@ type Check struct { Scored bool `json:"scored"` } +// Runner wraps the basic Run method. +type Runner interface { + // Run runs a given check and returns the execution state. + Run(c *Check) State +} + +// NewRunner constructs a default Runner. +func NewRunner() Runner { + return &defaultRunner{} +} + +type defaultRunner struct{} + +func (r *defaultRunner) Run(c *Check) State { + return c.run() +} + // Run executes the audit commands specified in a check and outputs // the results. -func (c *Check) Run() { +func (c *Check) run() State { // If check type is skip, force result to INFO if c.Type == "skip" { c.State = INFO - return + return c.State } // If check type is manual or the check is not scored, force result to WARN if c.Type == "manual" || !c.Scored { c.State = WARN - return + return c.State } var out bytes.Buffer @@ -97,7 +114,7 @@ func (c *Check) Run() { for _, cmd := range c.Commands { if !isShellCommand(cmd.Path) { c.State = WARN - return + return c.State } } @@ -106,7 +123,7 @@ func (c *Check) Run() { if n == 0 { // Likely a warning message. c.State = WARN - return + return c.State } // Each command runs, @@ -188,6 +205,7 @@ func (c *Check) Run() { if errmsgs != "" { glog.V(2).Info(errmsgs) } + return c.State } // textToCommand transforms an input text representation of commands to be diff --git a/check/check_test.go b/check/check_test.go index ab74656..27c3c64 100644 --- a/check/check_test.go +++ b/check/check_test.go @@ -21,7 +21,7 @@ func TestCheck_Run(t *testing.T) { for _, testCase := range testCases { - testCase.check.Run() + testCase.check.run() if testCase.check.State != testCase.Expected { t.Errorf("test failed, expected %s, actual %s\n", testCase.Expected, testCase.check.State) diff --git a/check/controls.go b/check/controls.go index f6d4ab9..84635a1 100644 --- a/check/controls.go +++ b/check/controls.go @@ -17,8 +17,7 @@ package check import ( "encoding/json" "fmt" - - yaml "gopkg.in/yaml.v2" + "gopkg.in/yaml.v2" ) // Controls holds all controls to check for master nodes. @@ -50,6 +49,8 @@ type Summary struct { Info int `json:"total_info"` } +type Predicate func(group *Group, check *Check) bool + // NewControls instantiates a new master Controls object. func NewControls(t NodeType, in []byte) (*Controls, error) { c := new(Controls) @@ -73,76 +74,44 @@ func NewControls(t NodeType, in []byte) (*Controls, error) { return c, nil } -// RunGroup runs all checks in a group. -func (controls *Controls) RunGroup(gids ...string) Summary { - g := []*Group{} +// RunChecks runs the checks with the given Runner. Only checks for which the filter Predicate returns `true` will run. +func (controls *Controls) RunChecks(runner Runner, filter Predicate) Summary { + var g []*Group + m := make(map[string]*Group) controls.Summary.Pass, controls.Summary.Fail, controls.Summary.Warn, controls.Info = 0, 0, 0, 0 - // If no groupid is passed run all group checks. - if len(gids) == 0 { - gids = controls.getAllGroupIDs() - } - for _, group := range controls.Groups { + for _, check := range group.Checks { - for _, gid := range gids { - if gid == group.ID { - for _, check := range group.Checks { - check.Run() - check.TestInfo = append(check.TestInfo, check.Remediation) - summarize(controls, check) - summarizeGroup(group, check) - } - - g = append(g, group) + if !filter(group, check) { + continue } - } - } - controls.Groups = g - return controls.Summary -} + state := runner.Run(check) + check.TestInfo = append(check.TestInfo, check.Remediation) -// RunChecks runs the checks with the supplied IDs. -func (controls *Controls) RunChecks(ids ...string) Summary { - g := []*Group{} - m := make(map[string]*Group) - controls.Summary.Pass, controls.Summary.Fail, controls.Summary.Warn, controls.Info = 0, 0, 0, 0 + // Check if we have already added this checks group. + if v, ok := m[group.ID]; !ok { + // Create a group with same info + w := &Group{ + ID: group.ID, + Text: group.Text, + Checks: []*Check{}, + } - // If no groupid is passed run all group checks. - if len(ids) == 0 { - ids = controls.getAllCheckIDs() - } + // Add this check to the new group + w.Checks = append(w.Checks, check) + summarizeGroup(w, state) - for _, group := range controls.Groups { - for _, check := range group.Checks { - for _, id := range ids { - if id == check.ID { - check.Run() - check.TestInfo = append(check.TestInfo, check.Remediation) - summarize(controls, check) - - // Check if we have already added this checks group. - if v, ok := m[group.ID]; !ok { - // Create a group with same info - w := &Group{ - ID: group.ID, - Text: group.Text, - Checks: []*Check{}, - } - - // Add this check to the new group - w.Checks = append(w.Checks, check) - - // Add to groups we have visited. - m[w.ID] = w - g = append(g, w) - } else { - v.Checks = append(v.Checks, check) - } - - } + // Add to groups we have visited. + m[w.ID] = w + g = append(g, w) + } else { + v.Checks = append(v.Checks, check) + summarizeGroup(v, state) } + + summarize(controls, state) } } @@ -155,29 +124,8 @@ func (controls *Controls) JSON() ([]byte, error) { return json.Marshal(controls) } -func (controls *Controls) getAllGroupIDs() []string { - var ids []string - - for _, group := range controls.Groups { - ids = append(ids, group.ID) - } - return ids -} - -func (controls *Controls) getAllCheckIDs() []string { - var ids []string - - for _, group := range controls.Groups { - for _, check := range group.Checks { - ids = append(ids, check.ID) - } - } - return ids - -} - -func summarize(controls *Controls, check *Check) { - switch check.State { +func summarize(controls *Controls, state State) { + switch state { case PASS: controls.Summary.Pass++ case FAIL: @@ -189,8 +137,8 @@ func summarize(controls *Controls, check *Check) { } } -func summarizeGroup(group *Group, check *Check) { - switch check.State { +func summarizeGroup(group *Group, state State) { + switch state { case PASS: group.Pass++ case FAIL: diff --git a/check/controls_test.go b/check/controls_test.go index 17c62e5..d0d480c 100644 --- a/check/controls_test.go +++ b/check/controls_test.go @@ -6,11 +6,22 @@ import ( "path/filepath" "testing" - yaml "gopkg.in/yaml.v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "gopkg.in/yaml.v2" ) const cfgDir = "../cfg/" +type mockRunner struct { + mock.Mock +} + +func (m *mockRunner) Run(c *Check) State { + args := m.Called(c) + return args.Get(0).(State) +} + // validate that the files we're shipping are valid YAML func TestYamlFiles(t *testing.T) { err := filepath.Walk(cfgDir, func(path string, info os.FileInfo, err error) error { @@ -38,3 +49,74 @@ func TestYamlFiles(t *testing.T) { t.Fatalf("failure walking cfg dir: %v\n", err) } } + +func TestNewControls(t *testing.T) { + + t.Run("Should return error when node type is not specified", func(t *testing.T) { + // given + in := []byte(` +--- +controls: +type: # not specified +groups: +`) + // when + _, err := NewControls(MASTER, in) + // then + assert.EqualError(t, err, "non-master controls file specified") + }) + + t.Run("Should return error when input YAML is invalid", func(t *testing.T) { + // given + in := []byte("BOOM") + // when + _, err := NewControls(MASTER, in) + // then + assert.EqualError(t, err, "failed to unmarshal YAML: yaml: unmarshal errors:\n line 1: cannot unmarshal !!str `BOOM` into check.Controls") + }) + +} + +func TestControls_RunChecks(t *testing.T) { + + t.Run("Should run all checks", func(t *testing.T) { + // given + runner := new(mockRunner) + // and + in := []byte(` +--- +type: "master" +groups: +- id: G1 + checks: + - id: G1/C1 +- id: G2 + checks: + - id: G2/C1 +`) + // and + controls, _ := NewControls(MASTER, in) + // and + runner.On("Run", controls.Groups[0].Checks[0]).Return(PASS) + runner.On("Run", controls.Groups[1].Checks[0]).Return(FAIL) + // and + var runAll Predicate = func(group *Group, c *Check) bool { + return true + } + // when + controls.RunChecks(runner, runAll) + // then + assert.Equal(t, 2, len(controls.Groups)) + // and + assert.Equal(t, "G1", controls.Groups[0].ID) + assert.Equal(t, "G1/C1", controls.Groups[0].Checks[0].ID) + // and + assert.Equal(t, "G2", controls.Groups[1].ID) + assert.Equal(t, "G2/C1", controls.Groups[1].Checks[0].ID) + // and + // TODO We can assert that group and controls summaries are updated. + // and + runner.AssertExpectations(t) + }) + +} diff --git a/cmd/common.go b/cmd/common.go index ed6e9b5..cbc8f91 100644 --- a/cmd/common.go +++ b/cmd/common.go @@ -29,6 +29,51 @@ var ( errmsgs string ) +// NewRunFilter constructs a Predicate based on FilterOptions which determines whether tested Checks should be run or not. +func NewRunFilter(opts FilterOpts) check.Predicate { + + if opts.CheckList != "" && opts.GroupList != "" { + exitWithError(fmt.Errorf("group option and check option can't be used together")) + } + + var groupIDs map[string]bool + if opts.GroupList != "" { + groupIDs = cleanIDs(opts.GroupList) + } + + var checkIDs map[string]bool + if opts.CheckList != "" { + checkIDs = cleanIDs(opts.CheckList) + } + + return func(g *check.Group, c *check.Check) bool { + if len(groupIDs) > 0 { + _, ok := groupIDs[g.ID] + if !ok { + return false + } + } + + if len(checkIDs) > 0 { + _, ok := checkIDs[c.ID] + if !ok { + return false + } + } + + if opts.Scored && opts.Unscored { + return true + } + if opts.Scored { + return c.Scored + } + if opts.Unscored { + return !c.Scored + } + return true + } +} + func runChecks(nodetype check.NodeType) { var summary check.Summary @@ -40,7 +85,7 @@ func runChecks(nodetype check.NodeType) { glog.V(1).Info(fmt.Sprintf("Using benchmark file: %s\n", def)) - // Get the set of exectuables and config files we care about on this type of node. + // Get the set of executables and config files we care about on this type of node. typeConf := viper.Sub(string(nodetype)) binmap, err := getBinaries(typeConf) @@ -65,17 +110,10 @@ func runChecks(nodetype check.NodeType) { exitWithError(fmt.Errorf("error setting up %s controls: %v", nodetype, err)) } - if groupList != "" && checkList == "" { - ids := cleanIDs(groupList) - summary = controls.RunGroup(ids...) - } else if checkList != "" && groupList == "" { - ids := cleanIDs(checkList) - summary = controls.RunChecks(ids...) - } else if checkList != "" && groupList != "" { - exitWithError(fmt.Errorf("group option and check option can't be used together")) - } else { - summary = controls.RunGroup() - } + runner := check.NewRunner() + filter := NewRunFilter(filterOpts) + + summary = controls.RunChecks(runner, filter) // if we successfully ran some tests and it's json format, ignore the warnings if (summary.Fail > 0 || summary.Warn > 0 || summary.Pass > 0 || summary.Info > 0) && jsonFmt { diff --git a/cmd/common_test.go b/cmd/common_test.go new file mode 100644 index 0000000..40fc906 --- /dev/null +++ b/cmd/common_test.go @@ -0,0 +1,103 @@ +// Copyright © 2017 Aqua Security Software Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cmd + +import ( + "github.com/aquasecurity/kube-bench/check" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestNewRunFilter(t *testing.T) { + + type TestCase struct { + Name string + FilterOpts FilterOpts + Group *check.Group + Check *check.Check + + Expected bool + } + + testCases := []TestCase{ + { + Name: "Should return true when scored flag is enabled and check is scored", + FilterOpts: FilterOpts{Scored: true, Unscored: false}, + Group: &check.Group{}, + Check: &check.Check{Scored: true}, + Expected: true, + }, + { + Name: "Should return false when scored flag is enabled and check is not scored", + FilterOpts: FilterOpts{Scored: true, Unscored: false}, + Group: &check.Group{}, + Check: &check.Check{Scored: false}, + Expected: false, + }, + + { + Name: "Should return true when unscored flag is enabled and check is not scored", + FilterOpts: FilterOpts{Scored: false, Unscored: true}, + Group: &check.Group{}, + Check: &check.Check{Scored: false}, + Expected: true, + }, + { + Name: "Should return false when unscored flag is enabled and check is scored", + FilterOpts: FilterOpts{Scored: false, Unscored: true}, + Group: &check.Group{}, + Check: &check.Check{Scored: true}, + Expected: false, + }, + + { + Name: "Should return true when group flag contains group's ID", + FilterOpts: FilterOpts{GroupList: "G1,G2,G3"}, + Group: &check.Group{ID: "G2"}, + Check: &check.Check{}, + Expected: true, + }, + { + Name: "Should return false when group flag doesn't contain group's ID", + FilterOpts: FilterOpts{GroupList: "G1,G3"}, + Group: &check.Group{ID: "G2"}, + Check: &check.Check{}, + Expected: false, + }, + + { + Name: "Should return true when check flag contains check's ID", + FilterOpts: FilterOpts{CheckList: "C1,C2,C3"}, + Group: &check.Group{}, + Check: &check.Check{ID: "C2"}, + Expected: true, + }, + { + Name: "Should return false when check flag doesn't contain check's ID", + FilterOpts: FilterOpts{CheckList: "C1,C3"}, + Group: &check.Group{}, + Check: &check.Check{ID: "C2"}, + Expected: false, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.Name, func(t *testing.T) { + filter := NewRunFilter(testCase.FilterOpts) + assert.Equal(t, testCase.Expected, filter(testCase.Group, testCase.Check)) + }) + } + +} diff --git a/cmd/root.go b/cmd/root.go index eab044d..1a3e844 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -25,6 +25,13 @@ import ( "github.com/spf13/viper" ) +type FilterOpts struct { + CheckList string + GroupList string + Scored bool + Unscored bool +} + var ( envVarsPrefix = "KUBE_BENCH" defaultKubeVersion = "1.6" @@ -33,14 +40,13 @@ var ( cfgDir string jsonFmt bool pgSQL bool - checkList string - groupList string masterFile = "master.yaml" nodeFile = "node.yaml" federatedFile string noResults bool noSummary bool noRemediations bool + filterOpts FilterOpts ) // RootCmd represents the base command when called without any subcommands @@ -79,16 +85,18 @@ func init() { RootCmd.PersistentFlags().BoolVar(&noRemediations, "noremediations", false, "Disable printing of remediations section") RootCmd.PersistentFlags().BoolVar(&jsonFmt, "json", false, "Prints the results as JSON") RootCmd.PersistentFlags().BoolVar(&pgSQL, "pgsql", false, "Save the results to PostgreSQL") + RootCmd.PersistentFlags().BoolVar(&filterOpts.Scored, "scored", false, "Run only scored CIS checks") + RootCmd.PersistentFlags().BoolVar(&filterOpts.Unscored, "unscored", false, "Run only unscored CIS checks") RootCmd.PersistentFlags().StringVarP( - &checkList, + &filterOpts.CheckList, "check", "c", "", `A comma-delimited list of checks to run as specified in CIS document. Example --check="1.1.1,1.1.2"`, ) RootCmd.PersistentFlags().StringVarP( - &groupList, + &filterOpts.GroupList, "group", "g", "", diff --git a/cmd/util.go b/cmd/util.go index 87cda65..29b7d69 100644 --- a/cmd/util.go +++ b/cmd/util.go @@ -50,15 +50,18 @@ func continueWithError(err error, msg string) string { return "" } -func cleanIDs(list string) []string { +func cleanIDs(list string) map[string]bool { list = strings.Trim(list, ",") ids := strings.Split(list, ",") + set := make(map[string]bool) + for _, id := range ids { id = strings.Trim(id, " ") + set[id] = true } - return ids + return set } // ps execs out to the ps command; it's separated into a function so we can write tests