From 40bd270c71725a6be80cf35aa431ca3f540a4b9b Mon Sep 17 00:00:00 2001 From: Ahmet Ibrahim Aksoy Date: Tue, 5 Mar 2024 14:38:45 +0300 Subject: [PATCH] Implement HttpWebRequest AllowWriteStreamBuffering property (#95001) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Implement HttpWebRequest WriteStreamBuffering * Fix tests * Fix tests * Review feedback * Fix timeouts * fix test build * Hang test * Fix build * fix tests * delete unnecessary line * fix tests * refactor test * Dispose http client * Fix buffer write in RequestStream.Flush() method * Fix same length bug in FlushAsync * Update ContentLength in HttpWebRequestTest.cs * Fix flushing and ending request stream * Fix flushing and ending request stream * Fix FlushAsync method to handle cancellation * Update src/libraries/System.Net.Requests/src/System/Net/HttpClientContentStream.cs Co-authored-by: Anton Firszov * Review feedback * Bound streamBuffer lifecycle to HttpClientContentStream * Review feedback * Review feedback * Change ??= to = * change delay on test * Apply suggestions from code review Co-authored-by: Miha Zupan * Fix build * Review feedback * Apply suggestions from code review Co-authored-by: Miha Zupan * Review feedback * Apply suggestions from code review Co-authored-by: Marie Píchová <11718369+ManickaP@users.noreply.github.com> * Add ProtocolViolationException if we're not buffering and we didn't set either SendChunked or ContentLength * Review feedback * Add test for not buffering and sending the content before we call `GetResponse` * Update src/libraries/System.Net.Requests/src/System/Net/HttpWebRequest.cs Co-authored-by: Marie Píchová <11718369+ManickaP@users.noreply.github.com> * Remove exception * Review feedback * Fix string resources * Add RequestStreamContent class for handling request stream serialization * Seperate Buffering and Non-Buffering Streams and Connect non-buffering stream to HttpClient Internal Stream directly * Remove unnecessary code in HttpWebRequestTest * Use random data for testing * Update src/libraries/System.Net.Requests/src/System/Net/RequestBufferingStream.cs Co-authored-by: Anton Firszov * Remove default flushes and add flush on test * Review feedback * Removing unused code --------- Co-authored-by: Anton Firszov Co-authored-by: Miha Zupan Co-authored-by: Marie Píchová <11718369+ManickaP@users.noreply.github.com> --- .../src/System.Net.Requests.csproj | 2 + .../src/System/Net/HttpWebRequest.cs | 177 +++++++++++------- .../src/System/Net/RequestBufferingStream.cs | 134 +++++++++++++ .../src/System/Net/RequestStream.cs | 70 ++++--- .../src/System/Net/RequestStreamContent.cs | 31 +++ .../tests/HttpWebRequestTest.cs | 155 +++++++++++++-- 6 files changed, 466 insertions(+), 103 deletions(-) create mode 100644 src/libraries/System.Net.Requests/src/System/Net/RequestBufferingStream.cs create mode 100644 src/libraries/System.Net.Requests/src/System/Net/RequestStreamContent.cs diff --git a/src/libraries/System.Net.Requests/src/System.Net.Requests.csproj b/src/libraries/System.Net.Requests/src/System.Net.Requests.csproj index b53b7272ea417..46bda299d9a64 100644 --- a/src/libraries/System.Net.Requests/src/System.Net.Requests.csproj +++ b/src/libraries/System.Net.Requests/src/System.Net.Requests.csproj @@ -29,6 +29,7 @@ + @@ -48,6 +49,7 @@ + diff --git a/src/libraries/System.Net.Requests/src/System/Net/HttpWebRequest.cs b/src/libraries/System.Net.Requests/src/System/Net/HttpWebRequest.cs index 44c6011591365..8f37002cf206b 100644 --- a/src/libraries/System.Net.Requests/src/System/Net/HttpWebRequest.cs +++ b/src/libraries/System.Net.Requests/src/System/Net/HttpWebRequest.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Collections.Generic; using System.ComponentModel; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; @@ -12,6 +13,7 @@ using System.Net.Http.Headers; using System.Net.Security; using System.Net.Sockets; +using System.Runtime.CompilerServices; using System.Runtime.Serialization; using System.Security.Authentication; using System.Security.Cryptography.X509Certificates; @@ -40,6 +42,7 @@ public class HttpWebRequest : WebRequest, ISerializable private IWebProxy? _proxy = WebRequest.DefaultWebProxy; private Task? _sendRequestTask; + private HttpRequestMessage? _sendRequestMessage; private static int _defaultMaxResponseHeadersLength = HttpHandlerDefaults.DefaultMaxResponseHeadersLength; private static int _defaultMaximumErrorResponseLength = -1; @@ -62,7 +65,7 @@ public class HttpWebRequest : WebRequest, ISerializable private bool _hostHasPort; private Uri? _hostUri; - private RequestStream? _requestStream; + private Stream? _requestStream; private TaskCompletionSource? _requestStreamOperation; private TaskCompletionSource? _responseOperation; private AsyncCallback? _requestStreamCallback; @@ -78,6 +81,8 @@ public class HttpWebRequest : WebRequest, ISerializable private static readonly object s_syncRoot = new object(); private static volatile HttpClient? s_cachedHttpClient; private static HttpClientParameters? s_cachedHttpClientParameters; + private bool _disposeRequired; + private HttpClient? _httpClient; //these should be safe. [Flags] @@ -1003,17 +1008,17 @@ public override void Abort() { _responseCallback(_responseOperation.Task); } - - // Cancel the underlying send operation. - Debug.Assert(_sendRequestCts != null); - _sendRequestCts.Cancel(); } - else if (_requestStreamOperation != null) + if (_requestStreamOperation != null) { if (_requestStreamOperation.TrySetCanceled() && _requestStreamCallback != null) { _requestStreamCallback(_requestStreamOperation.Task); } + + // Cancel the underlying send operation. + Debug.Assert(_sendRequestCts != null); + _sendRequestCts.Cancel(); } } @@ -1041,8 +1046,7 @@ public override WebResponse GetResponse() { try { - _sendRequestCts = new CancellationTokenSource(); - return SendRequest(async: false).GetAwaiter().GetResult(); + return HandleResponse(async: false).GetAwaiter().GetResult(); } catch (Exception ex) { @@ -1052,10 +1056,11 @@ public override WebResponse GetResponse() public override Stream GetRequestStream() { + CheckRequestStream(); return InternalGetRequestStream().Result; } - private Task InternalGetRequestStream() + private void CheckRequestStream() { CheckAbort(); @@ -1073,10 +1078,28 @@ private Task InternalGetRequestStream() { throw new InvalidOperationException(SR.net_reqsubmitted); } + } - _requestStream = new RequestStream(); + private async Task InternalGetRequestStream() + { + // If we aren't buffering we need to open the connection right away. + // Because we need to send the data as soon as possible when it's available from the RequestStream. + // Making this allows us to keep the sync send request path for buffering cases. + if (AllowWriteStreamBuffering is false) + { + // We're calling SendRequest with async, because we need to open the connection and send the request + // Otherwise, sync path will block the current thread until the request is sent. + TaskCompletionSource getStreamTcs = new(); + TaskCompletionSource completeTcs = new(); + _sendRequestTask = SendRequest(async: true, new RequestStreamContent(getStreamTcs, completeTcs)); + _requestStream = new RequestStream(await getStreamTcs.Task.ConfigureAwait(false), completeTcs); + } + else + { + _requestStream = new RequestBufferingStream(); + } - return Task.FromResult((Stream)_requestStream); + return _requestStream; } public Stream EndGetRequestStream(IAsyncResult asyncResult, out TransportContext? context) @@ -1100,6 +1123,8 @@ public override IAsyncResult BeginGetRequestStream(AsyncCallback? callback, obje throw new InvalidOperationException(SR.net_repcall); } + CheckRequestStream(); + _requestStreamCallback = callback; _requestStreamOperation = InternalGetRequestStream().ToApm(callback, state); @@ -1133,78 +1158,95 @@ public override Stream EndGetRequestStream(IAsyncResult asyncResult) return stream; } - private async Task SendRequest(bool async) + private Task SendRequest(bool async, HttpContent? content = null) { if (RequestSubmitted) { throw new InvalidOperationException(SR.net_reqsubmitted); } - var request = new HttpRequestMessage(HttpMethod.Parse(_originVerb), _requestUri); + _sendRequestMessage = new HttpRequestMessage(HttpMethod.Parse(_originVerb), _requestUri); + _sendRequestCts = new CancellationTokenSource(); + _httpClient = GetCachedOrCreateHttpClient(async, out _disposeRequired); - bool disposeRequired = false; - HttpClient? client = null; - try + if (content is not null) + { + _sendRequestMessage.Content = content; + } + + if (_hostUri is not null) { - client = GetCachedOrCreateHttpClient(async, out disposeRequired); - if (_requestStream != null) + _sendRequestMessage.Headers.Host = Host; + } + + AddCacheControlHeaders(_sendRequestMessage); + + // Copy the HttpWebRequest request headers from the WebHeaderCollection into HttpRequestMessage.Headers and + // HttpRequestMessage.Content.Headers. + foreach (string headerName in _webHeaderCollection) + { + // The System.Net.Http APIs require HttpRequestMessage headers to be properly divided between the request headers + // collection and the request content headers collection for all well-known header names. And custom headers + // are only allowed in the request headers collection and not in the request content headers collection. + if (IsWellKnownContentHeader(headerName)) { - ArraySegment bytes = _requestStream.GetBuffer(); - request.Content = new ByteArrayContent(bytes.Array!, bytes.Offset, bytes.Count); + _sendRequestMessage.Content ??= new ByteArrayContent(Array.Empty()); + _sendRequestMessage.Content.Headers.TryAddWithoutValidation(headerName, _webHeaderCollection[headerName!]); } - - if (_hostUri != null) + else { - request.Headers.Host = Host; + _sendRequestMessage.Headers.TryAddWithoutValidation(headerName, _webHeaderCollection[headerName!]); } + } - AddCacheControlHeaders(request); + if (_servicePoint?.Expect100Continue == true) + { + _sendRequestMessage.Headers.ExpectContinue = true; + } - // Copy the HttpWebRequest request headers from the WebHeaderCollection into HttpRequestMessage.Headers and - // HttpRequestMessage.Content.Headers. - foreach (string headerName in _webHeaderCollection) - { - // The System.Net.Http APIs require HttpRequestMessage headers to be properly divided between the request headers - // collection and the request content headers collection for all well-known header names. And custom headers - // are only allowed in the request headers collection and not in the request content headers collection. - if (IsWellKnownContentHeader(headerName)) - { - // Create empty content so that we can send the entity-body header. - request.Content ??= new ByteArrayContent(Array.Empty()); + _sendRequestMessage.Headers.TransferEncodingChunked = SendChunked; - request.Content.Headers.TryAddWithoutValidation(headerName, _webHeaderCollection[headerName!]); - } - else - { - request.Headers.TryAddWithoutValidation(headerName, _webHeaderCollection[headerName!]); - } - } + if (KeepAlive) + { + _sendRequestMessage.Headers.Connection.Add(HttpKnownHeaderNames.KeepAlive); + } + else + { + _sendRequestMessage.Headers.ConnectionClose = true; + } - request.Headers.TransferEncodingChunked = SendChunked; + _sendRequestMessage.Version = ProtocolVersion; + HttpCompletionOption completionOption = _allowReadStreamBuffering ? HttpCompletionOption.ResponseContentRead : HttpCompletionOption.ResponseHeadersRead; + // If we're not buffering, there is no way to open the connection and not send the request without async. + // So we should use Async, if we're not buffering. + _sendRequestTask = async || !AllowWriteStreamBuffering ? + _httpClient.SendAsync(_sendRequestMessage, completionOption, _sendRequestCts.Token) : + Task.FromResult(_httpClient.Send(_sendRequestMessage, completionOption, _sendRequestCts.Token)); - if (KeepAlive) - { - request.Headers.Connection.Add(HttpKnownHeaderNames.KeepAlive); - } - else - { - request.Headers.ConnectionClose = true; - } + return _sendRequestTask!; + } - if (_servicePoint?.Expect100Continue == true) - { - request.Headers.ExpectContinue = true; - } + private async Task HandleResponse(bool async) + { + // If user code used requestStream and didn't dispose it + // We're completing it here. + if (_requestStream is RequestStream requestStream) + { + requestStream.Complete(); + } - request.Version = ProtocolVersion; + if (_sendRequestTask is null && _requestStream is RequestBufferingStream requestBufferingStream) + { + ArraySegment buffer = requestBufferingStream.GetBuffer(); + _sendRequestTask = SendRequest(async, new ByteArrayContent(buffer.Array!, buffer.Offset, buffer.Count)); + } - _sendRequestTask = async ? - client.SendAsync(request, _allowReadStreamBuffering ? HttpCompletionOption.ResponseContentRead : HttpCompletionOption.ResponseHeadersRead, _sendRequestCts!.Token) : - Task.FromResult(client.Send(request, _allowReadStreamBuffering ? HttpCompletionOption.ResponseContentRead : HttpCompletionOption.ResponseHeadersRead, _sendRequestCts!.Token)); + _sendRequestTask ??= SendRequest(async); + try + { HttpResponseMessage responseMessage = await _sendRequestTask.ConfigureAwait(false); - - HttpWebResponse response = new HttpWebResponse(responseMessage, _requestUri, _cookieContainer); + HttpWebResponse response = new(responseMessage, _requestUri, _cookieContainer); int maxSuccessStatusCode = AllowAutoRedirect ? 299 : 399; if ((int)response.StatusCode > maxSuccessStatusCode || (int)response.StatusCode < 200) @@ -1220,9 +1262,15 @@ private async Task SendRequest(bool async) } finally { - if (disposeRequired) + _sendRequestMessage?.Dispose(); + if (_requestStream is RequestBufferingStream bufferStream) { - client?.Dispose(); + bufferStream.GetMemoryStream().Dispose(); + } + + if (_disposeRequired) + { + _httpClient?.Dispose(); } } } @@ -1348,9 +1396,8 @@ public override IAsyncResult BeginGetResponse(AsyncCallback? callback, object? s throw new InvalidOperationException(SR.net_repcall); } - _sendRequestCts = new CancellationTokenSource(); _responseCallback = callback; - _responseOperation = SendRequest(async: true).ToApm(callback, state); + _responseOperation = HandleResponse(async: true).ToApm(callback, state); return _responseOperation.Task; } diff --git a/src/libraries/System.Net.Requests/src/System/Net/RequestBufferingStream.cs b/src/libraries/System.Net.Requests/src/System/Net/RequestBufferingStream.cs new file mode 100644 index 0000000000000..3a5bb170314e1 --- /dev/null +++ b/src/libraries/System.Net.Requests/src/System/Net/RequestBufferingStream.cs @@ -0,0 +1,134 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Net +{ + // Cache the request stream into a MemoryStream. + internal sealed class RequestBufferingStream : Stream + { + private bool _disposed; + private readonly MemoryStream _buffer = new MemoryStream(); + + public RequestBufferingStream() + { + } + + public override bool CanRead => false; + public override bool CanSeek => false; + public override bool CanWrite => true; + + public override void Flush() => ThrowIfDisposed(); // Nothing to do. + + public override Task FlushAsync(CancellationToken cancellationToken) + { + ThrowIfDisposed(); + // Nothing to do. + return cancellationToken.IsCancellationRequested ? + Task.FromCanceled(cancellationToken) : + Task.CompletedTask; + } + + public override long Length + { + get + { + throw new NotSupportedException(); + } + } + + public override long Position + { + get + { + throw new NotSupportedException(); + } + set + { + throw new NotSupportedException(); + } + } + + public override int Read(byte[] buffer, int offset, int count) + { + throw new NotSupportedException(); + } + + public override long Seek(long offset, SeekOrigin origin) + { + throw new NotSupportedException(); + } + + public override void SetLength(long value) + { + throw new NotSupportedException(); + } + + public override void Write(byte[] buffer, int offset, int count) + { + ThrowIfDisposed(); + ValidateBufferArguments(buffer, offset, count); + _buffer.Write(buffer, offset, count); + } + + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + ThrowIfDisposed(); + ValidateBufferArguments(buffer, offset, count); + return _buffer.WriteAsync(buffer, offset, count, cancellationToken); + } + + public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) + { + ThrowIfDisposed(); + return _buffer.WriteAsync(buffer, cancellationToken); + } + + public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback? asyncCallback, object? asyncState) + { + ThrowIfDisposed(); + ValidateBufferArguments(buffer, offset, count); + return _buffer.BeginWrite(buffer, offset, count, asyncCallback, asyncState); + } + + public override void EndWrite(IAsyncResult asyncResult) + { + ThrowIfDisposed(); + _buffer.EndWrite(asyncResult); + } + + public ArraySegment GetBuffer() + { + ArraySegment bytes; + + bool success = _buffer.TryGetBuffer(out bytes); + Debug.Assert(success); // Buffer should always be visible since default MemoryStream constructor was used. + + return bytes; + } + + // We need this to dispose the MemoryStream. + public MemoryStream GetMemoryStream() + { + return _buffer; + } + + protected override void Dispose(bool disposing) + { + if (disposing && !_disposed) + { + _disposed = true; + } + base.Dispose(disposing); + } + + private void ThrowIfDisposed() + { + ObjectDisposedException.ThrowIf(_disposed, this); + } + } +} diff --git a/src/libraries/System.Net.Requests/src/System/Net/RequestStream.cs b/src/libraries/System.Net.Requests/src/System/Net/RequestStream.cs index 5323c2ac836f0..5961339576d30 100644 --- a/src/libraries/System.Net.Requests/src/System/Net/RequestStream.cs +++ b/src/libraries/System.Net.Requests/src/System/Net/RequestStream.cs @@ -3,22 +3,22 @@ using System.Diagnostics; using System.IO; +using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; namespace System.Net { - // Cache the request stream into a MemoryStream. This is the - // default behavior of Desktop HttpWebRequest.AllowWriteStreamBuffering (true). - // Unfortunately, this property is not exposed in .NET Core, so it can't be changed - // This will result in inefficient memory usage when sending (POST'ing) large - // amounts of data to the server such as from a file stream. internal sealed class RequestStream : Stream { - private readonly MemoryStream _buffer = new MemoryStream(); + private bool _disposed; + private readonly TaskCompletionSource _completeTcs; + private readonly Stream _internalStream; - public RequestStream() + public RequestStream(Stream internalStream, TaskCompletionSource completeTcs) { + _internalStream = internalStream; + _completeTcs = completeTcs; } public override bool CanRead @@ -47,15 +47,14 @@ public override bool CanWrite public override void Flush() { - // Nothing to do. + ThrowIfDisposed(); + _internalStream.Flush(); } public override Task FlushAsync(CancellationToken cancellationToken) { - // Nothing to do. - return cancellationToken.IsCancellationRequested ? - Task.FromCanceled(cancellationToken) : - Task.CompletedTask; + ThrowIfDisposed(); + return _internalStream.FlushAsync(cancellationToken); } public override long Length @@ -95,40 +94,67 @@ public override void SetLength(long value) public override void Write(byte[] buffer, int offset, int count) { + ThrowIfDisposed(); ValidateBufferArguments(buffer, offset, count); - _buffer.Write(buffer, offset, count); + _internalStream.Write(new(buffer, offset, count)); } public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { + ThrowIfDisposed(); ValidateBufferArguments(buffer, offset, count); - return _buffer.WriteAsync(buffer, offset, count, cancellationToken); + return _internalStream.WriteAsync(buffer.AsMemory(offset, count), cancellationToken).AsTask(); } public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) { - return _buffer.WriteAsync(buffer, cancellationToken); + ThrowIfDisposed(); + return _internalStream.WriteAsync(buffer, cancellationToken); } public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback? asyncCallback, object? asyncState) { + ThrowIfDisposed(); ValidateBufferArguments(buffer, offset, count); - return _buffer.BeginWrite(buffer, offset, count, asyncCallback, asyncState); + return _internalStream.BeginWrite(buffer, offset, count, asyncCallback, asyncState); + } + + public void Complete() + { + _completeTcs.TrySetResult(); } public override void EndWrite(IAsyncResult asyncResult) { - _buffer.EndWrite(asyncResult); + ThrowIfDisposed(); + _internalStream.EndWrite(asyncResult); } - public ArraySegment GetBuffer() + protected override void Dispose(bool disposing) { - ArraySegment bytes; + if (disposing && !_disposed) + { + _disposed = true; + } + _internalStream.Flush(); + Complete(); + base.Dispose(disposing); + } - bool success = _buffer.TryGetBuffer(out bytes); - Debug.Assert(success); // Buffer should always be visible since default MemoryStream constructor was used. + public override async ValueTask DisposeAsync() + { + if (!_disposed) + { + _disposed = true; + } + await _internalStream.FlushAsync().ConfigureAwait(false); + Complete(); + await base.DisposeAsync().ConfigureAwait(false); + } - return bytes; + private void ThrowIfDisposed() + { + ObjectDisposedException.ThrowIf(_disposed, this); } } } diff --git a/src/libraries/System.Net.Requests/src/System/Net/RequestStreamContent.cs b/src/libraries/System.Net.Requests/src/System/Net/RequestStreamContent.cs new file mode 100644 index 0000000000000..b78829c22de72 --- /dev/null +++ b/src/libraries/System.Net.Requests/src/System/Net/RequestStreamContent.cs @@ -0,0 +1,31 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics; +using System.IO; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Net +{ + internal sealed class RequestStreamContent(TaskCompletionSource getStreamTcs, TaskCompletionSource completeTcs) : HttpContent + { + protected override Task SerializeToStreamAsync(Stream stream, TransportContext? context) + { + return SerializeToStreamAsync(stream, context, default); + } + protected override async Task SerializeToStreamAsync(Stream stream, TransportContext? context, CancellationToken cancellationToken) + { + Debug.Assert(stream is not null); + + getStreamTcs.TrySetResult(stream); + await completeTcs.Task.WaitAsync(cancellationToken).ConfigureAwait(false); + } + protected override bool TryComputeLength(out long length) + { + length = -1; + return false; + } + } +} diff --git a/src/libraries/System.Net.Requests/tests/HttpWebRequestTest.cs b/src/libraries/System.Net.Requests/tests/HttpWebRequestTest.cs index 73c66872a7e56..e88d8800f99f0 100644 --- a/src/libraries/System.Net.Requests/tests/HttpWebRequestTest.cs +++ b/src/libraries/System.Net.Requests/tests/HttpWebRequestTest.cs @@ -1914,7 +1914,7 @@ public void Abort_CreateRequestThenAbort_Success(Uri remoteServer) } [Theory] - [InlineData(HttpRequestCacheLevel.NoCacheNoStore, null, null, new string[] { "Pragma: no-cache", "Cache-Control: no-store, no-cache"})] + [InlineData(HttpRequestCacheLevel.NoCacheNoStore, null, null, new string[] { "Pragma: no-cache", "Cache-Control: no-store, no-cache" })] [InlineData(HttpRequestCacheLevel.Reload, null, null, new string[] { "Pragma: no-cache", "Cache-Control: no-cache" })] [InlineData(HttpRequestCacheLevel.CacheOrNextCacheOnly, null, null, new string[] { "Cache-Control: only-if-cached" })] [InlineData(HttpRequestCacheLevel.Default, HttpCacheAgeControl.MinFresh, 10, new string[] { "Cache-Control: min-fresh=10" })] @@ -2077,6 +2077,125 @@ await server.AcceptConnectionAsync(async connection => }); } + [Fact] + public async Task SendHttpPostRequest_BufferingDisabled_ConnectionShouldStartWithRequestStream() + { + await LoopbackServer.CreateClientAndServerAsync( + async (uri) => + { + HttpWebRequest request = WebRequest.CreateHttp(uri); + request.Method = "POST"; + request.AllowWriteStreamBuffering = false; + request.SendChunked = true; + var stream = await request.GetRequestStreamAsync(); + await Assert.ThrowsAnyAsync(() => request.GetResponseAsync()); + }, + async (server) => + { + await server.AcceptConnectionAsync(_ => + { + return Task.CompletedTask; + }); + } + ); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task SendHttpPostRequest_WhenBufferingChanges_Success(bool buffering) + { + byte[] randomData = Encoding.ASCII.GetBytes("Hello World!!!!\n"); + await LoopbackServer.CreateClientAndServerAsync( + async (uri) => + { + int size = randomData.Length * 100; + HttpWebRequest request = WebRequest.CreateHttp(uri); + request.Method = "POST"; + request.AllowWriteStreamBuffering = buffering; + using var stream = await request.GetRequestStreamAsync(); + for (int i = 0; i < size / randomData.Length; i++) + { + await stream.WriteAsync(new ReadOnlyMemory(randomData)); + } + await request.GetResponseAsync(); + }, + async (server) => + { + await server.AcceptConnectionAsync(async connection => + { + var data = await connection.ReadRequestDataAsync(); + for (int i = 0; i < data.Body.Length; i += randomData.Length) + { + Assert.Equal(randomData, data.Body[i..(i + randomData.Length)]); + } + await connection.SendResponseAsync(); + }); + } + ); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task SendHttpRequest_WhenNotBuffering_SendSuccess(bool isChunked) + { + byte[] firstBlock = "Hello"u8.ToArray(); + byte[] secondBlock = "WorlddD"u8.ToArray(); + SemaphoreSlim sem = new(0); + await LoopbackServer.CreateClientAndServerAsync( + async (uri) => + { + HttpWebRequest request = WebRequest.CreateHttp(uri); + request.Method = "POST"; + if (isChunked is false) + { + request.ContentLength = 5 + 7; + } + request.AllowWriteStreamBuffering = false; + + using (Stream requestStream = await request.GetRequestStreamAsync()) + { + requestStream.Write(firstBlock); + requestStream.Flush(); + await sem.WaitAsync(); + requestStream.Write(secondBlock); + requestStream.Flush(); + } + await request.GetResponseAsync(); + sem.Release(); + }, + async (server) => + { + await server.AcceptConnectionAsync(async (connection) => + { + byte[] buffer = new byte[1024]; + await connection.ReadRequestHeaderAsync(); + if (isChunked) + { + // Discard chunk length and CRLF. + await connection.ReadLineAsync(); + } + int readBytes = await connection.ReadBlockAsync(buffer, 0, firstBlock.Length); + Assert.Equal(firstBlock.Length, readBytes); + Assert.Equal(firstBlock, buffer[..readBytes]); + sem.Release(); + if (isChunked) + { + // Discard CRLF, chunk length and CRLF. + await connection.ReadLineAsync(); + await connection.ReadLineAsync(); + } + readBytes = await connection.ReadBlockAsync(buffer, 0, secondBlock.Length); + Assert.Equal(secondBlock.Length, readBytes); + Assert.Equal(secondBlock, buffer[..readBytes]); + await connection.SendResponseAsync(); + await sem.WaitAsync(); + }); + } + ); + } + [Fact] public async Task SendHttpPostRequest_WithContinueTimeoutAndBody_BodyIsDelayed() { @@ -2087,18 +2206,20 @@ await LoopbackServer.CreateClientAndServerAsync( request.Method = "POST"; request.ServicePoint.Expect100Continue = true; request.ContinueTimeout = 30000; - Stream requestStream = await request.GetRequestStreamAsync(); - requestStream.Write("aaaa\r\n\r\n"u8); + using (Stream requestStream = await request.GetRequestStreamAsync()) + { + requestStream.Write("aaaa\r\n\r\n"u8); + } await GetResponseAsync(request); }, async (server) => { - await server.AcceptConnectionAsync(async (client) => + await server.AcceptConnectionAsync(async (connection) => { - await client.ReadRequestHeaderAsync(); + await connection.ReadRequestHeaderAsync(); // This should time out, because we're expecting the body itself but we'll get it after 30 sec. - await Assert.ThrowsAsync(() => client.ReadLineAsync().WaitAsync(TimeSpan.FromMilliseconds(100))); - await client.SendResponseAsync(); + await Assert.ThrowsAsync(() => connection.ReadLineAsync().WaitAsync(TimeSpan.FromMilliseconds(100))); + await connection.SendResponseAsync(); }); } ); @@ -2116,19 +2237,21 @@ await LoopbackServer.CreateClientAndServerAsync( request.Method = "POST"; request.ServicePoint.Expect100Continue = expect100Continue; request.ContinueTimeout = continueTimeout; - Stream requestStream = await request.GetRequestStreamAsync(); - requestStream.Write("aaaa\r\n\r\n"u8); + using (Stream requestStream = await request.GetRequestStreamAsync()) + { + requestStream.Write("aaaa\r\n\r\n"u8); + } await GetResponseAsync(request); }, async (server) => { - await server.AcceptConnectionAsync(async (client) => + await server.AcceptConnectionAsync(async (connection) => { - await client.ReadRequestHeaderAsync(); + await connection.ReadRequestHeaderAsync(); // This should not time out, because we're expecting the body itself and we should get it after 1 sec. - string data = await client.ReadLineAsync().WaitAsync(TimeSpan.FromSeconds(10)); + string data = await connection.ReadLineAsync().WaitAsync(TimeSpan.FromSeconds(10)); Assert.StartsWith("aaaa", data); - await client.SendResponseAsync(); + await connection.SendResponseAsync(); }); }); } @@ -2149,9 +2272,9 @@ await LoopbackServer.CreateClientAndServerAsync( async (server) => { await server.AcceptConnectionAsync( - async (client) => + async (connection) => { - List headers = await client.ReadRequestHeaderAsync(); + List headers = await connection.ReadRequestHeaderAsync(); if (expect100Continue) { Assert.Contains("Expect: 100-continue", headers); @@ -2160,7 +2283,7 @@ await server.AcceptConnectionAsync( { Assert.DoesNotContain("Expect: 100-continue", headers); } - await client.SendResponseAsync(); + await connection.SendResponseAsync(); } ); }