From b736b85a33045941747937051f738210df2759fe Mon Sep 17 00:00:00 2001 From: Christopher Swenson Date: Tue, 6 Feb 2024 13:18:20 -0800 Subject: [PATCH] events: Add filter watch function (#25147) This allows other parts of the code to be able to watch for changes to the event filter map, i.e., so that they can know when the global or local filters are changed. This is going to be used in an upcoming enterprise PR for event replication to performance secondaries. --- vault/eventbus/bus.go | 101 ++++++++-- vault/eventbus/bus_test.go | 324 ++++++++++++++++++++++++++++++++ vault/eventbus/filter.go | 337 +++++++++++++++++++++++++++++----- vault/eventbus/filter_test.go | 122 +++++++++++- 4 files changed, 827 insertions(+), 57 deletions(-) diff --git a/vault/eventbus/bus.go b/vault/eventbus/bus.go index fd5102ad8f6d..b45839f1ac0b 100644 --- a/vault/eventbus/bus.go +++ b/vault/eventbus/bus.go @@ -164,7 +164,7 @@ func (bus *pluginEventBus) SendEvent(ctx context.Context, eventType logical.Even return bus.bus.SendEventInternal(ctx, bus.namespace, bus.pluginInfo, eventType, data) } -func NewEventBus(localNodeID string, logger hclog.Logger) (*EventBus, error) { +func NewEventBus(localClusterID string, logger hclog.Logger) (*EventBus, error) { broker, err := eventlogger.NewBroker() if err != nil { return nil, err @@ -180,7 +180,7 @@ func NewEventBus(localNodeID string, logger hclog.Logger) (*EventBus, error) { logger = hclog.Default().Named("events") } - sourceUrl, err := url.Parse("vault://" + localNodeID) + sourceUrl, err := url.Parse("vault://" + localClusterID) if err != nil { return nil, err } @@ -198,7 +198,7 @@ func NewEventBus(localNodeID string, logger hclog.Logger) (*EventBus, error) { formatterNodeID: formatterNodeID, timeout: defaultTimeout, cloudEventsFormatterFilter: cloudEventsFormatterFilter, - filters: NewFilters(localNodeID), + filters: NewFilters(localClusterID), }, nil } @@ -211,6 +211,13 @@ func (bus *EventBus) Subscribe(ctx context.Context, ns *namespace.Namespace, pat // SubscribeMultipleNamespaces subscribes to events in the given namespace matching the event type // pattern and after applying the optional go-bexpr filter. func (bus *EventBus) SubscribeMultipleNamespaces(ctx context.Context, namespacePathPatterns []string, pattern string, bexprFilter string) (<-chan *eventlogger.Event, context.CancelFunc, error) { + return bus.subscribeInternal(ctx, namespacePathPatterns, pattern, bexprFilter, nil) +} + +// subscribeInternal creates the pipeline and connects it to the event bus to receive events. +// if the cluster is specified, then the namespacePathPatterns, pattern, and bexprFilter are ignored, and instead this +// subscription will be tied to the given cluster's filter. +func (bus *EventBus) subscribeInternal(ctx context.Context, namespacePathPatterns []string, pattern string, bexprFilter string, cluster *string) (<-chan *eventlogger.Event, context.CancelFunc, error) { // subscriptions are still stored even if the bus has not been started pipelineID, err := uuid.GenerateUUID() if err != nil { @@ -227,9 +234,15 @@ func (bus *EventBus) SubscribeMultipleNamespaces(ctx context.Context, namespaceP return nil, nil, err } - filterNode, err := newFilterNode(namespacePathPatterns, pattern, bexprFilter) - if err != nil { - return nil, nil, err + var filterNode *eventlogger.Filter + if cluster != nil { + filterNode, err = newClusterFilterNode(bus.filters, clusterID(*cluster)) + } else { + filterNode, err = newFilterNode(namespacePathPatterns, pattern, bexprFilter) + if err != nil { + return nil, nil, err + } + bus.filters.addPattern(bus.filters.self, namespacePathPatterns, pattern) } err = bus.broker.RegisterNode(eventlogger.NodeID(filterNodeID), filterNode) if err != nil { @@ -242,11 +255,10 @@ func (bus *EventBus) SubscribeMultipleNamespaces(ctx context.Context, namespaceP } ctx, cancel := context.WithCancel(ctx) - - bus.filters.addPattern(bus.filters.self, namespacePathPatterns, pattern) - asyncNode := newAsyncNode(ctx, bus.logger, bus.broker, func() { - bus.filters.removePattern(bus.filters.self, namespacePathPatterns, pattern) + if cluster == nil { + bus.filters.removePattern(bus.filters.self, namespacePathPatterns, pattern) + } }) err = bus.broker.RegisterNode(eventlogger.NodeID(sinkNodeID), asyncNode) if err != nil { @@ -281,6 +293,73 @@ func (bus *EventBus) SetSendTimeout(timeout time.Duration) { bus.timeout = timeout } +// GlobalMatch returns true if the given namespace and event type match the current global filter. +func (bus *EventBus) GlobalMatch(ns *namespace.Namespace, eventType logical.EventType) bool { + return bus.filters.globalMatch(ns, eventType) +} + +// ApplyClusterFilterChanges applies the given filter changes to the cluster's filters. +func (bus *EventBus) ApplyClusterFilterChanges(c string, changes []FilterChange) { + bus.filters.applyChanges(clusterID(c), changes) +} + +// ApplyGlobalFilterChanges applies the given filter changes to the global filters. +func (bus *EventBus) ApplyGlobalFilterChanges(changes []FilterChange) { + bus.filters.applyChanges(globalCluster, changes) +} + +// ClearGlobalFilter removes all entries from the current global filter. +func (bus *EventBus) ClearGlobalFilter() { + bus.filters.clearGlobalPatterns() +} + +// ClearClusterFilter removes all entries from the given cluster's filter. +func (bus *EventBus) ClearClusterFilter(id string) { + bus.filters.clearClusterPatterns(clusterID(id)) +} + +// NotifyOnGlobalFilterChanges returns a channel that receives changes to the global filter. +func (bus *EventBus) NotifyOnGlobalFilterChanges(ctx context.Context) (<-chan []FilterChange, context.CancelFunc, error) { + return bus.filters.watch(ctx, globalCluster) +} + +// NotifyOnLocalFilterChanges returns a channel that receives changes to the filter for the current cluster. +func (bus *EventBus) NotifyOnLocalFilterChanges(ctx context.Context) (<-chan []FilterChange, context.CancelFunc, error) { + return bus.NotifyOnClusterFilterChanges(ctx, string(bus.filters.self)) +} + +// NotifyOnClusterFilterChanges returns a channel that receives changes to the filter for the given cluster. +func (bus *EventBus) NotifyOnClusterFilterChanges(ctx context.Context, cluster string) (<-chan []FilterChange, context.CancelFunc, error) { + return bus.filters.watch(ctx, clusterID(cluster)) +} + +// NewGlobalSubscription creates a new subscription to all events that match the global filter. +func (bus *EventBus) NewGlobalSubscription(ctx context.Context) (<-chan *eventlogger.Event, context.CancelFunc, error) { + g := globalCluster + return bus.subscribeInternal(ctx, nil, "", "", &g) +} + +// NewClusterSubscription creates a new subscription to all events that match the given cluster's filter. +func (bus *EventBus) NewClusterSubscription(ctx context.Context, cluster string) (<-chan *eventlogger.Event, context.CancelFunc, error) { + return bus.subscribeInternal(ctx, nil, "", "", &cluster) +} + +// creates a new filter node that is tied to the filter for a given cluster +func newClusterFilterNode(filters *Filters, c clusterID) (*eventlogger.Filter, error) { + return &eventlogger.Filter{ + Predicate: func(e *eventlogger.Event) (bool, error) { + eventRecv := e.Payload.(*logical.EventReceived) + eventNs := strings.Trim(eventRecv.Namespace, "/") + if filters.clusterMatch(c, &namespace.Namespace{ + Path: eventNs, + }, logical.EventType(eventRecv.EventType)) { + return true, nil + } + return false, nil + }, + }, nil +} + func newFilterNode(namespacePatterns []string, pattern string, bexprFilter string) (*eventlogger.Filter, error) { var evaluator *bexpr.Evaluator if bexprFilter != "" { @@ -308,7 +387,7 @@ func newFilterNode(namespacePatterns []string, pattern string, bexprFilter strin } } - // NodeFilter for correct event type, including wildcards. + // ClusterFilter for correct event type, including wildcards. if !glob.Glob(pattern, eventRecv.EventType) { return false, nil } diff --git a/vault/eventbus/bus_test.go b/vault/eventbus/bus_test.go index 0255f7fe8678..cb4c12e950fc 100644 --- a/vault/eventbus/bus_test.go +++ b/vault/eventbus/bus_test.go @@ -703,3 +703,327 @@ func TestPipelineCleanedUp(t *testing.T) { t.Fatal() } } + +// TestSubscribeGlobal tests that the global filter subscription mechanism works. +func TestSubscribeGlobal(t *testing.T) { + bus, err := NewEventBus("", nil) + if err != nil { + t.Fatal(err) + } + + bus.Start() + + bus.filters.addGlobalPattern([]string{""}, "abc*") + ctx, cancelFunc := context.WithCancel(context.Background()) + t.Cleanup(cancelFunc) + ch, cancel2, err := bus.NewGlobalSubscription(ctx) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cancel2) + ev, err := logical.NewEvent() + if err != nil { + t.Fatal(err) + } + err = bus.SendEventInternal(nil, namespace.RootNamespace, nil, "abcd", ev) + if err != nil { + t.Fatal(err) + } + + select { + case recv := <-ch: + // ok + event := recv.Payload.(*logical.EventReceived) + assert.Equal(t, "abcd", event.EventType) + case <-time.After(1 * time.Second): + t.Fatal("Timed out waiting for event") + } +} + +// TestSubscribeGlobal_WithApply tests that the global filter subscription mechanism works when using ApplyGlobalFilterChanges. +func TestSubscribeGlobal_WithApply(t *testing.T) { + bus, err := NewEventBus("", nil) + if err != nil { + t.Fatal(err) + } + + bus.Start() + assert.False(t, bus.GlobalMatch(namespace.RootNamespace, "abcd")) + bus.ApplyGlobalFilterChanges([]FilterChange{ + { + Operation: FilterChangeAdd, + NamespacePatterns: []string{""}, + EventTypePattern: "abc*", + }, + }) + assert.True(t, bus.GlobalMatch(namespace.RootNamespace, "abcd")) + + ctx, cancelFunc := context.WithCancel(context.Background()) + t.Cleanup(cancelFunc) + ch, cancel2, err := bus.NewGlobalSubscription(ctx) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cancel2) + ev, err := logical.NewEvent() + if err != nil { + t.Fatal(err) + } + err = bus.SendEventInternal(nil, namespace.RootNamespace, nil, "abcd", ev) + if err != nil { + t.Fatal(err) + } + + select { + case recv := <-ch: + // ok + event := recv.Payload.(*logical.EventReceived) + assert.Equal(t, "abcd", event.EventType) + case <-time.After(1 * time.Second): + t.Fatal("Timed out waiting for event") + } +} + +// TestSubscribeCluster tests that the cluster filter subscription mechanism works. +func TestSubscribeCluster(t *testing.T) { + bus, err := NewEventBus("", nil) + if err != nil { + t.Fatal(err) + } + + bus.Start() + + bus.filters.addPattern("somecluster", []string{""}, "abc*") + ctx, cancelFunc := context.WithCancel(context.Background()) + t.Cleanup(cancelFunc) + ch, cancel2, err := bus.NewClusterSubscription(ctx, "somecluster") + if err != nil { + t.Fatal(err) + } + t.Cleanup(cancel2) + ev, err := logical.NewEvent() + if err != nil { + t.Fatal(err) + } + err = bus.SendEventInternal(nil, namespace.RootNamespace, nil, "abcd", ev) + if err != nil { + t.Fatal(err) + } + + select { + case recv := <-ch: + // ok + event := recv.Payload.(*logical.EventReceived) + assert.Equal(t, "abcd", event.EventType) + case <-time.After(1 * time.Second): + t.Fatal("Timed out waiting for event") + } +} + +// TestSubscribeCluster_WithApply tests that the cluster filter subscription mechanism works when using ApplyClusterFilterChanges. +func TestSubscribeCluster_WithApply(t *testing.T) { + bus, err := NewEventBus("", nil) + if err != nil { + t.Fatal(err) + } + + bus.Start() + + ctx, cancelFunc := context.WithCancel(context.Background()) + t.Cleanup(cancelFunc) + bus.ApplyClusterFilterChanges("somecluster", []FilterChange{ + { + Operation: FilterChangeAdd, + NamespacePatterns: []string{""}, + EventTypePattern: "abc*", + }, + }) + ch, cancel2, err := bus.NewClusterSubscription(ctx, "somecluster") + if err != nil { + t.Fatal(err) + } + t.Cleanup(cancel2) + ev, err := logical.NewEvent() + if err != nil { + t.Fatal(err) + } + err = bus.SendEventInternal(nil, namespace.RootNamespace, nil, "abcd", ev) + if err != nil { + t.Fatal(err) + } + + select { + case recv := <-ch: + // ok + event := recv.Payload.(*logical.EventReceived) + assert.Equal(t, "abcd", event.EventType) + case <-time.After(1 * time.Second): + t.Fatal("Timed out waiting for event") + } +} + +// TestClearGlobalFilter tests that clearing the global filter means no messages get through. +func TestClearGlobalFilter(t *testing.T) { + bus, err := NewEventBus("", nil) + if err != nil { + t.Fatal(err) + } + + bus.Start() + + ctx, cancelFunc := context.WithCancel(context.Background()) + t.Cleanup(cancelFunc) + bus.ApplyGlobalFilterChanges([]FilterChange{ + { + Operation: FilterChangeAdd, + NamespacePatterns: []string{""}, + EventTypePattern: "abc*", + }, + }) + bus.ClearGlobalFilter() + ch, cancel2, err := bus.NewGlobalSubscription(ctx) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cancel2) + ev, err := logical.NewEvent() + if err != nil { + t.Fatal(err) + } + err = bus.SendEventInternal(nil, namespace.RootNamespace, nil, "abcd", ev) + if err != nil { + t.Fatal(err) + } + + select { + case <-ch: + t.Fatal("We should not have gotten an event") + case <-time.After(100 * time.Millisecond): + // ok + } +} + +// TestClearClusterFilter tests that clearing a cluster filter means no messages get through. +func TestClearClusterFilter(t *testing.T) { + bus, err := NewEventBus("", nil) + if err != nil { + t.Fatal(err) + } + + bus.Start() + + ctx, cancelFunc := context.WithCancel(context.Background()) + t.Cleanup(cancelFunc) + bus.ApplyClusterFilterChanges("somecluster", []FilterChange{ + { + Operation: FilterChangeAdd, + NamespacePatterns: []string{""}, + EventTypePattern: "abc*", + }, + }) + bus.ClearClusterFilter("somecluster") + ch, cancel2, err := bus.NewClusterSubscription(ctx, "somecluster") + if err != nil { + t.Fatal(err) + } + t.Cleanup(cancel2) + ev, err := logical.NewEvent() + if err != nil { + t.Fatal(err) + } + err = bus.SendEventInternal(nil, namespace.RootNamespace, nil, "abcd", ev) + if err != nil { + t.Fatal(err) + } + + select { + case <-ch: + t.Fatal("We should not have gotten an event") + case <-time.After(100 * time.Millisecond): + // ok + } +} + +// TestNotifyOnGlobalFilterChanges tests that notifications on global filter changes are sent. +func TestNotifyOnGlobalFilterChanges(t *testing.T) { + bus, err := NewEventBus("", nil) + if err != nil { + t.Fatal(err) + } + + bus.Start() + + ctx, cancelFunc := context.WithCancel(context.Background()) + t.Cleanup(cancelFunc) + + ch, cancel2, err := bus.NotifyOnGlobalFilterChanges(ctx) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cancel2) + bus.ApplyGlobalFilterChanges([]FilterChange{ + { + Operation: FilterChangeAdd, + NamespacePatterns: []string{""}, + EventTypePattern: "abc*", + }, + }) + + select { + case changes := <-ch: + if len(changes) == 2 { + assert.Equal(t, []FilterChange{{Operation: FilterChangeClear}, {Operation: FilterChangeAdd, NamespacePatterns: []string{""}, EventTypePattern: "abc*"}}, changes) + } else { + // could be split into two updates + assert.Len(t, changes, 1) + assert.Equal(t, []FilterChange{{Operation: FilterChangeClear}}, changes) + changes := <-ch + assert.Len(t, changes, 1) + assert.Equal(t, []FilterChange{{Operation: FilterChangeAdd, NamespacePatterns: []string{""}, EventTypePattern: "abc*"}}, changes) + } + case <-time.After(100 * time.Millisecond): + t.Fatal("We expected to get a global filter notification") + } +} + +// TestNotifyOnLocalFilterChanges tests that notifications on local cluster filter changes are sent. +func TestNotifyOnLocalFilterChanges(t *testing.T) { + bus, err := NewEventBus("somecluster", nil) + if err != nil { + t.Fatal(err) + } + + bus.Start() + + ctx, cancelFunc := context.WithCancel(context.Background()) + t.Cleanup(cancelFunc) + + ch, cancel2, err := bus.NotifyOnLocalFilterChanges(ctx) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cancel2) + bus.ApplyClusterFilterChanges("somecluster", []FilterChange{ + { + Operation: FilterChangeAdd, + NamespacePatterns: []string{""}, + EventTypePattern: "abc*", + }, + }) + + select { + case changes := <-ch: + if len(changes) == 2 { + assert.Equal(t, []FilterChange{{Operation: FilterChangeClear}, {Operation: FilterChangeAdd, NamespacePatterns: []string{""}, EventTypePattern: "abc*"}}, changes) + } else { + // could be split into two updates + assert.Len(t, changes, 1) + assert.Equal(t, []FilterChange{{Operation: FilterChangeClear}}, changes) + changes := <-ch + assert.Len(t, changes, 1) + assert.Equal(t, []FilterChange{{Operation: FilterChangeAdd, NamespacePatterns: []string{""}, EventTypePattern: "abc*"}}, changes) + } + case <-time.After(100 * time.Millisecond): + t.Fatal("We expected to get a global filter notification") + } +} diff --git a/vault/eventbus/filter.go b/vault/eventbus/filter.go index ec91a7e0d16f..d4d496964f33 100644 --- a/vault/eventbus/filter.go +++ b/vault/eventbus/filter.go @@ -4,43 +4,62 @@ package eventbus import ( + "context" + "fmt" "slices" "sort" + "strings" "sync" "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/sdk/logical" "github.com/ryanuber/go-glob" + "k8s.io/apimachinery/pkg/util/sets" ) -// Filters keeps track of all the event patterns that each node is interested in. +const globalCluster = "" + +// Filters keeps track of all the event patterns that each cluster is interested in. type Filters struct { lock sync.RWMutex - self nodeID - filters map[nodeID]*NodeFilter + self clusterID + filters map[clusterID]*ClusterFilter + + // notifyChanges is used to notify about changes to filters. The condition variables are tied to single lock above. + notifyChanges map[clusterID]*sync.Cond } -// nodeID is used to syntactically indicate that the string is a node's name identifier. -type nodeID string +// clusterID is used to syntactically indicate that the string is a cluster's identifier. +type clusterID string // pattern is used to represent one or more combinations of patterns type pattern struct { + namespacePatterns string // space-separated (spaces are not allowed in namespaces, and slices are not comparable) eventTypePattern string - namespacePatterns []string } -// NodeFilter keeps track of all patterns that a particular node is interested in. -type NodeFilter struct { - patterns []pattern +func (p pattern) String() string { + return fmt.Sprintf("{ns=%s,ev=%s}", p.namespacePatterns, p.eventTypePattern) +} + +func (p pattern) isEmpty() bool { + return p.namespacePatterns == "" && p.eventTypePattern == "" +} + +// ClusterFilter keeps track of all patterns that a particular cluster is interested in. +type ClusterFilter struct { + patterns sets.Set[pattern] } -func (nf *NodeFilter) match(ns *namespace.Namespace, eventType logical.EventType) bool { +// match checks if the given ns and eventType matches any pattern in the cluster's filter. +// Must be called while holding a (read) lock for the filter. +func (nf *ClusterFilter) match(ns *namespace.Namespace, eventType logical.EventType) bool { if nf == nil { return false } - for _, p := range nf.patterns { + for p := range nf.patterns { if glob.Glob(p.eventTypePattern, string(eventType)) { - for _, nsp := range p.namespacePatterns { + for _, nsp := range strings.Split(p.namespacePatterns, " ") { if glob.Glob(nsp, ns.Path) { return true } @@ -50,71 +69,301 @@ func (nf *NodeFilter) match(ns *namespace.Namespace, eventType logical.EventType return false } -// NewFilters creates an empty set of filters to keep track of each node's pattern interests. +// NewFilters creates an empty set of filters to keep track of each cluster's pattern interests. func NewFilters(self string) *Filters { - return &Filters{ - self: nodeID(self), - filters: map[nodeID]*NodeFilter{}, + f := &Filters{ + self: clusterID(self), + filters: map[clusterID]*ClusterFilter{}, + notifyChanges: map[clusterID]*sync.Cond{}, + } + f.notifyChanges[clusterID(self)] = sync.NewCond(&f.lock) + f.notifyChanges[globalCluster] = sync.NewCond(&f.lock) + return f +} + +func (f *Filters) String() string { + x := "Filters {\n" + for k, v := range f.filters { + x += fmt.Sprintf(" %s: {%s}\n", k, v) + } + return x +} + +func (nf *ClusterFilter) String() string { + var x []string + l := nf.patterns.UnsortedList() + for _, v := range l { + x = append(x, v.String()) + } + return strings.Join(x, ",") +} + +func (f *Filters) addGlobalPattern(namespacePatterns []string, eventTypePattern string) { + f.addPattern(globalCluster, namespacePatterns, eventTypePattern) +} + +func (f *Filters) removeGlobalPattern(namespacePatterns []string, eventTypePattern string) { + f.removePattern(globalCluster, namespacePatterns, eventTypePattern) +} + +func (f *Filters) clearGlobalPatterns() { + defer f.notify(globalCluster) + f.lock.Lock() + defer f.lock.Unlock() + delete(f.filters, globalCluster) +} + +func (f *Filters) getOrCreateNotify(c clusterID) *sync.Cond { + // fast check when we don't need to create the Cond + f.lock.RLock() + n, ok := f.notifyChanges[c] + f.lock.RUnlock() + if ok { + return n + } + f.lock.Lock() + defer f.lock.Unlock() + // check again to avoid race condition + n, ok = f.notifyChanges[c] + if ok { + return n + } + n = sync.NewCond(&f.lock) + f.notifyChanges[c] = n + return n +} + +func (f *Filters) notify(c clusterID) { + f.lock.RLock() + defer f.lock.RUnlock() + if notifier, ok := f.notifyChanges[c]; ok { + notifier.Broadcast() + } +} + +func (f *Filters) clearClusterPatterns(c clusterID) { + defer f.notify(c) + f.lock.Lock() + defer f.lock.Unlock() + delete(f.filters, c) +} + +// copyPatternWithLock gets a copy of a cluster's filters +func (f *Filters) copyPatternWithLock(c clusterID) *ClusterFilter { + filters := &ClusterFilter{} + if got, ok := f.filters[c]; ok { + filters.patterns = got.patterns.Clone() + } else { + filters.patterns = sets.New[pattern]() + } + return filters +} + +// applyChanges applies the changes in the given list, atomically. +func (f *Filters) applyChanges(c clusterID, changes []FilterChange) { + defer f.notify(c) + f.lock.Lock() + defer f.lock.Unlock() + var newPatterns sets.Set[pattern] + if existing, ok := f.filters[c]; ok { + newPatterns = existing.patterns + } else { + newPatterns = sets.New[pattern]() + } + for _, change := range changes { + applyChange(newPatterns, &change) + } + f.filters[c] = &ClusterFilter{patterns: newPatterns} +} + +// applyChange applies a single filter change to the given set. +func applyChange(s sets.Set[pattern], change *FilterChange) { + switch change.Operation { + case FilterChangeAdd: + nsPatterns := slices.Clone(change.NamespacePatterns) + sort.Strings(nsPatterns) + p := pattern{eventTypePattern: change.EventTypePattern, namespacePatterns: cleanJoinNamespaces(nsPatterns)} + s.Insert(p) + case FilterChangeRemove: + nsPatterns := slices.Clone(change.NamespacePatterns) + sort.Strings(nsPatterns) + check := pattern{eventTypePattern: change.EventTypePattern, namespacePatterns: cleanJoinNamespaces(nsPatterns)} + s.Delete(check) + case FilterChangeClear: + s.Clear() + } +} + +func cleanJoinNamespaces(nsPatterns []string) string { + trimmed := make([]string, len(nsPatterns)) + for i := 0; i < len(nsPatterns); i++ { + trimmed[i] = strings.TrimSpace(nsPatterns[i]) } + // sort and uniq + trimmed = sets.NewString(trimmed...).List() + return strings.Join(trimmed, " ") } // addPattern adds a pattern to a node's list. -func (f *Filters) addPattern(node nodeID, namespacePatterns []string, eventTypePattern string) { +func (f *Filters) addPattern(c clusterID, namespacePatterns []string, eventTypePattern string) { + defer f.notify(c) f.lock.Lock() defer f.lock.Unlock() - if _, ok := f.filters[node]; !ok { - f.filters[node] = &NodeFilter{} + if _, ok := f.filters[c]; !ok { + f.filters[c] = &ClusterFilter{ + patterns: sets.New[pattern](), + } } nsPatterns := slices.Clone(namespacePatterns) sort.Strings(nsPatterns) - f.filters[node].patterns = append(f.filters[node].patterns, pattern{eventTypePattern: eventTypePattern, namespacePatterns: nsPatterns}) + p := pattern{eventTypePattern: eventTypePattern, namespacePatterns: cleanJoinNamespaces(namespacePatterns)} + f.filters[c].patterns.Insert(p) } -func (f *Filters) addNsPattern(node nodeID, ns *namespace.Namespace, eventTypePattern string) { - f.addPattern(node, []string{ns.Path}, eventTypePattern) -} - -// removePattern removes a pattern from a node's list. -func (f *Filters) removePattern(node nodeID, namespacePatterns []string, eventTypePattern string) { +// removePattern removes a pattern from a cluster's list. +func (f *Filters) removePattern(c clusterID, namespacePatterns []string, eventTypePattern string) { + defer f.notify(c) nsPatterns := slices.Clone(namespacePatterns) sort.Strings(nsPatterns) - check := pattern{eventTypePattern: eventTypePattern, namespacePatterns: nsPatterns} + check := pattern{eventTypePattern: eventTypePattern, namespacePatterns: cleanJoinNamespaces(nsPatterns)} f.lock.Lock() defer f.lock.Unlock() - filters, ok := f.filters[node] + filters, ok := f.filters[c] if !ok { return } - filters.patterns = slices.DeleteFunc(filters.patterns, func(m pattern) bool { - return m.eventTypePattern == check.eventTypePattern && - slices.Equal(m.namespacePatterns, check.namespacePatterns) - }) -} - -func (f *Filters) removeNsPattern(node nodeID, ns *namespace.Namespace, eventTypePattern string) { - f.removePattern(node, []string{ns.Path}, eventTypePattern) + filters.patterns.Delete(check) } -// anyMatch returns true if any node's pattern list matches the arguments. +// anyMatch returns true if any cluster's pattern list matches the arguments. func (f *Filters) anyMatch(ns *namespace.Namespace, eventType logical.EventType) bool { f.lock.RLock() defer f.lock.RUnlock() - for _, nf := range f.filters { - if nf.match(ns, eventType) { + for _, cf := range f.filters { + if cf.match(ns, eventType) { return true } } return false } -// nodeMatch returns true if the given node's pattern list matches the arguments. -func (f *Filters) nodeMatch(node nodeID, ns *namespace.Namespace, eventType logical.EventType) bool { +// globalMatch returns true if the global cluster's pattern list matches the arguments. +func (f *Filters) globalMatch(ns *namespace.Namespace, eventType logical.EventType) bool { + return f.clusterMatch(globalCluster, ns, eventType) +} + +// clusterMatch returns true if the given cluster's pattern list matches the arguments. +func (f *Filters) clusterMatch(c clusterID, ns *namespace.Namespace, eventType logical.EventType) bool { f.lock.RLock() defer f.lock.RUnlock() - return f.filters[node].match(ns, eventType) + return f.filters[c].match(ns, eventType) } -// localMatch returns true if the local node's pattern list matches the arguments. +// localMatch returns true if the local cluster's pattern list matches the arguments. func (f *Filters) localMatch(ns *namespace.Namespace, eventType logical.EventType) bool { - return f.nodeMatch(f.self, ns, eventType) + return f.clusterMatch(f.self, ns, eventType) +} + +// watch creates a notification channel that receives changes for the given cluster. +func (f *Filters) watch(ctx context.Context, cluster clusterID) (<-chan []FilterChange, context.CancelFunc, error) { + notify := f.getOrCreateNotify(cluster) + ctx, cancelFunc := context.WithCancel(ctx) + doneCh := ctx.Done() + ch := make(chan []FilterChange) + + // ensure that the sleeping goroutine wakes up if the channel is closed + go func() { + select { + case <-doneCh: + notify.Broadcast() + } + }() + + // actual watcher goroutine that waits for notifications and calculates changes + go func() { + var current *ClusterFilter + for { + done := func() bool { + f.lock.Lock() + defer f.lock.Unlock() + next := f.copyPatternWithLock(cluster) + changes := calculateChanges(current, next) + current = next + // check if the context is finished before sending + select { + case <-doneCh: + close(ch) + return true + default: + go func() { + ch <- changes + }() + } + notify.Wait() + return false + }() + if done { + return + } + } + }() + return ch, cancelFunc, nil +} + +// FilterChange represents a change to a cluster's filters. +type FilterChange struct { + Operation int + NamespacePatterns []string + EventTypePattern string +} + +const ( + FilterChangeAdd = 0 + FilterChangeRemove = 1 + FilterChangeClear = 2 +) + +// calculateChanges calculates a set of changes necessary to transform from into to. +func calculateChanges(from *ClusterFilter, to *ClusterFilter) []FilterChange { + var changes []FilterChange + if to == nil { + changes = append(changes, FilterChange{ + Operation: FilterChangeClear, + }) + } else if from == nil { + changes = append(changes, FilterChange{ + Operation: FilterChangeClear, + }) + for pattern := range to.patterns { + if !pattern.isEmpty() { + changes = append(changes, FilterChange{ + Operation: FilterChangeAdd, + NamespacePatterns: strings.Split(pattern.namespacePatterns, " "), + EventTypePattern: pattern.eventTypePattern, + }) + } + } + } else { + additions := to.patterns.Difference(from.patterns) + subtractions := from.patterns.Difference(to.patterns) + for add := range additions { + if !add.isEmpty() { + changes = append(changes, FilterChange{ + Operation: FilterChangeAdd, + NamespacePatterns: strings.Split(add.namespacePatterns, " "), + EventTypePattern: add.eventTypePattern, + }) + } + } + for sub := range subtractions { + if !sub.isEmpty() { + changes = append(changes, FilterChange{ + Operation: FilterChangeRemove, + NamespacePatterns: strings.Split(sub.namespacePatterns, " "), + EventTypePattern: sub.eventTypePattern, + }) + } + } + } + return changes } diff --git a/vault/eventbus/filter_test.go b/vault/eventbus/filter_test.go index fe5044563fda..4fb661f55763 100644 --- a/vault/eventbus/filter_test.go +++ b/vault/eventbus/filter_test.go @@ -4,12 +4,27 @@ package eventbus import ( + "context" + "fmt" + "strings" + "sync" "testing" + "time" "github.com/hashicorp/vault/helper/namespace" "github.com/stretchr/testify/assert" ) +// TestCleanJoinNamespaces tests some cases in cleanJoinNamespaces. +func TestCleanJoinNamespaces(t *testing.T) { + assert.Equal(t, "", cleanJoinNamespaces([]string{""})) + assert.Equal(t, " abc", cleanJoinNamespaces([]string{"", "abc"})) + // just checking that inverting works as expected + assert.Equal(t, []string{"", "abc"}, strings.Split(" abc", " ")) + assert.Equal(t, "abc", cleanJoinNamespaces([]string{"abc"})) + assert.Equal(t, "abc def", cleanJoinNamespaces([]string{"def", "abc"})) +} + // TestFilters_AddRemoveMatchLocal checks that basic matching, adding, and removing of patterns all work. func TestFilters_AddRemoveMatchLocal(t *testing.T) { f := NewFilters("self") @@ -20,12 +35,115 @@ func TestFilters_AddRemoveMatchLocal(t *testing.T) { assert.False(t, f.localMatch(ns, "abc")) assert.False(t, f.anyMatch(ns, "abc")) - f.addNsPattern("self", ns, "abc") + f.addPattern("self", []string{ns.Path}, "abc") assert.True(t, f.localMatch(ns, "abc")) assert.False(t, f.localMatch(ns, "abcd")) assert.True(t, f.anyMatch(ns, "abc")) assert.False(t, f.anyMatch(ns, "abcd")) - f.removeNsPattern("self", ns, "abc") + f.removePattern("self", []string{ns.Path}, "abc") assert.False(t, f.localMatch(ns, "abc")) assert.False(t, f.anyMatch(ns, "abc")) } + +// TestFilters_Watch checks that adding a watch for a cluster will send a notification when the patterns are modified. +func TestFilters_Watch(t *testing.T) { + ctx, cancelFunc := context.WithCancel(context.Background()) + t.Cleanup(cancelFunc) + f := NewFilters("self") + f.addPattern("self", []string{"ns1"}, "e3") + ch, cancelFunc2, err := f.watch(ctx, "self") + assert.Nil(t, err) + t.Cleanup(cancelFunc2) + initial := <-ch // we always get one immediately for the current state + assert.Len(t, initial, 2) + assert.Equal(t, FilterChangeClear, initial[0].Operation) + assert.Equal(t, FilterChangeAdd, initial[1].Operation) + assert.Equal(t, []string{"ns1"}, initial[1].NamespacePatterns) + assert.Equal(t, "e3", initial[1].EventTypePattern) + + go func() { + f.addPattern("self", []string{"ns1"}, "e2") + }() + changes := waitForChanges(t, ch) + assert.Equal(t, []FilterChange{{ + Operation: FilterChangeAdd, + NamespacePatterns: []string{"ns1"}, + EventTypePattern: "e2", + }}, changes) + go func() { + f.removePattern("self", []string{"ns1"}, "e3") + }() + changes = waitForChanges(t, ch) + assert.Equal(t, []FilterChange{{ + Operation: FilterChangeRemove, + NamespacePatterns: []string{"ns1"}, + EventTypePattern: "e3", + }}, changes) +} + +func waitForChanges(t *testing.T, ch <-chan []FilterChange) []FilterChange { + t.Helper() + timeout := time.After(2000 * time.Millisecond) + var changes []FilterChange + select { + case changes = <-ch: + case <-timeout: + fmt.Println("Timeout waiting for changes") + } + return changes +} + +// TestFilters_WatchCancel checks that calling the cancel function will clean up the channel. +func TestFilters_WatchCancel(t *testing.T) { + f := NewFilters("self") + f.addPattern("self", []string{"ns1"}, "e3") + ch, cancelFunc, err := f.watch(context.Background(), "self") + assert.Nil(t, err) + t.Cleanup(cancelFunc) + initial := <-ch // we always get one immediately for the current state + assert.Len(t, initial, 2) + assert.Equal(t, FilterChangeClear, initial[0].Operation) + assert.Equal(t, FilterChangeAdd, initial[1].Operation) + assert.Equal(t, []string{"ns1"}, initial[1].NamespacePatterns) + assert.Equal(t, "e3", initial[1].EventTypePattern) + + var changes []FilterChange + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + changes = waitForChanges(t, ch) + wg.Done() + }() + + cancelFunc() + wg.Wait() + assert.Nil(t, changes) + select { + case _, ok := <-ch: + assert.False(t, ok) + default: + t.Fatal("Channel should be closed") + } +} + +// TestFilters_AddRemoveClear tests that add/remove/clear works as expected. +func TestFilters_AddRemoveClear(t *testing.T) { + f := NewFilters("self") + f.addPattern("somecluster", []string{"ns1"}, "abc") + f.removePattern("somecluster", []string{"ns1"}, "abcd") + assert.Equal(t, "{ns=ns1,ev=abc}", f.filters["somecluster"].String()) + f.removePattern("somecluster", []string{"ns1"}, "abc") + assert.Equal(t, "", f.filters["somecluster"].String()) + f.addPattern("somecluster", []string{"ns1"}, "abc") + f.clearClusterPatterns("somecluster") + assert.NotContains(t, f.filters, "somecluster") + + f.addGlobalPattern([]string{"ns1"}, "abc") + f.removeGlobalPattern([]string{"ns1"}, "abcd") + assert.Equal(t, "{ns=ns1,ev=abc}", f.filters[globalCluster].String()) + f.removeGlobalPattern([]string{"ns1"}, "abc") + assert.Equal(t, "", f.filters[globalCluster].String()) + f.addGlobalPattern([]string{"ns1"}, "abc") + f.clearGlobalPatterns() + assert.NotContains(t, f.filters, globalCluster) +}