Read NAL unit data as 4 byte integer

When converting NAL units from AnnexB to Avcc format,
one byte at a time was read. In fact many bytes were read
multiple times due to suboptimal logic.

Changed the logic to read 4 bytes at once and also to avoid
reading same bytes again.

This improved the time taken for writing a batch of 30
samples from 40ms to 20ms.

PiperOrigin-RevId: 673025781
This commit is contained in:
sheenachhabra 2024-09-10 10:54:16 -07:00 committed by Copybara-Service
parent 35dc10aac8
commit 4be5b74366

View File

@ -15,7 +15,8 @@
*/
package androidx.media3.muxer;
import androidx.media3.common.C;
import static androidx.media3.common.util.Assertions.checkState;
import androidx.media3.common.MimeTypes;
import com.google.common.collect.ImmutableList;
import java.nio.ByteBuffer;
@ -40,64 +41,34 @@ import java.nio.ByteBuffer;
return ImmutableList.of();
}
int nalStartIndex = C.INDEX_UNSET;
int inputLimit = input.limit();
boolean readingNalUnit = false;
// The algorithm always searches for 0x000001 start code but it will work for 0x00000001 start
// code as well because the first 0 will be considered as a leading 0 and will be skipped.
// The input must start with a NAL unit.
for (int i = 0; i < inputLimit; i++) {
if (isThreeByteNalStartCode(input, i)) {
nalStartIndex = i + THREE_BYTE_NAL_START_CODE_SIZE;
readingNalUnit = true;
break;
} else if (input.get(i) == 0) {
// Skip the leading zeroes.
} else {
throw new IllegalStateException("Sample does not start with a NAL unit");
}
}
int nalStartCodeIndex = skipLeadingZerosAndFindNalStartCodeIndex(input, /* currentIndex= */ 0);
int nalStartIndex = nalStartCodeIndex + THREE_BYTE_NAL_START_CODE_SIZE;
boolean readingNalUnit = true;
ImmutableList.Builder<ByteBuffer> nalUnits = new ImmutableList.Builder<>();
// Look for start code 0x000001. The logic will work for 0x00000001 start code as well because a
// NAL unit gets ended even when 0x000000 (which is a prefix of 0x00000001 start code) is found.
for (int i = nalStartIndex; i < inputLimit; ) {
int i = nalStartIndex;
while (i < input.limit()) {
if (readingNalUnit) {
// Found next start code 0x000001.
if (isThreeByteNalStartCode(input, i)) {
nalUnits.add(getBytes(input, nalStartIndex, i - nalStartIndex));
i = i + THREE_BYTE_NAL_START_CODE_SIZE;
nalStartIndex = i;
continue;
} else if (isThreeBytesZeroSequence(input, i)) {
// Found code 0x000000; The previous NAL unit should be ended.
nalUnits.add(getBytes(input, nalStartIndex, i - nalStartIndex));
// Stop reading NAL unit until next start code is found.
readingNalUnit = false;
i++;
} else {
// Continue reading NAL unit.
i++;
}
int nalEndIndex = findNalEndIndex(input, i);
nalUnits.add(getBytes(input, nalStartIndex, nalEndIndex - nalStartIndex));
i = nalEndIndex;
readingNalUnit = false;
} else {
// Found new start code 0x000001.
if (isThreeByteNalStartCode(input, i)) {
i = i + THREE_BYTE_NAL_START_CODE_SIZE;
nalStartIndex = i;
int nextNalStartCodeIndex = skipLeadingZerosAndFindNalStartCodeIndex(input, i);
if (nextNalStartCodeIndex != input.limit()) {
nalStartIndex = nextNalStartCodeIndex + THREE_BYTE_NAL_START_CODE_SIZE;
i = nalStartIndex;
readingNalUnit = true;
} else if (input.get(i) == 0x00) {
// Skip trailing zeroes.
i++;
} else {
// Found garbage data.
throw new IllegalStateException("Invalid NAL units");
break;
}
}
// Add the last NAL unit.
if (i == inputLimit && readingNalUnit) {
nalUnits.add(getBytes(input, nalStartIndex, i - nalStartIndex));
}
}
input.rewind();
return nalUnits.build();
}
@ -137,18 +108,107 @@ import java.nio.ByteBuffer;
|| sampleMimeType.equals(MimeTypes.VIDEO_H265);
}
private static boolean isThreeByteNalStartCode(ByteBuffer input, int currentIndex) {
return (currentIndex <= input.limit() - THREE_BYTE_NAL_START_CODE_SIZE
&& input.get(currentIndex) == 0x00
&& input.get(currentIndex + 1) == 0x00
&& input.get(currentIndex + 2) == 0x01);
/**
* Returns the end position (exclusive) of the current NAL unit within the input.
*
* <p>A NAL unit is terminated by one of the following sequences:
*
* <ul>
* <li>0x000000
* <li>0x000001
* <li>The end of the input data.
* </ul>
*
* @param input The {@link ByteBuffer} containing NAL units.
* @param currentIndex The starting position for the search.
* @return The NAL unit end index (exclusive).
*/
private static int findNalEndIndex(ByteBuffer input, int currentIndex) {
while (currentIndex <= input.limit() - 4) {
int fourBytes = input.getInt(currentIndex);
// Check if the first 3 bytes are 0x000000 or 0x000001.
if ((fourBytes & 0xFFFFFF00) == 0 || (fourBytes & 0xFFFFFF00) == 0x00000100) {
return currentIndex;
}
// Check if the last 3 bytes are 0x000000 or 0x000001.
if ((fourBytes & 0x00FFFFFF) == 0 || (fourBytes & 0x00FFFFFF) == 0x00000001) {
return currentIndex + 1;
}
// Check if the last 2 bytes are prefix of 0x000000 or 0x000001.
if ((fourBytes & 0x0000FFFF) == 0) {
currentIndex = currentIndex + 2;
} else if ((fourBytes & 0x000000FF)
== 0) { // Check if the last byte is prefix of 0x000000 or 0x000001.
currentIndex = currentIndex + 3;
} else {
currentIndex = currentIndex + 4;
}
}
// Handle remaining bytes if any (less than 4).
// Last 3 bytes could be 0x000000 or 0x000001.
if (currentIndex == input.limit() - THREE_BYTE_NAL_START_CODE_SIZE) {
short firstTwoBytes = input.getShort(currentIndex);
byte lastByte = input.get(currentIndex + 2);
if (firstTwoBytes == 0 && (lastByte == 0 || lastByte == 1)) {
return currentIndex;
}
}
return input.limit();
}
private static boolean isThreeBytesZeroSequence(ByteBuffer input, int currentIndex) {
return (currentIndex <= input.limit() - THREE_BYTE_NAL_START_CODE_SIZE
&& input.get(currentIndex) == 0x00
&& input.get(currentIndex + 1) == 0x00
&& input.get(currentIndex + 2) == 0x00);
/**
* Skips leading zeros and locates the start of the next NAL unit (0x000001).
*
* @param input The {@link ByteBuffer} containing NAL units.
* @param currentIndex The starting position for the search.
* @return The index of the NAL start code, or the end of the input if NAL start code is not
* found.
*/
private static int skipLeadingZerosAndFindNalStartCodeIndex(ByteBuffer input, int currentIndex) {
while (currentIndex <= input.limit() - 4) {
int fourBytes = input.getInt(currentIndex);
// Check if the first 3 bytes is 0x000001.
if ((fourBytes & 0xFFFFFF00) == 0x00000100) {
return currentIndex;
}
// Otherwise the first 3 bytes must be 0.
checkState((fourBytes & 0xFFFFFF00) == 0, "Invalid Nal units");
// Check if the last byte is 1. It then makes last three bytes 0x000001.
if ((fourBytes & 0x000000FF) == 1) {
return currentIndex + 1;
}
// Otherwise the last byte must be 0;
checkState((fourBytes & 0x000000FF) == 0, "Invalid Nal units");
// Last three zeroes can be a prefix of the NAL start code 0x000001.
currentIndex = currentIndex + 1;
}
// Handle remaining bytes if any (less than 4).
// Last 3 bytes could be 0x000001.
if (currentIndex <= input.limit() - THREE_BYTE_NAL_START_CODE_SIZE) {
short firstTwoBytes = input.getShort(currentIndex);
checkState(firstTwoBytes == 0, "Invalid NAL units");
byte lastByte = input.get(currentIndex + 2);
if (lastByte == 1) {
return currentIndex;
}
checkState(lastByte == 0, "Invalid NAL units");
} else {
// Remaining bytes must be 0.
while (currentIndex < input.limit()) {
checkState(input.get(currentIndex) == 0, "Invalid NAL units");
currentIndex++;
}
}
return input.limit();
}
private static ByteBuffer getBytes(ByteBuffer buf, int offset, int length) {