diff --git a/ds/reactive/set.go b/ds/reactive/set.go index fd769a3fa..5de713fb9 100644 --- a/ds/reactive/set.go +++ b/ds/reactive/set.go @@ -29,6 +29,10 @@ type ReadableSet[ElementType comparable] interface { // OnUpdate registers the given callback that is triggered when the value changes. OnUpdate(callback func(appliedMutations ds.SetMutations[ElementType]), triggerWithInitialZeroValue ...bool) (unsubscribe func()) + // SubtractReactive returns a new set that will automatically be updated to always hold all elements of the current + // set minus the elements of the other sets. + SubtractReactive(others ...ReadableSet[ElementType]) Set[ElementType] + // ReadableSet imports the read methods of the Set interface. ds.ReadableSet[ElementType] } diff --git a/ds/reactive/set_impl.go b/ds/reactive/set_impl.go index fc5668b1a..6b5af885d 100644 --- a/ds/reactive/set_impl.go +++ b/ds/reactive/set_impl.go @@ -192,6 +192,57 @@ func (r *readableSet[ElementType]) OnUpdate(callback func(appliedMutations ds.Se } } +// SubtractReactive returns a new set that will automatically be updated to always hold all elements of the current set +// minus the elements of the other sets. +func (r *readableSet[ElementType]) SubtractReactive(others ...ReadableSet[ElementType]) Set[ElementType] { + elementCounters := shrinkingmap.New[ElementType, int]() + countMutations := func(targetSet, opposingSet ds.Set[ElementType], diff int, threshold int) func(addedElement ElementType) { + return func(element ElementType) { + if elementCounters.Compute(element, func(currentValue int, _ bool) int { + return currentValue + diff + }) == threshold && !opposingSet.Delete(element) { + targetSet.Add(element) + } + } + } + + addMutations := func(mutations ds.SetMutations[ElementType]) func(ElementType) { + return countMutations(mutations.AddedElements(), mutations.DeletedElements(), +1, 1) + } + + deleteMutations := func(mutations ds.SetMutations[ElementType]) func(addedElement ElementType) { + return countMutations(mutations.DeletedElements(), mutations.AddedElements(), -1, 0) + } + + s := NewSet[ElementType]() + r.OnUpdate(func(appliedMutations ds.SetMutations[ElementType]) { + s.Compute(func(elements ds.ReadableSet[ElementType]) ds.SetMutations[ElementType] { + mutations := ds.NewSetMutations[ElementType]() + + appliedMutations.AddedElements().Range(addMutations(mutations)) + appliedMutations.DeletedElements().Range(deleteMutations(mutations)) + + return mutations + }) + + }) + + for _, other := range others { + other.OnUpdate(func(appliedMutations ds.SetMutations[ElementType]) { + s.Compute(func(elements ds.ReadableSet[ElementType]) ds.SetMutations[ElementType] { + mutations := ds.NewSetMutations[ElementType]() + + appliedMutations.AddedElements().Range(deleteMutations(mutations)) + appliedMutations.DeletedElements().Range(addMutations(mutations)) + + return mutations + }) + }) + } + + return s +} + // endregion /////////////////////////////////////////////////////////////////////////////////////////////////////////// // region derivedSet /////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/ds/reactive/set_test.go b/ds/reactive/set_test.go index 573cbde76..752fb0e7f 100644 --- a/ds/reactive/set_test.go +++ b/ds/reactive/set_test.go @@ -33,3 +33,45 @@ func TestSet(t *testing.T) { require.True(t, inheritedSet1.Has(7)) require.True(t, inheritedSet1.Has(9)) } + +func TestSubtract(t *testing.T) { + sourceSet := NewSet[int]() + sourceSet.Add(3) + + removedSet := NewSet[int]() + removedSet.Add(5) + + subtraction := sourceSet.SubtractReactive(removedSet) + require.True(t, subtraction.Has(3)) + require.Equal(t, 1, subtraction.Size()) + + sourceSet.Add(4) + require.True(t, subtraction.Has(3)) + require.True(t, subtraction.Has(4)) + require.Equal(t, 2, subtraction.Size()) + + removedSet.Add(4) + require.True(t, subtraction.Has(3)) + require.False(t, subtraction.Has(4)) + require.Equal(t, 1, subtraction.Size()) + + removedSet.Add(3) + require.False(t, subtraction.Has(3)) + require.False(t, subtraction.Has(4)) + require.Equal(t, 0, subtraction.Size()) + + sourceSet.Add(5) + require.False(t, subtraction.Has(3)) + require.False(t, subtraction.Has(4)) + require.False(t, subtraction.Has(3)) + require.False(t, subtraction.Has(5)) + require.Equal(t, 0, subtraction.Size()) + + sourceSet.Add(6) + require.False(t, subtraction.Has(3)) + require.False(t, subtraction.Has(4)) + require.False(t, subtraction.Has(3)) + require.False(t, subtraction.Has(5)) + require.True(t, subtraction.Has(6)) + require.Equal(t, 1, subtraction.Size()) +}