Add MssimCalculatorTest to verify SSIM calculations.

As part of this change, MssimCalculator is moved from androidTest/ to main/

PiperOrigin-RevId: 473771344
(cherry picked from commit 8ce42f0670504a89c5c3e546f8be3d849be36195)
This commit is contained in:
samrobinson 2022-09-12 16:45:36 +00:00 committed by microkatz
parent 46ff38ebe2
commit eeba63ab3e
3 changed files with 287 additions and 180 deletions

View File

@ -21,7 +21,6 @@ import static androidx.media3.common.util.Assertions.checkState;
import static androidx.media3.common.util.Assertions.checkStateNotNull;
import static com.google.common.truth.Truth.assertThat;
import static com.google.common.truth.Truth.assertWithMessage;
import static java.lang.Math.pow;
import android.content.Context;
import android.content.res.AssetFileDescriptor;
@ -327,183 +326,4 @@ public final class SsimHelper {
imageReader.close();
}
}
/**
* Image comparison using the Mean Structural Similarity (MSSIM), developed by Wang, Bovik,
* Sheikh, and Simoncelli.
*
* <p>MSSIM divides the image into windows, calculates SSIM of each, then returns the average.
*
* @see <a href=https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf>The SSIM paper</a>.
*/
private static final class MssimCalculator {
// Referred to as 'L' in the SSIM paper, this constant defines the maximum pixel values. The
// range of pixel values is 0 to 255 (8 bit unsigned range).
private static final int PIXEL_MAX_VALUE = 255;
// K1 and K2, as defined in the SSIM paper.
private static final double K1 = 0.01;
private static final double K2 = 0.03;
// C1 and C2 stabilize the SSIM value when either (referenceMean^2 + distortedMean^2) or
// (referenceVariance + distortedVariance) is close to 0. See the SSIM formula in
// `getWindowSsim` for how these values impact each other in the calculation.
private static final double C1 = pow(PIXEL_MAX_VALUE * K1, 2);
private static final double C2 = pow(PIXEL_MAX_VALUE * K2, 2);
private static final int WINDOW_SIZE = 8;
/**
* Calculates the Mean Structural Similarity (MSSIM) between two images.
*
* @param referenceBuffer The luma channel (Y) buffer of the reference image.
* @param distortedBuffer The luma channel (Y) buffer of the distorted image.
* @param width The image width in pixels.
* @param height The image height in pixels.
* @return The MSSIM score between the input images.
*/
public static double calculate(
byte[] referenceBuffer, byte[] distortedBuffer, int width, int height) {
double totalSsim = 0;
int windowsCount = 0;
for (int currentWindowY = 0; currentWindowY < height; currentWindowY += WINDOW_SIZE) {
int windowHeight = computeWindowSize(currentWindowY, height);
for (int currentWindowX = 0; currentWindowX < width; currentWindowX += WINDOW_SIZE) {
windowsCount++;
int windowWidth = computeWindowSize(currentWindowX, width);
int bufferIndexOffset =
get1dIndex(currentWindowX, currentWindowY, /* stride= */ width, /* offset= */ 0);
double referenceMean =
getMean(
referenceBuffer,
bufferIndexOffset,
/* stride= */ width,
windowWidth,
windowHeight);
double distortedMean =
getMean(
distortedBuffer,
bufferIndexOffset,
/* stride= */ width,
windowWidth,
windowHeight);
double[] variances =
getVariancesAndCovariance(
referenceBuffer,
distortedBuffer,
referenceMean,
distortedMean,
bufferIndexOffset,
/* stride= */ width,
windowWidth,
windowHeight);
double referenceVariance = variances[0];
double distortedVariance = variances[1];
double referenceDistortedCovariance = variances[2];
totalSsim +=
getWindowSsim(
referenceMean,
distortedMean,
referenceVariance,
distortedVariance,
referenceDistortedCovariance);
}
}
if (windowsCount == 0) {
return 1.0d;
}
return totalSsim / windowsCount;
}
/**
* Returns the window size at the provided start coordinate, uses {@link #WINDOW_SIZE} if there
* is enough space, otherwise the number of pixels between {@code start} and {@code dimension}.
*/
private static int computeWindowSize(int start, int dimension) {
if (start + WINDOW_SIZE <= dimension) {
return WINDOW_SIZE;
}
return dimension - start;
}
/** Returns the SSIM of a window. */
private static double getWindowSsim(
double referenceMean,
double distortedMean,
double referenceVariance,
double distortedVariance,
double referenceDistortedCovariance) {
// Uses equation 13 on page 6 from the linked paper.
double numerator =
(((2 * referenceMean * distortedMean) + C1) * ((2 * referenceDistortedCovariance) + C2));
double denominator =
((referenceMean * referenceMean) + (distortedMean * distortedMean) + C1)
* (referenceVariance + distortedVariance + C2);
return numerator / denominator;
}
/** Returns the mean of the pixels in the window. */
private static double getMean(
byte[] pixelBuffer, int bufferIndexOffset, int stride, int windowWidth, int windowHeight) {
double total = 0;
for (int y = 0; y < windowHeight; y++) {
for (int x = 0; x < windowWidth; x++) {
total += pixelBuffer[get1dIndex(x, y, stride, bufferIndexOffset)] & 0xFF;
}
}
return total / (windowWidth * windowHeight);
}
/** Calculates the variances and covariance of the pixels in the window for both buffers. */
private static double[] getVariancesAndCovariance(
byte[] referenceBuffer,
byte[] distortedBuffer,
double referenceMean,
double distortedMean,
int bufferIndexOffset,
int stride,
int windowWidth,
int windowHeight) {
double referenceVariance = 0;
double distortedVariance = 0;
double referenceDistortedCovariance = 0;
for (int y = 0; y < windowHeight; y++) {
for (int x = 0; x < windowWidth; x++) {
int index = get1dIndex(x, y, stride, bufferIndexOffset);
double referencePixelDeviation = (referenceBuffer[index] & 0xFF) - referenceMean;
double distortedPixelDeviation = (distortedBuffer[index] & 0xFF) - distortedMean;
referenceVariance += referencePixelDeviation * referencePixelDeviation;
distortedVariance += distortedPixelDeviation * distortedPixelDeviation;
referenceDistortedCovariance += referencePixelDeviation * distortedPixelDeviation;
}
}
int normalizationFactor = windowWidth * windowHeight - 1;
return new double[] {
referenceVariance / normalizationFactor,
distortedVariance / normalizationFactor,
referenceDistortedCovariance / normalizationFactor
};
}
/**
* Translates a 2D coordinate into an 1D index, based on the stride of the 2D space.
*
* @param x The width component of coordinate.
* @param y The height component of coordinate.
* @param stride The width of the 2D space.
* @param offset An offset to apply.
* @return The 1D index.
*/
private static int get1dIndex(int x, int y, int stride, int offset) {
return x + (y * stride) + offset;
}
}
}

View File

@ -0,0 +1,191 @@
/*
* Copyright 2022 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package androidx.media3.transformer;
import static java.lang.Math.pow;
/**
* Image comparison tool that calculates the Mean Structural Similarity (MSSIM) of two images,
* developed by Wang, Bovik, Sheikh, and Simoncelli.
*
* <p>MSSIM divides the image into windows, calculates SSIM of each, then returns the average.
*
* @see <a href=https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf>The SSIM paper</a>.
*/
/* package */ final class MssimCalculator {
// Referred to as 'L' in the SSIM paper, this constant defines the maximum pixel values. The
// range of pixel values is 0 to 255 (8 bit unsigned range).
private static final int PIXEL_MAX_VALUE = 255;
// K1 and K2, as defined in the SSIM paper.
private static final double K1 = 0.01;
private static final double K2 = 0.03;
// C1 and C2 stabilize the SSIM value when either (referenceMean^2 + distortedMean^2) or
// (referenceVariance + distortedVariance) is close to 0. See the SSIM formula in
// `getWindowSsim` for how these values impact each other in the calculation.
private static final double C1 = pow(PIXEL_MAX_VALUE * K1, 2);
private static final double C2 = pow(PIXEL_MAX_VALUE * K2, 2);
private static final int WINDOW_SIZE = 8;
private MssimCalculator() {}
/**
* Calculates the Mean Structural Similarity (MSSIM) between two images.
*
* @param referenceBuffer The luma channel (Y) buffer of the reference image.
* @param distortedBuffer The luma channel (Y) buffer of the distorted image.
* @param width The image width in pixels.
* @param height The image height in pixels.
* @return The MSSIM score between the input images.
*/
public static double calculate(
byte[] referenceBuffer, byte[] distortedBuffer, int width, int height) {
double totalSsim = 0;
int windowsCount = 0;
for (int currentWindowY = 0; currentWindowY < height; currentWindowY += WINDOW_SIZE) {
int windowHeight = computeWindowSize(currentWindowY, height);
for (int currentWindowX = 0; currentWindowX < width; currentWindowX += WINDOW_SIZE) {
windowsCount++;
int windowWidth = computeWindowSize(currentWindowX, width);
int bufferIndexOffset =
get1dIndex(currentWindowX, currentWindowY, /* stride= */ width, /* offset= */ 0);
double referenceMean =
getMean(
referenceBuffer, bufferIndexOffset, /* stride= */ width, windowWidth, windowHeight);
double distortedMean =
getMean(
distortedBuffer, bufferIndexOffset, /* stride= */ width, windowWidth, windowHeight);
double[] variances =
getVariancesAndCovariance(
referenceBuffer,
distortedBuffer,
referenceMean,
distortedMean,
bufferIndexOffset,
/* stride= */ width,
windowWidth,
windowHeight);
double referenceVariance = variances[0];
double distortedVariance = variances[1];
double referenceDistortedCovariance = variances[2];
totalSsim +=
getWindowSsim(
referenceMean,
distortedMean,
referenceVariance,
distortedVariance,
referenceDistortedCovariance);
}
}
if (windowsCount == 0) {
return 1.0d;
}
return totalSsim / windowsCount;
}
/**
* Returns the window size at the provided start coordinate, uses {@link #WINDOW_SIZE} if there is
* enough space, otherwise the number of pixels between {@code start} and {@code dimension}.
*/
private static int computeWindowSize(int start, int dimension) {
if (start + WINDOW_SIZE <= dimension) {
return WINDOW_SIZE;
}
return dimension - start;
}
/** Returns the SSIM of a window. */
private static double getWindowSsim(
double referenceMean,
double distortedMean,
double referenceVariance,
double distortedVariance,
double referenceDistortedCovariance) {
// Uses equation 13 on page 6 from the linked paper.
double numerator =
(((2 * referenceMean * distortedMean) + C1) * ((2 * referenceDistortedCovariance) + C2));
double denominator =
((referenceMean * referenceMean) + (distortedMean * distortedMean) + C1)
* (referenceVariance + distortedVariance + C2);
return numerator / denominator;
}
/** Returns the mean of the pixels in the window. */
private static double getMean(
byte[] pixelBuffer, int bufferIndexOffset, int stride, int windowWidth, int windowHeight) {
double total = 0;
for (int y = 0; y < windowHeight; y++) {
for (int x = 0; x < windowWidth; x++) {
total += pixelBuffer[get1dIndex(x, y, stride, bufferIndexOffset)] & 0xFF;
}
}
return total / (windowWidth * windowHeight);
}
/** Calculates the variances and covariance of the pixels in the window for both buffers. */
private static double[] getVariancesAndCovariance(
byte[] referenceBuffer,
byte[] distortedBuffer,
double referenceMean,
double distortedMean,
int bufferIndexOffset,
int stride,
int windowWidth,
int windowHeight) {
double referenceVariance = 0;
double distortedVariance = 0;
double referenceDistortedCovariance = 0;
for (int y = 0; y < windowHeight; y++) {
for (int x = 0; x < windowWidth; x++) {
int index = get1dIndex(x, y, stride, bufferIndexOffset);
double referencePixelDeviation = (referenceBuffer[index] & 0xFF) - referenceMean;
double distortedPixelDeviation = (distortedBuffer[index] & 0xFF) - distortedMean;
referenceVariance += referencePixelDeviation * referencePixelDeviation;
distortedVariance += distortedPixelDeviation * distortedPixelDeviation;
referenceDistortedCovariance += referencePixelDeviation * distortedPixelDeviation;
}
}
int normalizationFactor = windowWidth * windowHeight - 1;
return new double[] {
referenceVariance / normalizationFactor,
distortedVariance / normalizationFactor,
referenceDistortedCovariance / normalizationFactor
};
}
/**
* Translates a 2D coordinate into an 1D index, based on the stride of the 2D space.
*
* @param x The width component of coordinate.
* @param y The height component of coordinate.
* @param stride The width of the 2D space.
* @param offset An offset to apply.
* @return The 1D index.
*/
private static int get1dIndex(int x, int y, int stride, int offset) {
return x + (y * stride) + offset;
}
}

View File

@ -0,0 +1,96 @@
/*
* Copyright 2022 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package androidx.media3.transformer;
import static androidx.test.core.app.ApplicationProvider.getApplicationContext;
import static com.google.common.truth.Truth.assertThat;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.graphics.Color;
import androidx.annotation.ColorInt;
import androidx.test.ext.junit.runners.AndroidJUnit4;
import java.io.IOException;
import java.io.InputStream;
import org.junit.Test;
import org.junit.runner.RunWith;
/** Unit test for {@link MssimCalculator}. */
@RunWith(AndroidJUnit4.class)
public class MssimCalculatorTest {
@Test
public void calculateSsim_sameImage() throws Exception {
Bitmap bitmap = readBitmap("media/bitmap/sample_mp4_first_frame/original.png");
byte[] imageLuminosities = bitmapToLuminosityArray(bitmap);
// SSIM equals 1 if the two images match.
assertThat(
MssimCalculator.calculate(
imageLuminosities, imageLuminosities, bitmap.getWidth(), bitmap.getHeight()))
.isEqualTo(1);
}
@Test
public void calculateSsim_increasedBrightness() throws Exception {
Bitmap refBitmap = readBitmap("media/bitmap/sample_mp4_first_frame/original.png");
Bitmap distBitmap = readBitmap("media/bitmap/sample_mp4_first_frame/increase_brightness.png");
// SSIM as calculated by ffmpeg: 0.634326 = 63%
assertThat(
(int)
(MssimCalculator.calculate(
bitmapToLuminosityArray(refBitmap),
bitmapToLuminosityArray(distBitmap),
refBitmap.getWidth(),
refBitmap.getHeight())
* 100))
.isEqualTo(63);
}
private static Bitmap readBitmap(String assetString) throws IOException {
try (InputStream inputStream = getApplicationContext().getAssets().open(assetString)) {
return BitmapFactory.decodeStream(inputStream);
}
}
private static byte[] bitmapToLuminosityArray(Bitmap bitmap) {
int width = bitmap.getWidth();
int height = bitmap.getHeight();
@ColorInt int[] pixels = new int[width * height];
byte[] luminosities = new byte[width * height];
bitmap.getPixels(
pixels, /* offset= */ 0, /* stride= */ width, /* x= */ 0, /* y= */ 0, width, height);
for (int i = 0; i < pixels.length; i++) {
luminosities[i] = (byte) (getLuminosity(pixels[i]) & 0xFF);
}
return luminosities;
}
/**
* Gets the intensity of a given RGB {@link ColorInt pixel} using the luminosity formula
*
* <pre>l = 0.2126R + 0.7152G + 0.0722B
*/
private static int getLuminosity(@ColorInt int pixel) {
double l = 0;
l += (0.2126f * Color.red(pixel));
l += (0.7152f * Color.green(pixel));
l += (0.0722f * Color.blue(pixel));
return (int) l;
}
}