DefaultHttpDS: Allow forcing cross protocol redirects to original

One can set crossProtocolRedirectsForceOriginal to force cross protocol redirects to use the original protocol. This might cause the connection to fail so it can only used when allowCrossProtocolRedirects is set to false or unset (default false).

PiperOrigin-RevId: 631937956
This commit is contained in:
Googler 2024-05-08 15:08:34 -07:00 committed by Copybara-Service
parent d977eab5f1
commit 524181d7f2
2 changed files with 169 additions and 10 deletions

View File

@ -129,6 +129,123 @@ public class DefaultHttpDataSourceTest {
assertThat(exception.responseBody).isEqualTo(TestUtil.createByteArray(1, 2, 3));
}
@Test
public void open_redirectCrossProtocol_shouldNotAllowCrossProtocol() throws Exception {
byte[] postBody = new byte[] {1, 2, 3};
DefaultHttpDataSource defaultHttpDataSource =
new DefaultHttpDataSource.Factory()
.setConnectTimeoutMs(1000)
.setReadTimeoutMs(1000)
.setAllowCrossProtocolRedirects(false)
.createDataSource();
MockWebServer mockWebServer = new MockWebServer();
String newLocationUrl = mockWebServer.url("/redirect-path").toString();
String httpsUrl = newLocationUrl.replaceFirst("http", "https");
mockWebServer.enqueue(
new MockResponse()
.setResponseCode(HttpURLConnection.HTTP_MOVED_TEMP)
.addHeader("Location", httpsUrl));
mockWebServer.enqueue(new MockResponse().setResponseCode(HttpURLConnection.HTTP_OK));
DataSpec dataSpec =
new DataSpec.Builder()
.setUri(mockWebServer.url("/test-path").toString())
.setHttpMethod(DataSpec.HTTP_METHOD_POST)
.setHttpBody(postBody)
.build();
HttpDataSource.InvalidResponseCodeException exception =
assertThrows(
HttpDataSource.InvalidResponseCodeException.class,
() -> defaultHttpDataSource.open(dataSpec));
assertThat(exception.responseCode).isEqualTo(302);
}
@Test
public void open_redirectCrossProtocol_shouldForceOriginalProtocol()
throws HttpDataSourceException, InterruptedException {
byte[] postBody = new byte[] {1, 2, 3};
DefaultHttpDataSource defaultHttpDataSource =
new DefaultHttpDataSource.Factory()
.setConnectTimeoutMs(1000)
.setReadTimeoutMs(1000)
.setAllowCrossProtocolRedirects(false)
.setCrossProtocolRedirectsForceOriginal(true)
.createDataSource();
MockWebServer mockWebServer = new MockWebServer();
String newLocationUrl = mockWebServer.url("/redirect-path").toString();
String httpsUrl = newLocationUrl.replaceFirst("http", "https");
mockWebServer.enqueue(
new MockResponse()
.setResponseCode(HttpURLConnection.HTTP_MOVED_TEMP)
.addHeader("Location", httpsUrl));
mockWebServer.enqueue(new MockResponse().setResponseCode(HttpURLConnection.HTTP_OK));
DataSpec dataSpec =
new DataSpec.Builder()
.setUri(mockWebServer.url("/test-path").toString())
.setHttpMethod(DataSpec.HTTP_METHOD_POST)
.setHttpBody(postBody)
.build();
defaultHttpDataSource.open(dataSpec);
RecordedRequest request1 = mockWebServer.takeRequest(10, SECONDS);
assertThat(request1).isNotNull();
assertThat(request1.getPath()).isEqualTo("/test-path");
assertThat(request1.getMethod()).isEqualTo("POST");
assertThat(request1.getBodySize()).isEqualTo(postBody.length);
RecordedRequest request2 = mockWebServer.takeRequest(10, SECONDS);
assertThat(request2).isNotNull();
assertThat(request2.getPath()).isEqualTo("/redirect-path");
assertThat(request2.getMethod()).isEqualTo("GET");
assertThat(request2.getBodySize()).isEqualTo(0);
}
@Test
public void open_redirectSameProtocolWithRelativeReference_shouldFollowRedirect()
throws HttpDataSourceException, InterruptedException {
byte[] postBody = new byte[] {1, 2, 3};
DefaultHttpDataSource defaultHttpDataSource =
new DefaultHttpDataSource.Factory()
.setConnectTimeoutMs(1000)
.setReadTimeoutMs(1000)
.setAllowCrossProtocolRedirects(false)
.setCrossProtocolRedirectsForceOriginal(true)
.createDataSource();
MockWebServer mockWebServer = new MockWebServer();
String newLocationUrl = "https/redirect-path";
mockWebServer.enqueue(
new MockResponse()
.setResponseCode(HttpURLConnection.HTTP_MOVED_TEMP)
.addHeader("Location", newLocationUrl));
mockWebServer.enqueue(new MockResponse().setResponseCode(HttpURLConnection.HTTP_OK));
DataSpec dataSpec =
new DataSpec.Builder()
.setUri(mockWebServer.url("/test-path").toString())
.setHttpMethod(DataSpec.HTTP_METHOD_POST)
.setHttpBody(postBody)
.build();
defaultHttpDataSource.open(dataSpec);
RecordedRequest request1 = mockWebServer.takeRequest(10, SECONDS);
assertThat(request1).isNotNull();
assertThat(request1.getPath()).isEqualTo("/test-path");
assertThat(request1.getMethod()).isEqualTo("POST");
assertThat(request1.getBodySize()).isEqualTo(postBody.length);
RecordedRequest request2 = mockWebServer.takeRequest(10, SECONDS);
assertThat(request2).isNotNull();
assertThat(request2.getPath()).isEqualTo("/https/redirect-path");
assertThat(request2.getMethod()).isEqualTo("GET");
assertThat(request2.getBodySize()).isEqualTo(0);
}
@Test
public void open_redirectChanges302PostToGet()
throws HttpDataSourceException, InterruptedException {

View File

@ -75,6 +75,7 @@ public class DefaultHttpDataSource extends BaseDataSource implements HttpDataSou
private int connectTimeoutMs;
private int readTimeoutMs;
private boolean allowCrossProtocolRedirects;
private boolean crossProtocolRedirectsForceOriginal;
private boolean keepPostFor302Redirects;
/** Creates an instance. */
@ -154,6 +155,23 @@ public class DefaultHttpDataSource extends BaseDataSource implements HttpDataSou
return this;
}
/**
* Sets whether cross protocol redirects should be forced to follow original protocol. This
* should only be set if {@code allowCrossProtocolRedirects} is false.
*
* <p>The default is {@code false}.
*
* @param crossProtocolRedirectsForceOriginal Whether to force original protocol.
* @return This factory.
*/
@CanIgnoreReturnValue
@UnstableApi
public Factory setCrossProtocolRedirectsForceOriginal(
boolean crossProtocolRedirectsForceOriginal) {
this.crossProtocolRedirectsForceOriginal = crossProtocolRedirectsForceOriginal;
return this;
}
/**
* Sets a content type {@link Predicate}. If a content type is rejected by the predicate then a
* {@link HttpDataSource.InvalidContentTypeException} is thrown from {@link
@ -209,6 +227,7 @@ public class DefaultHttpDataSource extends BaseDataSource implements HttpDataSou
connectTimeoutMs,
readTimeoutMs,
allowCrossProtocolRedirects,
crossProtocolRedirectsForceOriginal,
defaultRequestProperties,
contentTypePredicate,
keepPostFor302Redirects);
@ -232,6 +251,7 @@ public class DefaultHttpDataSource extends BaseDataSource implements HttpDataSou
private static final long MAX_BYTES_TO_DRAIN = 2048;
private final boolean allowCrossProtocolRedirects;
private final boolean crossProtocolRedirectsForceOriginal;
private final int connectTimeoutMillis;
private final int readTimeoutMillis;
@Nullable private final String userAgent;
@ -300,6 +320,7 @@ public class DefaultHttpDataSource extends BaseDataSource implements HttpDataSou
connectTimeoutMillis,
readTimeoutMillis,
allowCrossProtocolRedirects,
/* crossProtocolRedirectsForceOriginal= */ false,
defaultRequestProperties,
/* contentTypePredicate= */ null,
/* keepPostFor302Redirects= */ false);
@ -310,6 +331,7 @@ public class DefaultHttpDataSource extends BaseDataSource implements HttpDataSou
int connectTimeoutMillis,
int readTimeoutMillis,
boolean allowCrossProtocolRedirects,
boolean crossProtocolRedirectsForceOriginal,
@Nullable RequestProperties defaultRequestProperties,
@Nullable Predicate<String> contentTypePredicate,
boolean keepPostFor302Redirects) {
@ -318,6 +340,12 @@ public class DefaultHttpDataSource extends BaseDataSource implements HttpDataSou
this.connectTimeoutMillis = connectTimeoutMillis;
this.readTimeoutMillis = readTimeoutMillis;
this.allowCrossProtocolRedirects = allowCrossProtocolRedirects;
this.crossProtocolRedirectsForceOriginal = crossProtocolRedirectsForceOriginal;
if (allowCrossProtocolRedirects && crossProtocolRedirectsForceOriginal) {
throw new IllegalArgumentException(
"crossProtocolRedirectsForceOriginal should not be set if allowCrossProtocolRedirects is"
+ " true");
}
this.defaultRequestProperties = defaultRequestProperties;
this.contentTypePredicate = contentTypePredicate;
this.requestProperties = new RequestProperties();
@ -554,7 +582,9 @@ public class DefaultHttpDataSource extends BaseDataSource implements HttpDataSou
long length = dataSpec.length;
boolean allowGzip = dataSpec.isFlagSet(DataSpec.FLAG_ALLOW_GZIP);
if (!allowCrossProtocolRedirects && !keepPostFor302Redirects) {
if (!allowCrossProtocolRedirects
&& !crossProtocolRedirectsForceOriginal
&& !keepPostFor302Redirects) {
// HttpURLConnection disallows cross-protocol redirects, but otherwise performs redirection
// automatically. This is the behavior we want, so use it.
return makeConnection(
@ -727,15 +757,27 @@ public class DefaultHttpDataSource extends BaseDataSource implements HttpDataSou
HttpDataSourceException.TYPE_OPEN);
}
if (!allowCrossProtocolRedirects && !protocol.equals(originalUrl.getProtocol())) {
throw new HttpDataSourceException(
"Disallowed cross-protocol redirect ("
+ originalUrl.getProtocol()
+ " to "
+ protocol
+ ")",
dataSpec,
PlaybackException.ERROR_CODE_IO_NETWORK_CONNECTION_FAILED,
HttpDataSourceException.TYPE_OPEN);
if (!crossProtocolRedirectsForceOriginal) {
throw new HttpDataSourceException(
"Disallowed cross-protocol redirect ("
+ originalUrl.getProtocol()
+ " to "
+ protocol
+ ")",
dataSpec,
PlaybackException.ERROR_CODE_IO_NETWORK_CONNECTION_FAILED,
HttpDataSourceException.TYPE_OPEN);
} else {
try {
url = new URL(url.toString().replaceFirst(protocol, originalUrl.getProtocol()));
} catch (MalformedURLException e) {
throw new HttpDataSourceException(
e,
dataSpec,
PlaybackException.ERROR_CODE_IO_NETWORK_CONNECTION_FAILED,
HttpDataSourceException.TYPE_OPEN);
}
}
}
return url;
}