diff --git a/src/Nito.AsyncEx.Tasks/Interop/ApmAsyncFactory.cs b/src/Nito.AsyncEx.Tasks/Interop/ApmAsyncFactory.cs index e694b90..59a834d 100644 --- a/src/Nito.AsyncEx.Tasks/Interop/ApmAsyncFactory.cs +++ b/src/Nito.AsyncEx.Tasks/Interop/ApmAsyncFactory.cs @@ -1,4 +1,5 @@ using System; +using System.Threading; using System.Threading.Tasks; using Nito.AsyncEx.Synchronous; @@ -18,6 +19,18 @@ public static class ApmAsyncFactory /// The asynchronous operation, to be returned by the Begin method of the APM pattern. public static IAsyncResult ToBegin(Task task, AsyncCallback callback, object state) { + if (task == null) + { + throw new ArgumentNullException(nameof(task)); + } + + if (task.IsCompleted) + { + // we need this so it throws in case of faulted task + task.GetAwaiter().GetResult(); + return new CompletedAsyncResult(state); + } + var tcs = new TaskCompletionSource(state, TaskCreationOptions.RunContinuationsAsynchronously); SynchronizationContextSwitcher.NoContext(() => CompleteAsync(task, callback, tcs)); return tcs.Task; @@ -54,7 +67,18 @@ private static async void CompleteAsync(Task task, AsyncCallback callback, TaskC /// The result of the asynchronous operation, to be returned by the End method of the APM pattern. public static void ToEnd(IAsyncResult asyncResult) { - ((Task)asyncResult).WaitAndUnwrapException(); + if (asyncResult is Task task) + { + task.GetAwaiter().GetResult(); + } + else if (asyncResult is CompletedAsyncResult) + { + // Do nothing + } + else + { + throw new ArgumentException("Invalid asyncResult", nameof(asyncResult)); + } } /// @@ -66,6 +90,16 @@ public static void ToEnd(IAsyncResult asyncResult) /// The asynchronous operation, to be returned by the Begin method of the APM pattern. public static IAsyncResult ToBegin(Task task, AsyncCallback callback, object state) { + if (task == null) + { + throw new ArgumentNullException(nameof(task)); + } + + if (task.IsCompleted) + { + return new CompletedAsyncResult(task.GetAwaiter().GetResult(), state); + } + var tcs = new TaskCompletionSource(state, TaskCreationOptions.RunContinuationsAsynchronously); SynchronizationContextSwitcher.NoContext(() => CompleteAsync(task, callback, tcs)); return tcs.Task; @@ -101,7 +135,36 @@ private static async void CompleteAsync(Task task, AsyncCallba /// The result of the asynchronous operation, to be returned by the End method of the APM pattern. public static TResult ToEnd(IAsyncResult asyncResult) { - return ((Task)asyncResult).WaitAndUnwrapException(); + return asyncResult switch + { + Task task => task.GetAwaiter().GetResult(), + CompletedAsyncResult completedAsyncResult => completedAsyncResult.Result, + _ => throw new ArgumentException("Invalid asyncResult", nameof(asyncResult)) + }; + } + + internal class CompletedAsyncResult : CompletedAsyncResult + { + public CompletedAsyncResult(T result, object? state = null) + : base(state) + { + Result = result; + } + + public T Result { get; } + } + + internal class CompletedAsyncResult : IAsyncResult + { + public CompletedAsyncResult(object? state = null) + { + AsyncState = state; + } + + public bool IsCompleted { get; } = true; + public bool CompletedSynchronously { get; } = true; + public WaitHandle? AsyncWaitHandle { get; } + public object? AsyncState { get; } } } }