diff --git a/libraries/datasource/src/androidTest/java/androidx/media3/datasource/DefaultHttpDataSourceTest.java b/libraries/datasource/src/androidTest/java/androidx/media3/datasource/DefaultHttpDataSourceTest.java index ced718e61a..ef9fa01f0d 100644 --- a/libraries/datasource/src/androidTest/java/androidx/media3/datasource/DefaultHttpDataSourceTest.java +++ b/libraries/datasource/src/androidTest/java/androidx/media3/datasource/DefaultHttpDataSourceTest.java @@ -129,6 +129,123 @@ public class DefaultHttpDataSourceTest { assertThat(exception.responseBody).isEqualTo(TestUtil.createByteArray(1, 2, 3)); } + @Test + public void open_redirectCrossProtocol_shouldNotAllowCrossProtocol() throws Exception { + byte[] postBody = new byte[] {1, 2, 3}; + DefaultHttpDataSource defaultHttpDataSource = + new DefaultHttpDataSource.Factory() + .setConnectTimeoutMs(1000) + .setReadTimeoutMs(1000) + .setAllowCrossProtocolRedirects(false) + .createDataSource(); + + MockWebServer mockWebServer = new MockWebServer(); + String newLocationUrl = mockWebServer.url("/redirect-path").toString(); + String httpsUrl = newLocationUrl.replaceFirst("http", "https"); + mockWebServer.enqueue( + new MockResponse() + .setResponseCode(HttpURLConnection.HTTP_MOVED_TEMP) + .addHeader("Location", httpsUrl)); + mockWebServer.enqueue(new MockResponse().setResponseCode(HttpURLConnection.HTTP_OK)); + + DataSpec dataSpec = + new DataSpec.Builder() + .setUri(mockWebServer.url("/test-path").toString()) + .setHttpMethod(DataSpec.HTTP_METHOD_POST) + .setHttpBody(postBody) + .build(); + + HttpDataSource.InvalidResponseCodeException exception = + assertThrows( + HttpDataSource.InvalidResponseCodeException.class, + () -> defaultHttpDataSource.open(dataSpec)); + + assertThat(exception.responseCode).isEqualTo(302); + } + + @Test + public void open_redirectCrossProtocol_shouldForceOriginalProtocol() + throws HttpDataSourceException, InterruptedException { + byte[] postBody = new byte[] {1, 2, 3}; + DefaultHttpDataSource defaultHttpDataSource = + new DefaultHttpDataSource.Factory() + .setConnectTimeoutMs(1000) + .setReadTimeoutMs(1000) + .setAllowCrossProtocolRedirects(false) + .setCrossProtocolRedirectsForceOriginal(true) + .createDataSource(); + + MockWebServer mockWebServer = new MockWebServer(); + String newLocationUrl = mockWebServer.url("/redirect-path").toString(); + String httpsUrl = newLocationUrl.replaceFirst("http", "https"); + mockWebServer.enqueue( + new MockResponse() + .setResponseCode(HttpURLConnection.HTTP_MOVED_TEMP) + .addHeader("Location", httpsUrl)); + mockWebServer.enqueue(new MockResponse().setResponseCode(HttpURLConnection.HTTP_OK)); + + DataSpec dataSpec = + new DataSpec.Builder() + .setUri(mockWebServer.url("/test-path").toString()) + .setHttpMethod(DataSpec.HTTP_METHOD_POST) + .setHttpBody(postBody) + .build(); + + defaultHttpDataSource.open(dataSpec); + + RecordedRequest request1 = mockWebServer.takeRequest(10, SECONDS); + assertThat(request1).isNotNull(); + assertThat(request1.getPath()).isEqualTo("/test-path"); + assertThat(request1.getMethod()).isEqualTo("POST"); + assertThat(request1.getBodySize()).isEqualTo(postBody.length); + RecordedRequest request2 = mockWebServer.takeRequest(10, SECONDS); + assertThat(request2).isNotNull(); + assertThat(request2.getPath()).isEqualTo("/redirect-path"); + assertThat(request2.getMethod()).isEqualTo("GET"); + assertThat(request2.getBodySize()).isEqualTo(0); + } + + @Test + public void open_redirectSameProtocolWithRelativeReference_shouldFollowRedirect() + throws HttpDataSourceException, InterruptedException { + byte[] postBody = new byte[] {1, 2, 3}; + DefaultHttpDataSource defaultHttpDataSource = + new DefaultHttpDataSource.Factory() + .setConnectTimeoutMs(1000) + .setReadTimeoutMs(1000) + .setAllowCrossProtocolRedirects(false) + .setCrossProtocolRedirectsForceOriginal(true) + .createDataSource(); + + MockWebServer mockWebServer = new MockWebServer(); + String newLocationUrl = "https/redirect-path"; + mockWebServer.enqueue( + new MockResponse() + .setResponseCode(HttpURLConnection.HTTP_MOVED_TEMP) + .addHeader("Location", newLocationUrl)); + mockWebServer.enqueue(new MockResponse().setResponseCode(HttpURLConnection.HTTP_OK)); + + DataSpec dataSpec = + new DataSpec.Builder() + .setUri(mockWebServer.url("/test-path").toString()) + .setHttpMethod(DataSpec.HTTP_METHOD_POST) + .setHttpBody(postBody) + .build(); + + defaultHttpDataSource.open(dataSpec); + + RecordedRequest request1 = mockWebServer.takeRequest(10, SECONDS); + assertThat(request1).isNotNull(); + assertThat(request1.getPath()).isEqualTo("/test-path"); + assertThat(request1.getMethod()).isEqualTo("POST"); + assertThat(request1.getBodySize()).isEqualTo(postBody.length); + RecordedRequest request2 = mockWebServer.takeRequest(10, SECONDS); + assertThat(request2).isNotNull(); + assertThat(request2.getPath()).isEqualTo("/https/redirect-path"); + assertThat(request2.getMethod()).isEqualTo("GET"); + assertThat(request2.getBodySize()).isEqualTo(0); + } + @Test public void open_redirectChanges302PostToGet() throws HttpDataSourceException, InterruptedException { diff --git a/libraries/datasource/src/main/java/androidx/media3/datasource/DefaultHttpDataSource.java b/libraries/datasource/src/main/java/androidx/media3/datasource/DefaultHttpDataSource.java index c43e6ddd25..cfefab49da 100644 --- a/libraries/datasource/src/main/java/androidx/media3/datasource/DefaultHttpDataSource.java +++ b/libraries/datasource/src/main/java/androidx/media3/datasource/DefaultHttpDataSource.java @@ -75,6 +75,7 @@ public class DefaultHttpDataSource extends BaseDataSource implements HttpDataSou private int connectTimeoutMs; private int readTimeoutMs; private boolean allowCrossProtocolRedirects; + private boolean crossProtocolRedirectsForceOriginal; private boolean keepPostFor302Redirects; /** Creates an instance. */ @@ -154,6 +155,23 @@ public class DefaultHttpDataSource extends BaseDataSource implements HttpDataSou return this; } + /** + * Sets whether cross protocol redirects should be forced to follow original protocol. This + * should only be set if {@code allowCrossProtocolRedirects} is false. + * + *

The default is {@code false}. + * + * @param crossProtocolRedirectsForceOriginal Whether to force original protocol. + * @return This factory. + */ + @CanIgnoreReturnValue + @UnstableApi + public Factory setCrossProtocolRedirectsForceOriginal( + boolean crossProtocolRedirectsForceOriginal) { + this.crossProtocolRedirectsForceOriginal = crossProtocolRedirectsForceOriginal; + return this; + } + /** * Sets a content type {@link Predicate}. If a content type is rejected by the predicate then a * {@link HttpDataSource.InvalidContentTypeException} is thrown from {@link @@ -209,6 +227,7 @@ public class DefaultHttpDataSource extends BaseDataSource implements HttpDataSou connectTimeoutMs, readTimeoutMs, allowCrossProtocolRedirects, + crossProtocolRedirectsForceOriginal, defaultRequestProperties, contentTypePredicate, keepPostFor302Redirects); @@ -232,6 +251,7 @@ public class DefaultHttpDataSource extends BaseDataSource implements HttpDataSou private static final long MAX_BYTES_TO_DRAIN = 2048; private final boolean allowCrossProtocolRedirects; + private final boolean crossProtocolRedirectsForceOriginal; private final int connectTimeoutMillis; private final int readTimeoutMillis; @Nullable private final String userAgent; @@ -300,6 +320,7 @@ public class DefaultHttpDataSource extends BaseDataSource implements HttpDataSou connectTimeoutMillis, readTimeoutMillis, allowCrossProtocolRedirects, + /* crossProtocolRedirectsForceOriginal= */ false, defaultRequestProperties, /* contentTypePredicate= */ null, /* keepPostFor302Redirects= */ false); @@ -310,6 +331,7 @@ public class DefaultHttpDataSource extends BaseDataSource implements HttpDataSou int connectTimeoutMillis, int readTimeoutMillis, boolean allowCrossProtocolRedirects, + boolean crossProtocolRedirectsForceOriginal, @Nullable RequestProperties defaultRequestProperties, @Nullable Predicate contentTypePredicate, boolean keepPostFor302Redirects) { @@ -318,6 +340,12 @@ public class DefaultHttpDataSource extends BaseDataSource implements HttpDataSou this.connectTimeoutMillis = connectTimeoutMillis; this.readTimeoutMillis = readTimeoutMillis; this.allowCrossProtocolRedirects = allowCrossProtocolRedirects; + this.crossProtocolRedirectsForceOriginal = crossProtocolRedirectsForceOriginal; + if (allowCrossProtocolRedirects && crossProtocolRedirectsForceOriginal) { + throw new IllegalArgumentException( + "crossProtocolRedirectsForceOriginal should not be set if allowCrossProtocolRedirects is" + + " true"); + } this.defaultRequestProperties = defaultRequestProperties; this.contentTypePredicate = contentTypePredicate; this.requestProperties = new RequestProperties(); @@ -554,7 +582,9 @@ public class DefaultHttpDataSource extends BaseDataSource implements HttpDataSou long length = dataSpec.length; boolean allowGzip = dataSpec.isFlagSet(DataSpec.FLAG_ALLOW_GZIP); - if (!allowCrossProtocolRedirects && !keepPostFor302Redirects) { + if (!allowCrossProtocolRedirects + && !crossProtocolRedirectsForceOriginal + && !keepPostFor302Redirects) { // HttpURLConnection disallows cross-protocol redirects, but otherwise performs redirection // automatically. This is the behavior we want, so use it. return makeConnection( @@ -727,15 +757,27 @@ public class DefaultHttpDataSource extends BaseDataSource implements HttpDataSou HttpDataSourceException.TYPE_OPEN); } if (!allowCrossProtocolRedirects && !protocol.equals(originalUrl.getProtocol())) { - throw new HttpDataSourceException( - "Disallowed cross-protocol redirect (" - + originalUrl.getProtocol() - + " to " - + protocol - + ")", - dataSpec, - PlaybackException.ERROR_CODE_IO_NETWORK_CONNECTION_FAILED, - HttpDataSourceException.TYPE_OPEN); + if (!crossProtocolRedirectsForceOriginal) { + throw new HttpDataSourceException( + "Disallowed cross-protocol redirect (" + + originalUrl.getProtocol() + + " to " + + protocol + + ")", + dataSpec, + PlaybackException.ERROR_CODE_IO_NETWORK_CONNECTION_FAILED, + HttpDataSourceException.TYPE_OPEN); + } else { + try { + url = new URL(url.toString().replaceFirst(protocol, originalUrl.getProtocol())); + } catch (MalformedURLException e) { + throw new HttpDataSourceException( + e, + dataSpec, + PlaybackException.ERROR_CODE_IO_NETWORK_CONNECTION_FAILED, + HttpDataSourceException.TYPE_OPEN); + } + } } return url; }