Skip to content

Commit

Permalink
Add UniTaskAsyncEnumerable.Merge
Browse files Browse the repository at this point in the history
  • Loading branch information
hadashiA committed Sep 8, 2023
1 parent b5779d2 commit e32a5b5
Show file tree
Hide file tree
Showing 2 changed files with 197 additions and 44 deletions.
126 changes: 121 additions & 5 deletions src/UniTask.NetCoreTests/Linq/Merge.cs
Original file line number Diff line number Diff line change
@@ -1,21 +1,137 @@
using System;
using System.Threading;
using System.Threading.Tasks;
using Cysharp.Threading.Tasks;
using Cysharp.Threading.Tasks.Linq;
using FluentAssertions;
using Xunit;

namespace NetCoreTests.Linq
{
public class Merge
public class MergeTest
{
[Fact]
public async Task Hoge()
public async Task TwoSource()
{
var a = UniTaskAsyncEnumerable.Range(0, 5).Select(x => (x + 1) * 100);
var b = UniTaskAsyncEnumerable.Range(0, 3).Select(x => (x + 1) * 200);
var semaphore = new SemaphoreSlim(1, 1);

var a = UniTaskAsyncEnumerable.Create<string>(async (writer, _) =>
{
await UniTask.SwitchToThreadPool();
await semaphore.WaitAsync();
await writer.YieldAsync("A1");
semaphore.Release();
await semaphore.WaitAsync();
await writer.YieldAsync("A2");
semaphore.Release();
});

var b = UniTaskAsyncEnumerable.Create<string>(async (writer, _) =>
{
await UniTask.SwitchToThreadPool();
await semaphore.WaitAsync();
await writer.YieldAsync("B1");
await writer.YieldAsync("B2");
semaphore.Release();
await semaphore.WaitAsync();
await writer.YieldAsync("B3");
semaphore.Release();
});

var result = await a.Merge(b).ToArrayAsync();
result.Should().Equal("A1", "B1", "B2", "A2", "B3");
}

[Fact]
public async Task ThreeSource()
{
var semaphore = new SemaphoreSlim(0, 1);

var a = UniTaskAsyncEnumerable.Create<string>(async (writer, _) =>
{
await UniTask.SwitchToThreadPool();
await semaphore.WaitAsync();
await writer.YieldAsync("A1");
semaphore.Release();
await semaphore.WaitAsync();
await writer.YieldAsync("A2");
semaphore.Release();
});

var b = UniTaskAsyncEnumerable.Create<string>(async (writer, _) =>
{
await UniTask.SwitchToThreadPool();
await semaphore.WaitAsync();
await writer.YieldAsync("B1");
await writer.YieldAsync("B2");
semaphore.Release();
await semaphore.WaitAsync();
await writer.YieldAsync("B3");
semaphore.Release();
});

var c = UniTaskAsyncEnumerable.Create<string>(async (writer, _) =>
{
await UniTask.SwitchToThreadPool();
await writer.YieldAsync("C1");
semaphore.Release();
});

var result = await a.Merge(b, c).ToArrayAsync();
result.Should().Equal("C1", "A1", "B1", "B2", "A2", "B3");
}

[Fact]
public async Task Throw()
{
var a = UniTaskAsyncEnumerable.Create<string>(async (writer, _) =>
{
await writer.YieldAsync("A1");
});

var b = UniTaskAsyncEnumerable.Create<string>(async (writer, _) =>
{
throw new UniTaskTestException();
});

var enumerator = a.Merge(b).GetAsyncEnumerator();
(await enumerator.MoveNextAsync()).Should().Be(true);
enumerator.Current.Should().Be(100);
enumerator.Current.Should().Be("A1");

await Assert.ThrowsAsync<UniTaskTestException>(async () => await enumerator.MoveNextAsync());
}

[Fact]
public async Task Cancel()
{
var cts = new CancellationTokenSource();

var a = UniTaskAsyncEnumerable.Create<string>(async (writer, _) =>
{
await writer.YieldAsync("A1");
});

var b = UniTaskAsyncEnumerable.Create<string>(async (writer, _) =>
{
await writer.YieldAsync("B1");
});

var enumerator = a.Merge(b).GetAsyncEnumerator(cts.Token);
(await enumerator.MoveNextAsync()).Should().Be(true);
enumerator.Current.Should().Be("A1");

cts.Cancel();
await Assert.ThrowsAsync<OperationCanceledException>(async () => await enumerator.MoveNextAsync());
}
}
}
115 changes: 76 additions & 39 deletions src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Merge.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using Cysharp.Threading.Tasks.Internal;

Expand All @@ -12,7 +13,21 @@ public static IUniTaskAsyncEnumerable<T> Merge<T>(this IUniTaskAsyncEnumerable<T
Error.ThrowArgumentNullException(first, nameof(first));
Error.ThrowArgumentNullException(second, nameof(second));

return new Merge<T>(first, second);
return new Merge<T>(new [] { first, second });
}

public static IUniTaskAsyncEnumerable<T> Merge<T>(this IUniTaskAsyncEnumerable<T> first, IUniTaskAsyncEnumerable<T> second, IUniTaskAsyncEnumerable<T> third)
{
Error.ThrowArgumentNullException(first, nameof(first));
Error.ThrowArgumentNullException(second, nameof(second));
Error.ThrowArgumentNullException(third, nameof(third));

return new Merge<T>(new[] { first, second, third });
}

public static IUniTaskAsyncEnumerable<T> Merge<T>(this IEnumerable<IUniTaskAsyncEnumerable<T>> sources)
{
return new Merge<T>(sources.ToArray());
}

public static IUniTaskAsyncEnumerable<T> Merge<T>(params IUniTaskAsyncEnumerable<T>[] sources)
Expand All @@ -25,7 +40,7 @@ internal sealed class Merge<T> : IUniTaskAsyncEnumerable<T>
{
readonly IUniTaskAsyncEnumerable<T>[] sources;

public Merge(params IUniTaskAsyncEnumerable<T>[] sources)
public Merge(IUniTaskAsyncEnumerable<T>[] sources)
{
if (sources.Length <= 0)
{
Expand All @@ -51,7 +66,7 @@ sealed class _Merge : MoveNextSource, IUniTaskAsyncEnumerator<T>
readonly int length;
readonly IUniTaskAsyncEnumerator<T>[] enumerators;
readonly MergeSourceState[] states;
readonly Queue<T> resultQueue = new Queue<T>();
readonly Queue<(T, Exception)> resultQueue = new Queue<(T, Exception)>();
readonly CancellationToken cancellationToken;

public T Current { get; private set; }
Expand All @@ -64,7 +79,7 @@ public _Merge(IUniTaskAsyncEnumerable<T>[] sources, CancellationToken cancellati
enumerators = ArrayPool<IUniTaskAsyncEnumerator<T>>.Shared.Rent(length);
for (var i = 0; i < length; i++)
{
enumerators[i] = sources[i].GetAsyncEnumerator();
enumerators[i] = sources[i].GetAsyncEnumerator(cancellationToken);
states[i] = MergeSourceState.Pending;
}
}
Expand All @@ -74,25 +89,29 @@ public UniTask<bool> MoveNextAsync()
cancellationToken.ThrowIfCancellationRequested();
completionSource.Reset();

if (TryDequeue(out var queuedValue))
if (TryDequeue(out var queuedValue, out var queuedException))
{
Current = queuedValue;
completionSource.TrySetResult(!HasIncompleteSource());
if (queuedException != null)
{
completionSource.TrySetException(queuedException);
}
else
{
Current = queuedValue;
completionSource.TrySetResult(!IsCompletedAll());
}
return new UniTask<bool>(this, completionSource.Version);
}

for (var i = 0; i < length; i++)
{
lock (states)
{
if (states[i] == MergeSourceState.Pending)
{
states[i] = MergeSourceState.Running;
}
else
if (states[i] != MergeSourceState.Pending)
{
continue;
}
states[i] = MergeSourceState.Running;
}

var awaiter = enumerators[i].MoveNextAsync().GetAwaiter();
Expand All @@ -102,24 +121,21 @@ public UniTask<bool> MoveNextAsync()
}
else
{
awaiter.SourceOnCompleted(GetResultAtAction, StateTuple.Create(this, i));
awaiter.SourceOnCompleted(GetResultAtAction, StateTuple.Create(this, i, awaiter));
}
}
return new UniTask<bool>(this, completionSource.Version);
}

public async UniTask DisposeAsync()
{
foreach (var enumerator in enumerators)
for (var i = 0; i < length; i++)
{
await enumerator.DisposeAsync();
await enumerators[i].DisposeAsync();
}

lock (states)
{
ArrayPool<MergeSourceState>.Shared.Return(states);
}
ArrayPool<IUniTaskAsyncEnumerator<T>>.Shared.Return(enumerators);
ArrayPool<MergeSourceState>.Shared.Return(states, true);
ArrayPool<IUniTaskAsyncEnumerator<T>>.Shared.Return(enumerators, true);
}

static void GetResultAt(object state)
Expand All @@ -130,55 +146,76 @@ static void GetResultAt(object state)

void GetResultAt(int index, UniTask<bool>.Awaiter awaiter)
{
if (!TryGetResult(awaiter, out var hasNext))
bool hasNext;
try
{
return;
hasNext = awaiter.GetResult();
lock (states)
{
states[index] = hasNext ? MergeSourceState.Pending : MergeSourceState.Completed;
}
}

lock (states)
catch (Exception ex)
{
states[index] = hasNext ? MergeSourceState.Pending : MergeSourceState.Completed;
if (!completionSource.TrySetException(ex))
{
lock (resultQueue)
{
resultQueue.Enqueue((default, ex));
}
}
return;
}

var recentValue = enumerators[index].Current;
if (completionSource.TrySetResult(HasIncompleteSource()))
var completed = IsCompletedAll();
if (hasNext || completed)
{
Current = recentValue;
}
else
{
lock (resultQueue)
if (completionSource.GetStatus(completionSource.Version).IsCompleted())
{
lock (resultQueue)
{
resultQueue.Enqueue((enumerators[index].Current, null));
}
}
else
{
resultQueue.Enqueue(recentValue);
Current = enumerators[index].Current;
completionSource.TrySetResult(!completed);
}
}
}

bool TryDequeue(out T value)
bool TryDequeue(out T value, out Exception ex)
{
lock (resultQueue)
{
if (resultQueue.Count > 0)
{
value = resultQueue.Dequeue();
var result = resultQueue.Dequeue();
value = result.Item1;
ex = result.Item2;
return true;
}
}
value = default;
ex = default;
return false;
}

bool HasIncompleteSource()
bool IsCompletedAll()
{
lock (states)
{
for (var i = 0; i < length; i++)
{
if (states[i] != MergeSourceState.Completed) return true;
if (states[i] != MergeSourceState.Completed)
{
return false;
}
}
return false;
return true;
}
}
}
}
}
}

0 comments on commit e32a5b5

Please sign in to comment.