Refactor MuxerWrapper handling of track details.

This brings together the multiple details about a muxer track, and
reduces the need for additional variables and more complicated track
tracking.

PiperOrigin-RevId: 499872145
This commit is contained in:
samrobinson 2023-01-05 15:37:39 +00:00 committed by christosts
parent a59c2b8222
commit f32b632b09
3 changed files with 88 additions and 78 deletions

View File

@ -80,10 +80,10 @@ public interface Muxer {
}
/**
* Adds a track with the specified format, and returns its index (to be passed in subsequent calls
* to {@link #writeSampleData(int, ByteBuffer, boolean, long)}).
* Adds a track with the specified format.
*
* @param format The {@link Format} of the track.
* @return The index for this track, which should be passed to {@link #writeSampleData}.
* @throws MuxerException If the muxer encounters a problem while adding the track.
*/
int addTrack(Format format) throws MuxerException;

View File

@ -16,15 +16,15 @@
package androidx.media3.transformer;
import static androidx.media3.common.util.Assertions.checkArgument;
import static androidx.media3.common.util.Assertions.checkNotNull;
import static androidx.media3.common.util.Assertions.checkState;
import static androidx.media3.common.util.Util.maxValue;
import static androidx.media3.common.util.Util.minValue;
import static java.lang.Math.max;
import static java.lang.Math.min;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import android.os.ParcelFileDescriptor;
import android.util.SparseIntArray;
import android.util.SparseLongArray;
import android.util.SparseArray;
import androidx.annotation.IntRange;
import androidx.annotation.Nullable;
import androidx.media3.common.C;
@ -49,7 +49,8 @@ import org.checkerframework.checker.nullness.qual.RequiresNonNull;
/* package */ final class MuxerWrapper {
public interface Listener {
void onTrackEnded(@C.TrackType int trackType, int averageBitrate, int sampleCount);
void onTrackEnded(
@C.TrackType int trackType, Format format, int averageBitrate, int sampleCount);
void onEnded(long durationMs, long fileSizeBytes);
@ -68,18 +69,15 @@ import org.checkerframework.checker.nullness.qual.RequiresNonNull;
@Nullable private final ParcelFileDescriptor outputParcelFileDescriptor;
private final Muxer.Factory muxerFactory;
private final Listener listener;
private final SparseIntArray trackTypeToIndex;
private final SparseIntArray trackTypeToSampleCount;
private final SparseLongArray trackTypeToTimeUs;
private final SparseLongArray trackTypeToBytesWritten;
private final SparseArray<TrackInfo> trackTypeToInfo;
private final ScheduledExecutorService abortScheduledExecutorService;
private int trackCount;
private int trackFormatCount;
private boolean isReady;
private boolean isEnded;
private @C.TrackType int previousTrackType;
private long minTrackTimeUs;
private long maxEndedTrackTimeUs;
private @MonotonicNonNull ScheduledFuture<?> abortScheduledFuture;
private boolean isAborted;
private @MonotonicNonNull Muxer muxer;
@ -98,10 +96,7 @@ import org.checkerframework.checker.nullness.qual.RequiresNonNull;
this.muxerFactory = muxerFactory;
this.listener = listener;
trackTypeToIndex = new SparseIntArray();
trackTypeToSampleCount = new SparseIntArray();
trackTypeToTimeUs = new SparseLongArray();
trackTypeToBytesWritten = new SparseLongArray();
trackTypeToInfo = new SparseArray<>();
previousTrackType = C.TRACK_TYPE_NONE;
abortScheduledExecutorService = Executors.newSingleThreadScheduledExecutor();
}
@ -117,7 +112,7 @@ import org.checkerframework.checker.nullness.qual.RequiresNonNull;
*/
public void setTrackCount(@IntRange(from = 1) int trackCount) {
checkState(
trackFormatCount == 0,
trackTypeToInfo.size() == 0,
"The track count cannot be set after track formats have been added.");
this.trackCount = trackCount;
}
@ -151,25 +146,21 @@ import org.checkerframework.checker.nullness.qual.RequiresNonNull;
*/
public void addTrackFormat(Format format) throws Muxer.MuxerException {
checkState(trackCount > 0, "The track count should be set before the formats are added.");
checkState(trackFormatCount < trackCount, "All track formats have already been added.");
checkState(trackTypeToInfo.size() < trackCount, "All track formats have already been added.");
@Nullable String sampleMimeType = format.sampleMimeType;
boolean isAudio = MimeTypes.isAudio(sampleMimeType);
boolean isVideo = MimeTypes.isVideo(sampleMimeType);
checkState(isAudio || isVideo, "Unsupported track format: " + sampleMimeType);
@C.TrackType int trackType = MimeTypes.getTrackType(sampleMimeType);
// SparseArray.get() returns null by default if the value is not found.
checkState(
trackTypeToIndex.get(trackType, /* valueIfKeyNotFound= */ C.INDEX_UNSET) == C.INDEX_UNSET,
"There is already a track of type " + trackType);
trackTypeToInfo.get(trackType) == null, "There is already a track of type " + trackType);
ensureMuxerInitialized();
int trackIndex = muxer.addTrack(format);
trackTypeToIndex.put(trackType, trackIndex);
trackTypeToSampleCount.put(trackType, 0);
trackTypeToTimeUs.put(trackType, 0L);
trackTypeToBytesWritten.put(trackType, 0L);
trackFormatCount++;
if (trackFormatCount == trackCount) {
TrackInfo trackInfo = new TrackInfo(format, muxer.addTrack(format));
trackTypeToInfo.put(trackType, trackInfo);
if (trackTypeToInfo.size() == trackCount) {
isReady = true;
resetAbortTimer();
}
@ -193,25 +184,22 @@ import org.checkerframework.checker.nullness.qual.RequiresNonNull;
public boolean writeSample(
@C.TrackType int trackType, ByteBuffer data, boolean isKeyFrame, long presentationTimeUs)
throws Muxer.MuxerException {
int trackIndex = trackTypeToIndex.get(trackType, /* valueIfKeyNotFound= */ C.INDEX_UNSET);
checkState(
trackIndex != C.INDEX_UNSET,
"Could not write sample because there is no track of type " + trackType);
@Nullable TrackInfo trackInfo = trackTypeToInfo.get(trackType);
// SparseArray.get() returns null by default if the value is not found.
checkArgument(
trackInfo != null, "Could not write sample because there is no track of type " + trackType);
if (!canWriteSampleOfType(trackType)) {
return false;
}
trackTypeToSampleCount.put(trackType, trackTypeToSampleCount.get(trackType) + 1);
trackTypeToBytesWritten.put(
trackType, trackTypeToBytesWritten.get(trackType) + data.remaining());
if (trackTypeToTimeUs.get(trackType) < presentationTimeUs) {
trackTypeToTimeUs.put(trackType, presentationTimeUs);
}
trackInfo.sampleCount++;
trackInfo.bytesWritten += data.remaining();
trackInfo.timeUs = max(trackInfo.timeUs, presentationTimeUs);
checkNotNull(muxer);
resetAbortTimer();
muxer.writeSampleData(trackIndex, data, isKeyFrame, presentationTimeUs);
muxer.writeSampleData(trackInfo.index, data, isKeyFrame, presentationTimeUs);
previousTrackType = trackType;
return true;
}
@ -223,17 +211,22 @@ import org.checkerframework.checker.nullness.qual.RequiresNonNull;
* @param trackType The {@link C.TrackType track type}.
*/
public void endTrack(@C.TrackType int trackType) {
listener.onTrackEnded(
trackType,
getTrackAverageBitrate(trackType),
trackTypeToSampleCount.get(trackType, /* valueIfKeyNotFound= */ 0));
@Nullable TrackInfo trackInfo = trackTypeToInfo.get(trackType);
if (trackInfo == null) {
// SparseArray.get() returns null by default if the value is not found.
return;
}
trackTypeToIndex.delete(trackType);
if (trackTypeToIndex.size() == 0) {
maxEndedTrackTimeUs = max(maxEndedTrackTimeUs, trackInfo.timeUs);
listener.onTrackEnded(
trackType, trackInfo.format, trackInfo.getAverageBitrate(), trackInfo.sampleCount);
trackTypeToInfo.delete(trackType);
if (trackTypeToInfo.size() == 0) {
abortScheduledExecutorService.shutdownNow();
if (!isEnded) {
isEnded = true;
listener.onEnded(getDurationMs(), getCurrentOutputSizeBytes());
listener.onEnded(Util.usToMs(maxEndedTrackTimeUs), getCurrentOutputSizeBytes());
}
}
}
@ -273,18 +266,20 @@ import org.checkerframework.checker.nullness.qual.RequiresNonNull;
* track of the given track type.
*/
private boolean canWriteSampleOfType(int trackType) {
long trackTimeUs = trackTypeToTimeUs.get(trackType, /* valueIfKeyNotFound= */ C.TIME_UNSET);
checkState(trackTimeUs != C.TIME_UNSET);
@Nullable TrackInfo trackInfo = trackTypeToInfo.get(trackType);
// SparseArray.get() returns null by default if the value is not found.
checkArgument(trackInfo != null, "There is no track of type " + trackType);
if (!isReady) {
return false;
}
if (trackTypeToIndex.size() == 1) {
if (trackTypeToInfo.size() == 1) {
return true;
}
if (trackType != previousTrackType) {
minTrackTimeUs = minValue(trackTypeToTimeUs);
minTrackTimeUs = getMinTrackTimeUs(trackTypeToInfo);
}
return trackTimeUs - minTrackTimeUs <= MAX_TRACK_WRITE_AHEAD_US;
return trackInfo.timeUs - minTrackTimeUs <= MAX_TRACK_WRITE_AHEAD_US;
}
@RequiresNonNull("muxer")
@ -327,17 +322,6 @@ import org.checkerframework.checker.nullness.qual.RequiresNonNull;
}
}
/**
* Returns the duration of the longest track in milliseconds, or {@link C#TIME_UNSET} if there is
* no track.
*/
private long getDurationMs() {
if (trackTypeToTimeUs.size() == 0) {
return C.TIME_UNSET;
}
return Util.usToMs(maxValue(trackTypeToTimeUs));
}
/** Returns the current size in bytes of the output, or {@link C#LENGTH_UNSET} if unavailable. */
private long getCurrentOutputSizeBytes() {
long fileSize = C.LENGTH_UNSET;
@ -351,22 +335,47 @@ import org.checkerframework.checker.nullness.qual.RequiresNonNull;
return fileSize > 0 ? fileSize : C.LENGTH_UNSET;
}
/**
* Returns the average bitrate of data written to the track of the provided {@code trackType}, or
* {@link C#RATE_UNSET_INT} if there is no track data.
*/
private int getTrackAverageBitrate(@C.TrackType int trackType) {
long trackDurationUs = trackTypeToTimeUs.get(trackType, /* valueIfKeyNotFound= */ -1);
long trackBytes = trackTypeToBytesWritten.get(trackType, /* valueIfKeyNotFound= */ -1);
if (trackDurationUs <= 0 || trackBytes <= 0) {
return C.RATE_UNSET_INT;
private static long getMinTrackTimeUs(SparseArray<TrackInfo> trackTypeToInfo) {
if (trackTypeToInfo.size() == 0) {
return C.TIME_UNSET;
}
long minTrackTimeUs = Long.MAX_VALUE;
for (int i = 0; i < trackTypeToInfo.size(); i++) {
minTrackTimeUs = min(minTrackTimeUs, trackTypeToInfo.valueAt(i).timeUs);
}
return minTrackTimeUs;
}
private static final class TrackInfo {
public final Format format;
public final int index;
public long bytesWritten;
public int sampleCount;
public long timeUs;
public TrackInfo(Format format, int index) {
this.format = format;
this.index = index;
}
/**
* Returns the average bitrate of data written to the track, or {@link C#RATE_UNSET_INT} if
* there is no track data.
*/
public int getAverageBitrate() {
if (timeUs <= 0 || bytesWritten <= 0) {
return C.RATE_UNSET_INT;
}
// The number of bytes written is not a timestamp, however this utility method provides
// overflow-safe multiplication & division.
return (int)
Util.scaleLargeTimestamp(
/* timestamp= */ bytesWritten,
/* multiplier= */ C.BITS_PER_BYTE * C.MICROS_PER_SECOND,
/* divisor= */ timeUs);
}
// The number of bytes written is not a timestamp, however this utility method provides
// overflow-safe multiplication & division.
return (int)
Util.scaleLargeTimestamp(
/* timestamp= */ trackBytes,
/* multiplier= */ C.BITS_PER_BYTE * C.MICROS_PER_SECOND,
/* divisor= */ trackDurationUs);
}
}

View File

@ -465,7 +465,8 @@ import org.checkerframework.checker.nullness.qual.MonotonicNonNull;
// MuxerWrapper.Listener implementation.
@Override
public void onTrackEnded(@C.TrackType int trackType, int averageBitrate, int sampleCount) {
public void onTrackEnded(
@C.TrackType int trackType, Format format, int averageBitrate, int sampleCount) {
if (trackType == C.TRACK_TYPE_AUDIO) {
transformationResultBuilder.setAverageAudioBitrate(averageBitrate);
} else if (trackType == C.TRACK_TYPE_VIDEO) {