Clarify variables and improve documentation of SSIM.

PiperOrigin-RevId: 451392021
This commit is contained in:
samrobinson 2022-05-27 14:32:22 +00:00 committed by Marc Baechinger
parent 224761833f
commit 6c4f6ecf46
2 changed files with 117 additions and 69 deletions

View File

@ -63,66 +63,65 @@ public final class SsimHelper {
private static final int DECODED_IMAGE_CHANNEL_COUNT = 3;
/**
* Returns the mean SSIM score between the reference and the actual video.
* Returns the mean SSIM score between the reference and the distorted video.
*
* <p>The method compares every {@link #DEFAULT_COMPARISON_INTERVAL n-th} frame from both videos.
*
* @param context The {@link Context}.
* @param referenceVideoPath The path to the reference video file, which must be in {@linkplain
* Context#getAssets() Assets}.
* @param actualVideoPath The path to the actual video file.
* @param distortedVideoPath The path to the distorted video file.
* @throws IOException When unable to open the provided video paths.
*/
public static double calculate(Context context, String referenceVideoPath, String actualVideoPath)
public static double calculate(
Context context, String referenceVideoPath, String distortedVideoPath)
throws IOException, InterruptedException {
VideoDecodingWrapper referenceDecodingWrapper =
new VideoDecodingWrapper(context, referenceVideoPath, DEFAULT_COMPARISON_INTERVAL);
VideoDecodingWrapper actualDecodingWrapper =
new VideoDecodingWrapper(context, actualVideoPath, DEFAULT_COMPARISON_INTERVAL);
VideoDecodingWrapper distortedDecodingWrapper =
new VideoDecodingWrapper(context, distortedVideoPath, DEFAULT_COMPARISON_INTERVAL);
@Nullable byte[] referenceLumaBuffer = null;
@Nullable byte[] actualLumaBuffer = null;
@Nullable byte[] distortedLumaBuffer = null;
double accumulatedSsim = 0.0;
int comparedImagesCount = 0;
try {
while (true) {
@Nullable Image referenceImage = referenceDecodingWrapper.runUntilComparisonFrameOrEnded();
@Nullable Image actualImage = actualDecodingWrapper.runUntilComparisonFrameOrEnded();
@Nullable Image distortedImage = distortedDecodingWrapper.runUntilComparisonFrameOrEnded();
if (referenceImage == null) {
assertThat(actualImage).isNull();
assertThat(distortedImage).isNull();
break;
}
checkNotNull(actualImage);
checkNotNull(distortedImage);
int width = referenceImage.getWidth();
int height = referenceImage.getHeight();
assertThat(actualImage.getWidth()).isEqualTo(width);
assertThat(actualImage.getHeight()).isEqualTo(height);
assertThat(distortedImage.getWidth()).isEqualTo(width);
assertThat(distortedImage.getHeight()).isEqualTo(height);
if (referenceLumaBuffer == null || referenceLumaBuffer.length != width * height) {
referenceLumaBuffer = new byte[width * height];
}
if (actualLumaBuffer == null || actualLumaBuffer.length != width * height) {
actualLumaBuffer = new byte[width * height];
if (distortedLumaBuffer == null || distortedLumaBuffer.length != width * height) {
distortedLumaBuffer = new byte[width * height];
}
try {
accumulatedSsim +=
SsimCalculator.calculate(
MssimCalculator.calculate(
extractLumaChannelBuffer(referenceImage, referenceLumaBuffer),
extractLumaChannelBuffer(actualImage, actualLumaBuffer),
/* offset= */ 0,
/* stride= */ width,
extractLumaChannelBuffer(distortedImage, distortedLumaBuffer),
width,
height);
} finally {
referenceImage.close();
actualImage.close();
distortedImage.close();
}
comparedImagesCount++;
}
} finally {
referenceDecodingWrapper.close();
actualDecodingWrapper.close();
distortedDecodingWrapper.close();
}
assertWithMessage("Input had no frames.").that(comparedImagesCount).isGreaterThan(0);
return accumulatedSsim / comparedImagesCount;
@ -330,55 +329,87 @@ public final class SsimHelper {
}
/**
* Image comparison using the Structural Similarity Index, developed by Wang, Bovik, Sheikh, and
* Simoncelli.
* Image comparison using the Mean Structural Similarity (MSSIM), developed by Wang, Bovik,
* Sheikh, and Simoncelli.
*
* <p>MSSIM divides the image into windows, calculates SSIM of each, then returns the average.
*
* @see <a href=https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf>The SSIM paper</a>.
*/
private static final class SsimCalculator {
// These values were taken from the SSIM paper. Please see the linked paper for details.
private static final double IMAGE_DYNAMIC_RANGE = 255;
private static final class MssimCalculator {
// Referred to as 'L' in the SSIM paper, this constant defines the maximum pixel values. The
// range of pixel values is 0 to 255 (8 bit unsigned range).
private static final int PIXEL_MAX_VALUE = 255;
// K1 and K2, as defined in the SSIM paper.
private static final double K1 = 0.01;
private static final double K2 = 0.03;
private static final double C1 = pow(IMAGE_DYNAMIC_RANGE * K1, 2);
private static final double C2 = pow(IMAGE_DYNAMIC_RANGE * K2, 2);
// C1 and C2 stabilize the SSIM value when either (referenceMean^2 + distortedMean^2) or
// (referenceVariance + distortedVariance) is close to 0. See the SSIM formula in
// `getWindowSsim` for how these values impact each other in the calculation.
private static final double C1 = pow(PIXEL_MAX_VALUE * K1, 2);
private static final double C2 = pow(PIXEL_MAX_VALUE * K2, 2);
private static final int WINDOW_SIZE = 8;
/**
* Calculates the Structural Similarity Index (SSIM) between two images.
* Calculates the Mean Structural Similarity (MSSIM) between two images.
*
* @param reference The luma channel (Y) bitmap of the reference image.
* @param actual The luma channel (Y) bitmap of the actual image.
* @param offset The offset.
* @param stride The stride of the bitmap.
* @param referenceBuffer The luma channel (Y) buffer of the reference image.
* @param distortedBuffer The luma channel (Y) buffer of the distorted image.
* @param width The image width in pixels.
* @param height The image height in pixels.
* @return The SSIM score between the input images.
* @return The MSSIM score between the input images.
*/
public static double calculate(
byte[] reference, byte[] actual, int offset, int stride, int width, int height) {
byte[] referenceBuffer, byte[] distortedBuffer, int width, int height) {
double totalSsim = 0;
int windowsCount = 0;
// X refers to the reference image, while Y refers to the actual image.
for (int currentWindowY = 0; currentWindowY < height; currentWindowY += WINDOW_SIZE) {
int windowHeight = computeWindowSize(currentWindowY, height);
for (int currentWindowX = 0; currentWindowX < width; currentWindowX += WINDOW_SIZE) {
windowsCount++;
int windowWidth = computeWindowSize(currentWindowX, width);
int start = getGlobalCoordinate(currentWindowX, currentWindowY, stride, offset);
double meanX = getMean(reference, start, stride, windowWidth, windowHeight);
double meanY = getMean(actual, start, stride, windowWidth, windowHeight);
int bufferIndexOffset =
get1dIndex(currentWindowX, currentWindowY, /* stride= */ width, /* offset= */ 0);
double referenceMean =
getMean(
referenceBuffer,
bufferIndexOffset,
/* stride= */ width,
windowWidth,
windowHeight);
double distortedMean =
getMean(
distortedBuffer,
bufferIndexOffset,
/* stride= */ width,
windowWidth,
windowHeight);
double[] variances =
getVariancesAndCovariance(
reference, actual, meanX, meanY, start, stride, windowWidth, windowHeight);
// varX is the variance of window X, covXY is the covariance between window X and Y.
double varX = variances[0];
double varY = variances[1];
double covXY = variances[2];
referenceBuffer,
distortedBuffer,
referenceMean,
distortedMean,
bufferIndexOffset,
/* stride= */ width,
windowWidth,
windowHeight);
double referenceVariance = variances[0];
double distortedVariance = variances[1];
double referenceDistortedCovariance = variances[2];
totalSsim += getWindowSsim(meanX, meanY, varX, varY, covXY);
totalSsim +=
getWindowSsim(
referenceMean,
distortedMean,
referenceVariance,
distortedVariance,
referenceDistortedCovariance);
}
}
@ -402,59 +433,76 @@ public final class SsimHelper {
/** Returns the SSIM of a window. */
private static double getWindowSsim(
double meanX, double meanY, double varX, double varY, double covXY) {
double referenceMean,
double distortedMean,
double referenceVariance,
double distortedVariance,
double referenceDistortedCovariance) {
// Uses equation 13 on page 6 from the linked paper.
double numerator = (((2 * meanX * meanY) + C1) * ((2 * covXY) + C2));
double denominator = ((meanX * meanX) + (meanY * meanY) + C1) * (varX + varY + C2);
double numerator =
(((2 * referenceMean * distortedMean) + C1) * ((2 * referenceDistortedCovariance) + C2));
double denominator =
((referenceMean * referenceMean) + (distortedMean * distortedMean) + C1)
* (referenceVariance + distortedVariance + C2);
return numerator / denominator;
}
/** Returns the means of the pixels in the two provided windows, in order. */
/** Returns the mean of the pixels in the window. */
private static double getMean(
byte[] pixels, int start, int stride, int windowWidth, int windowHeight) {
byte[] pixelBuffer, int bufferIndexOffset, int stride, int windowWidth, int windowHeight) {
double total = 0;
for (int y = 0; y < windowHeight; y++) {
for (int x = 0; x < windowWidth; x++) {
total += pixels[getGlobalCoordinate(x, y, stride, start)];
total += pixelBuffer[get1dIndex(x, y, stride, bufferIndexOffset)];
}
}
return total / windowWidth * windowHeight;
}
/** Returns the two variances and the covariance of the two windows. */
/** Calculates the variances and covariance of the pixels in the window for both buffers. */
private static double[] getVariancesAndCovariance(
byte[] pixelsX,
byte[] pixelsY,
double meanX,
double meanY,
int start,
byte[] referenceBuffer,
byte[] distortedBuffer,
double referenceMean,
double distortedMean,
int bufferIndexOffset,
int stride,
int windowWidth,
int windowHeight) {
// The variances in X and Y.
double varX = 0;
double varY = 0;
// The covariance between X and Y.
double covXY = 0;
double referenceVariance = 0;
double distortedVariance = 0;
double referenceDistortedCovariance = 0;
for (int y = 0; y < windowHeight; y++) {
for (int x = 0; x < windowWidth; x++) {
int index = getGlobalCoordinate(x, y, stride, start);
double offsetX = pixelsX[index] - meanX;
double offsetY = pixelsY[index] - meanY;
varX += pow(offsetX, 2);
varY += pow(offsetY, 2);
covXY += offsetX * offsetY;
int index = get1dIndex(x, y, stride, bufferIndexOffset);
double referencePixelDeviation = referenceBuffer[index] - referenceMean;
double distortedPixelDeviation = distortedBuffer[index] - distortedMean;
referenceVariance += referencePixelDeviation * referencePixelDeviation;
distortedVariance += distortedPixelDeviation * distortedPixelDeviation;
referenceDistortedCovariance += referencePixelDeviation * distortedPixelDeviation;
}
}
int normalizationFactor = windowWidth * windowHeight - 1;
return new double[] {
varX / normalizationFactor, varY / normalizationFactor, covXY / normalizationFactor
referenceVariance / normalizationFactor,
distortedVariance / normalizationFactor,
referenceDistortedCovariance / normalizationFactor
};
}
private static int getGlobalCoordinate(int x, int y, int stride, int offset) {
/**
* Translates a 2D coordinate into an 1D index, based on the stride of the 2D space.
*
* @param x The width component of coordinate.
* @param y The height component of coordinate.
* @param stride The width of the 2D space.
* @param offset An offset to apply.
* @return The 1D index.
*/
private static int get1dIndex(int x, int y, int stride, int offset) {
return x + (y * stride) + offset;
}
}

View File

@ -333,7 +333,7 @@ public class TransformerAndroidTestRunner {
SsimHelper.calculate(
context,
/* referenceVideoPath= */ checkNotNull(mediaItem.localConfiguration).uri.toString(),
outputVideoFile.getPath());
/* distortedVideoPath= */ outputVideoFile.getPath());
resultBuilder.setSsim(ssim);
} catch (InterruptedException interruptedException) {
// InterruptedException is a special unexpected case because it is not related to Ssim