diff --git a/libraries/transformer/src/main/java/androidx/media3/transformer/ChannelMixingMatrix.java b/libraries/transformer/src/main/java/androidx/media3/transformer/ChannelMixingMatrix.java index 9cc87ada11..0b2d91c764 100644 --- a/libraries/transformer/src/main/java/androidx/media3/transformer/ChannelMixingMatrix.java +++ b/libraries/transformer/src/main/java/androidx/media3/transformer/ChannelMixingMatrix.java @@ -49,6 +49,7 @@ import static androidx.media3.common.util.Assertions.checkArgument; private final float[] coefficients; private final boolean isZero; private final boolean isDiagonal; + private final boolean isIdentity; /** * Creates a standard channel mixing matrix that converts from {@code inputChannelCount} channels @@ -71,31 +72,6 @@ import static androidx.media3.common.util.Assertions.checkArgument; createMixingCoefficients(inputChannelCount, outputChannelCount)); } - private static float[] createMixingCoefficients(int inputChannelCount, int outputChannelCount) { - if (inputChannelCount == outputChannelCount) { - int channelCount = inputChannelCount; - float[] coefficients = new float[channelCount * channelCount]; - for (int c = 0; c < channelCount; c++) { - coefficients[channelCount * c + c] = 1f; - } - return coefficients; - } - if (inputChannelCount == 1 && outputChannelCount == 2) { - // Mono -> stereo. - return new float[] {1f, 1f}; - } - if (inputChannelCount == 2 && outputChannelCount == 1) { - // Stereo -> mono. - return new float[] {0.5f, 0.5f}; - } - throw new UnsupportedOperationException( - "Default channel mixing coefficients for " - + inputChannelCount - + "->" - + outputChannelCount - + " are not yet implemented."); - } - /** * Creates a matrix with the given coefficients in row-major order. * @@ -114,20 +90,28 @@ import static androidx.media3.common.util.Assertions.checkArgument; this.coefficients = checkCoefficientsValid(coefficients); // Calculate matrix properties. - boolean hasNonZero = false; - boolean hasNonZeroOutOfDiagonal = false; - for (int i = 0; i < inputChannelCount; i++) { - for (int o = 0; o < outputChannelCount; o++) { - if (getMixingCoefficient(i, o) != 0f) { - hasNonZero = true; - if (i != o) { - hasNonZeroOutOfDiagonal = true; + boolean allDiagonalCoefficientsAreOne = true; + boolean allCoefficientsAreZero = true; + boolean allNonDiagonalCoefficientsAreZero = true; + for (int row = 0; row < inputChannelCount; row++) { + for (int col = 0; col < outputChannelCount; col++) { + float coefficient = getMixingCoefficient(row, col); + boolean onDiagonal = row == col; + + if (coefficient != 1f && onDiagonal) { + allDiagonalCoefficientsAreOne = false; + } + if (coefficient != 0f) { + allCoefficientsAreZero = false; + if (!onDiagonal) { + allNonDiagonalCoefficientsAreZero = false; } } } } - isZero = !hasNonZero; - isDiagonal = isSquare() && !hasNonZeroOutOfDiagonal; + isZero = allCoefficientsAreZero; + isDiagonal = isSquare() && allNonDiagonalCoefficientsAreZero; + isIdentity = isDiagonal && allDiagonalCoefficientsAreOne; } public int getInputChannelCount() { @@ -158,6 +142,11 @@ import static androidx.media3.common.util.Assertions.checkArgument; return isDiagonal; } + /** Returns whether this is an identity matrix. */ + public boolean isIdentity() { + return isIdentity; + } + /** Returns a new matrix with the given scaling factor applied to all coefficients. */ public ChannelMixingMatrix scaleBy(float scale) { float[] scaledCoefficients = new float[coefficients.length]; @@ -167,6 +156,34 @@ import static androidx.media3.common.util.Assertions.checkArgument; return new ChannelMixingMatrix(inputChannelCount, outputChannelCount, scaledCoefficients); } + private static float[] createMixingCoefficients(int inputChannelCount, int outputChannelCount) { + if (inputChannelCount == outputChannelCount) { + return initializeIdentityMatrix(outputChannelCount); + } + if (inputChannelCount == 1 && outputChannelCount == 2) { + // Mono -> stereo. + return new float[] {1f, 1f}; + } + if (inputChannelCount == 2 && outputChannelCount == 1) { + // Stereo -> mono. + return new float[] {0.5f, 0.5f}; + } + throw new UnsupportedOperationException( + "Default channel mixing coefficients for " + + inputChannelCount + + "->" + + outputChannelCount + + " are not yet implemented."); + } + + private static float[] initializeIdentityMatrix(int channelCount) { + float[] coefficients = new float[channelCount * channelCount]; + for (int c = 0; c < channelCount; c++) { + coefficients[channelCount * c + c] = 1f; + } + return coefficients; + } + private static float[] checkCoefficientsValid(float[] coefficients) { for (int i = 0; i < coefficients.length; i++) { if (coefficients[i] < 0f) { diff --git a/libraries/transformer/src/test/java/androidx/media3/transformer/ChannelMixingMatrixTest.java b/libraries/transformer/src/test/java/androidx/media3/transformer/ChannelMixingMatrixTest.java index 162e0a4f01..bade484d6f 100644 --- a/libraries/transformer/src/test/java/androidx/media3/transformer/ChannelMixingMatrixTest.java +++ b/libraries/transformer/src/test/java/androidx/media3/transformer/ChannelMixingMatrixTest.java @@ -25,6 +25,51 @@ import org.junit.runner.RunWith; @RunWith(AndroidJUnit4.class) public class ChannelMixingMatrixTest { + @Test + public void onesOnDiagonal_1To1_hasCorrectProperties() { + int inputCount = 1; + int outputCount = 1; + float[] coefficients = new float[] {1f}; + ChannelMixingMatrix matrix = new ChannelMixingMatrix(inputCount, outputCount, coefficients); + assertThat(matrix.isZero()).isFalse(); + assertThat(matrix.isSquare()).isTrue(); + assertThat(matrix.isDiagonal()).isTrue(); + assertThat(matrix.isIdentity()).isTrue(); + } + + @Test + public void onesOnDiagonal_2To3_hasCorrectProperties() { + int inputCount = 2; + int outputCount = 3; + float[] coefficients = + new float[] { + 1f, 0f, 0f, + 0f, 1f, 0f, + }; + ChannelMixingMatrix matrix = new ChannelMixingMatrix(inputCount, outputCount, coefficients); + assertThat(matrix.isZero()).isFalse(); + assertThat(matrix.isSquare()).isFalse(); + assertThat(matrix.isDiagonal()).isFalse(); + assertThat(matrix.isIdentity()).isFalse(); + } + + @Test + public void onesOnDiagonal_3To3_hasCorrectProperties() { + int inputCount = 3; + int outputCount = 3; + float[] coefficients = + new float[] { + 1f, 0f, 0f, + 0f, 1f, 0f, + 0f, 0f, 1f + }; + ChannelMixingMatrix matrix = new ChannelMixingMatrix(inputCount, outputCount, coefficients); + assertThat(matrix.isZero()).isFalse(); + assertThat(matrix.isSquare()).isTrue(); + assertThat(matrix.isDiagonal()).isTrue(); + assertThat(matrix.isIdentity()).isTrue(); + } + @Test public void allZeroValues_3To2_hasCorrectProperties() { int inputCount = 3; @@ -40,6 +85,7 @@ public class ChannelMixingMatrixTest { assertThat(matrix.isZero()).isTrue(); assertThat(matrix.isSquare()).isFalse(); assertThat(matrix.isDiagonal()).isFalse(); + assertThat(matrix.isIdentity()).isFalse(); } @Test @@ -57,6 +103,7 @@ public class ChannelMixingMatrixTest { assertThat(matrix.isZero()).isTrue(); assertThat(matrix.isSquare()).isTrue(); assertThat(matrix.isDiagonal()).isTrue(); + assertThat(matrix.isIdentity()).isFalse(); } @Test @@ -74,6 +121,7 @@ public class ChannelMixingMatrixTest { assertThat(matrix.isZero()).isTrue(); assertThat(matrix.isSquare()).isFalse(); assertThat(matrix.isDiagonal()).isFalse(); + assertThat(matrix.isIdentity()).isFalse(); } @Test @@ -91,6 +139,7 @@ public class ChannelMixingMatrixTest { assertThat(matrix.isZero()).isFalse(); assertThat(matrix.isSquare()).isFalse(); assertThat(matrix.isDiagonal()).isFalse(); + assertThat(matrix.isIdentity()).isFalse(); } @Test @@ -107,6 +156,7 @@ public class ChannelMixingMatrixTest { assertThat(matrix.isZero()).isFalse(); assertThat(matrix.isSquare()).isTrue(); assertThat(matrix.isDiagonal()).isFalse(); + assertThat(matrix.isIdentity()).isFalse(); } @Test @@ -125,6 +175,7 @@ public class ChannelMixingMatrixTest { assertThat(matrix.isZero()).isFalse(); assertThat(matrix.isSquare()).isTrue(); assertThat(matrix.isDiagonal()).isTrue(); + assertThat(matrix.isIdentity()).isFalse(); } @Test @@ -141,5 +192,6 @@ public class ChannelMixingMatrixTest { assertThat(matrix.isZero()).isFalse(); assertThat(matrix.isSquare()).isFalse(); assertThat(matrix.isDiagonal()).isFalse(); + assertThat(matrix.isIdentity()).isFalse(); } }