diff --git a/NBitcoin/Protocol/Filters/NodeFiltersCollection.cs b/NBitcoin/Protocol/Filters/NodeFiltersCollection.cs index 5f57381fa2..7e9b978c7e 100644 --- a/NBitcoin/Protocol/Filters/NodeFiltersCollection.cs +++ b/NBitcoin/Protocol/Filters/NodeFiltersCollection.cs @@ -7,7 +7,7 @@ namespace NBitcoin.Protocol.Filters { - public class NodeFiltersCollection : ThreadSafeCollection + public class NodeFiltersCollection : ThreadSafeList { public IDisposable Add(Action onReceiving, Action onSending = null) { diff --git a/NBitcoin/Utils/ThreadSafeList.cs b/NBitcoin/Utils/ThreadSafeList.cs new file mode 100644 index 0000000000..dc891e9fd4 --- /dev/null +++ b/NBitcoin/Utils/ThreadSafeList.cs @@ -0,0 +1,123 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; + +namespace NBitcoin +{ + public class ThreadSafeList : IEnumerable + { + private List _Behaviors; + private object _lock = new object(); + + private List _EnumeratorList = null; + + public ThreadSafeList() + { + lock (_lock) + _Behaviors = new List(); + } + + /// + /// Add an item to the collection + /// + /// + /// When disposed, the item is removed + public IDisposable Add(T item) + { + if (item == null) + throw new ArgumentNullException(nameof(item)); + OnAdding(item); + lock (_lock) + { + _Behaviors.Add(item); + _EnumeratorList = null; + } + return new ActionDisposable(() => + { + }, () => Remove(item)); + } + + protected virtual void OnAdding(T obj) + { + } + protected virtual void OnRemoved(T obj) + { + } + + public bool Remove(T item) + { + bool removed = false; + lock (_lock) + { + removed = _Behaviors.Remove(item); + _EnumeratorList = null; + } + + if (removed) + OnRemoved(item); + return removed; + } + + public void Clear() + { + foreach (var behavior in this) + Remove(behavior); + } + + public T FindOrCreate() where U : T, new() + { + return FindOrCreate(() => new U()); + } + public U FindOrCreate(Func create) where U : T + { + var result = this.OfType().FirstOrDefault(); + if (result == null) + { + result = create(); + Add(result); + } + return result; + } + public U Find() where U : T + { + return this.OfType().FirstOrDefault(); + } + + public void Remove() where U : T + { + foreach (var b in this.OfType()) + { + Remove(b); + } + } + + #region IEnumerable Members + + public IEnumerator GetEnumerator() + { + IEnumerator enumerator = _EnumeratorList?.GetEnumerator(); + if (enumerator == null) + { + lock (_lock) + { + var behaviorsList = _Behaviors.ToList(); + _EnumeratorList = behaviorsList; + enumerator = behaviorsList.GetEnumerator(); + } + } + return enumerator; + } + + #endregion + + #region IEnumerable Members + + System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } + + #endregion + } +}