diff --git a/libraries/muxer/src/main/java/androidx/media3/muxer/AnnexBUtils.java b/libraries/muxer/src/main/java/androidx/media3/muxer/AnnexBUtils.java index 85fc85e105..d1905272de 100644 --- a/libraries/muxer/src/main/java/androidx/media3/muxer/AnnexBUtils.java +++ b/libraries/muxer/src/main/java/androidx/media3/muxer/AnnexBUtils.java @@ -15,7 +15,8 @@ */ package androidx.media3.muxer; -import androidx.media3.common.C; +import static androidx.media3.common.util.Assertions.checkState; + import androidx.media3.common.MimeTypes; import com.google.common.collect.ImmutableList; import java.nio.ByteBuffer; @@ -40,64 +41,34 @@ import java.nio.ByteBuffer; return ImmutableList.of(); } - int nalStartIndex = C.INDEX_UNSET; - int inputLimit = input.limit(); - boolean readingNalUnit = false; + // The algorithm always searches for 0x000001 start code but it will work for 0x00000001 start + // code as well because the first 0 will be considered as a leading 0 and will be skipped. - // The input must start with a NAL unit. - for (int i = 0; i < inputLimit; i++) { - if (isThreeByteNalStartCode(input, i)) { - nalStartIndex = i + THREE_BYTE_NAL_START_CODE_SIZE; - readingNalUnit = true; - break; - } else if (input.get(i) == 0) { - // Skip the leading zeroes. - } else { - throw new IllegalStateException("Sample does not start with a NAL unit"); - } - } + int nalStartCodeIndex = skipLeadingZerosAndFindNalStartCodeIndex(input, /* currentIndex= */ 0); + + int nalStartIndex = nalStartCodeIndex + THREE_BYTE_NAL_START_CODE_SIZE; + boolean readingNalUnit = true; ImmutableList.Builder nalUnits = new ImmutableList.Builder<>(); - // Look for start code 0x000001. The logic will work for 0x00000001 start code as well because a - // NAL unit gets ended even when 0x000000 (which is a prefix of 0x00000001 start code) is found. - for (int i = nalStartIndex; i < inputLimit; ) { + int i = nalStartIndex; + while (i < input.limit()) { if (readingNalUnit) { - // Found next start code 0x000001. - if (isThreeByteNalStartCode(input, i)) { - nalUnits.add(getBytes(input, nalStartIndex, i - nalStartIndex)); - i = i + THREE_BYTE_NAL_START_CODE_SIZE; - nalStartIndex = i; - continue; - } else if (isThreeBytesZeroSequence(input, i)) { - // Found code 0x000000; The previous NAL unit should be ended. - nalUnits.add(getBytes(input, nalStartIndex, i - nalStartIndex)); - // Stop reading NAL unit until next start code is found. - readingNalUnit = false; - i++; - } else { - // Continue reading NAL unit. - i++; - } + int nalEndIndex = findNalEndIndex(input, i); + nalUnits.add(getBytes(input, nalStartIndex, nalEndIndex - nalStartIndex)); + i = nalEndIndex; + readingNalUnit = false; } else { - // Found new start code 0x000001. - if (isThreeByteNalStartCode(input, i)) { - i = i + THREE_BYTE_NAL_START_CODE_SIZE; - nalStartIndex = i; + int nextNalStartCodeIndex = skipLeadingZerosAndFindNalStartCodeIndex(input, i); + if (nextNalStartCodeIndex != input.limit()) { + nalStartIndex = nextNalStartCodeIndex + THREE_BYTE_NAL_START_CODE_SIZE; + i = nalStartIndex; readingNalUnit = true; - } else if (input.get(i) == 0x00) { - // Skip trailing zeroes. - i++; } else { - // Found garbage data. - throw new IllegalStateException("Invalid NAL units"); + break; } } - - // Add the last NAL unit. - if (i == inputLimit && readingNalUnit) { - nalUnits.add(getBytes(input, nalStartIndex, i - nalStartIndex)); - } } + input.rewind(); return nalUnits.build(); } @@ -137,18 +108,107 @@ import java.nio.ByteBuffer; || sampleMimeType.equals(MimeTypes.VIDEO_H265); } - private static boolean isThreeByteNalStartCode(ByteBuffer input, int currentIndex) { - return (currentIndex <= input.limit() - THREE_BYTE_NAL_START_CODE_SIZE - && input.get(currentIndex) == 0x00 - && input.get(currentIndex + 1) == 0x00 - && input.get(currentIndex + 2) == 0x01); + /** + * Returns the end position (exclusive) of the current NAL unit within the input. + * + *

A NAL unit is terminated by one of the following sequences: + * + *

+ * + * @param input The {@link ByteBuffer} containing NAL units. + * @param currentIndex The starting position for the search. + * @return The NAL unit end index (exclusive). + */ + private static int findNalEndIndex(ByteBuffer input, int currentIndex) { + while (currentIndex <= input.limit() - 4) { + int fourBytes = input.getInt(currentIndex); + // Check if the first 3 bytes are 0x000000 or 0x000001. + if ((fourBytes & 0xFFFFFF00) == 0 || (fourBytes & 0xFFFFFF00) == 0x00000100) { + return currentIndex; + } + + // Check if the last 3 bytes are 0x000000 or 0x000001. + if ((fourBytes & 0x00FFFFFF) == 0 || (fourBytes & 0x00FFFFFF) == 0x00000001) { + return currentIndex + 1; + } + + // Check if the last 2 bytes are prefix of 0x000000 or 0x000001. + if ((fourBytes & 0x0000FFFF) == 0) { + currentIndex = currentIndex + 2; + } else if ((fourBytes & 0x000000FF) + == 0) { // Check if the last byte is prefix of 0x000000 or 0x000001. + currentIndex = currentIndex + 3; + } else { + currentIndex = currentIndex + 4; + } + } + + // Handle remaining bytes if any (less than 4). + // Last 3 bytes could be 0x000000 or 0x000001. + if (currentIndex == input.limit() - THREE_BYTE_NAL_START_CODE_SIZE) { + short firstTwoBytes = input.getShort(currentIndex); + byte lastByte = input.get(currentIndex + 2); + if (firstTwoBytes == 0 && (lastByte == 0 || lastByte == 1)) { + return currentIndex; + } + } + return input.limit(); } - private static boolean isThreeBytesZeroSequence(ByteBuffer input, int currentIndex) { - return (currentIndex <= input.limit() - THREE_BYTE_NAL_START_CODE_SIZE - && input.get(currentIndex) == 0x00 - && input.get(currentIndex + 1) == 0x00 - && input.get(currentIndex + 2) == 0x00); + /** + * Skips leading zeros and locates the start of the next NAL unit (0x000001). + * + * @param input The {@link ByteBuffer} containing NAL units. + * @param currentIndex The starting position for the search. + * @return The index of the NAL start code, or the end of the input if NAL start code is not + * found. + */ + private static int skipLeadingZerosAndFindNalStartCodeIndex(ByteBuffer input, int currentIndex) { + while (currentIndex <= input.limit() - 4) { + int fourBytes = input.getInt(currentIndex); + + // Check if the first 3 bytes is 0x000001. + if ((fourBytes & 0xFFFFFF00) == 0x00000100) { + return currentIndex; + } + + // Otherwise the first 3 bytes must be 0. + checkState((fourBytes & 0xFFFFFF00) == 0, "Invalid Nal units"); + + // Check if the last byte is 1. It then makes last three bytes 0x000001. + if ((fourBytes & 0x000000FF) == 1) { + return currentIndex + 1; + } + + // Otherwise the last byte must be 0; + checkState((fourBytes & 0x000000FF) == 0, "Invalid Nal units"); + + // Last three zeroes can be a prefix of the NAL start code 0x000001. + currentIndex = currentIndex + 1; + } + + // Handle remaining bytes if any (less than 4). + // Last 3 bytes could be 0x000001. + if (currentIndex <= input.limit() - THREE_BYTE_NAL_START_CODE_SIZE) { + short firstTwoBytes = input.getShort(currentIndex); + checkState(firstTwoBytes == 0, "Invalid NAL units"); + byte lastByte = input.get(currentIndex + 2); + if (lastByte == 1) { + return currentIndex; + } + checkState(lastByte == 0, "Invalid NAL units"); + } else { + // Remaining bytes must be 0. + while (currentIndex < input.limit()) { + checkState(input.get(currentIndex) == 0, "Invalid NAL units"); + currentIndex++; + } + } + return input.limit(); } private static ByteBuffer getBytes(ByteBuffer buf, int offset, int length) {