Fix truncation error accumulation on Sonic's time stretching algorithm

This CL also fixes EOS handling to account for not-yet-copied samples in
`remainingInputToCopyFrameCount`, which would throw off the final output
sample count calculation.

For testing, we allow a tolerance of 0.000017% drift between expected
and actual number of output samples. The value was obtained from running
100 iterations of `timeStretching_returnsExpectedNumberOfSamples()` and
calculating the average delta percentage between expected and actual
number of output samples. Roughly, this means a tolerance of 40 samples
on a 90 min mono stream @48KHz.

PiperOrigin-RevId: 683133461
This commit is contained in:
ivanbuper 2024-10-07 04:57:13 -07:00 committed by Copybara-Service
parent f7af58951d
commit af922fbcb0
6 changed files with 241 additions and 14 deletions

View File

@ -63,6 +63,8 @@
* DataSource:
* Audio:
* Fix pop sounds that may occur during seeks.
* Fix truncation error accumulation for Sonic's
time-stretching/pitch-shifting algorithm.
* Video:
* Add workaround for a device issue on Galaxy Tab S7 FE that causes 60fps
secure H264 streams to be marked as unsupported

View File

@ -52,11 +52,23 @@ import java.util.Arrays;
private int pitchFrameCount;
private int oldRatePosition;
private int newRatePosition;
/**
* Number of frames pending to be copied from {@link #inputBuffer} directly to {@link
* #outputBuffer}.
*
* <p>This field is only relevant to time-stretching or pitch-shifting in {@link
* #changeSpeed(double)}, particularly when more frames need to be copied to the {@link
* #outputBuffer} than are available in {@link #inputBuffer} and Sonic must wait until the next
* buffer (or EOS) is queued.
*/
private int remainingInputToCopyFrameCount;
private int prevPeriod;
private int prevMinDiff;
private int minDiff;
private int maxDiff;
private double accumulatedSpeedAdjustmentError;
/**
* Creates a new Sonic audio stream processor.
@ -130,10 +142,26 @@ import java.util.Arrays;
*/
public void queueEndOfStream() {
int remainingFrameCount = inputFrameCount;
float s = speed / pitch;
float r = rate * pitch;
double s = speed / pitch;
double r = rate * pitch;
// If there are frames to be copied directly onto the output buffer, we should not count those
// as "input frames" because Sonic is not applying any processing on them.
int adjustedRemainingFrames = remainingFrameCount - remainingInputToCopyFrameCount;
// We add directly to the output the number of frames in remainingInputToCopyFrameCount.
// Otherwise, expectedOutputFrames will be off and will make Sonic output an incorrect number of
// frames.
int expectedOutputFrames =
outputFrameCount + (int) ((remainingFrameCount / s + pitchFrameCount) / r + 0.5f);
outputFrameCount
+ (int)
((adjustedRemainingFrames / s
+ remainingInputToCopyFrameCount
+ accumulatedSpeedAdjustmentError
+ pitchFrameCount)
/ r
+ 0.5);
accumulatedSpeedAdjustmentError = 0;
// Add enough silence to flush both input and pitch buffers.
inputBuffer =
@ -166,6 +194,7 @@ import java.util.Arrays;
prevMinDiff = 0;
minDiff = 0;
maxDiff = 0;
accumulatedSpeedAdjustmentError = 0;
}
/** Returns the size of output that can be read with {@link #getOutput(ShortBuffer)}, in bytes. */
@ -408,14 +437,19 @@ import java.util.Arrays;
removePitchFrames(pitchFrameCount - 1);
}
private int skipPitchPeriod(short[] samples, int position, float speed, int period) {
private int skipPitchPeriod(short[] samples, int position, double speed, int period) {
// Skip over a pitch period, and copy period/speed samples to the output.
int newFrameCount;
if (speed >= 2.0f) {
newFrameCount = (int) (period / (speed - 1.0f));
double expectedFrameCount = period / (speed - 1.0) + accumulatedSpeedAdjustmentError;
newFrameCount = (int) Math.round(expectedFrameCount);
accumulatedSpeedAdjustmentError = expectedFrameCount - newFrameCount;
} else {
newFrameCount = period;
remainingInputToCopyFrameCount = (int) (period * (2.0f - speed) / (speed - 1.0f));
double expectedInputToCopy =
period * (2.0f - speed) / (speed - 1.0f) + accumulatedSpeedAdjustmentError;
remainingInputToCopyFrameCount = (int) Math.round(expectedInputToCopy);
accumulatedSpeedAdjustmentError = expectedInputToCopy - remainingInputToCopyFrameCount;
}
outputBuffer = ensureSpaceForAdditionalFrames(outputBuffer, outputFrameCount, newFrameCount);
overlapAdd(
@ -431,14 +465,19 @@ import java.util.Arrays;
return newFrameCount;
}
private int insertPitchPeriod(short[] samples, int position, float speed, int period) {
private int insertPitchPeriod(short[] samples, int position, double speed, int period) {
// Insert a pitch period, and determine how much input to copy directly.
int newFrameCount;
if (speed < 0.5f) {
newFrameCount = (int) (period * speed / (1.0f - speed));
double expectedFrameCount = period * speed / (1.0f - speed) + accumulatedSpeedAdjustmentError;
newFrameCount = (int) Math.round(expectedFrameCount);
accumulatedSpeedAdjustmentError = expectedFrameCount - newFrameCount;
} else {
newFrameCount = period;
remainingInputToCopyFrameCount = (int) (period * (2.0f * speed - 1.0f) / (1.0f - speed));
double expectedInputToCopy =
period * (2.0f * speed - 1.0f) / (1.0f - speed) + accumulatedSpeedAdjustmentError;
remainingInputToCopyFrameCount = (int) Math.round(expectedInputToCopy);
accumulatedSpeedAdjustmentError = expectedInputToCopy - remainingInputToCopyFrameCount;
}
outputBuffer =
ensureSpaceForAdditionalFrames(outputBuffer, outputFrameCount, period + newFrameCount);
@ -461,7 +500,7 @@ import java.util.Arrays;
return newFrameCount;
}
private void changeSpeed(float speed) {
private void changeSpeed(double speed) {
if (inputFrameCount < maxRequiredFrameCount) {
return;
}
@ -485,7 +524,7 @@ import java.util.Arrays;
private void processStreamInput() {
// Resample as many pitch periods as we have buffered on the input.
int originalOutputFrameCount = outputFrameCount;
float s = speed / pitch;
double s = speed / pitch;
float r = rate * pitch;
if (s > 1.00001 || s < 0.99999) {
changeSpeed(s);

View File

@ -16,6 +16,7 @@
package androidx.media3.common.audio;
import static com.google.common.truth.Truth.assertThat;
import static java.lang.Math.max;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
@ -45,9 +46,30 @@ public final class RandomParameterizedSonicTest {
private static final int PARAM_COUNT = 5;
private static final int SPEED_DECIMAL_PRECISION = 2;
/**
* Allowed error tolerance ratio for number of output samples for Sonic's time stretching
* algorithm.
*
* <p>The actual tolerance is calculated as {@code expectedOutputSampleCount /
* TIME_STRETCHING_SAMPLE_DRIFT_TOLERANCE}, rounded to the nearest integer value. However, we
* always allow a minimum tolerance of ±1 samples.
*
* <p>This tolerance is roughly equal to an error of 900us/~44 samples/0.000017% for a 90 min mono
* stream @48KHz. To obtain the value, we ran 100 iterations of {@link
* #timeStretching_returnsExpectedNumberOfSamples()} (by setting {@link #PARAM_COUNT} to 10) and
* we calculated the average delta percentage between expected number of samples and actual number
* of samples (b/366169590).
*/
private static final BigDecimal TIME_STRETCHING_SAMPLE_DRIFT_TOLERANCE =
new BigDecimal("0.00000017");
private static final ImmutableList<Range<Float>> SPEED_RANGES =
ImmutableList.of(
Range.closedOpen(0f, 1f), Range.closedOpen(1f, 2f), Range.closedOpen(2f, 20f));
Range.closedOpen(0f, 0.5f),
Range.closedOpen(0.5f, 1f),
Range.closedOpen(1f, 2f),
Range.closedOpen(2f, 20f));
private static final Random random = new Random(/* seed */ 0);
@ -165,6 +187,55 @@ public final class RandomParameterizedSonicTest {
.of(expectedSize.longValueExact() - accumulatedError.longValueExact());
}
@Test
public void timeStretching_returnsExpectedNumberOfSamples() {
byte[] buf = new byte[BLOCK_SIZE * BYTES_PER_SAMPLE];
ShortBuffer outBuffer = ShortBuffer.allocate(BLOCK_SIZE);
Sonic sonic =
new Sonic(
/* inputSampleRateHz= */ SAMPLE_RATE,
/* channelCount= */ 1,
speed,
/* pitch= */ 1,
/* outputSampleRateHz= */ SAMPLE_RATE);
long readSampleCount = 0;
for (long samplesLeft = streamLength; samplesLeft > 0; samplesLeft -= BLOCK_SIZE) {
random.nextBytes(buf);
if (samplesLeft >= BLOCK_SIZE) {
sonic.queueInput(ByteBuffer.wrap(buf).asShortBuffer());
} else {
sonic.queueInput(
ByteBuffer.wrap(buf, 0, (int) (samplesLeft * BYTES_PER_SAMPLE)).asShortBuffer());
sonic.queueEndOfStream();
}
while (sonic.getOutputSize() > 0) {
sonic.getOutput(outBuffer);
readSampleCount += outBuffer.position();
outBuffer.clear();
}
}
sonic.flush();
BigDecimal bigSpeed = new BigDecimal(String.valueOf(speed));
BigDecimal bigLength = new BigDecimal(String.valueOf(streamLength));
// The scale of expectedSampleCount will always be equal to bigLength. Thus, the result will
// always
// yield an integer.
BigDecimal expectedSampleCount = bigLength.divide(bigSpeed, RoundingMode.HALF_EVEN);
// Calculate allowed tolerance and round to nearest integer.
BigDecimal allowedTolerance =
TIME_STRETCHING_SAMPLE_DRIFT_TOLERANCE
.multiply(expectedSampleCount)
.setScale(/* newScale= */ 0, RoundingMode.HALF_EVEN);
// Always allow at least 1 sample of tolerance.
long tolerance = max(allowedTolerance.longValue(), 1);
assertThat(readSampleCount).isWithin(tolerance).of(expectedSampleCount.longValueExact());
}
private static float round(float num) {
BigDecimal bigDecimal = new BigDecimal(Float.toString(num));
return bigDecimal.setScale(SPEED_DECIMAL_PRECISION, RoundingMode.HALF_EVEN).floatValue();

View File

@ -0,0 +1,62 @@
format audio:
averageBitrate = 131072
sampleMimeType = audio/mp4a-latm
channelCount = 1
sampleRate = 44100
pcmEncoding = 2
sample:
trackType = audio
dataHashCode = -858457440
size = 4096
isKeyFrame = true
sample:
trackType = audio
dataHashCode = -317223982
size = 4096
isKeyFrame = true
sample:
trackType = audio
dataHashCode = -510794633
size = 4096
isKeyFrame = true
sample:
trackType = audio
dataHashCode = -392394518
size = 4096
isKeyFrame = true
sample:
trackType = audio
dataHashCode = -1161865299
size = 4096
isKeyFrame = true
sample:
trackType = audio
dataHashCode = 251977808
size = 4096
isKeyFrame = true
sample:
trackType = audio
dataHashCode = -2046238978
size = 4096
isKeyFrame = true
sample:
trackType = audio
dataHashCode = -1083051456
size = 4096
isKeyFrame = true
sample:
trackType = audio
dataHashCode = 1068783564
size = 4096
isKeyFrame = true
sample:
trackType = audio
dataHashCode = -825415045
size = 4096
isKeyFrame = true
sample:
trackType = audio
dataHashCode = -1525522823
size = 3140
isKeyFrame = true
released = true

View File

@ -2122,8 +2122,9 @@ public class TransformerEndToEndTest {
sonic.setPitch(resamplingRate);
Effects effects =
new Effects(
ImmutableList.of(sonic, createByteCountingAudioProcessor(readBytes)),
ImmutableList.of());
/* audioProcessors= */ ImmutableList.of(
sonic, createByteCountingAudioProcessor(readBytes)),
/* videoEffects= */ ImmutableList.of());
EditedMediaItem editedMediaItem =
new EditedMediaItem.Builder(MediaItem.fromUri(WAV_ASSET.uri)).setEffects(effects).build();
@ -2137,6 +2138,28 @@ public class TransformerEndToEndTest {
assertThat(readBytes.get() / 2).isWithin(1).of(29400);
}
@Test
public void adjustAudioSpeed_to2pt5Speed_hasExpectedOutputSampleCount() throws Exception {
AtomicInteger readBytes = new AtomicInteger();
Transformer transformer = new Transformer.Builder(context).build();
SonicAudioProcessor sonic = new SonicAudioProcessor();
sonic.setSpeed(2.5f);
Effects effects =
new Effects(
/* audioProcessors= */ ImmutableList.of(
sonic, createByteCountingAudioProcessor(readBytes)),
/* videoEffects= */ ImmutableList.of());
EditedMediaItem editedMediaItem =
new EditedMediaItem.Builder(MediaItem.fromUri(WAV_ASSET.uri)).setEffects(effects).build();
new TransformerAndroidTestRunner.Builder(context, transformer)
.build()
.run(testId, editedMediaItem);
// The test file contains 44100 samples (1 sec @44.1KHz, mono). We expect to receive 44100 / 2.5
// samples.
assertThat(readBytes.get() / 2).isEqualTo(17640);
}
@Test
public void speedAdjustedMedia_shorterAudioTrack_completesWithCorrectDuration() throws Exception {
assumeFormatsSupported(

View File

@ -596,6 +596,36 @@ public final class MediaItemExportTest {
getDumpFileName(/* originalFileName= */ FILE_AUDIO_RAW, /* modifications...= */ "48000hz"));
}
@Test
public void adjustAudioSpeed_toDoubleSpeed_returnsExpectedNumberOfSamples() throws Exception {
CapturingMuxer.Factory muxerFactory = new CapturingMuxer.Factory(/* handleAudioAsPcm= */ true);
SonicAudioProcessor sonicAudioProcessor = new SonicAudioProcessor();
sonicAudioProcessor.setSpeed(2f);
Transformer transformer =
createTransformerBuilder(muxerFactory, /* enableFallback= */ false).build();
MediaItem mediaItem = MediaItem.fromUri(ASSET_URI_PREFIX + FILE_AUDIO_RAW);
AtomicInteger bytesRead = new AtomicInteger();
EditedMediaItem editedMediaItem =
new EditedMediaItem.Builder(mediaItem)
.setEffects(
createAudioEffects(
sonicAudioProcessor, createByteCountingAudioProcessor(bytesRead)))
.build();
transformer.start(editedMediaItem, outputDir.newFile().getPath());
TransformerTestRunner.runLooper(transformer);
// Time stretching 1 second @ 44100Hz into 22050 samples.
assertThat(bytesRead.get() / 2).isEqualTo(22050);
DumpFileAsserts.assertOutput(
context,
muxerFactory.getCreatedMuxer(),
getDumpFileName(
/* originalFileName= */ FILE_AUDIO_RAW, /* modifications...= */ "doubleSpeed"));
}
@Test
public void start_withRawBigEndianAudioInput_completesSuccessfully() throws Exception {
CapturingMuxer.Factory muxerFactory = new CapturingMuxer.Factory(/* handleAudioAsPcm= */ true);