diff --git a/libraries/transformer/src/main/java/androidx/media3/transformer/Muxer.java b/libraries/transformer/src/main/java/androidx/media3/transformer/Muxer.java index be0b0de1ef..cebf94d7eb 100644 --- a/libraries/transformer/src/main/java/androidx/media3/transformer/Muxer.java +++ b/libraries/transformer/src/main/java/androidx/media3/transformer/Muxer.java @@ -80,10 +80,10 @@ public interface Muxer { } /** - * Adds a track with the specified format, and returns its index (to be passed in subsequent calls - * to {@link #writeSampleData(int, ByteBuffer, boolean, long)}). + * Adds a track with the specified format. * * @param format The {@link Format} of the track. + * @return The index for this track, which should be passed to {@link #writeSampleData}. * @throws MuxerException If the muxer encounters a problem while adding the track. */ int addTrack(Format format) throws MuxerException; diff --git a/libraries/transformer/src/main/java/androidx/media3/transformer/MuxerWrapper.java b/libraries/transformer/src/main/java/androidx/media3/transformer/MuxerWrapper.java index 98b09dfc69..0e3e03412d 100644 --- a/libraries/transformer/src/main/java/androidx/media3/transformer/MuxerWrapper.java +++ b/libraries/transformer/src/main/java/androidx/media3/transformer/MuxerWrapper.java @@ -16,15 +16,15 @@ package androidx.media3.transformer; +import static androidx.media3.common.util.Assertions.checkArgument; import static androidx.media3.common.util.Assertions.checkNotNull; import static androidx.media3.common.util.Assertions.checkState; -import static androidx.media3.common.util.Util.maxValue; -import static androidx.media3.common.util.Util.minValue; +import static java.lang.Math.max; +import static java.lang.Math.min; import static java.util.concurrent.TimeUnit.MILLISECONDS; import android.os.ParcelFileDescriptor; -import android.util.SparseIntArray; -import android.util.SparseLongArray; +import android.util.SparseArray; import androidx.annotation.IntRange; import androidx.annotation.Nullable; import androidx.media3.common.C; @@ -49,7 +49,8 @@ import org.checkerframework.checker.nullness.qual.RequiresNonNull; /* package */ final class MuxerWrapper { public interface Listener { - void onTrackEnded(@C.TrackType int trackType, int averageBitrate, int sampleCount); + void onTrackEnded( + @C.TrackType int trackType, Format format, int averageBitrate, int sampleCount); void onEnded(long durationMs, long fileSizeBytes); @@ -68,18 +69,15 @@ import org.checkerframework.checker.nullness.qual.RequiresNonNull; @Nullable private final ParcelFileDescriptor outputParcelFileDescriptor; private final Muxer.Factory muxerFactory; private final Listener listener; - private final SparseIntArray trackTypeToIndex; - private final SparseIntArray trackTypeToSampleCount; - private final SparseLongArray trackTypeToTimeUs; - private final SparseLongArray trackTypeToBytesWritten; + private final SparseArray trackTypeToInfo; private final ScheduledExecutorService abortScheduledExecutorService; private int trackCount; - private int trackFormatCount; private boolean isReady; private boolean isEnded; private @C.TrackType int previousTrackType; private long minTrackTimeUs; + private long maxEndedTrackTimeUs; private @MonotonicNonNull ScheduledFuture abortScheduledFuture; private boolean isAborted; private @MonotonicNonNull Muxer muxer; @@ -98,10 +96,7 @@ import org.checkerframework.checker.nullness.qual.RequiresNonNull; this.muxerFactory = muxerFactory; this.listener = listener; - trackTypeToIndex = new SparseIntArray(); - trackTypeToSampleCount = new SparseIntArray(); - trackTypeToTimeUs = new SparseLongArray(); - trackTypeToBytesWritten = new SparseLongArray(); + trackTypeToInfo = new SparseArray<>(); previousTrackType = C.TRACK_TYPE_NONE; abortScheduledExecutorService = Executors.newSingleThreadScheduledExecutor(); } @@ -117,7 +112,7 @@ import org.checkerframework.checker.nullness.qual.RequiresNonNull; */ public void setTrackCount(@IntRange(from = 1) int trackCount) { checkState( - trackFormatCount == 0, + trackTypeToInfo.size() == 0, "The track count cannot be set after track formats have been added."); this.trackCount = trackCount; } @@ -151,25 +146,21 @@ import org.checkerframework.checker.nullness.qual.RequiresNonNull; */ public void addTrackFormat(Format format) throws Muxer.MuxerException { checkState(trackCount > 0, "The track count should be set before the formats are added."); - checkState(trackFormatCount < trackCount, "All track formats have already been added."); + checkState(trackTypeToInfo.size() < trackCount, "All track formats have already been added."); @Nullable String sampleMimeType = format.sampleMimeType; boolean isAudio = MimeTypes.isAudio(sampleMimeType); boolean isVideo = MimeTypes.isVideo(sampleMimeType); checkState(isAudio || isVideo, "Unsupported track format: " + sampleMimeType); @C.TrackType int trackType = MimeTypes.getTrackType(sampleMimeType); + // SparseArray.get() returns null by default if the value is not found. checkState( - trackTypeToIndex.get(trackType, /* valueIfKeyNotFound= */ C.INDEX_UNSET) == C.INDEX_UNSET, - "There is already a track of type " + trackType); + trackTypeToInfo.get(trackType) == null, "There is already a track of type " + trackType); ensureMuxerInitialized(); - int trackIndex = muxer.addTrack(format); - trackTypeToIndex.put(trackType, trackIndex); - trackTypeToSampleCount.put(trackType, 0); - trackTypeToTimeUs.put(trackType, 0L); - trackTypeToBytesWritten.put(trackType, 0L); - trackFormatCount++; - if (trackFormatCount == trackCount) { + TrackInfo trackInfo = new TrackInfo(format, muxer.addTrack(format)); + trackTypeToInfo.put(trackType, trackInfo); + if (trackTypeToInfo.size() == trackCount) { isReady = true; resetAbortTimer(); } @@ -193,25 +184,22 @@ import org.checkerframework.checker.nullness.qual.RequiresNonNull; public boolean writeSample( @C.TrackType int trackType, ByteBuffer data, boolean isKeyFrame, long presentationTimeUs) throws Muxer.MuxerException { - int trackIndex = trackTypeToIndex.get(trackType, /* valueIfKeyNotFound= */ C.INDEX_UNSET); - checkState( - trackIndex != C.INDEX_UNSET, - "Could not write sample because there is no track of type " + trackType); + @Nullable TrackInfo trackInfo = trackTypeToInfo.get(trackType); + // SparseArray.get() returns null by default if the value is not found. + checkArgument( + trackInfo != null, "Could not write sample because there is no track of type " + trackType); if (!canWriteSampleOfType(trackType)) { return false; } - trackTypeToSampleCount.put(trackType, trackTypeToSampleCount.get(trackType) + 1); - trackTypeToBytesWritten.put( - trackType, trackTypeToBytesWritten.get(trackType) + data.remaining()); - if (trackTypeToTimeUs.get(trackType) < presentationTimeUs) { - trackTypeToTimeUs.put(trackType, presentationTimeUs); - } + trackInfo.sampleCount++; + trackInfo.bytesWritten += data.remaining(); + trackInfo.timeUs = max(trackInfo.timeUs, presentationTimeUs); checkNotNull(muxer); resetAbortTimer(); - muxer.writeSampleData(trackIndex, data, isKeyFrame, presentationTimeUs); + muxer.writeSampleData(trackInfo.index, data, isKeyFrame, presentationTimeUs); previousTrackType = trackType; return true; } @@ -223,17 +211,22 @@ import org.checkerframework.checker.nullness.qual.RequiresNonNull; * @param trackType The {@link C.TrackType track type}. */ public void endTrack(@C.TrackType int trackType) { - listener.onTrackEnded( - trackType, - getTrackAverageBitrate(trackType), - trackTypeToSampleCount.get(trackType, /* valueIfKeyNotFound= */ 0)); + @Nullable TrackInfo trackInfo = trackTypeToInfo.get(trackType); + if (trackInfo == null) { + // SparseArray.get() returns null by default if the value is not found. + return; + } - trackTypeToIndex.delete(trackType); - if (trackTypeToIndex.size() == 0) { + maxEndedTrackTimeUs = max(maxEndedTrackTimeUs, trackInfo.timeUs); + listener.onTrackEnded( + trackType, trackInfo.format, trackInfo.getAverageBitrate(), trackInfo.sampleCount); + + trackTypeToInfo.delete(trackType); + if (trackTypeToInfo.size() == 0) { abortScheduledExecutorService.shutdownNow(); if (!isEnded) { isEnded = true; - listener.onEnded(getDurationMs(), getCurrentOutputSizeBytes()); + listener.onEnded(Util.usToMs(maxEndedTrackTimeUs), getCurrentOutputSizeBytes()); } } } @@ -273,18 +266,20 @@ import org.checkerframework.checker.nullness.qual.RequiresNonNull; * track of the given track type. */ private boolean canWriteSampleOfType(int trackType) { - long trackTimeUs = trackTypeToTimeUs.get(trackType, /* valueIfKeyNotFound= */ C.TIME_UNSET); - checkState(trackTimeUs != C.TIME_UNSET); + @Nullable TrackInfo trackInfo = trackTypeToInfo.get(trackType); + // SparseArray.get() returns null by default if the value is not found. + checkArgument(trackInfo != null, "There is no track of type " + trackType); + if (!isReady) { return false; } - if (trackTypeToIndex.size() == 1) { + if (trackTypeToInfo.size() == 1) { return true; } if (trackType != previousTrackType) { - minTrackTimeUs = minValue(trackTypeToTimeUs); + minTrackTimeUs = getMinTrackTimeUs(trackTypeToInfo); } - return trackTimeUs - minTrackTimeUs <= MAX_TRACK_WRITE_AHEAD_US; + return trackInfo.timeUs - minTrackTimeUs <= MAX_TRACK_WRITE_AHEAD_US; } @RequiresNonNull("muxer") @@ -327,17 +322,6 @@ import org.checkerframework.checker.nullness.qual.RequiresNonNull; } } - /** - * Returns the duration of the longest track in milliseconds, or {@link C#TIME_UNSET} if there is - * no track. - */ - private long getDurationMs() { - if (trackTypeToTimeUs.size() == 0) { - return C.TIME_UNSET; - } - return Util.usToMs(maxValue(trackTypeToTimeUs)); - } - /** Returns the current size in bytes of the output, or {@link C#LENGTH_UNSET} if unavailable. */ private long getCurrentOutputSizeBytes() { long fileSize = C.LENGTH_UNSET; @@ -351,22 +335,47 @@ import org.checkerframework.checker.nullness.qual.RequiresNonNull; return fileSize > 0 ? fileSize : C.LENGTH_UNSET; } - /** - * Returns the average bitrate of data written to the track of the provided {@code trackType}, or - * {@link C#RATE_UNSET_INT} if there is no track data. - */ - private int getTrackAverageBitrate(@C.TrackType int trackType) { - long trackDurationUs = trackTypeToTimeUs.get(trackType, /* valueIfKeyNotFound= */ -1); - long trackBytes = trackTypeToBytesWritten.get(trackType, /* valueIfKeyNotFound= */ -1); - if (trackDurationUs <= 0 || trackBytes <= 0) { - return C.RATE_UNSET_INT; + private static long getMinTrackTimeUs(SparseArray trackTypeToInfo) { + if (trackTypeToInfo.size() == 0) { + return C.TIME_UNSET; + } + + long minTrackTimeUs = Long.MAX_VALUE; + for (int i = 0; i < trackTypeToInfo.size(); i++) { + minTrackTimeUs = min(minTrackTimeUs, trackTypeToInfo.valueAt(i).timeUs); + } + return minTrackTimeUs; + } + + private static final class TrackInfo { + public final Format format; + public final int index; + + public long bytesWritten; + public int sampleCount; + public long timeUs; + + public TrackInfo(Format format, int index) { + this.format = format; + this.index = index; + } + + /** + * Returns the average bitrate of data written to the track, or {@link C#RATE_UNSET_INT} if + * there is no track data. + */ + public int getAverageBitrate() { + if (timeUs <= 0 || bytesWritten <= 0) { + return C.RATE_UNSET_INT; + } + + // The number of bytes written is not a timestamp, however this utility method provides + // overflow-safe multiplication & division. + return (int) + Util.scaleLargeTimestamp( + /* timestamp= */ bytesWritten, + /* multiplier= */ C.BITS_PER_BYTE * C.MICROS_PER_SECOND, + /* divisor= */ timeUs); } - // The number of bytes written is not a timestamp, however this utility method provides - // overflow-safe multiplication & division. - return (int) - Util.scaleLargeTimestamp( - /* timestamp= */ trackBytes, - /* multiplier= */ C.BITS_PER_BYTE * C.MICROS_PER_SECOND, - /* divisor= */ trackDurationUs); } } diff --git a/libraries/transformer/src/main/java/androidx/media3/transformer/TransformerInternal.java b/libraries/transformer/src/main/java/androidx/media3/transformer/TransformerInternal.java index 10e8b5ea5f..2bd8705140 100644 --- a/libraries/transformer/src/main/java/androidx/media3/transformer/TransformerInternal.java +++ b/libraries/transformer/src/main/java/androidx/media3/transformer/TransformerInternal.java @@ -465,7 +465,8 @@ import org.checkerframework.checker.nullness.qual.MonotonicNonNull; // MuxerWrapper.Listener implementation. @Override - public void onTrackEnded(@C.TrackType int trackType, int averageBitrate, int sampleCount) { + public void onTrackEnded( + @C.TrackType int trackType, Format format, int averageBitrate, int sampleCount) { if (trackType == C.TRACK_TYPE_AUDIO) { transformationResultBuilder.setAverageAudioBitrate(averageBitrate); } else if (trackType == C.TRACK_TYPE_VIDEO) {