Skip to content

Commit

Permalink
Implement HttpWebRequest AllowWriteStreamBuffering property (dotnet#9…
Browse files Browse the repository at this point in the history
…5001)

* Implement HttpWebRequest WriteStreamBuffering

* Fix tests

* Fix tests

* Review feedback

* Fix timeouts

* fix test build

* Hang test

* Fix build

* fix tests

* delete unnecessary line

* fix tests

* refactor test

* Dispose http client

* Fix buffer write in RequestStream.Flush() method

* Fix same length bug in FlushAsync

* Update ContentLength in HttpWebRequestTest.cs

* Fix flushing and ending request stream

* Fix flushing and ending request stream

* Fix FlushAsync method to handle cancellation

* Update src/libraries/System.Net.Requests/src/System/Net/HttpClientContentStream.cs

Co-authored-by: Anton Firszov <antonfir@gmail.com>

* Review feedback

* Bound streamBuffer lifecycle to HttpClientContentStream

* Review feedback

* Review feedback

* Change ??= to =

* change delay on test

* Apply suggestions from code review

Co-authored-by: Miha Zupan <mihazupan.zupan1@gmail.com>

* Fix build

* Review feedback

* Apply suggestions from code review

Co-authored-by: Miha Zupan <mihazupan.zupan1@gmail.com>

* Review feedback

* Apply suggestions from code review

Co-authored-by: Marie Píchová <11718369+ManickaP@users.noreply.github.com>

* Add ProtocolViolationException if we're not buffering and we didn't set either SendChunked or ContentLength

* Review feedback

* Add test for not buffering and sending the content before we call `GetResponse`

* Update src/libraries/System.Net.Requests/src/System/Net/HttpWebRequest.cs

Co-authored-by: Marie Píchová <11718369+ManickaP@users.noreply.github.com>

* Remove exception

* Review feedback

* Fix string resources

* Add RequestStreamContent class for handling request stream serialization

* Seperate Buffering and Non-Buffering Streams and Connect non-buffering stream to HttpClient Internal Stream directly

* Remove unnecessary code in HttpWebRequestTest

* Use random data for testing

* Update src/libraries/System.Net.Requests/src/System/Net/RequestBufferingStream.cs

Co-authored-by: Anton Firszov <antonfir@gmail.com>

* Remove default flushes and add flush on test

* Review feedback

* Removing unused code

---------

Co-authored-by: Anton Firszov <antonfir@gmail.com>
Co-authored-by: Miha Zupan <mihazupan.zupan1@gmail.com>
Co-authored-by: Marie Píchová <11718369+ManickaP@users.noreply.github.com>
  • Loading branch information
4 people authored Mar 5, 2024
1 parent 24820b6 commit 40bd270
Show file tree
Hide file tree
Showing 6 changed files with 466 additions and 103 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
<Compile Include="System\Net\IWebRequestCreate.cs" />
<Compile Include="System\Net\ProtocolViolationException.cs" />
<Compile Include="System\Net\RequestStream.cs" />
<Compile Include="System\Net\RequestBufferingStream.cs" />
<Compile Include="System\Net\TaskExtensions.cs" />
<Compile Include="System\Net\WebException.cs" />
<Compile Include="System\Net\WebExceptionStatus.cs" />
Expand All @@ -48,6 +49,7 @@
<Compile Include="System\Net\NetRes.cs" />
<Compile Include="System\Net\NetworkStreamWrapper.cs" />
<Compile Include="System\Net\TimerThread.cs" />
<Compile Include="System\Net\RequestStreamContent.cs" />
<Compile Include="System\Net\Cache\HttpCacheAgeControl.cs" />
<Compile Include="System\Net\Cache\HttpRequestCacheLevel.cs" />
<Compile Include="System\Net\Cache\HttpRequestCachePolicy.cs" />
Expand Down
177 changes: 112 additions & 65 deletions src/libraries/System.Net.Requests/src/System/Net/HttpWebRequest.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using System.ComponentModel;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
Expand All @@ -12,6 +13,7 @@
using System.Net.Http.Headers;
using System.Net.Security;
using System.Net.Sockets;
using System.Runtime.CompilerServices;
using System.Runtime.Serialization;
using System.Security.Authentication;
using System.Security.Cryptography.X509Certificates;
Expand Down Expand Up @@ -40,6 +42,7 @@ public class HttpWebRequest : WebRequest, ISerializable
private IWebProxy? _proxy = WebRequest.DefaultWebProxy;

private Task<HttpResponseMessage>? _sendRequestTask;
private HttpRequestMessage? _sendRequestMessage;

private static int _defaultMaxResponseHeadersLength = HttpHandlerDefaults.DefaultMaxResponseHeadersLength;
private static int _defaultMaximumErrorResponseLength = -1;
Expand All @@ -62,7 +65,7 @@ public class HttpWebRequest : WebRequest, ISerializable
private bool _hostHasPort;
private Uri? _hostUri;

private RequestStream? _requestStream;
private Stream? _requestStream;
private TaskCompletionSource<Stream>? _requestStreamOperation;
private TaskCompletionSource<WebResponse>? _responseOperation;
private AsyncCallback? _requestStreamCallback;
Expand All @@ -78,6 +81,8 @@ public class HttpWebRequest : WebRequest, ISerializable
private static readonly object s_syncRoot = new object();
private static volatile HttpClient? s_cachedHttpClient;
private static HttpClientParameters? s_cachedHttpClientParameters;
private bool _disposeRequired;
private HttpClient? _httpClient;

//these should be safe.
[Flags]
Expand Down Expand Up @@ -1003,17 +1008,17 @@ public override void Abort()
{
_responseCallback(_responseOperation.Task);
}

// Cancel the underlying send operation.
Debug.Assert(_sendRequestCts != null);
_sendRequestCts.Cancel();
}
else if (_requestStreamOperation != null)
if (_requestStreamOperation != null)
{
if (_requestStreamOperation.TrySetCanceled() && _requestStreamCallback != null)
{
_requestStreamCallback(_requestStreamOperation.Task);
}

// Cancel the underlying send operation.
Debug.Assert(_sendRequestCts != null);
_sendRequestCts.Cancel();
}
}

Expand Down Expand Up @@ -1041,8 +1046,7 @@ public override WebResponse GetResponse()
{
try
{
_sendRequestCts = new CancellationTokenSource();
return SendRequest(async: false).GetAwaiter().GetResult();
return HandleResponse(async: false).GetAwaiter().GetResult();
}
catch (Exception ex)
{
Expand All @@ -1052,10 +1056,11 @@ public override WebResponse GetResponse()

public override Stream GetRequestStream()
{
CheckRequestStream();
return InternalGetRequestStream().Result;
}

private Task<Stream> InternalGetRequestStream()
private void CheckRequestStream()
{
CheckAbort();

Expand All @@ -1073,10 +1078,28 @@ private Task<Stream> InternalGetRequestStream()
{
throw new InvalidOperationException(SR.net_reqsubmitted);
}
}

_requestStream = new RequestStream();
private async Task<Stream> InternalGetRequestStream()
{
// If we aren't buffering we need to open the connection right away.
// Because we need to send the data as soon as possible when it's available from the RequestStream.
// Making this allows us to keep the sync send request path for buffering cases.
if (AllowWriteStreamBuffering is false)
{
// We're calling SendRequest with async, because we need to open the connection and send the request
// Otherwise, sync path will block the current thread until the request is sent.
TaskCompletionSource<Stream> getStreamTcs = new();
TaskCompletionSource completeTcs = new();
_sendRequestTask = SendRequest(async: true, new RequestStreamContent(getStreamTcs, completeTcs));
_requestStream = new RequestStream(await getStreamTcs.Task.ConfigureAwait(false), completeTcs);
}
else
{
_requestStream = new RequestBufferingStream();
}

return Task.FromResult((Stream)_requestStream);
return _requestStream;
}

public Stream EndGetRequestStream(IAsyncResult asyncResult, out TransportContext? context)
Expand All @@ -1100,6 +1123,8 @@ public override IAsyncResult BeginGetRequestStream(AsyncCallback? callback, obje
throw new InvalidOperationException(SR.net_repcall);
}

CheckRequestStream();

_requestStreamCallback = callback;
_requestStreamOperation = InternalGetRequestStream().ToApm(callback, state);

Expand Down Expand Up @@ -1133,78 +1158,95 @@ public override Stream EndGetRequestStream(IAsyncResult asyncResult)
return stream;
}

private async Task<WebResponse> SendRequest(bool async)
private Task<HttpResponseMessage> SendRequest(bool async, HttpContent? content = null)
{
if (RequestSubmitted)
{
throw new InvalidOperationException(SR.net_reqsubmitted);
}

var request = new HttpRequestMessage(HttpMethod.Parse(_originVerb), _requestUri);
_sendRequestMessage = new HttpRequestMessage(HttpMethod.Parse(_originVerb), _requestUri);
_sendRequestCts = new CancellationTokenSource();
_httpClient = GetCachedOrCreateHttpClient(async, out _disposeRequired);

bool disposeRequired = false;
HttpClient? client = null;
try
if (content is not null)
{
_sendRequestMessage.Content = content;
}

if (_hostUri is not null)
{
client = GetCachedOrCreateHttpClient(async, out disposeRequired);
if (_requestStream != null)
_sendRequestMessage.Headers.Host = Host;
}

AddCacheControlHeaders(_sendRequestMessage);

// Copy the HttpWebRequest request headers from the WebHeaderCollection into HttpRequestMessage.Headers and
// HttpRequestMessage.Content.Headers.
foreach (string headerName in _webHeaderCollection)
{
// The System.Net.Http APIs require HttpRequestMessage headers to be properly divided between the request headers
// collection and the request content headers collection for all well-known header names. And custom headers
// are only allowed in the request headers collection and not in the request content headers collection.
if (IsWellKnownContentHeader(headerName))
{
ArraySegment<byte> bytes = _requestStream.GetBuffer();
request.Content = new ByteArrayContent(bytes.Array!, bytes.Offset, bytes.Count);
_sendRequestMessage.Content ??= new ByteArrayContent(Array.Empty<byte>());
_sendRequestMessage.Content.Headers.TryAddWithoutValidation(headerName, _webHeaderCollection[headerName!]);
}

if (_hostUri != null)
else
{
request.Headers.Host = Host;
_sendRequestMessage.Headers.TryAddWithoutValidation(headerName, _webHeaderCollection[headerName!]);
}
}

AddCacheControlHeaders(request);
if (_servicePoint?.Expect100Continue == true)
{
_sendRequestMessage.Headers.ExpectContinue = true;
}

// Copy the HttpWebRequest request headers from the WebHeaderCollection into HttpRequestMessage.Headers and
// HttpRequestMessage.Content.Headers.
foreach (string headerName in _webHeaderCollection)
{
// The System.Net.Http APIs require HttpRequestMessage headers to be properly divided between the request headers
// collection and the request content headers collection for all well-known header names. And custom headers
// are only allowed in the request headers collection and not in the request content headers collection.
if (IsWellKnownContentHeader(headerName))
{
// Create empty content so that we can send the entity-body header.
request.Content ??= new ByteArrayContent(Array.Empty<byte>());
_sendRequestMessage.Headers.TransferEncodingChunked = SendChunked;

request.Content.Headers.TryAddWithoutValidation(headerName, _webHeaderCollection[headerName!]);
}
else
{
request.Headers.TryAddWithoutValidation(headerName, _webHeaderCollection[headerName!]);
}
}
if (KeepAlive)
{
_sendRequestMessage.Headers.Connection.Add(HttpKnownHeaderNames.KeepAlive);
}
else
{
_sendRequestMessage.Headers.ConnectionClose = true;
}

request.Headers.TransferEncodingChunked = SendChunked;
_sendRequestMessage.Version = ProtocolVersion;
HttpCompletionOption completionOption = _allowReadStreamBuffering ? HttpCompletionOption.ResponseContentRead : HttpCompletionOption.ResponseHeadersRead;
// If we're not buffering, there is no way to open the connection and not send the request without async.
// So we should use Async, if we're not buffering.
_sendRequestTask = async || !AllowWriteStreamBuffering ?
_httpClient.SendAsync(_sendRequestMessage, completionOption, _sendRequestCts.Token) :
Task.FromResult(_httpClient.Send(_sendRequestMessage, completionOption, _sendRequestCts.Token));

if (KeepAlive)
{
request.Headers.Connection.Add(HttpKnownHeaderNames.KeepAlive);
}
else
{
request.Headers.ConnectionClose = true;
}
return _sendRequestTask!;
}

if (_servicePoint?.Expect100Continue == true)
{
request.Headers.ExpectContinue = true;
}
private async Task<WebResponse> HandleResponse(bool async)
{
// If user code used requestStream and didn't dispose it
// We're completing it here.
if (_requestStream is RequestStream requestStream)
{
requestStream.Complete();
}

request.Version = ProtocolVersion;
if (_sendRequestTask is null && _requestStream is RequestBufferingStream requestBufferingStream)
{
ArraySegment<byte> buffer = requestBufferingStream.GetBuffer();
_sendRequestTask = SendRequest(async, new ByteArrayContent(buffer.Array!, buffer.Offset, buffer.Count));
}

_sendRequestTask = async ?
client.SendAsync(request, _allowReadStreamBuffering ? HttpCompletionOption.ResponseContentRead : HttpCompletionOption.ResponseHeadersRead, _sendRequestCts!.Token) :
Task.FromResult(client.Send(request, _allowReadStreamBuffering ? HttpCompletionOption.ResponseContentRead : HttpCompletionOption.ResponseHeadersRead, _sendRequestCts!.Token));
_sendRequestTask ??= SendRequest(async);

try
{
HttpResponseMessage responseMessage = await _sendRequestTask.ConfigureAwait(false);

HttpWebResponse response = new HttpWebResponse(responseMessage, _requestUri, _cookieContainer);
HttpWebResponse response = new(responseMessage, _requestUri, _cookieContainer);

int maxSuccessStatusCode = AllowAutoRedirect ? 299 : 399;
if ((int)response.StatusCode > maxSuccessStatusCode || (int)response.StatusCode < 200)
Expand All @@ -1220,9 +1262,15 @@ private async Task<WebResponse> SendRequest(bool async)
}
finally
{
if (disposeRequired)
_sendRequestMessage?.Dispose();
if (_requestStream is RequestBufferingStream bufferStream)
{
client?.Dispose();
bufferStream.GetMemoryStream().Dispose();
}

if (_disposeRequired)
{
_httpClient?.Dispose();
}
}
}
Expand Down Expand Up @@ -1348,9 +1396,8 @@ public override IAsyncResult BeginGetResponse(AsyncCallback? callback, object? s
throw new InvalidOperationException(SR.net_repcall);
}

_sendRequestCts = new CancellationTokenSource();
_responseCallback = callback;
_responseOperation = SendRequest(async: true).ToApm(callback, state);
_responseOperation = HandleResponse(async: true).ToApm(callback, state);

return _responseOperation.Task;
}
Expand Down
Loading

0 comments on commit 40bd270

Please sign in to comment.