diff --git a/cmd/common.go b/cmd/common.go index fff30c7..eb4c5bb 100644 --- a/cmd/common.go +++ b/cmd/common.go @@ -205,24 +205,29 @@ func prettyPrint(r *check.Controls, summary check.Summary) { // Print remediations. if !noRemediations { + var remediationOutput strings.Builder if summary.Fail > 0 || summary.Warn > 0 { - colors[check.WARN].Printf("== Remediations %s ==\n", r.Type) for _, g := range r.Groups { for _, c := range g.Checks { - if c.State == check.FAIL { - fmt.Printf("%s %s\n", c.ID, c.Remediation) + if c.State == check.FAIL && printStatus(check.FAIL) { + remediationOutput.WriteString(fmt.Sprintf("%s %s\n", c.ID, c.Remediation)) } - if c.State == check.WARN { + if c.State == check.WARN && printStatus(check.WARN) { // Print the error if test failed due to problem with the audit command if c.Reason != "" && c.Type != "manual" { - fmt.Printf("%s audit test did not run: %s\n", c.ID, c.Reason) + remediationOutput.WriteString(fmt.Sprintf("%s audit test did not run: %s\n", c.ID, c.Reason)) } else { - fmt.Printf("%s %s\n", c.ID, c.Remediation) + remediationOutput.WriteString(fmt.Sprintf("%s %s\n", c.ID, c.Remediation)) } } } } - fmt.Println() + output := remediationOutput.String() + if len(output) > 0 { + remediationOutput.WriteString("\n") + fmt.Printf(colors[check.WARN].Sprintf("== Remediations %s ==\n", r.Type)) + fmt.Printf(remediationOutput.String()) + } } } diff --git a/cmd/common_test.go b/cmd/common_test.go index c9b4a4f..459b5c0 100644 --- a/cmd/common_test.go +++ b/cmd/common_test.go @@ -756,27 +756,35 @@ func TestWriteStdoutOutputStatusList(t *testing.T) { statusList string notContains []string + contains []string } testCases := []testCase{ { name: "statusList PASS", statusList: "PASS", - notContains: []string{"INFO", "WARN", "FAIL"}, + notContains: []string{"INFO", "WARN", "FAIL", "== Remediations controlplane =="}, }, { name: "statusList PASS,INFO", statusList: "PASS,INFO", - notContains: []string{"WARN", "FAIL"}, + notContains: []string{"WARN", "FAIL", "== Remediations controlplane =="}, + }, + { + name: "statusList WARN", + statusList: "WARN", + notContains: []string{"INFO", "FAIL", "PASS"}, + contains: []string{"== Remediations controlplane =="}, }, { name: "statusList FAIL", statusList: "FAIL", - notContains: []string{"INFO", "WARN", "PASS"}, + notContains: []string{"INFO", "WARN", "PASS", "== Remediations controlplane =="}, }, { name: "statusList empty", statusList: "", notContains: nil, + contains: []string{"== Remediations controlplane =="}, }, } @@ -801,6 +809,10 @@ func TestWriteStdoutOutputStatusList(t *testing.T) { for _, n := range tt.notContains { assert.NotContains(t, string(out), n) } + + for _, c := range tt.contains { + assert.Contains(t, string(out), c) + } } }