Skip to content

Commit

Permalink
[VAULT-23467] Validate audit filter option against filter selectors r…
Browse files Browse the repository at this point in the history
…eferencing unsupported fields (#25012)

* Validate audit filter option against filter selectors referencing unsupported fields

* Test updates due to filter validation

* Test all properties of the log input bexpr datum struct in filters

* Remove redundant cloning of the client in external tests for audit filtering

* TestAuditFilteringFilterForUnsupportedField now also tests the same behaviour with skip_test option set to true

* Add filter validation test cases to unit tests for audit backends

---------

Co-authored-by: Peter Wilson <[email protected]>
  • Loading branch information
kubawi and Peter Wilson authored Jan 23, 2024
1 parent 349a859 commit a1295a5
Show file tree
Hide file tree
Showing 7 changed files with 117 additions and 26 deletions.
10 changes: 10 additions & 0 deletions audit/entry_filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/hashicorp/go-bexpr"
"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/internal/observability/event"
"github.com/hashicorp/vault/sdk/logical"
)

var _ eventlogger.Node = (*EntryFilter)(nil)
Expand All @@ -31,6 +32,15 @@ func NewEntryFilter(filter string) (*EntryFilter, error) {
return nil, fmt.Errorf("%s: cannot create new audit filter: %w", op, err)
}

// Validate the filter by attempting to evaluate it with an empty input.
// This prevents users providing a filter with a field that would error during
// matching, and block all auditable requests to Vault.
li := logical.LogInputBexpr{}
_, err = eval.Evaluate(li)
if err != nil {
return nil, fmt.Errorf("%s: filter references an unsupported field: %s", op, filter)
}

return &EntryFilter{evaluator: eval}, nil
}

Expand Down
33 changes: 27 additions & 6 deletions audit/entry_filter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,29 @@ func TestEntryFilter_NewEntryFilter(t *testing.T) {
IsErrorExpected: true,
ExpectedErrorMessage: "audit.NewEntryFilter: cannot create new audit filter",
},
"good-filter": {
Filter: "foo == bar",
"unsupported-field-filter": {
Filter: "foo == bar",
IsErrorExpected: true,
ExpectedErrorMessage: "audit.NewEntryFilter: filter references an unsupported field: foo == bar",
},
"good-filter-operation": {
Filter: "operation == create",
IsErrorExpected: false,
},
"good-filter-mount_type": {
Filter: "mount_type == kv",
IsErrorExpected: false,
},
"good-filter-mount_point": {
Filter: "mount_point == \"/auth/userpass\"",
IsErrorExpected: false,
},
"good-filter-namespace": {
Filter: "namespace == juan",
IsErrorExpected: false,
},
"good-filter-path": {
Filter: "path == foo",
IsErrorExpected: false,
},
}
Expand Down Expand Up @@ -92,7 +113,7 @@ func TestEntryFilter_Process_ContextDone(t *testing.T) {
// Explicitly cancel the context
cancel()

l, err := NewEntryFilter("foo == bar")
l, err := NewEntryFilter("operation == foo")
require.NoError(t, err)

// Fake audit event
Expand Down Expand Up @@ -121,7 +142,7 @@ func TestEntryFilter_Process_ContextDone(t *testing.T) {
func TestEntryFilter_Process_NilEvent(t *testing.T) {
t.Parallel()

l, err := NewEntryFilter("foo == bar")
l, err := NewEntryFilter("operation == foo")
require.NoError(t, err)
e, err := l.Process(context.Background(), nil)
require.Error(t, err)
Expand All @@ -137,7 +158,7 @@ func TestEntryFilter_Process_NilEvent(t *testing.T) {
func TestEntryFilter_Process_BadPayload(t *testing.T) {
t.Parallel()

l, err := NewEntryFilter("foo == bar")
l, err := NewEntryFilter("operation == foo")
require.NoError(t, err)

e := &eventlogger.Event{
Expand All @@ -160,7 +181,7 @@ func TestEntryFilter_Process_BadPayload(t *testing.T) {
func TestEntryFilter_Process_NoAuditDataInPayload(t *testing.T) {
t.Parallel()

l, err := NewEntryFilter("foo == bar")
l, err := NewEntryFilter("operation == foo")
require.NoError(t, err)

a, err := NewEvent(RequestType)
Expand Down
9 changes: 7 additions & 2 deletions builtin/audit/file/backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ func TestBackend_configureFilterNode(t *testing.T) {
expectedErrorMsg string
}{
"happy": {
filter: "foo == bar",
filter: "operation == update",
},
"empty": {
filter: "",
Expand All @@ -266,6 +266,11 @@ func TestBackend_configureFilterNode(t *testing.T) {
wantErr: true,
expectedErrorMsg: "file.(Backend).configureFilterNode: error creating filter node: audit.NewEntryFilter: cannot create new audit filter",
},
"unsupported-field": {
filter: "foo == bar",
wantErr: true,
expectedErrorMsg: "filter references an unsupported field: foo == bar",
},
}
for name, tc := range tests {
name := name
Expand Down Expand Up @@ -477,7 +482,7 @@ func TestBackend_configureFilterFormatterSink(t *testing.T) {
formatConfig, err := audit.NewFormatterConfig()
require.NoError(t, err)

err = b.configureFilterNode("foo == bar")
err = b.configureFilterNode("path == bar")
require.NoError(t, err)

err = b.configureFormatterNode(formatConfig)
Expand Down
9 changes: 7 additions & 2 deletions builtin/audit/socket/backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ func TestBackend_configureFilterNode(t *testing.T) {
expectedErrorMsg string
}{
"happy": {
filter: "foo == bar",
filter: "mount_point == \"/auth/token\"",
},
"empty": {
filter: "",
Expand All @@ -142,6 +142,11 @@ func TestBackend_configureFilterNode(t *testing.T) {
wantErr: true,
expectedErrorMsg: "socket.(Backend).configureFilterNode: error creating filter node: audit.NewEntryFilter: cannot create new audit filter",
},
"unsupported-field": {
filter: "foo == bar",
wantErr: true,
expectedErrorMsg: "filter references an unsupported field: foo == bar",
},
}
for name, tc := range tests {
name := name
Expand Down Expand Up @@ -309,7 +314,7 @@ func TestBackend_configureFilterFormatterSink(t *testing.T) {
formatConfig, err := audit.NewFormatterConfig()
require.NoError(t, err)

err = b.configureFilterNode("foo == bar")
err = b.configureFilterNode("mount_type == kv")
require.NoError(t, err)

err = b.configureFormatterNode(formatConfig)
Expand Down
9 changes: 7 additions & 2 deletions builtin/audit/syslog/backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ func TestBackend_configureFilterNode(t *testing.T) {
expectedErrorMsg string
}{
"happy": {
filter: "foo == bar",
filter: "namespace == bar",
},
"empty": {
filter: "",
Expand All @@ -142,6 +142,11 @@ func TestBackend_configureFilterNode(t *testing.T) {
wantErr: true,
expectedErrorMsg: "syslog.(Backend).configureFilterNode: error creating filter node: audit.NewEntryFilter: cannot create new audit filter",
},
"unsupported-field": {
filter: "foo == bar",
wantErr: true,
expectedErrorMsg: "filter references an unsupported field: foo == bar",
},
}
for name, tc := range tests {
name := name
Expand Down Expand Up @@ -291,7 +296,7 @@ func TestBackend_configureFilterFormatterSink(t *testing.T) {
formatConfig, err := audit.NewFormatterConfig()
require.NoError(t, err)

err = b.configureFilterNode("foo == bar")
err = b.configureFilterNode("mount_type == kv")
require.NoError(t, err)

err = b.configureFormatterNode(formatConfig)
Expand Down
9 changes: 4 additions & 5 deletions vault/audit_broker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,11 @@ import (
"testing"
"time"

"github.com/hashicorp/vault/builtin/audit/file"
"github.com/hashicorp/vault/helper/namespace"

"github.com/hashicorp/eventlogger"
"github.com/hashicorp/vault/audit"
"github.com/hashicorp/vault/builtin/audit/file"
"github.com/hashicorp/vault/builtin/audit/syslog"
"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/helper/testhelpers/corehelpers"
"github.com/hashicorp/vault/internal/observability/event"
"github.com/hashicorp/vault/sdk/helper/salt"
Expand Down Expand Up @@ -66,7 +65,7 @@ func TestAuditBroker_Register_SuccessThresholdSinks(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, a)

filterBackend := testAuditBackend(t, "b1-filter", map[string]string{"filter": "foo == bar"})
filterBackend := testAuditBackend(t, "b1-filter", map[string]string{"filter": "operation == create"})
noFilterBackend := testAuditBackend(t, "b2-no-filter", map[string]string{})

// Should be set to 0 for required sinks (and not found, as we've never registered before).
Expand Down Expand Up @@ -108,7 +107,7 @@ func TestAuditBroker_Deregister_SuccessThresholdSinks(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, a)

filterBackend := testAuditBackend(t, "b1-filter", map[string]string{"filter": "foo == bar"})
filterBackend := testAuditBackend(t, "b1-filter", map[string]string{"filter": "operation == create"})
noFilterBackend := testAuditBackend(t, "b2-no-filter", map[string]string{})

err = a.Register("b1-filter", filterBackend, false)
Expand Down
64 changes: 55 additions & 9 deletions vault/external_tests/audit/audit_filtering_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@ import (
func TestAuditFilteringOnDifferentFields(t *testing.T) {
t.Parallel()
cluster := minimal.NewTestSoloCluster(t, nil)
client, err := cluster.Cores[0].Client.Clone()
require.NoError(t, err)
client.SetToken(cluster.RootToken)
client := cluster.Cores[0].Client

// Create audit devices.
tempDir := t.TempDir()
Expand Down Expand Up @@ -115,9 +113,7 @@ func TestAuditFilteringOnDifferentFields(t *testing.T) {
func TestAuditFilteringMultipleDevices(t *testing.T) {
t.Parallel()
cluster := minimal.NewTestSoloCluster(t, nil)
client, err := cluster.Cores[0].Client.Clone()
require.NoError(t, err)
client.SetToken(cluster.RootToken)
client := cluster.Cores[0].Client

// Create audit devices.
tempDir := t.TempDir()
Expand Down Expand Up @@ -212,9 +208,7 @@ func TestAuditFilteringMultipleDevices(t *testing.T) {
func TestAuditFilteringFallbackDevice(t *testing.T) {
t.Parallel()
cluster := minimal.NewTestSoloCluster(t, nil)
client, err := cluster.Cores[0].Client.Clone()
require.NoError(t, err)
client.SetToken(cluster.RootToken)
client := cluster.Cores[0].Client

tempDir := t.TempDir()
fallbackLogFile, err := os.CreateTemp(tempDir, "")
Expand Down Expand Up @@ -289,6 +283,58 @@ func TestAuditFilteringFallbackDevice(t *testing.T) {
require.Equal(t, 5, numberOfEntries)
}

// TestAuditFilteringFilterForUnsupportedField validates that the audit device
// 'filter' option fails when the filter expression selector references an
// unsupported field and that the error prevents an audit device from created.
func TestAuditFilteringFilterForUnsupportedField(t *testing.T) {
t.Parallel()
cluster := minimal.NewTestSoloCluster(t, nil)
client := cluster.Cores[0].Client

tempDir := t.TempDir()
filteredLogFile, err := os.CreateTemp(tempDir, "")
filteredDevicePath := "filtered"
filteredDeviceData := map[string]any{
"type": "file",
"description": "",
"local": false,
"options": map[string]any{
"file_path": filteredLogFile.Name(),
"filter": "auth == foo", // 'auth' is not one of the fields we allow filtering on
},
}
_, err = client.Logical().Write("sys/audit/"+filteredDevicePath, filteredDeviceData)
require.Error(t, err)
require.ErrorContains(t, err, "audit.NewEntryFilter: filter references an unsupported field: auth == foo")

// Ensure the device has not been created.
devices, err := client.Sys().ListAudit()
require.NoError(t, err)
_, ok := devices[filteredDevicePath]
require.False(t, ok)

// Now we do the same test but with the 'skip_test' option set to true.
filteredDeviceDataSkipTest := map[string]any{
"type": "file",
"description": "",
"local": false,
"options": map[string]any{
"file_path": filteredLogFile.Name(),
"filter": "auth == foo", // 'auth' is not one of the fields we allow filtering on
"skip_test": true,
},
}
_, err = client.Logical().Write("sys/audit/"+filteredDevicePath, filteredDeviceDataSkipTest)
require.Error(t, err)
require.ErrorContains(t, err, "audit.NewEntryFilter: filter references an unsupported field: auth == foo")

// Ensure the device has not been created.
devices, err = client.Sys().ListAudit()
require.NoError(t, err)
_, ok = devices[filteredDevicePath]
require.False(t, ok)
}

// getFileSize returns the size of the given file in bytes.
func getFileSize(t *testing.T, filePath string) int64 {
t.Helper()
Expand Down

0 comments on commit a1295a5

Please sign in to comment.