Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 100 additions & 71 deletions src/MongoDB.Driver/Core/Misc/StreamExtensionMethods.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,41 +38,14 @@ public static void EfficientCopyTo(this Stream input, Stream output)

public static int Read(this Stream stream, byte[] buffer, int offset, int count, TimeSpan timeout, CancellationToken cancellationToken)
{
try
{
using var manualResetEvent = new ManualResetEventSlim();
var readOperation = stream.BeginRead(
buffer,
offset,
count,
state => ((ManualResetEventSlim)state.AsyncState).Set(),
manualResetEvent);

if (readOperation.IsCompleted || manualResetEvent.Wait(timeout, cancellationToken))
{
return stream.EndRead(readOperation);
}
}
catch (OperationCanceledException)
{
// Have to suppress OperationCanceledException here, it will be thrown after the stream will be disposed.
}
catch (ObjectDisposedException)
{
throw new IOException();
}

try
{
stream.Dispose();
}
catch
{
// Ignore any exceptions
}

cancellationToken.ThrowIfCancellationRequested();
throw new TimeoutException();
return ExecuteOperationWithTimeout(
stream,
(str, state) => str.Read(state.Buffer, state.Offset, state.Count),
buffer,
offset,
count,
timeout,
cancellationToken);
}

public static async Task<int> ReadAsync(this Stream stream, byte[] buffer, int offset, int count, TimeSpan timeout, CancellationToken cancellationToken)
Expand Down Expand Up @@ -219,43 +192,18 @@ public static async Task ReadBytesAsync(this Stream stream, byte[] destination,

public static void Write(this Stream stream, byte[] buffer, int offset, int count, TimeSpan timeout, CancellationToken cancellationToken)
{
try
{
using var manualResetEvent = new ManualResetEventSlim();
var writeOperation = stream.BeginWrite(
buffer,
offset,
count,
state => ((ManualResetEventSlim)state.AsyncState).Set(),
manualResetEvent);

if (writeOperation.IsCompleted || manualResetEvent.Wait(timeout, cancellationToken))
ExecuteOperationWithTimeout(
stream,
(str, state) =>
{
stream.EndWrite(writeOperation);
return;
}
}
catch (OperationCanceledException)
{
// Have to suppress OperationCanceledException here, it will be thrown after the stream will be disposed.
}
catch (ObjectDisposedException)
{
// It's possible to get ObjectDisposedException when the connection pool was closed with interruptInUseConnections set to true.
throw new IOException();
}

try
{
stream.Dispose();
}
catch
{
// Ignore any exceptions
}

cancellationToken.ThrowIfCancellationRequested();
throw new TimeoutException();
str.Write(state.Buffer, state.Offset, state.Count);
return true;
},
buffer,
offset,
count,
timeout,
cancellationToken);
}

public static async Task WriteAsync(this Stream stream, byte[] buffer, int offset, int count, TimeSpan timeout, CancellationToken cancellationToken)
Expand Down Expand Up @@ -325,5 +273,86 @@ public static async Task WriteBytesAsync(this Stream stream, OperationContext op
count -= bytesToWrite;
}
}

private static TResult ExecuteOperationWithTimeout<TResult>(Stream stream, Func<Stream, (byte[] Buffer, int Offset, int Count), TResult> operation, byte[] buffer, int offset, int count, TimeSpan timeout, CancellationToken cancellationToken)
{
StreamDisposeCallbackState callbackState = null;
Timer timer = null;
CancellationTokenRegistration cancellationSubscription = default;
if (timeout != Timeout.InfiniteTimeSpan)
{
callbackState = new StreamDisposeCallbackState(stream);
timer = new Timer(DisposeStreamCallback, callbackState, timeout, Timeout.InfiniteTimeSpan);
}

if (cancellationToken.CanBeCanceled)
{
callbackState ??= new StreamDisposeCallbackState(stream);
cancellationSubscription = cancellationToken.Register(DisposeStreamCallback, callbackState);
}

try
{
var result = operation(stream, (buffer, offset, count));
if (callbackState?.TryChangeStateFromInProgress(OperationState.Done) == false)
{
// if cannot change the state - then the stream was/will be disposed, throw here
throw new IOException();
}

return result;
}
catch (IOException)
{
if (callbackState?.OperationState == OperationState.Cancelled)
{
cancellationToken.ThrowIfCancellationRequested();
throw new TimeoutException();
}

throw;
}
finally
{
timer?.Dispose();
cancellationSubscription.Dispose();
}

static void DisposeStreamCallback(object state)
{
var disposeCallbackState = (StreamDisposeCallbackState)state;
if (!disposeCallbackState.TryChangeStateFromInProgress(OperationState.Cancelled))
{
// if cannot change the state - then I/O was already succeeded
return;
}

try
{
disposeCallbackState.Stream.Dispose();
}
catch (Exception)
{
// callbacks should not fail, suppress any exceptions here
}
}
}

private record StreamDisposeCallbackState(Stream Stream)
{
private int _operationState = 0;

public OperationState OperationState => (OperationState)_operationState;

public bool TryChangeStateFromInProgress(OperationState newState) =>
Interlocked.CompareExchange(ref _operationState, (int)newState, (int)OperationState.InProgress) == (int)OperationState.InProgress;
}

private enum OperationState
{
InProgress = 0,
Done,
Cancelled,
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -811,19 +811,8 @@ public async Task SendMessage_should_put_the_message_on_the_stream_and_raise_the

private void SetupStreamRead(Mock<Stream> streamMock, TaskCompletionSource<int> tcs)
{
streamMock.Setup(s => s.BeginRead(It.IsAny<byte[]>(), It.IsAny<int>(), It.IsAny<int>(), It.IsAny<AsyncCallback>(), It.IsAny<object>()))
.Returns((byte[] _, int __, int ___, AsyncCallback callback, object state) =>
{
var innerTcs = new TaskCompletionSource<int>(state);
tcs.Task.ContinueWith(t =>
{
innerTcs.TrySetException(t.Exception.InnerException);
callback(innerTcs.Task);
});
return innerTcs.Task;
});
streamMock.Setup(s => s.EndRead(It.IsAny<IAsyncResult>()))
.Returns<IAsyncResult>(x => ((Task<int>)x).GetAwaiter().GetResult());
streamMock.Setup(s => s.Read(It.IsAny<byte[]>(), It.IsAny<int>(), It.IsAny<int>()))
.Returns((byte[] _, int __, int ___) => tcs.Task.GetAwaiter().GetResult());
streamMock.Setup(s => s.ReadAsync(It.IsAny<byte[]>(), It.IsAny<int>(), It.IsAny<int>(), It.IsAny<CancellationToken>()))
.Returns(tcs.Task);
streamMock.Setup(s => s.Close()).Callback(() => tcs.TrySetException(new ObjectDisposedException("stream")));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,20 +90,18 @@ public async Task ReadBytes_with_byte_array_should_have_expected_effect_for_part
var bytes = new byte[] { 1, 2, 3 };
var n = 0;
var position = 0;
Task<int> ReadPartial (byte[] buffer, int offset, int count)
int ReadPartial (byte[] buffer, int offset, int count)
{
var length = partition[n++];
Buffer.BlockCopy(bytes, position, buffer, offset, length);
position += length;
return Task.FromResult(length);
return length;
}

mockStream.Setup(s => s.ReadAsync(It.IsAny<byte[]>(), It.IsAny<int>(), It.IsAny<int>(), It.IsAny<CancellationToken>()))
.Returns((byte[] buffer, int offset, int count, CancellationToken cancellationToken) => ReadPartial(buffer, offset, count));
mockStream.Setup(s => s.BeginRead(It.IsAny<byte[]>(), It.IsAny<int>(), It.IsAny<int>(), It.IsAny<AsyncCallback>(), It.IsAny<object>()))
.Returns((byte[] buffer, int offset, int count, AsyncCallback callback, object state) => ReadPartial(buffer, offset, count));
mockStream.Setup(s => s.EndRead(It.IsAny<IAsyncResult>()))
.Returns<IAsyncResult>(x => ((Task<int>)x).GetAwaiter().GetResult());
.Returns((byte[] buffer, int offset, int count, CancellationToken cancellationToken) => Task.FromResult(ReadPartial(buffer, offset, count)));
mockStream.Setup(s => s.Read(It.IsAny<byte[]>(), It.IsAny<int>(), It.IsAny<int>()))
.Returns((byte[] buffer, int offset, int count) => ReadPartial(buffer, offset, count));
var destination = new byte[3];

if (async)
Expand Down Expand Up @@ -267,20 +265,18 @@ public async Task ReadBytes_with_byte_buffer_should_have_expected_effect_for_par
var destination = new ByteArrayBuffer(new byte[3], 3);
var n = 0;
var position = 0;
Task<int> ReadPartial (byte[] buffer, int offset, int count)
int ReadPartial (byte[] buffer, int offset, int count)
{
var length = partition[n++];
Buffer.BlockCopy(bytes, position, buffer, offset, length);
position += length;
return Task.FromResult(length);
return length;
}

mockStream.Setup(s => s.ReadAsync(It.IsAny<byte[]>(), It.IsAny<int>(), It.IsAny<int>(), It.IsAny<CancellationToken>()))
.Returns((byte[] buffer, int offset, int count, CancellationToken cancellationToken) => ReadPartial(buffer, offset, count));
mockStream.Setup(s => s.BeginRead(It.IsAny<byte[]>(), It.IsAny<int>(), It.IsAny<int>(), It.IsAny<AsyncCallback>(), It.IsAny<object>()))
.Returns((byte[] buffer, int offset, int count, AsyncCallback callback, object state) => ReadPartial(buffer, offset, count));
mockStream.Setup(s => s.EndRead(It.IsAny<IAsyncResult>()))
.Returns<IAsyncResult>(x => ((Task<int>)x).GetAwaiter().GetResult());
.Returns((byte[] buffer, int offset, int count, CancellationToken cancellationToken) => Task.FromResult(ReadPartial(buffer, offset, count)));
mockStream.Setup(s => s.Read(It.IsAny<byte[]>(), It.IsAny<int>(), It.IsAny<int>()))
.Returns((byte[] buffer, int offset, int count) => ReadPartial(buffer, offset, count));

if (async)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ public void Heartbeat_should_be_emitted_before_connection_open()

var mockStream = new Mock<Stream>();
mockStream
.Setup(s => s.BeginWrite(It.IsAny<byte[]>(), It.IsAny<int>(), It.IsAny<int>(), It.IsAny<AsyncCallback>(), It.IsAny<object>()))
.Setup(s => s.Write(It.IsAny<byte[]>(), It.IsAny<int>(), It.IsAny<int>()))
.Callback(() => EnqueueEvent(HelloReceivedEvent))
.Throws(new Exception("Stream is closed."));

Expand Down