diff --git a/libraries/datasource/src/androidTest/java/androidx/media3/datasource/DefaultHttpDataSourceTest.java b/libraries/datasource/src/androidTest/java/androidx/media3/datasource/DefaultHttpDataSourceTest.java index ced718e61a..ef9fa01f0d 100644 --- a/libraries/datasource/src/androidTest/java/androidx/media3/datasource/DefaultHttpDataSourceTest.java +++ b/libraries/datasource/src/androidTest/java/androidx/media3/datasource/DefaultHttpDataSourceTest.java @@ -129,6 +129,123 @@ public class DefaultHttpDataSourceTest { assertThat(exception.responseBody).isEqualTo(TestUtil.createByteArray(1, 2, 3)); } + @Test + public void open_redirectCrossProtocol_shouldNotAllowCrossProtocol() throws Exception { + byte[] postBody = new byte[] {1, 2, 3}; + DefaultHttpDataSource defaultHttpDataSource = + new DefaultHttpDataSource.Factory() + .setConnectTimeoutMs(1000) + .setReadTimeoutMs(1000) + .setAllowCrossProtocolRedirects(false) + .createDataSource(); + + MockWebServer mockWebServer = new MockWebServer(); + String newLocationUrl = mockWebServer.url("/redirect-path").toString(); + String httpsUrl = newLocationUrl.replaceFirst("http", "https"); + mockWebServer.enqueue( + new MockResponse() + .setResponseCode(HttpURLConnection.HTTP_MOVED_TEMP) + .addHeader("Location", httpsUrl)); + mockWebServer.enqueue(new MockResponse().setResponseCode(HttpURLConnection.HTTP_OK)); + + DataSpec dataSpec = + new DataSpec.Builder() + .setUri(mockWebServer.url("/test-path").toString()) + .setHttpMethod(DataSpec.HTTP_METHOD_POST) + .setHttpBody(postBody) + .build(); + + HttpDataSource.InvalidResponseCodeException exception = + assertThrows( + HttpDataSource.InvalidResponseCodeException.class, + () -> defaultHttpDataSource.open(dataSpec)); + + assertThat(exception.responseCode).isEqualTo(302); + } + + @Test + public void open_redirectCrossProtocol_shouldForceOriginalProtocol() + throws HttpDataSourceException, InterruptedException { + byte[] postBody = new byte[] {1, 2, 3}; + DefaultHttpDataSource defaultHttpDataSource = + new DefaultHttpDataSource.Factory() + .setConnectTimeoutMs(1000) + .setReadTimeoutMs(1000) + .setAllowCrossProtocolRedirects(false) + .setCrossProtocolRedirectsForceOriginal(true) + .createDataSource(); + + MockWebServer mockWebServer = new MockWebServer(); + String newLocationUrl = mockWebServer.url("/redirect-path").toString(); + String httpsUrl = newLocationUrl.replaceFirst("http", "https"); + mockWebServer.enqueue( + new MockResponse() + .setResponseCode(HttpURLConnection.HTTP_MOVED_TEMP) + .addHeader("Location", httpsUrl)); + mockWebServer.enqueue(new MockResponse().setResponseCode(HttpURLConnection.HTTP_OK)); + + DataSpec dataSpec = + new DataSpec.Builder() + .setUri(mockWebServer.url("/test-path").toString()) + .setHttpMethod(DataSpec.HTTP_METHOD_POST) + .setHttpBody(postBody) + .build(); + + defaultHttpDataSource.open(dataSpec); + + RecordedRequest request1 = mockWebServer.takeRequest(10, SECONDS); + assertThat(request1).isNotNull(); + assertThat(request1.getPath()).isEqualTo("/test-path"); + assertThat(request1.getMethod()).isEqualTo("POST"); + assertThat(request1.getBodySize()).isEqualTo(postBody.length); + RecordedRequest request2 = mockWebServer.takeRequest(10, SECONDS); + assertThat(request2).isNotNull(); + assertThat(request2.getPath()).isEqualTo("/redirect-path"); + assertThat(request2.getMethod()).isEqualTo("GET"); + assertThat(request2.getBodySize()).isEqualTo(0); + } + + @Test + public void open_redirectSameProtocolWithRelativeReference_shouldFollowRedirect() + throws HttpDataSourceException, InterruptedException { + byte[] postBody = new byte[] {1, 2, 3}; + DefaultHttpDataSource defaultHttpDataSource = + new DefaultHttpDataSource.Factory() + .setConnectTimeoutMs(1000) + .setReadTimeoutMs(1000) + .setAllowCrossProtocolRedirects(false) + .setCrossProtocolRedirectsForceOriginal(true) + .createDataSource(); + + MockWebServer mockWebServer = new MockWebServer(); + String newLocationUrl = "https/redirect-path"; + mockWebServer.enqueue( + new MockResponse() + .setResponseCode(HttpURLConnection.HTTP_MOVED_TEMP) + .addHeader("Location", newLocationUrl)); + mockWebServer.enqueue(new MockResponse().setResponseCode(HttpURLConnection.HTTP_OK)); + + DataSpec dataSpec = + new DataSpec.Builder() + .setUri(mockWebServer.url("/test-path").toString()) + .setHttpMethod(DataSpec.HTTP_METHOD_POST) + .setHttpBody(postBody) + .build(); + + defaultHttpDataSource.open(dataSpec); + + RecordedRequest request1 = mockWebServer.takeRequest(10, SECONDS); + assertThat(request1).isNotNull(); + assertThat(request1.getPath()).isEqualTo("/test-path"); + assertThat(request1.getMethod()).isEqualTo("POST"); + assertThat(request1.getBodySize()).isEqualTo(postBody.length); + RecordedRequest request2 = mockWebServer.takeRequest(10, SECONDS); + assertThat(request2).isNotNull(); + assertThat(request2.getPath()).isEqualTo("/https/redirect-path"); + assertThat(request2.getMethod()).isEqualTo("GET"); + assertThat(request2.getBodySize()).isEqualTo(0); + } + @Test public void open_redirectChanges302PostToGet() throws HttpDataSourceException, InterruptedException { diff --git a/libraries/datasource/src/main/java/androidx/media3/datasource/DefaultHttpDataSource.java b/libraries/datasource/src/main/java/androidx/media3/datasource/DefaultHttpDataSource.java index c43e6ddd25..cfefab49da 100644 --- a/libraries/datasource/src/main/java/androidx/media3/datasource/DefaultHttpDataSource.java +++ b/libraries/datasource/src/main/java/androidx/media3/datasource/DefaultHttpDataSource.java @@ -75,6 +75,7 @@ public class DefaultHttpDataSource extends BaseDataSource implements HttpDataSou private int connectTimeoutMs; private int readTimeoutMs; private boolean allowCrossProtocolRedirects; + private boolean crossProtocolRedirectsForceOriginal; private boolean keepPostFor302Redirects; /** Creates an instance. */ @@ -154,6 +155,23 @@ public class DefaultHttpDataSource extends BaseDataSource implements HttpDataSou return this; } + /** + * Sets whether cross protocol redirects should be forced to follow original protocol. This + * should only be set if {@code allowCrossProtocolRedirects} is false. + * + *
The default is {@code false}.
+ *
+ * @param crossProtocolRedirectsForceOriginal Whether to force original protocol.
+ * @return This factory.
+ */
+ @CanIgnoreReturnValue
+ @UnstableApi
+ public Factory setCrossProtocolRedirectsForceOriginal(
+ boolean crossProtocolRedirectsForceOriginal) {
+ this.crossProtocolRedirectsForceOriginal = crossProtocolRedirectsForceOriginal;
+ return this;
+ }
+
/**
* Sets a content type {@link Predicate}. If a content type is rejected by the predicate then a
* {@link HttpDataSource.InvalidContentTypeException} is thrown from {@link
@@ -209,6 +227,7 @@ public class DefaultHttpDataSource extends BaseDataSource implements HttpDataSou
connectTimeoutMs,
readTimeoutMs,
allowCrossProtocolRedirects,
+ crossProtocolRedirectsForceOriginal,
defaultRequestProperties,
contentTypePredicate,
keepPostFor302Redirects);
@@ -232,6 +251,7 @@ public class DefaultHttpDataSource extends BaseDataSource implements HttpDataSou
private static final long MAX_BYTES_TO_DRAIN = 2048;
private final boolean allowCrossProtocolRedirects;
+ private final boolean crossProtocolRedirectsForceOriginal;
private final int connectTimeoutMillis;
private final int readTimeoutMillis;
@Nullable private final String userAgent;
@@ -300,6 +320,7 @@ public class DefaultHttpDataSource extends BaseDataSource implements HttpDataSou
connectTimeoutMillis,
readTimeoutMillis,
allowCrossProtocolRedirects,
+ /* crossProtocolRedirectsForceOriginal= */ false,
defaultRequestProperties,
/* contentTypePredicate= */ null,
/* keepPostFor302Redirects= */ false);
@@ -310,6 +331,7 @@ public class DefaultHttpDataSource extends BaseDataSource implements HttpDataSou
int connectTimeoutMillis,
int readTimeoutMillis,
boolean allowCrossProtocolRedirects,
+ boolean crossProtocolRedirectsForceOriginal,
@Nullable RequestProperties defaultRequestProperties,
@Nullable Predicate