Clarify variables and improve documentation of SSIM.
PiperOrigin-RevId: 451392021
This commit is contained in:
parent
224761833f
commit
6c4f6ecf46
@ -63,66 +63,65 @@ public final class SsimHelper {
|
||||
private static final int DECODED_IMAGE_CHANNEL_COUNT = 3;
|
||||
|
||||
/**
|
||||
* Returns the mean SSIM score between the reference and the actual video.
|
||||
* Returns the mean SSIM score between the reference and the distorted video.
|
||||
*
|
||||
* <p>The method compares every {@link #DEFAULT_COMPARISON_INTERVAL n-th} frame from both videos.
|
||||
*
|
||||
* @param context The {@link Context}.
|
||||
* @param referenceVideoPath The path to the reference video file, which must be in {@linkplain
|
||||
* Context#getAssets() Assets}.
|
||||
* @param actualVideoPath The path to the actual video file.
|
||||
* @param distortedVideoPath The path to the distorted video file.
|
||||
* @throws IOException When unable to open the provided video paths.
|
||||
*/
|
||||
public static double calculate(Context context, String referenceVideoPath, String actualVideoPath)
|
||||
public static double calculate(
|
||||
Context context, String referenceVideoPath, String distortedVideoPath)
|
||||
throws IOException, InterruptedException {
|
||||
VideoDecodingWrapper referenceDecodingWrapper =
|
||||
new VideoDecodingWrapper(context, referenceVideoPath, DEFAULT_COMPARISON_INTERVAL);
|
||||
VideoDecodingWrapper actualDecodingWrapper =
|
||||
new VideoDecodingWrapper(context, actualVideoPath, DEFAULT_COMPARISON_INTERVAL);
|
||||
VideoDecodingWrapper distortedDecodingWrapper =
|
||||
new VideoDecodingWrapper(context, distortedVideoPath, DEFAULT_COMPARISON_INTERVAL);
|
||||
@Nullable byte[] referenceLumaBuffer = null;
|
||||
@Nullable byte[] actualLumaBuffer = null;
|
||||
@Nullable byte[] distortedLumaBuffer = null;
|
||||
double accumulatedSsim = 0.0;
|
||||
int comparedImagesCount = 0;
|
||||
try {
|
||||
while (true) {
|
||||
@Nullable Image referenceImage = referenceDecodingWrapper.runUntilComparisonFrameOrEnded();
|
||||
@Nullable Image actualImage = actualDecodingWrapper.runUntilComparisonFrameOrEnded();
|
||||
@Nullable Image distortedImage = distortedDecodingWrapper.runUntilComparisonFrameOrEnded();
|
||||
if (referenceImage == null) {
|
||||
assertThat(actualImage).isNull();
|
||||
assertThat(distortedImage).isNull();
|
||||
break;
|
||||
}
|
||||
checkNotNull(actualImage);
|
||||
checkNotNull(distortedImage);
|
||||
|
||||
int width = referenceImage.getWidth();
|
||||
int height = referenceImage.getHeight();
|
||||
|
||||
assertThat(actualImage.getWidth()).isEqualTo(width);
|
||||
assertThat(actualImage.getHeight()).isEqualTo(height);
|
||||
assertThat(distortedImage.getWidth()).isEqualTo(width);
|
||||
assertThat(distortedImage.getHeight()).isEqualTo(height);
|
||||
|
||||
if (referenceLumaBuffer == null || referenceLumaBuffer.length != width * height) {
|
||||
referenceLumaBuffer = new byte[width * height];
|
||||
}
|
||||
if (actualLumaBuffer == null || actualLumaBuffer.length != width * height) {
|
||||
actualLumaBuffer = new byte[width * height];
|
||||
if (distortedLumaBuffer == null || distortedLumaBuffer.length != width * height) {
|
||||
distortedLumaBuffer = new byte[width * height];
|
||||
}
|
||||
try {
|
||||
accumulatedSsim +=
|
||||
SsimCalculator.calculate(
|
||||
MssimCalculator.calculate(
|
||||
extractLumaChannelBuffer(referenceImage, referenceLumaBuffer),
|
||||
extractLumaChannelBuffer(actualImage, actualLumaBuffer),
|
||||
/* offset= */ 0,
|
||||
/* stride= */ width,
|
||||
extractLumaChannelBuffer(distortedImage, distortedLumaBuffer),
|
||||
width,
|
||||
height);
|
||||
} finally {
|
||||
referenceImage.close();
|
||||
actualImage.close();
|
||||
distortedImage.close();
|
||||
}
|
||||
comparedImagesCount++;
|
||||
}
|
||||
} finally {
|
||||
referenceDecodingWrapper.close();
|
||||
actualDecodingWrapper.close();
|
||||
distortedDecodingWrapper.close();
|
||||
}
|
||||
assertWithMessage("Input had no frames.").that(comparedImagesCount).isGreaterThan(0);
|
||||
return accumulatedSsim / comparedImagesCount;
|
||||
@ -330,55 +329,87 @@ public final class SsimHelper {
|
||||
}
|
||||
|
||||
/**
|
||||
* Image comparison using the Structural Similarity Index, developed by Wang, Bovik, Sheikh, and
|
||||
* Simoncelli.
|
||||
* Image comparison using the Mean Structural Similarity (MSSIM), developed by Wang, Bovik,
|
||||
* Sheikh, and Simoncelli.
|
||||
*
|
||||
* <p>MSSIM divides the image into windows, calculates SSIM of each, then returns the average.
|
||||
*
|
||||
* @see <a href=https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf>The SSIM paper</a>.
|
||||
*/
|
||||
private static final class SsimCalculator {
|
||||
// These values were taken from the SSIM paper. Please see the linked paper for details.
|
||||
private static final double IMAGE_DYNAMIC_RANGE = 255;
|
||||
private static final class MssimCalculator {
|
||||
// Referred to as 'L' in the SSIM paper, this constant defines the maximum pixel values. The
|
||||
// range of pixel values is 0 to 255 (8 bit unsigned range).
|
||||
private static final int PIXEL_MAX_VALUE = 255;
|
||||
|
||||
// K1 and K2, as defined in the SSIM paper.
|
||||
private static final double K1 = 0.01;
|
||||
private static final double K2 = 0.03;
|
||||
private static final double C1 = pow(IMAGE_DYNAMIC_RANGE * K1, 2);
|
||||
private static final double C2 = pow(IMAGE_DYNAMIC_RANGE * K2, 2);
|
||||
|
||||
// C1 and C2 stabilize the SSIM value when either (referenceMean^2 + distortedMean^2) or
|
||||
// (referenceVariance + distortedVariance) is close to 0. See the SSIM formula in
|
||||
// `getWindowSsim` for how these values impact each other in the calculation.
|
||||
private static final double C1 = pow(PIXEL_MAX_VALUE * K1, 2);
|
||||
private static final double C2 = pow(PIXEL_MAX_VALUE * K2, 2);
|
||||
|
||||
private static final int WINDOW_SIZE = 8;
|
||||
|
||||
/**
|
||||
* Calculates the Structural Similarity Index (SSIM) between two images.
|
||||
* Calculates the Mean Structural Similarity (MSSIM) between two images.
|
||||
*
|
||||
* @param reference The luma channel (Y) bitmap of the reference image.
|
||||
* @param actual The luma channel (Y) bitmap of the actual image.
|
||||
* @param offset The offset.
|
||||
* @param stride The stride of the bitmap.
|
||||
* @param referenceBuffer The luma channel (Y) buffer of the reference image.
|
||||
* @param distortedBuffer The luma channel (Y) buffer of the distorted image.
|
||||
* @param width The image width in pixels.
|
||||
* @param height The image height in pixels.
|
||||
* @return The SSIM score between the input images.
|
||||
* @return The MSSIM score between the input images.
|
||||
*/
|
||||
public static double calculate(
|
||||
byte[] reference, byte[] actual, int offset, int stride, int width, int height) {
|
||||
byte[] referenceBuffer, byte[] distortedBuffer, int width, int height) {
|
||||
double totalSsim = 0;
|
||||
int windowsCount = 0;
|
||||
|
||||
// X refers to the reference image, while Y refers to the actual image.
|
||||
for (int currentWindowY = 0; currentWindowY < height; currentWindowY += WINDOW_SIZE) {
|
||||
int windowHeight = computeWindowSize(currentWindowY, height);
|
||||
for (int currentWindowX = 0; currentWindowX < width; currentWindowX += WINDOW_SIZE) {
|
||||
windowsCount++;
|
||||
int windowWidth = computeWindowSize(currentWindowX, width);
|
||||
int start = getGlobalCoordinate(currentWindowX, currentWindowY, stride, offset);
|
||||
double meanX = getMean(reference, start, stride, windowWidth, windowHeight);
|
||||
double meanY = getMean(actual, start, stride, windowWidth, windowHeight);
|
||||
int bufferIndexOffset =
|
||||
get1dIndex(currentWindowX, currentWindowY, /* stride= */ width, /* offset= */ 0);
|
||||
double referenceMean =
|
||||
getMean(
|
||||
referenceBuffer,
|
||||
bufferIndexOffset,
|
||||
/* stride= */ width,
|
||||
windowWidth,
|
||||
windowHeight);
|
||||
double distortedMean =
|
||||
getMean(
|
||||
distortedBuffer,
|
||||
bufferIndexOffset,
|
||||
/* stride= */ width,
|
||||
windowWidth,
|
||||
windowHeight);
|
||||
|
||||
double[] variances =
|
||||
getVariancesAndCovariance(
|
||||
reference, actual, meanX, meanY, start, stride, windowWidth, windowHeight);
|
||||
// varX is the variance of window X, covXY is the covariance between window X and Y.
|
||||
double varX = variances[0];
|
||||
double varY = variances[1];
|
||||
double covXY = variances[2];
|
||||
referenceBuffer,
|
||||
distortedBuffer,
|
||||
referenceMean,
|
||||
distortedMean,
|
||||
bufferIndexOffset,
|
||||
/* stride= */ width,
|
||||
windowWidth,
|
||||
windowHeight);
|
||||
double referenceVariance = variances[0];
|
||||
double distortedVariance = variances[1];
|
||||
double referenceDistortedCovariance = variances[2];
|
||||
|
||||
totalSsim += getWindowSsim(meanX, meanY, varX, varY, covXY);
|
||||
totalSsim +=
|
||||
getWindowSsim(
|
||||
referenceMean,
|
||||
distortedMean,
|
||||
referenceVariance,
|
||||
distortedVariance,
|
||||
referenceDistortedCovariance);
|
||||
}
|
||||
}
|
||||
|
||||
@ -402,59 +433,76 @@ public final class SsimHelper {
|
||||
|
||||
/** Returns the SSIM of a window. */
|
||||
private static double getWindowSsim(
|
||||
double meanX, double meanY, double varX, double varY, double covXY) {
|
||||
double referenceMean,
|
||||
double distortedMean,
|
||||
double referenceVariance,
|
||||
double distortedVariance,
|
||||
double referenceDistortedCovariance) {
|
||||
|
||||
// Uses equation 13 on page 6 from the linked paper.
|
||||
double numerator = (((2 * meanX * meanY) + C1) * ((2 * covXY) + C2));
|
||||
double denominator = ((meanX * meanX) + (meanY * meanY) + C1) * (varX + varY + C2);
|
||||
double numerator =
|
||||
(((2 * referenceMean * distortedMean) + C1) * ((2 * referenceDistortedCovariance) + C2));
|
||||
double denominator =
|
||||
((referenceMean * referenceMean) + (distortedMean * distortedMean) + C1)
|
||||
* (referenceVariance + distortedVariance + C2);
|
||||
return numerator / denominator;
|
||||
}
|
||||
|
||||
/** Returns the means of the pixels in the two provided windows, in order. */
|
||||
/** Returns the mean of the pixels in the window. */
|
||||
private static double getMean(
|
||||
byte[] pixels, int start, int stride, int windowWidth, int windowHeight) {
|
||||
byte[] pixelBuffer, int bufferIndexOffset, int stride, int windowWidth, int windowHeight) {
|
||||
double total = 0;
|
||||
for (int y = 0; y < windowHeight; y++) {
|
||||
for (int x = 0; x < windowWidth; x++) {
|
||||
total += pixels[getGlobalCoordinate(x, y, stride, start)];
|
||||
total += pixelBuffer[get1dIndex(x, y, stride, bufferIndexOffset)];
|
||||
}
|
||||
}
|
||||
return total / windowWidth * windowHeight;
|
||||
}
|
||||
|
||||
/** Returns the two variances and the covariance of the two windows. */
|
||||
/** Calculates the variances and covariance of the pixels in the window for both buffers. */
|
||||
private static double[] getVariancesAndCovariance(
|
||||
byte[] pixelsX,
|
||||
byte[] pixelsY,
|
||||
double meanX,
|
||||
double meanY,
|
||||
int start,
|
||||
byte[] referenceBuffer,
|
||||
byte[] distortedBuffer,
|
||||
double referenceMean,
|
||||
double distortedMean,
|
||||
int bufferIndexOffset,
|
||||
int stride,
|
||||
int windowWidth,
|
||||
int windowHeight) {
|
||||
// The variances in X and Y.
|
||||
double varX = 0;
|
||||
double varY = 0;
|
||||
// The covariance between X and Y.
|
||||
double covXY = 0;
|
||||
double referenceVariance = 0;
|
||||
double distortedVariance = 0;
|
||||
double referenceDistortedCovariance = 0;
|
||||
for (int y = 0; y < windowHeight; y++) {
|
||||
for (int x = 0; x < windowWidth; x++) {
|
||||
int index = getGlobalCoordinate(x, y, stride, start);
|
||||
double offsetX = pixelsX[index] - meanX;
|
||||
double offsetY = pixelsY[index] - meanY;
|
||||
varX += pow(offsetX, 2);
|
||||
varY += pow(offsetY, 2);
|
||||
covXY += offsetX * offsetY;
|
||||
int index = get1dIndex(x, y, stride, bufferIndexOffset);
|
||||
double referencePixelDeviation = referenceBuffer[index] - referenceMean;
|
||||
double distortedPixelDeviation = distortedBuffer[index] - distortedMean;
|
||||
referenceVariance += referencePixelDeviation * referencePixelDeviation;
|
||||
distortedVariance += distortedPixelDeviation * distortedPixelDeviation;
|
||||
referenceDistortedCovariance += referencePixelDeviation * distortedPixelDeviation;
|
||||
}
|
||||
}
|
||||
|
||||
int normalizationFactor = windowWidth * windowHeight - 1;
|
||||
|
||||
return new double[] {
|
||||
varX / normalizationFactor, varY / normalizationFactor, covXY / normalizationFactor
|
||||
referenceVariance / normalizationFactor,
|
||||
distortedVariance / normalizationFactor,
|
||||
referenceDistortedCovariance / normalizationFactor
|
||||
};
|
||||
}
|
||||
|
||||
private static int getGlobalCoordinate(int x, int y, int stride, int offset) {
|
||||
/**
|
||||
* Translates a 2D coordinate into an 1D index, based on the stride of the 2D space.
|
||||
*
|
||||
* @param x The width component of coordinate.
|
||||
* @param y The height component of coordinate.
|
||||
* @param stride The width of the 2D space.
|
||||
* @param offset An offset to apply.
|
||||
* @return The 1D index.
|
||||
*/
|
||||
private static int get1dIndex(int x, int y, int stride, int offset) {
|
||||
return x + (y * stride) + offset;
|
||||
}
|
||||
}
|
||||
|
@ -333,7 +333,7 @@ public class TransformerAndroidTestRunner {
|
||||
SsimHelper.calculate(
|
||||
context,
|
||||
/* referenceVideoPath= */ checkNotNull(mediaItem.localConfiguration).uri.toString(),
|
||||
outputVideoFile.getPath());
|
||||
/* distortedVideoPath= */ outputVideoFile.getPath());
|
||||
resultBuilder.setSsim(ssim);
|
||||
} catch (InterruptedException interruptedException) {
|
||||
// InterruptedException is a special unexpected case because it is not related to Ssim
|
||||
|
Loading…
x
Reference in New Issue
Block a user