Avoid providing invalid responses to MediaDrm

MediaDrm.provideXResponse methods only accept the response
corresponding to the most recent MediaDrm.getXRequest call.
Previously, our code allowed the following incorrect call
sequence:

a = getKeyRequest
b = getKeyRequest
provideKeyResponse(responseFor(a));

This would occur in the edge case of a second key request
being triggered whilst the first was still in flight. The
provideKeyResponse call would then fail.

This change fixes the problem by treating responseFor(a)
as stale. Note that a slightly better fix would be to
defer calling getKeyRequest the second time until after
processing the response corresponding to the first one,
however this is significantly harder to implement, and is
probably not worth it for what should be an edge case.

-------------
Created by MOE: https://github.com/google/moe
MOE_MIGRATED_REVID=203481685
This commit is contained in:
olly 2018-07-06 08:36:14 -07:00 committed by Oliver Woodman
parent a50d31a70b
commit 6ad98405a3

View File

@ -97,6 +97,9 @@ import java.util.UUID;
private byte[] sessionId; private byte[] sessionId;
private byte[] offlineLicenseKeySetId; private byte[] offlineLicenseKeySetId;
private Object currentKeyRequest;
private Object currentProvisionRequest;
/** /**
* Instantiates a new DRM session. * Instantiates a new DRM session.
* *
@ -171,6 +174,8 @@ import java.util.UUID;
requestHandlerThread = null; requestHandlerThread = null;
mediaCrypto = null; mediaCrypto = null;
lastException = null; lastException = null;
currentKeyRequest = null;
currentProvisionRequest = null;
if (sessionId != null) { if (sessionId != null) {
mediaDrm.closeSession(sessionId); mediaDrm.closeSession(sessionId);
sessionId = null; sessionId = null;
@ -215,8 +220,8 @@ import java.util.UUID;
// Provisioning implementation. // Provisioning implementation.
public void provision() { public void provision() {
ProvisionRequest request = mediaDrm.getProvisionRequest(); currentProvisionRequest = mediaDrm.getProvisionRequest();
postRequestHandler.obtainMessage(MSG_PROVISION, request, true).sendToTarget(); postRequestHandler.post(MSG_PROVISION, currentProvisionRequest, /* allowRetry= */ true);
} }
public void onProvisionCompleted() { public void onProvisionCompleted() {
@ -289,11 +294,12 @@ import java.util.UUID;
return false; return false;
} }
private void onProvisionResponse(Object response) { private void onProvisionResponse(Object request, Object response) {
if (state != STATE_OPENING && !isOpen()) { if (request != currentProvisionRequest || (state != STATE_OPENING && !isOpen())) {
// This event is stale. // This event is stale.
return; return;
} }
currentProvisionRequest = null;
if (response instanceof Exception) { if (response instanceof Exception) {
provisioningManager.onProvisionError((Exception) response); provisioningManager.onProvisionError((Exception) response);
@ -383,20 +389,21 @@ import java.util.UUID;
licenseServerUrl = schemeData.licenseServerUrl; licenseServerUrl = schemeData.licenseServerUrl;
} }
try { try {
KeyRequest request = KeyRequest mediaDrmKeyRequest =
mediaDrm.getKeyRequest(scope, initData, mimeType, type, optionalKeyRequestParameters); mediaDrm.getKeyRequest(scope, initData, mimeType, type, optionalKeyRequestParameters);
Pair<KeyRequest, String> arguments = Pair.create(request, licenseServerUrl); currentKeyRequest = Pair.create(mediaDrmKeyRequest, licenseServerUrl);
postRequestHandler.obtainMessage(MSG_KEYS, arguments, allowRetry).sendToTarget(); postRequestHandler.post(MSG_KEYS, currentKeyRequest, allowRetry);
} catch (Exception e) { } catch (Exception e) {
onKeysError(e); onKeysError(e);
} }
} }
private void onKeyResponse(Object response) { private void onKeyResponse(Object request, Object response) {
if (!isOpen()) { if (request != currentKeyRequest || !isOpen()) {
// This event is stale. // This event is stale.
return; return;
} }
currentKeyRequest = null;
if (response instanceof Exception) { if (response instanceof Exception) {
onKeysError((Exception) response); onKeysError((Exception) response);
@ -461,12 +468,15 @@ import java.util.UUID;
@Override @Override
public void handleMessage(Message msg) { public void handleMessage(Message msg) {
Pair<?, ?> requestAndResponse = (Pair<?, ?>) msg.obj;
Object request = requestAndResponse.first;
Object response = requestAndResponse.second;
switch (msg.what) { switch (msg.what) {
case MSG_PROVISION: case MSG_PROVISION:
onProvisionResponse(msg.obj); onProvisionResponse(request, response);
break; break;
case MSG_KEYS: case MSG_KEYS:
onKeyResponse(msg.obj); onKeyResponse(request, response);
break; break;
default: default:
break; break;
@ -483,23 +493,27 @@ import java.util.UUID;
super(backgroundLooper); super(backgroundLooper);
} }
Message obtainMessage(int what, Object object, boolean allowRetry) { void post(int what, Object request, boolean allowRetry) {
return obtainMessage(what, allowRetry ? 1 : 0 /* allow retry*/, 0 /* error count */, int allowRetryInt = allowRetry ? 1 : 0;
object); int errorCount = 0;
obtainMessage(what, allowRetryInt, errorCount, request).sendToTarget();
} }
@Override @Override
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public void handleMessage(Message msg) { public void handleMessage(Message msg) {
Object request = msg.obj;
Object response; Object response;
try { try {
switch (msg.what) { switch (msg.what) {
case MSG_PROVISION: case MSG_PROVISION:
response = callback.executeProvisionRequest(uuid, (ProvisionRequest) msg.obj); response = callback.executeProvisionRequest(uuid, (ProvisionRequest) request);
break; break;
case MSG_KEYS: case MSG_KEYS:
Pair<KeyRequest, String> arguments = (Pair<KeyRequest, String>) msg.obj; Pair<KeyRequest, String> keyRequest = (Pair<KeyRequest, String>) request;
response = callback.executeKeyRequest(uuid, arguments.first, arguments.second); KeyRequest mediaDrmKeyRequest = keyRequest.first;
String licenseServerUrl = keyRequest.second;
response = callback.executeKeyRequest(uuid, mediaDrmKeyRequest, licenseServerUrl);
break; break;
default: default:
throw new RuntimeException(); throw new RuntimeException();
@ -510,7 +524,7 @@ import java.util.UUID;
} }
response = e; response = e;
} }
postResponseHandler.obtainMessage(msg.what, response).sendToTarget(); postResponseHandler.obtainMessage(msg.what, Pair.create(request, response)).sendToTarget();
} }
private boolean maybeRetryRequest(Message originalMsg) { private boolean maybeRetryRequest(Message originalMsg) {