Skip to content

Commit

Permalink
Merge pull request #536 from iotaledger/feat/set-subtractreactive
Browse files Browse the repository at this point in the history
Feat: Added SubtractReactive method to Set
  • Loading branch information
karimodm authored Jul 31, 2023
2 parents eec9937 + 53dd865 commit 84bc933
Show file tree
Hide file tree
Showing 5 changed files with 169 additions and 23 deletions.
4 changes: 4 additions & 0 deletions ds/reactive/set.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ type ReadableSet[ElementType comparable] interface {
// OnUpdate registers the given callback that is triggered when the value changes.
OnUpdate(callback func(appliedMutations ds.SetMutations[ElementType]), triggerWithInitialZeroValue ...bool) (unsubscribe func())

// SubtractReactive returns a new set that will automatically be updated to always hold all elements of the current
// set minus the elements of the other sets.
SubtractReactive(others ...ReadableSet[ElementType]) Set[ElementType]

// ReadableSet imports the read methods of the Set interface.
ds.ReadableSet[ElementType]
}
Expand Down
53 changes: 30 additions & 23 deletions ds/reactive/set_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,30 @@ func (r *readableSet[ElementType]) OnUpdate(callback func(appliedMutations ds.Se
}
}

// SubtractReactive returns a new set that will automatically be updated to always hold all elements of the current set
// minus the elements of the other sets.
func (r *readableSet[ElementType]) SubtractReactive(others ...ReadableSet[ElementType]) Set[ElementType] {
s := NewSet[ElementType]()

setArithmetic := ds.NewSetArithmetic[ElementType]()

r.OnUpdate(func(mutations ds.SetMutations[ElementType]) {
s.Compute(func(ds.ReadableSet[ElementType]) ds.SetMutations[ElementType] {
return setArithmetic.Add(mutations)
})
})

for _, other := range others {
other.OnUpdate(func(mutations ds.SetMutations[ElementType]) {
s.Compute(func(ds.ReadableSet[ElementType]) ds.SetMutations[ElementType] {
return setArithmetic.Subtract(mutations)
})
})
}

return s
}

// endregion ///////////////////////////////////////////////////////////////////////////////////////////////////////////

// region derivedSet ///////////////////////////////////////////////////////////////////////////////////////////////////
Expand All @@ -201,16 +225,15 @@ type derivedSet[ElementType comparable] struct {
// set is the set that is derived from the source sets.
*set[ElementType]

// sourceCounters are the counters that keep track of the number of times an element is contained in the source
// sets (we only want to remove an element from the set if it is not contained in any of the source sets anymore).
sourceCounters *shrinkingmap.ShrinkingMap[ElementType, int]
// setArithmetic is used to track the amount of times an element has been added or removed by any of the sources.
setArithmetic ds.SetArithmetic[ElementType]
}

// newDerivedSet creates a new derivedSet with the given elements.
func newDerivedSet[ElementType comparable]() *derivedSet[ElementType] {
return &derivedSet[ElementType]{
set: newSet[ElementType](),
sourceCounters: shrinkingmap.New[ElementType, int](),
set: newSet[ElementType](),
setArithmetic: ds.NewSetArithmetic[ElementType](),
}
}

Expand Down Expand Up @@ -259,24 +282,8 @@ func (s *derivedSet[ElementType]) applyInheritedMutations(mutations ds.SetMutati
defer s.readableSet.mutex.Unlock()

inheritedMutations = ds.NewSetMutations[ElementType]()

elementsToAdd := inheritedMutations.AddedElements()
mutations.AddedElements().Range(func(element ElementType) {
if s.sourceCounters.Compute(element, func(currentValue int, _ bool) int {
return currentValue + 1
}) == 1 {
elementsToAdd.Add(element)
}
})

elementsToDelete := inheritedMutations.DeletedElements()
mutations.DeletedElements().Range(func(element ElementType) {
if s.sourceCounters.Compute(element, func(currentValue int, _ bool) int {
return currentValue - 1
}) == 0 && !elementsToAdd.Delete(element) {
elementsToDelete.Add(element)
}
})
mutations.AddedElements().Range(s.setArithmetic.AddedElementsCollector(inheritedMutations))
mutations.DeletedElements().Range(s.setArithmetic.SubtractedElementsCollector(inheritedMutations))

return s.value.Apply(inheritedMutations), s.uniqueUpdateID.Next(), s.updateCallbacks.Values()
}
Expand Down
42 changes: 42 additions & 0 deletions ds/reactive/set_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,45 @@ func TestSet(t *testing.T) {
require.True(t, inheritedSet1.Has(7))
require.True(t, inheritedSet1.Has(9))
}

func TestSubtract(t *testing.T) {
sourceSet := NewSet[int]()
sourceSet.Add(3)

removedSet := NewSet[int]()
removedSet.Add(5)

subtraction := sourceSet.SubtractReactive(removedSet)
require.True(t, subtraction.Has(3))
require.Equal(t, 1, subtraction.Size())

sourceSet.Add(4)
require.True(t, subtraction.Has(3))
require.True(t, subtraction.Has(4))
require.Equal(t, 2, subtraction.Size())

removedSet.Add(4)
require.True(t, subtraction.Has(3))
require.False(t, subtraction.Has(4))
require.Equal(t, 1, subtraction.Size())

removedSet.Add(3)
require.False(t, subtraction.Has(3))
require.False(t, subtraction.Has(4))
require.Equal(t, 0, subtraction.Size())

sourceSet.Add(5)
require.False(t, subtraction.Has(3))
require.False(t, subtraction.Has(4))
require.False(t, subtraction.Has(3))
require.False(t, subtraction.Has(5))
require.Equal(t, 0, subtraction.Size())

sourceSet.Add(6)
require.False(t, subtraction.Has(3))
require.False(t, subtraction.Has(4))
require.False(t, subtraction.Has(3))
require.False(t, subtraction.Has(5))
require.True(t, subtraction.Has(6))
require.Equal(t, 1, subtraction.Size())
}
29 changes: 29 additions & 0 deletions ds/set.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,32 @@ func NewSetMutations[ElementType comparable](elements ...ElementType) SetMutatio
}

// endregion ///////////////////////////////////////////////////////////////////////////////////////////////////////////

// region SetArithmetic ////////////////////////////////////////////////////////////////////////////////////////////////

// SetArithmetic is an interface that allows to perform arithmetic operations on a set of elements to return the
// resulting mutations of the operation.
type SetArithmetic[ElementType comparable] interface {
// Add adds the given mutations to the elements and returns the resulting net mutations for the set that are formed
// by tracking the elements that rise above the given threshold (defaults to 1).
Add(mutations SetMutations[ElementType], threshold ...int) SetMutations[ElementType]

// AddedElementsCollector returns a function that adds an element to the given mutations if its occurrence count
// reaches the given threshold (defaults to 1) after the addition.
AddedElementsCollector(mutations SetMutations[ElementType], threshold ...int) func(addedElement ElementType)

// Subtract subtracts the given mutations from the elements and returns the resulting net mutations for the set that
// are formed by tracking the elements that fall below the given threshold (defaults to 1).
Subtract(mutations SetMutations[ElementType], threshold ...int) SetMutations[ElementType]

// SubtractedElementsCollector returns a function that deletes an element from the given mutations if its occurrence
// count falls below the given threshold (defaults to 1) after the subtraction.
SubtractedElementsCollector(mutations SetMutations[ElementType], threshold ...int) func(ElementType)
}

// NewSetArithmetic creates a new SetArithmetic instance.
func NewSetArithmetic[ElementType comparable]() SetArithmetic[ElementType] {
return newSetArithmetic[ElementType]()
}

// endregion ///////////////////////////////////////////////////////////////////////////////////////////////////////////
64 changes: 64 additions & 0 deletions ds/set_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"sync"

"github.com/iotaledger/hive.go/ds/orderedmap"
"github.com/iotaledger/hive.go/ds/shrinkingmap"
"github.com/iotaledger/hive.go/ds/types"
"github.com/iotaledger/hive.go/ds/walker"
"github.com/iotaledger/hive.go/ierrors"
Expand Down Expand Up @@ -321,3 +322,66 @@ func (m *setMutations[ElementType]) IsEmpty() bool {
}

// endregion ///////////////////////////////////////////////////////////////////////////////////////////////////////////

// region setArithmetic ////////////////////////////////////////////////////////////////////////////////////////////////

// setArithmetic is the default implementation of the SetArithmetic interface.
type setArithmetic[ElementType comparable] struct {
// ShrinkingMap is used to keep track of the number of times an element is present.
*shrinkingmap.ShrinkingMap[ElementType, int]
}

// newSetArithmetic creates a new setArithmetic instance.
func newSetArithmetic[ElementType comparable]() *setArithmetic[ElementType] {
return &setArithmetic[ElementType]{
ShrinkingMap: shrinkingmap.New[ElementType, int](),
}
}

// Add adds the given mutations to the elements and returns the resulting net mutations for the set that are formed by
// tracking the elements that rise above the given threshold.
func (s *setArithmetic[ElementType]) Add(mutations SetMutations[ElementType], threshold ...int) SetMutations[ElementType] {
m := NewSetMutations[ElementType]()

mutations.AddedElements().Range(s.AddedElementsCollector(m, threshold...))
mutations.DeletedElements().Range(s.SubtractedElementsCollector(m, threshold...))

return m
}

// AddedElementsCollector returns a function that adds an element to the given mutations if its occurrence count reaches
// the given threshold (after the addition).
func (s *setArithmetic[ElementType]) AddedElementsCollector(mutations SetMutations[ElementType], threshold ...int) func(ElementType) {
return s.elementsCollector(mutations.AddedElements(), mutations.DeletedElements(), true, lo.First(threshold, 1))
}

// Subtract subtracts the given mutations from the elements and returns the resulting net mutations for the set that are
// formed by tracking the elements that fall below the given threshold.
func (s *setArithmetic[ElementType]) Subtract(mutations SetMutations[ElementType], threshold ...int) SetMutations[ElementType] {
m := NewSetMutations[ElementType]()

mutations.AddedElements().Range(s.SubtractedElementsCollector(m, threshold...))
mutations.DeletedElements().Range(s.AddedElementsCollector(m, threshold...))

return m
}

// SubtractedElementsCollector returns a function that deletes an element from the given mutations if its occurrence count
// falls below the given threshold (after the subtraction)
func (s *setArithmetic[ElementType]) SubtractedElementsCollector(mutations SetMutations[ElementType], threshold ...int) func(ElementType) {
return s.elementsCollector(mutations.DeletedElements(), mutations.AddedElements(), false, lo.First(threshold, 1))
}

// elementsCollector returns a function that collects elements in the given sets that pass the given threshold in either
// direction.
func (s *setArithmetic[ElementType]) elementsCollector(targetSet, opposingSet Set[ElementType], increase bool, threshold int) func(ElementType) {
return func(element ElementType) {
if s.Compute(element, func(currentValue int, _ bool) int {
return currentValue + lo.Cond(increase, 1, -1)
}) == lo.Cond(increase, threshold, threshold-1) && !opposingSet.Delete(element) {
targetSet.Add(element)
}
}
}

// endregion ///////////////////////////////////////////////////////////////////////////////////////////////////////////

0 comments on commit 84bc933

Please sign in to comment.