Refactor MuxerWrapper handling of track details.

This brings together the multiple details about a muxer track, and
reduces the need for additional variables and more complicated track
tracking.

PiperOrigin-RevId: 499872145
This commit is contained in:
samrobinson 2023-01-05 15:37:39 +00:00 committed by christosts
parent a59c2b8222
commit f32b632b09
3 changed files with 88 additions and 78 deletions

View File

@ -80,10 +80,10 @@ public interface Muxer {
} }
/** /**
* Adds a track with the specified format, and returns its index (to be passed in subsequent calls * Adds a track with the specified format.
* to {@link #writeSampleData(int, ByteBuffer, boolean, long)}).
* *
* @param format The {@link Format} of the track. * @param format The {@link Format} of the track.
* @return The index for this track, which should be passed to {@link #writeSampleData}.
* @throws MuxerException If the muxer encounters a problem while adding the track. * @throws MuxerException If the muxer encounters a problem while adding the track.
*/ */
int addTrack(Format format) throws MuxerException; int addTrack(Format format) throws MuxerException;

View File

@ -16,15 +16,15 @@
package androidx.media3.transformer; package androidx.media3.transformer;
import static androidx.media3.common.util.Assertions.checkArgument;
import static androidx.media3.common.util.Assertions.checkNotNull; import static androidx.media3.common.util.Assertions.checkNotNull;
import static androidx.media3.common.util.Assertions.checkState; import static androidx.media3.common.util.Assertions.checkState;
import static androidx.media3.common.util.Util.maxValue; import static java.lang.Math.max;
import static androidx.media3.common.util.Util.minValue; import static java.lang.Math.min;
import static java.util.concurrent.TimeUnit.MILLISECONDS; import static java.util.concurrent.TimeUnit.MILLISECONDS;
import android.os.ParcelFileDescriptor; import android.os.ParcelFileDescriptor;
import android.util.SparseIntArray; import android.util.SparseArray;
import android.util.SparseLongArray;
import androidx.annotation.IntRange; import androidx.annotation.IntRange;
import androidx.annotation.Nullable; import androidx.annotation.Nullable;
import androidx.media3.common.C; import androidx.media3.common.C;
@ -49,7 +49,8 @@ import org.checkerframework.checker.nullness.qual.RequiresNonNull;
/* package */ final class MuxerWrapper { /* package */ final class MuxerWrapper {
public interface Listener { public interface Listener {
void onTrackEnded(@C.TrackType int trackType, int averageBitrate, int sampleCount); void onTrackEnded(
@C.TrackType int trackType, Format format, int averageBitrate, int sampleCount);
void onEnded(long durationMs, long fileSizeBytes); void onEnded(long durationMs, long fileSizeBytes);
@ -68,18 +69,15 @@ import org.checkerframework.checker.nullness.qual.RequiresNonNull;
@Nullable private final ParcelFileDescriptor outputParcelFileDescriptor; @Nullable private final ParcelFileDescriptor outputParcelFileDescriptor;
private final Muxer.Factory muxerFactory; private final Muxer.Factory muxerFactory;
private final Listener listener; private final Listener listener;
private final SparseIntArray trackTypeToIndex; private final SparseArray<TrackInfo> trackTypeToInfo;
private final SparseIntArray trackTypeToSampleCount;
private final SparseLongArray trackTypeToTimeUs;
private final SparseLongArray trackTypeToBytesWritten;
private final ScheduledExecutorService abortScheduledExecutorService; private final ScheduledExecutorService abortScheduledExecutorService;
private int trackCount; private int trackCount;
private int trackFormatCount;
private boolean isReady; private boolean isReady;
private boolean isEnded; private boolean isEnded;
private @C.TrackType int previousTrackType; private @C.TrackType int previousTrackType;
private long minTrackTimeUs; private long minTrackTimeUs;
private long maxEndedTrackTimeUs;
private @MonotonicNonNull ScheduledFuture<?> abortScheduledFuture; private @MonotonicNonNull ScheduledFuture<?> abortScheduledFuture;
private boolean isAborted; private boolean isAborted;
private @MonotonicNonNull Muxer muxer; private @MonotonicNonNull Muxer muxer;
@ -98,10 +96,7 @@ import org.checkerframework.checker.nullness.qual.RequiresNonNull;
this.muxerFactory = muxerFactory; this.muxerFactory = muxerFactory;
this.listener = listener; this.listener = listener;
trackTypeToIndex = new SparseIntArray(); trackTypeToInfo = new SparseArray<>();
trackTypeToSampleCount = new SparseIntArray();
trackTypeToTimeUs = new SparseLongArray();
trackTypeToBytesWritten = new SparseLongArray();
previousTrackType = C.TRACK_TYPE_NONE; previousTrackType = C.TRACK_TYPE_NONE;
abortScheduledExecutorService = Executors.newSingleThreadScheduledExecutor(); abortScheduledExecutorService = Executors.newSingleThreadScheduledExecutor();
} }
@ -117,7 +112,7 @@ import org.checkerframework.checker.nullness.qual.RequiresNonNull;
*/ */
public void setTrackCount(@IntRange(from = 1) int trackCount) { public void setTrackCount(@IntRange(from = 1) int trackCount) {
checkState( checkState(
trackFormatCount == 0, trackTypeToInfo.size() == 0,
"The track count cannot be set after track formats have been added."); "The track count cannot be set after track formats have been added.");
this.trackCount = trackCount; this.trackCount = trackCount;
} }
@ -151,25 +146,21 @@ import org.checkerframework.checker.nullness.qual.RequiresNonNull;
*/ */
public void addTrackFormat(Format format) throws Muxer.MuxerException { public void addTrackFormat(Format format) throws Muxer.MuxerException {
checkState(trackCount > 0, "The track count should be set before the formats are added."); checkState(trackCount > 0, "The track count should be set before the formats are added.");
checkState(trackFormatCount < trackCount, "All track formats have already been added."); checkState(trackTypeToInfo.size() < trackCount, "All track formats have already been added.");
@Nullable String sampleMimeType = format.sampleMimeType; @Nullable String sampleMimeType = format.sampleMimeType;
boolean isAudio = MimeTypes.isAudio(sampleMimeType); boolean isAudio = MimeTypes.isAudio(sampleMimeType);
boolean isVideo = MimeTypes.isVideo(sampleMimeType); boolean isVideo = MimeTypes.isVideo(sampleMimeType);
checkState(isAudio || isVideo, "Unsupported track format: " + sampleMimeType); checkState(isAudio || isVideo, "Unsupported track format: " + sampleMimeType);
@C.TrackType int trackType = MimeTypes.getTrackType(sampleMimeType); @C.TrackType int trackType = MimeTypes.getTrackType(sampleMimeType);
// SparseArray.get() returns null by default if the value is not found.
checkState( checkState(
trackTypeToIndex.get(trackType, /* valueIfKeyNotFound= */ C.INDEX_UNSET) == C.INDEX_UNSET, trackTypeToInfo.get(trackType) == null, "There is already a track of type " + trackType);
"There is already a track of type " + trackType);
ensureMuxerInitialized(); ensureMuxerInitialized();
int trackIndex = muxer.addTrack(format); TrackInfo trackInfo = new TrackInfo(format, muxer.addTrack(format));
trackTypeToIndex.put(trackType, trackIndex); trackTypeToInfo.put(trackType, trackInfo);
trackTypeToSampleCount.put(trackType, 0); if (trackTypeToInfo.size() == trackCount) {
trackTypeToTimeUs.put(trackType, 0L);
trackTypeToBytesWritten.put(trackType, 0L);
trackFormatCount++;
if (trackFormatCount == trackCount) {
isReady = true; isReady = true;
resetAbortTimer(); resetAbortTimer();
} }
@ -193,25 +184,22 @@ import org.checkerframework.checker.nullness.qual.RequiresNonNull;
public boolean writeSample( public boolean writeSample(
@C.TrackType int trackType, ByteBuffer data, boolean isKeyFrame, long presentationTimeUs) @C.TrackType int trackType, ByteBuffer data, boolean isKeyFrame, long presentationTimeUs)
throws Muxer.MuxerException { throws Muxer.MuxerException {
int trackIndex = trackTypeToIndex.get(trackType, /* valueIfKeyNotFound= */ C.INDEX_UNSET); @Nullable TrackInfo trackInfo = trackTypeToInfo.get(trackType);
checkState( // SparseArray.get() returns null by default if the value is not found.
trackIndex != C.INDEX_UNSET, checkArgument(
"Could not write sample because there is no track of type " + trackType); trackInfo != null, "Could not write sample because there is no track of type " + trackType);
if (!canWriteSampleOfType(trackType)) { if (!canWriteSampleOfType(trackType)) {
return false; return false;
} }
trackTypeToSampleCount.put(trackType, trackTypeToSampleCount.get(trackType) + 1); trackInfo.sampleCount++;
trackTypeToBytesWritten.put( trackInfo.bytesWritten += data.remaining();
trackType, trackTypeToBytesWritten.get(trackType) + data.remaining()); trackInfo.timeUs = max(trackInfo.timeUs, presentationTimeUs);
if (trackTypeToTimeUs.get(trackType) < presentationTimeUs) {
trackTypeToTimeUs.put(trackType, presentationTimeUs);
}
checkNotNull(muxer); checkNotNull(muxer);
resetAbortTimer(); resetAbortTimer();
muxer.writeSampleData(trackIndex, data, isKeyFrame, presentationTimeUs); muxer.writeSampleData(trackInfo.index, data, isKeyFrame, presentationTimeUs);
previousTrackType = trackType; previousTrackType = trackType;
return true; return true;
} }
@ -223,17 +211,22 @@ import org.checkerframework.checker.nullness.qual.RequiresNonNull;
* @param trackType The {@link C.TrackType track type}. * @param trackType The {@link C.TrackType track type}.
*/ */
public void endTrack(@C.TrackType int trackType) { public void endTrack(@C.TrackType int trackType) {
listener.onTrackEnded( @Nullable TrackInfo trackInfo = trackTypeToInfo.get(trackType);
trackType, if (trackInfo == null) {
getTrackAverageBitrate(trackType), // SparseArray.get() returns null by default if the value is not found.
trackTypeToSampleCount.get(trackType, /* valueIfKeyNotFound= */ 0)); return;
}
trackTypeToIndex.delete(trackType); maxEndedTrackTimeUs = max(maxEndedTrackTimeUs, trackInfo.timeUs);
if (trackTypeToIndex.size() == 0) { listener.onTrackEnded(
trackType, trackInfo.format, trackInfo.getAverageBitrate(), trackInfo.sampleCount);
trackTypeToInfo.delete(trackType);
if (trackTypeToInfo.size() == 0) {
abortScheduledExecutorService.shutdownNow(); abortScheduledExecutorService.shutdownNow();
if (!isEnded) { if (!isEnded) {
isEnded = true; isEnded = true;
listener.onEnded(getDurationMs(), getCurrentOutputSizeBytes()); listener.onEnded(Util.usToMs(maxEndedTrackTimeUs), getCurrentOutputSizeBytes());
} }
} }
} }
@ -273,18 +266,20 @@ import org.checkerframework.checker.nullness.qual.RequiresNonNull;
* track of the given track type. * track of the given track type.
*/ */
private boolean canWriteSampleOfType(int trackType) { private boolean canWriteSampleOfType(int trackType) {
long trackTimeUs = trackTypeToTimeUs.get(trackType, /* valueIfKeyNotFound= */ C.TIME_UNSET); @Nullable TrackInfo trackInfo = trackTypeToInfo.get(trackType);
checkState(trackTimeUs != C.TIME_UNSET); // SparseArray.get() returns null by default if the value is not found.
checkArgument(trackInfo != null, "There is no track of type " + trackType);
if (!isReady) { if (!isReady) {
return false; return false;
} }
if (trackTypeToIndex.size() == 1) { if (trackTypeToInfo.size() == 1) {
return true; return true;
} }
if (trackType != previousTrackType) { if (trackType != previousTrackType) {
minTrackTimeUs = minValue(trackTypeToTimeUs); minTrackTimeUs = getMinTrackTimeUs(trackTypeToInfo);
} }
return trackTimeUs - minTrackTimeUs <= MAX_TRACK_WRITE_AHEAD_US; return trackInfo.timeUs - minTrackTimeUs <= MAX_TRACK_WRITE_AHEAD_US;
} }
@RequiresNonNull("muxer") @RequiresNonNull("muxer")
@ -327,17 +322,6 @@ import org.checkerframework.checker.nullness.qual.RequiresNonNull;
} }
} }
/**
* Returns the duration of the longest track in milliseconds, or {@link C#TIME_UNSET} if there is
* no track.
*/
private long getDurationMs() {
if (trackTypeToTimeUs.size() == 0) {
return C.TIME_UNSET;
}
return Util.usToMs(maxValue(trackTypeToTimeUs));
}
/** Returns the current size in bytes of the output, or {@link C#LENGTH_UNSET} if unavailable. */ /** Returns the current size in bytes of the output, or {@link C#LENGTH_UNSET} if unavailable. */
private long getCurrentOutputSizeBytes() { private long getCurrentOutputSizeBytes() {
long fileSize = C.LENGTH_UNSET; long fileSize = C.LENGTH_UNSET;
@ -351,22 +335,47 @@ import org.checkerframework.checker.nullness.qual.RequiresNonNull;
return fileSize > 0 ? fileSize : C.LENGTH_UNSET; return fileSize > 0 ? fileSize : C.LENGTH_UNSET;
} }
private static long getMinTrackTimeUs(SparseArray<TrackInfo> trackTypeToInfo) {
if (trackTypeToInfo.size() == 0) {
return C.TIME_UNSET;
}
long minTrackTimeUs = Long.MAX_VALUE;
for (int i = 0; i < trackTypeToInfo.size(); i++) {
minTrackTimeUs = min(minTrackTimeUs, trackTypeToInfo.valueAt(i).timeUs);
}
return minTrackTimeUs;
}
private static final class TrackInfo {
public final Format format;
public final int index;
public long bytesWritten;
public int sampleCount;
public long timeUs;
public TrackInfo(Format format, int index) {
this.format = format;
this.index = index;
}
/** /**
* Returns the average bitrate of data written to the track of the provided {@code trackType}, or * Returns the average bitrate of data written to the track, or {@link C#RATE_UNSET_INT} if
* {@link C#RATE_UNSET_INT} if there is no track data. * there is no track data.
*/ */
private int getTrackAverageBitrate(@C.TrackType int trackType) { public int getAverageBitrate() {
long trackDurationUs = trackTypeToTimeUs.get(trackType, /* valueIfKeyNotFound= */ -1); if (timeUs <= 0 || bytesWritten <= 0) {
long trackBytes = trackTypeToBytesWritten.get(trackType, /* valueIfKeyNotFound= */ -1);
if (trackDurationUs <= 0 || trackBytes <= 0) {
return C.RATE_UNSET_INT; return C.RATE_UNSET_INT;
} }
// The number of bytes written is not a timestamp, however this utility method provides // The number of bytes written is not a timestamp, however this utility method provides
// overflow-safe multiplication & division. // overflow-safe multiplication & division.
return (int) return (int)
Util.scaleLargeTimestamp( Util.scaleLargeTimestamp(
/* timestamp= */ trackBytes, /* timestamp= */ bytesWritten,
/* multiplier= */ C.BITS_PER_BYTE * C.MICROS_PER_SECOND, /* multiplier= */ C.BITS_PER_BYTE * C.MICROS_PER_SECOND,
/* divisor= */ trackDurationUs); /* divisor= */ timeUs);
}
} }
} }

View File

@ -465,7 +465,8 @@ import org.checkerframework.checker.nullness.qual.MonotonicNonNull;
// MuxerWrapper.Listener implementation. // MuxerWrapper.Listener implementation.
@Override @Override
public void onTrackEnded(@C.TrackType int trackType, int averageBitrate, int sampleCount) { public void onTrackEnded(
@C.TrackType int trackType, Format format, int averageBitrate, int sampleCount) {
if (trackType == C.TRACK_TYPE_AUDIO) { if (trackType == C.TRACK_TYPE_AUDIO) {
transformationResultBuilder.setAverageAudioBitrate(averageBitrate); transformationResultBuilder.setAverageAudioBitrate(averageBitrate);
} else if (trackType == C.TRACK_TYPE_VIDEO) { } else if (trackType == C.TRACK_TYPE_VIDEO) {