diff --git a/release-notes.txt b/release-notes.txt index 293fd16..fe22bfe 100644 --- a/release-notes.txt +++ b/release-notes.txt @@ -10,6 +10,7 @@ Release notes: - adds TaskSeq.unfold and TaskSeq.unfoldAsync, #289 - adds TaskSeq.chunkBySize (closes #258) and TaskSeq.windowed, #289 - fixes: CancellationToken passed to GetAsyncEnumerator is now honored in MoveNextAsync, #179 + - adds TaskSeq.withCancellation, #167 0.6.0 - fixes: async { for item in taskSeq do ... } no longer wraps exceptions in AggregateException, #129 diff --git a/src/FSharp.Control.TaskSeq.Test/FSharp.Control.TaskSeq.Test.fsproj b/src/FSharp.Control.TaskSeq.Test/FSharp.Control.TaskSeq.Test.fsproj index abc8b82..c8b5051 100644 --- a/src/FSharp.Control.TaskSeq.Test/FSharp.Control.TaskSeq.Test.fsproj +++ b/src/FSharp.Control.TaskSeq.Test/FSharp.Control.TaskSeq.Test.fsproj @@ -71,6 +71,7 @@ + diff --git a/src/FSharp.Control.TaskSeq.Test/TaskSeq.WithCancellation.Tests.fs b/src/FSharp.Control.TaskSeq.Test/TaskSeq.WithCancellation.Tests.fs new file mode 100644 index 0000000..10b31b7 --- /dev/null +++ b/src/FSharp.Control.TaskSeq.Test/TaskSeq.WithCancellation.Tests.fs @@ -0,0 +1,172 @@ +module TaskSeq.Tests.``WithCancellation`` + +open System +open System.Collections.Generic +open System.Threading +open System.Threading.Tasks + +open Xunit +open FsUnit.Xunit + +open FSharp.Control + +/// A simple IAsyncEnumerable whose GetAsyncEnumerator records the token it was called with. +type TokenCapturingSeq<'T>(items: 'T list) = + let mutable capturedToken = CancellationToken.None + + member _.CapturedToken = capturedToken + + interface IAsyncEnumerable<'T> with + member _.GetAsyncEnumerator(ct) = + capturedToken <- ct + + let source = taskSeq { + for x in items do + yield x + } + + source.GetAsyncEnumerator(ct) + +module ``Null check`` = + + [] + let ``TaskSeq-withCancellation: null source throws ArgumentNullException`` () = + assertNullArg + <| fun () -> TaskSeq.withCancellation CancellationToken.None null + +module ``Token threading`` = + + [] + let ``TaskSeq-withCancellation: passes supplied token to GetAsyncEnumerator`` () = task { + let source = TokenCapturingSeq([ 1; 2; 3 ]) + use cts = new CancellationTokenSource() + + let wrapped = TaskSeq.withCancellation cts.Token (source :> IAsyncEnumerable<_>) + let! _ = TaskSeq.toArrayAsync wrapped + source.CapturedToken |> should equal cts.Token + } + + [] + let ``TaskSeq-withCancellation: overrides any token passed to GetAsyncEnumerator`` () = task { + let source = TokenCapturingSeq([ 1; 2; 3 ]) + use cts = new CancellationTokenSource() + + let wrapped = TaskSeq.withCancellation cts.Token (source :> IAsyncEnumerable<_>) + + // Consume with a different token; withCancellation should win + use outerCts = new CancellationTokenSource() + let enum = wrapped.GetAsyncEnumerator(outerCts.Token) + + while! enum.MoveNextAsync() do + () + + source.CapturedToken |> should equal cts.Token + } + + [] + let ``TaskSeq-withCancellation: CancellationToken.None passes through correctly`` () = task { + let source = TokenCapturingSeq([ 10; 20 ]) + + let wrapped = TaskSeq.withCancellation CancellationToken.None (source :> IAsyncEnumerable<_>) + let! _ = TaskSeq.toArrayAsync wrapped + source.CapturedToken |> should equal CancellationToken.None + } + +module ``Cancellation behaviour`` = + + [] + let ``TaskSeq-withCancellation: pre-cancelled token causes OperationCanceledException on iteration`` () = task { + use cts = new CancellationTokenSource() + cts.Cancel() + + let source = taskSeq { + while true do + yield 1 + } + + let wrapped = TaskSeq.withCancellation cts.Token source + + fun () -> TaskSeq.iter ignore wrapped |> Task.ignore + |> should throwAsync typeof + } + + [] + let ``TaskSeq-withCancellation: token cancelled mid-iteration raises OperationCanceledException`` () = task { + use cts = new CancellationTokenSource() + + let source = taskSeq { + for i in 1..100 do + yield i + } + + let wrapped = TaskSeq.withCancellation cts.Token source + + fun () -> + task { + let mutable count = 0 + use enum = wrapped.GetAsyncEnumerator(CancellationToken.None) + + while! enum.MoveNextAsync() do + count <- count + 1 + + if count = 3 then + cts.Cancel() + } + |> Task.ignore + |> should throwAsync typeof + } + +module ``Sequence contents`` = + + [] + let ``TaskSeq-withCancellation: empty source produces empty sequence`` () = + TaskSeq.empty + |> TaskSeq.withCancellation CancellationToken.None + |> verifyEmpty + + [] + let ``TaskSeq-withCancellation: finite source produces all items`` () = task { + let! result = + taskSeq { + for i in 1..10 do + yield i + } + |> TaskSeq.withCancellation CancellationToken.None + |> TaskSeq.toArrayAsync + + result |> should equal [| 1..10 |] + } + + [] + let ``TaskSeq-withCancellation: can be used with TaskSeq combinators`` () = task { + use cts = new CancellationTokenSource() + + let! result = + taskSeq { + for i in 1..5 do + yield i + } + |> TaskSeq.withCancellation cts.Token + |> TaskSeq.map (fun x -> x * 2) + |> TaskSeq.toArrayAsync + + result |> should equal [| 2; 4; 6; 8; 10 |] + } + + [] + let ``TaskSeq-withCancellation: can be piped like .WithCancellation usage pattern`` () = task { + use cts = new CancellationTokenSource() + let mutable collected = ResizeArray() + + let source = taskSeq { + for i in 1..5 do + yield i + } + + do! + source + |> TaskSeq.withCancellation cts.Token + |> TaskSeq.iterAsync (fun x -> task { collected.Add(x) }) + + collected |> Seq.toArray |> should equal [| 1..5 |] + } diff --git a/src/FSharp.Control.TaskSeq/TaskSeq.fs b/src/FSharp.Control.TaskSeq/TaskSeq.fs index 046cca2..28e1f13 100644 --- a/src/FSharp.Control.TaskSeq/TaskSeq.fs +++ b/src/FSharp.Control.TaskSeq/TaskSeq.fs @@ -258,6 +258,13 @@ type TaskSeq private () = yield c } + static member withCancellation (cancellationToken: CancellationToken) (source: TaskSeq<'T>) = + Internal.checkNonNull (nameof source) source + + { new IAsyncEnumerable<'T> with + member _.GetAsyncEnumerator(_ct) = source.GetAsyncEnumerator(cancellationToken) + } + // // Utility functions // diff --git a/src/FSharp.Control.TaskSeq/TaskSeq.fsi b/src/FSharp.Control.TaskSeq/TaskSeq.fsi index cd47d20..89fd22f 100644 --- a/src/FSharp.Control.TaskSeq/TaskSeq.fsi +++ b/src/FSharp.Control.TaskSeq/TaskSeq.fsi @@ -1,6 +1,7 @@ namespace FSharp.Control open System.Collections.Generic +open System.Threading open System.Threading.Tasks [] @@ -602,6 +603,23 @@ type TaskSeq = /// Thrown when the input sequence is null. static member ofAsyncArray: source: Async<'T> array -> TaskSeq<'T> + /// + /// Returns a task sequence that, when iterated, passes the given to the + /// underlying . This is the equivalent of calling + /// .WithCancellation(cancellationToken) on an . + /// + /// + /// The supplied to this function overrides any token that would otherwise + /// be passed to the enumerator. This is useful when consuming sequences from libraries such as Entity Framework, + /// which accept a through GetAsyncEnumerator. + /// + /// + /// The cancellation token to pass to GetAsyncEnumerator. + /// The input task sequence. + /// A task sequence that uses the given when iterated. + /// Thrown when the input task sequence is null. + static member withCancellation: cancellationToken: CancellationToken -> source: TaskSeq<'T> -> TaskSeq<'T> + /// /// Views each item in the input task sequence as , boxing value types. ///