Clarify variables and improve documentation of SSIM.

PiperOrigin-RevId: 451392021
This commit is contained in:
samrobinson 2022-05-27 14:32:22 +00:00 committed by Marc Baechinger
parent 224761833f
commit 6c4f6ecf46
2 changed files with 117 additions and 69 deletions

View File

@ -63,66 +63,65 @@ public final class SsimHelper {
private static final int DECODED_IMAGE_CHANNEL_COUNT = 3; private static final int DECODED_IMAGE_CHANNEL_COUNT = 3;
/** /**
* Returns the mean SSIM score between the reference and the actual video. * Returns the mean SSIM score between the reference and the distorted video.
* *
* <p>The method compares every {@link #DEFAULT_COMPARISON_INTERVAL n-th} frame from both videos. * <p>The method compares every {@link #DEFAULT_COMPARISON_INTERVAL n-th} frame from both videos.
* *
* @param context The {@link Context}. * @param context The {@link Context}.
* @param referenceVideoPath The path to the reference video file, which must be in {@linkplain * @param referenceVideoPath The path to the reference video file, which must be in {@linkplain
* Context#getAssets() Assets}. * Context#getAssets() Assets}.
* @param actualVideoPath The path to the actual video file. * @param distortedVideoPath The path to the distorted video file.
* @throws IOException When unable to open the provided video paths. * @throws IOException When unable to open the provided video paths.
*/ */
public static double calculate(Context context, String referenceVideoPath, String actualVideoPath) public static double calculate(
Context context, String referenceVideoPath, String distortedVideoPath)
throws IOException, InterruptedException { throws IOException, InterruptedException {
VideoDecodingWrapper referenceDecodingWrapper = VideoDecodingWrapper referenceDecodingWrapper =
new VideoDecodingWrapper(context, referenceVideoPath, DEFAULT_COMPARISON_INTERVAL); new VideoDecodingWrapper(context, referenceVideoPath, DEFAULT_COMPARISON_INTERVAL);
VideoDecodingWrapper actualDecodingWrapper = VideoDecodingWrapper distortedDecodingWrapper =
new VideoDecodingWrapper(context, actualVideoPath, DEFAULT_COMPARISON_INTERVAL); new VideoDecodingWrapper(context, distortedVideoPath, DEFAULT_COMPARISON_INTERVAL);
@Nullable byte[] referenceLumaBuffer = null; @Nullable byte[] referenceLumaBuffer = null;
@Nullable byte[] actualLumaBuffer = null; @Nullable byte[] distortedLumaBuffer = null;
double accumulatedSsim = 0.0; double accumulatedSsim = 0.0;
int comparedImagesCount = 0; int comparedImagesCount = 0;
try { try {
while (true) { while (true) {
@Nullable Image referenceImage = referenceDecodingWrapper.runUntilComparisonFrameOrEnded(); @Nullable Image referenceImage = referenceDecodingWrapper.runUntilComparisonFrameOrEnded();
@Nullable Image actualImage = actualDecodingWrapper.runUntilComparisonFrameOrEnded(); @Nullable Image distortedImage = distortedDecodingWrapper.runUntilComparisonFrameOrEnded();
if (referenceImage == null) { if (referenceImage == null) {
assertThat(actualImage).isNull(); assertThat(distortedImage).isNull();
break; break;
} }
checkNotNull(actualImage); checkNotNull(distortedImage);
int width = referenceImage.getWidth(); int width = referenceImage.getWidth();
int height = referenceImage.getHeight(); int height = referenceImage.getHeight();
assertThat(actualImage.getWidth()).isEqualTo(width); assertThat(distortedImage.getWidth()).isEqualTo(width);
assertThat(actualImage.getHeight()).isEqualTo(height); assertThat(distortedImage.getHeight()).isEqualTo(height);
if (referenceLumaBuffer == null || referenceLumaBuffer.length != width * height) { if (referenceLumaBuffer == null || referenceLumaBuffer.length != width * height) {
referenceLumaBuffer = new byte[width * height]; referenceLumaBuffer = new byte[width * height];
} }
if (actualLumaBuffer == null || actualLumaBuffer.length != width * height) { if (distortedLumaBuffer == null || distortedLumaBuffer.length != width * height) {
actualLumaBuffer = new byte[width * height]; distortedLumaBuffer = new byte[width * height];
} }
try { try {
accumulatedSsim += accumulatedSsim +=
SsimCalculator.calculate( MssimCalculator.calculate(
extractLumaChannelBuffer(referenceImage, referenceLumaBuffer), extractLumaChannelBuffer(referenceImage, referenceLumaBuffer),
extractLumaChannelBuffer(actualImage, actualLumaBuffer), extractLumaChannelBuffer(distortedImage, distortedLumaBuffer),
/* offset= */ 0,
/* stride= */ width,
width, width,
height); height);
} finally { } finally {
referenceImage.close(); referenceImage.close();
actualImage.close(); distortedImage.close();
} }
comparedImagesCount++; comparedImagesCount++;
} }
} finally { } finally {
referenceDecodingWrapper.close(); referenceDecodingWrapper.close();
actualDecodingWrapper.close(); distortedDecodingWrapper.close();
} }
assertWithMessage("Input had no frames.").that(comparedImagesCount).isGreaterThan(0); assertWithMessage("Input had no frames.").that(comparedImagesCount).isGreaterThan(0);
return accumulatedSsim / comparedImagesCount; return accumulatedSsim / comparedImagesCount;
@ -330,55 +329,87 @@ public final class SsimHelper {
} }
/** /**
* Image comparison using the Structural Similarity Index, developed by Wang, Bovik, Sheikh, and * Image comparison using the Mean Structural Similarity (MSSIM), developed by Wang, Bovik,
* Simoncelli. * Sheikh, and Simoncelli.
*
* <p>MSSIM divides the image into windows, calculates SSIM of each, then returns the average.
* *
* @see <a href=https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf>The SSIM paper</a>. * @see <a href=https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf>The SSIM paper</a>.
*/ */
private static final class SsimCalculator { private static final class MssimCalculator {
// These values were taken from the SSIM paper. Please see the linked paper for details. // Referred to as 'L' in the SSIM paper, this constant defines the maximum pixel values. The
private static final double IMAGE_DYNAMIC_RANGE = 255; // range of pixel values is 0 to 255 (8 bit unsigned range).
private static final int PIXEL_MAX_VALUE = 255;
// K1 and K2, as defined in the SSIM paper.
private static final double K1 = 0.01; private static final double K1 = 0.01;
private static final double K2 = 0.03; private static final double K2 = 0.03;
private static final double C1 = pow(IMAGE_DYNAMIC_RANGE * K1, 2);
private static final double C2 = pow(IMAGE_DYNAMIC_RANGE * K2, 2); // C1 and C2 stabilize the SSIM value when either (referenceMean^2 + distortedMean^2) or
// (referenceVariance + distortedVariance) is close to 0. See the SSIM formula in
// `getWindowSsim` for how these values impact each other in the calculation.
private static final double C1 = pow(PIXEL_MAX_VALUE * K1, 2);
private static final double C2 = pow(PIXEL_MAX_VALUE * K2, 2);
private static final int WINDOW_SIZE = 8; private static final int WINDOW_SIZE = 8;
/** /**
* Calculates the Structural Similarity Index (SSIM) between two images. * Calculates the Mean Structural Similarity (MSSIM) between two images.
* *
* @param reference The luma channel (Y) bitmap of the reference image. * @param referenceBuffer The luma channel (Y) buffer of the reference image.
* @param actual The luma channel (Y) bitmap of the actual image. * @param distortedBuffer The luma channel (Y) buffer of the distorted image.
* @param offset The offset.
* @param stride The stride of the bitmap.
* @param width The image width in pixels. * @param width The image width in pixels.
* @param height The image height in pixels. * @param height The image height in pixels.
* @return The SSIM score between the input images. * @return The MSSIM score between the input images.
*/ */
public static double calculate( public static double calculate(
byte[] reference, byte[] actual, int offset, int stride, int width, int height) { byte[] referenceBuffer, byte[] distortedBuffer, int width, int height) {
double totalSsim = 0; double totalSsim = 0;
int windowsCount = 0; int windowsCount = 0;
// X refers to the reference image, while Y refers to the actual image.
for (int currentWindowY = 0; currentWindowY < height; currentWindowY += WINDOW_SIZE) { for (int currentWindowY = 0; currentWindowY < height; currentWindowY += WINDOW_SIZE) {
int windowHeight = computeWindowSize(currentWindowY, height); int windowHeight = computeWindowSize(currentWindowY, height);
for (int currentWindowX = 0; currentWindowX < width; currentWindowX += WINDOW_SIZE) { for (int currentWindowX = 0; currentWindowX < width; currentWindowX += WINDOW_SIZE) {
windowsCount++; windowsCount++;
int windowWidth = computeWindowSize(currentWindowX, width); int windowWidth = computeWindowSize(currentWindowX, width);
int start = getGlobalCoordinate(currentWindowX, currentWindowY, stride, offset); int bufferIndexOffset =
double meanX = getMean(reference, start, stride, windowWidth, windowHeight); get1dIndex(currentWindowX, currentWindowY, /* stride= */ width, /* offset= */ 0);
double meanY = getMean(actual, start, stride, windowWidth, windowHeight); double referenceMean =
getMean(
referenceBuffer,
bufferIndexOffset,
/* stride= */ width,
windowWidth,
windowHeight);
double distortedMean =
getMean(
distortedBuffer,
bufferIndexOffset,
/* stride= */ width,
windowWidth,
windowHeight);
double[] variances = double[] variances =
getVariancesAndCovariance( getVariancesAndCovariance(
reference, actual, meanX, meanY, start, stride, windowWidth, windowHeight); referenceBuffer,
// varX is the variance of window X, covXY is the covariance between window X and Y. distortedBuffer,
double varX = variances[0]; referenceMean,
double varY = variances[1]; distortedMean,
double covXY = variances[2]; bufferIndexOffset,
/* stride= */ width,
windowWidth,
windowHeight);
double referenceVariance = variances[0];
double distortedVariance = variances[1];
double referenceDistortedCovariance = variances[2];
totalSsim += getWindowSsim(meanX, meanY, varX, varY, covXY); totalSsim +=
getWindowSsim(
referenceMean,
distortedMean,
referenceVariance,
distortedVariance,
referenceDistortedCovariance);
} }
} }
@ -402,59 +433,76 @@ public final class SsimHelper {
/** Returns the SSIM of a window. */ /** Returns the SSIM of a window. */
private static double getWindowSsim( private static double getWindowSsim(
double meanX, double meanY, double varX, double varY, double covXY) { double referenceMean,
double distortedMean,
double referenceVariance,
double distortedVariance,
double referenceDistortedCovariance) {
// Uses equation 13 on page 6 from the linked paper. // Uses equation 13 on page 6 from the linked paper.
double numerator = (((2 * meanX * meanY) + C1) * ((2 * covXY) + C2)); double numerator =
double denominator = ((meanX * meanX) + (meanY * meanY) + C1) * (varX + varY + C2); (((2 * referenceMean * distortedMean) + C1) * ((2 * referenceDistortedCovariance) + C2));
double denominator =
((referenceMean * referenceMean) + (distortedMean * distortedMean) + C1)
* (referenceVariance + distortedVariance + C2);
return numerator / denominator; return numerator / denominator;
} }
/** Returns the means of the pixels in the two provided windows, in order. */ /** Returns the mean of the pixels in the window. */
private static double getMean( private static double getMean(
byte[] pixels, int start, int stride, int windowWidth, int windowHeight) { byte[] pixelBuffer, int bufferIndexOffset, int stride, int windowWidth, int windowHeight) {
double total = 0; double total = 0;
for (int y = 0; y < windowHeight; y++) { for (int y = 0; y < windowHeight; y++) {
for (int x = 0; x < windowWidth; x++) { for (int x = 0; x < windowWidth; x++) {
total += pixels[getGlobalCoordinate(x, y, stride, start)]; total += pixelBuffer[get1dIndex(x, y, stride, bufferIndexOffset)];
} }
} }
return total / windowWidth * windowHeight; return total / windowWidth * windowHeight;
} }
/** Returns the two variances and the covariance of the two windows. */ /** Calculates the variances and covariance of the pixels in the window for both buffers. */
private static double[] getVariancesAndCovariance( private static double[] getVariancesAndCovariance(
byte[] pixelsX, byte[] referenceBuffer,
byte[] pixelsY, byte[] distortedBuffer,
double meanX, double referenceMean,
double meanY, double distortedMean,
int start, int bufferIndexOffset,
int stride, int stride,
int windowWidth, int windowWidth,
int windowHeight) { int windowHeight) {
// The variances in X and Y. double referenceVariance = 0;
double varX = 0; double distortedVariance = 0;
double varY = 0; double referenceDistortedCovariance = 0;
// The covariance between X and Y.
double covXY = 0;
for (int y = 0; y < windowHeight; y++) { for (int y = 0; y < windowHeight; y++) {
for (int x = 0; x < windowWidth; x++) { for (int x = 0; x < windowWidth; x++) {
int index = getGlobalCoordinate(x, y, stride, start); int index = get1dIndex(x, y, stride, bufferIndexOffset);
double offsetX = pixelsX[index] - meanX; double referencePixelDeviation = referenceBuffer[index] - referenceMean;
double offsetY = pixelsY[index] - meanY; double distortedPixelDeviation = distortedBuffer[index] - distortedMean;
varX += pow(offsetX, 2); referenceVariance += referencePixelDeviation * referencePixelDeviation;
varY += pow(offsetY, 2); distortedVariance += distortedPixelDeviation * distortedPixelDeviation;
covXY += offsetX * offsetY; referenceDistortedCovariance += referencePixelDeviation * distortedPixelDeviation;
} }
} }
int normalizationFactor = windowWidth * windowHeight - 1; int normalizationFactor = windowWidth * windowHeight - 1;
return new double[] { return new double[] {
varX / normalizationFactor, varY / normalizationFactor, covXY / normalizationFactor referenceVariance / normalizationFactor,
distortedVariance / normalizationFactor,
referenceDistortedCovariance / normalizationFactor
}; };
} }
private static int getGlobalCoordinate(int x, int y, int stride, int offset) { /**
* Translates a 2D coordinate into an 1D index, based on the stride of the 2D space.
*
* @param x The width component of coordinate.
* @param y The height component of coordinate.
* @param stride The width of the 2D space.
* @param offset An offset to apply.
* @return The 1D index.
*/
private static int get1dIndex(int x, int y, int stride, int offset) {
return x + (y * stride) + offset; return x + (y * stride) + offset;
} }
} }

View File

@ -333,7 +333,7 @@ public class TransformerAndroidTestRunner {
SsimHelper.calculate( SsimHelper.calculate(
context, context,
/* referenceVideoPath= */ checkNotNull(mediaItem.localConfiguration).uri.toString(), /* referenceVideoPath= */ checkNotNull(mediaItem.localConfiguration).uri.toString(),
outputVideoFile.getPath()); /* distortedVideoPath= */ outputVideoFile.getPath());
resultBuilder.setSsim(ssim); resultBuilder.setSsim(ssim);
} catch (InterruptedException interruptedException) { } catch (InterruptedException interruptedException) {
// InterruptedException is a special unexpected case because it is not related to Ssim // InterruptedException is a special unexpected case because it is not related to Ssim