Clarify variables and improve documentation of SSIM.
PiperOrigin-RevId: 451392021
This commit is contained in:
parent
224761833f
commit
6c4f6ecf46
@ -63,66 +63,65 @@ public final class SsimHelper {
|
|||||||
private static final int DECODED_IMAGE_CHANNEL_COUNT = 3;
|
private static final int DECODED_IMAGE_CHANNEL_COUNT = 3;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns the mean SSIM score between the reference and the actual video.
|
* Returns the mean SSIM score between the reference and the distorted video.
|
||||||
*
|
*
|
||||||
* <p>The method compares every {@link #DEFAULT_COMPARISON_INTERVAL n-th} frame from both videos.
|
* <p>The method compares every {@link #DEFAULT_COMPARISON_INTERVAL n-th} frame from both videos.
|
||||||
*
|
*
|
||||||
* @param context The {@link Context}.
|
* @param context The {@link Context}.
|
||||||
* @param referenceVideoPath The path to the reference video file, which must be in {@linkplain
|
* @param referenceVideoPath The path to the reference video file, which must be in {@linkplain
|
||||||
* Context#getAssets() Assets}.
|
* Context#getAssets() Assets}.
|
||||||
* @param actualVideoPath The path to the actual video file.
|
* @param distortedVideoPath The path to the distorted video file.
|
||||||
* @throws IOException When unable to open the provided video paths.
|
* @throws IOException When unable to open the provided video paths.
|
||||||
*/
|
*/
|
||||||
public static double calculate(Context context, String referenceVideoPath, String actualVideoPath)
|
public static double calculate(
|
||||||
|
Context context, String referenceVideoPath, String distortedVideoPath)
|
||||||
throws IOException, InterruptedException {
|
throws IOException, InterruptedException {
|
||||||
VideoDecodingWrapper referenceDecodingWrapper =
|
VideoDecodingWrapper referenceDecodingWrapper =
|
||||||
new VideoDecodingWrapper(context, referenceVideoPath, DEFAULT_COMPARISON_INTERVAL);
|
new VideoDecodingWrapper(context, referenceVideoPath, DEFAULT_COMPARISON_INTERVAL);
|
||||||
VideoDecodingWrapper actualDecodingWrapper =
|
VideoDecodingWrapper distortedDecodingWrapper =
|
||||||
new VideoDecodingWrapper(context, actualVideoPath, DEFAULT_COMPARISON_INTERVAL);
|
new VideoDecodingWrapper(context, distortedVideoPath, DEFAULT_COMPARISON_INTERVAL);
|
||||||
@Nullable byte[] referenceLumaBuffer = null;
|
@Nullable byte[] referenceLumaBuffer = null;
|
||||||
@Nullable byte[] actualLumaBuffer = null;
|
@Nullable byte[] distortedLumaBuffer = null;
|
||||||
double accumulatedSsim = 0.0;
|
double accumulatedSsim = 0.0;
|
||||||
int comparedImagesCount = 0;
|
int comparedImagesCount = 0;
|
||||||
try {
|
try {
|
||||||
while (true) {
|
while (true) {
|
||||||
@Nullable Image referenceImage = referenceDecodingWrapper.runUntilComparisonFrameOrEnded();
|
@Nullable Image referenceImage = referenceDecodingWrapper.runUntilComparisonFrameOrEnded();
|
||||||
@Nullable Image actualImage = actualDecodingWrapper.runUntilComparisonFrameOrEnded();
|
@Nullable Image distortedImage = distortedDecodingWrapper.runUntilComparisonFrameOrEnded();
|
||||||
if (referenceImage == null) {
|
if (referenceImage == null) {
|
||||||
assertThat(actualImage).isNull();
|
assertThat(distortedImage).isNull();
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
checkNotNull(actualImage);
|
checkNotNull(distortedImage);
|
||||||
|
|
||||||
int width = referenceImage.getWidth();
|
int width = referenceImage.getWidth();
|
||||||
int height = referenceImage.getHeight();
|
int height = referenceImage.getHeight();
|
||||||
|
|
||||||
assertThat(actualImage.getWidth()).isEqualTo(width);
|
assertThat(distortedImage.getWidth()).isEqualTo(width);
|
||||||
assertThat(actualImage.getHeight()).isEqualTo(height);
|
assertThat(distortedImage.getHeight()).isEqualTo(height);
|
||||||
|
|
||||||
if (referenceLumaBuffer == null || referenceLumaBuffer.length != width * height) {
|
if (referenceLumaBuffer == null || referenceLumaBuffer.length != width * height) {
|
||||||
referenceLumaBuffer = new byte[width * height];
|
referenceLumaBuffer = new byte[width * height];
|
||||||
}
|
}
|
||||||
if (actualLumaBuffer == null || actualLumaBuffer.length != width * height) {
|
if (distortedLumaBuffer == null || distortedLumaBuffer.length != width * height) {
|
||||||
actualLumaBuffer = new byte[width * height];
|
distortedLumaBuffer = new byte[width * height];
|
||||||
}
|
}
|
||||||
try {
|
try {
|
||||||
accumulatedSsim +=
|
accumulatedSsim +=
|
||||||
SsimCalculator.calculate(
|
MssimCalculator.calculate(
|
||||||
extractLumaChannelBuffer(referenceImage, referenceLumaBuffer),
|
extractLumaChannelBuffer(referenceImage, referenceLumaBuffer),
|
||||||
extractLumaChannelBuffer(actualImage, actualLumaBuffer),
|
extractLumaChannelBuffer(distortedImage, distortedLumaBuffer),
|
||||||
/* offset= */ 0,
|
|
||||||
/* stride= */ width,
|
|
||||||
width,
|
width,
|
||||||
height);
|
height);
|
||||||
} finally {
|
} finally {
|
||||||
referenceImage.close();
|
referenceImage.close();
|
||||||
actualImage.close();
|
distortedImage.close();
|
||||||
}
|
}
|
||||||
comparedImagesCount++;
|
comparedImagesCount++;
|
||||||
}
|
}
|
||||||
} finally {
|
} finally {
|
||||||
referenceDecodingWrapper.close();
|
referenceDecodingWrapper.close();
|
||||||
actualDecodingWrapper.close();
|
distortedDecodingWrapper.close();
|
||||||
}
|
}
|
||||||
assertWithMessage("Input had no frames.").that(comparedImagesCount).isGreaterThan(0);
|
assertWithMessage("Input had no frames.").that(comparedImagesCount).isGreaterThan(0);
|
||||||
return accumulatedSsim / comparedImagesCount;
|
return accumulatedSsim / comparedImagesCount;
|
||||||
@ -330,55 +329,87 @@ public final class SsimHelper {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Image comparison using the Structural Similarity Index, developed by Wang, Bovik, Sheikh, and
|
* Image comparison using the Mean Structural Similarity (MSSIM), developed by Wang, Bovik,
|
||||||
* Simoncelli.
|
* Sheikh, and Simoncelli.
|
||||||
|
*
|
||||||
|
* <p>MSSIM divides the image into windows, calculates SSIM of each, then returns the average.
|
||||||
*
|
*
|
||||||
* @see <a href=https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf>The SSIM paper</a>.
|
* @see <a href=https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf>The SSIM paper</a>.
|
||||||
*/
|
*/
|
||||||
private static final class SsimCalculator {
|
private static final class MssimCalculator {
|
||||||
// These values were taken from the SSIM paper. Please see the linked paper for details.
|
// Referred to as 'L' in the SSIM paper, this constant defines the maximum pixel values. The
|
||||||
private static final double IMAGE_DYNAMIC_RANGE = 255;
|
// range of pixel values is 0 to 255 (8 bit unsigned range).
|
||||||
|
private static final int PIXEL_MAX_VALUE = 255;
|
||||||
|
|
||||||
|
// K1 and K2, as defined in the SSIM paper.
|
||||||
private static final double K1 = 0.01;
|
private static final double K1 = 0.01;
|
||||||
private static final double K2 = 0.03;
|
private static final double K2 = 0.03;
|
||||||
private static final double C1 = pow(IMAGE_DYNAMIC_RANGE * K1, 2);
|
|
||||||
private static final double C2 = pow(IMAGE_DYNAMIC_RANGE * K2, 2);
|
// C1 and C2 stabilize the SSIM value when either (referenceMean^2 + distortedMean^2) or
|
||||||
|
// (referenceVariance + distortedVariance) is close to 0. See the SSIM formula in
|
||||||
|
// `getWindowSsim` for how these values impact each other in the calculation.
|
||||||
|
private static final double C1 = pow(PIXEL_MAX_VALUE * K1, 2);
|
||||||
|
private static final double C2 = pow(PIXEL_MAX_VALUE * K2, 2);
|
||||||
|
|
||||||
private static final int WINDOW_SIZE = 8;
|
private static final int WINDOW_SIZE = 8;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Calculates the Structural Similarity Index (SSIM) between two images.
|
* Calculates the Mean Structural Similarity (MSSIM) between two images.
|
||||||
*
|
*
|
||||||
* @param reference The luma channel (Y) bitmap of the reference image.
|
* @param referenceBuffer The luma channel (Y) buffer of the reference image.
|
||||||
* @param actual The luma channel (Y) bitmap of the actual image.
|
* @param distortedBuffer The luma channel (Y) buffer of the distorted image.
|
||||||
* @param offset The offset.
|
|
||||||
* @param stride The stride of the bitmap.
|
|
||||||
* @param width The image width in pixels.
|
* @param width The image width in pixels.
|
||||||
* @param height The image height in pixels.
|
* @param height The image height in pixels.
|
||||||
* @return The SSIM score between the input images.
|
* @return The MSSIM score between the input images.
|
||||||
*/
|
*/
|
||||||
public static double calculate(
|
public static double calculate(
|
||||||
byte[] reference, byte[] actual, int offset, int stride, int width, int height) {
|
byte[] referenceBuffer, byte[] distortedBuffer, int width, int height) {
|
||||||
double totalSsim = 0;
|
double totalSsim = 0;
|
||||||
int windowsCount = 0;
|
int windowsCount = 0;
|
||||||
|
|
||||||
// X refers to the reference image, while Y refers to the actual image.
|
|
||||||
for (int currentWindowY = 0; currentWindowY < height; currentWindowY += WINDOW_SIZE) {
|
for (int currentWindowY = 0; currentWindowY < height; currentWindowY += WINDOW_SIZE) {
|
||||||
int windowHeight = computeWindowSize(currentWindowY, height);
|
int windowHeight = computeWindowSize(currentWindowY, height);
|
||||||
for (int currentWindowX = 0; currentWindowX < width; currentWindowX += WINDOW_SIZE) {
|
for (int currentWindowX = 0; currentWindowX < width; currentWindowX += WINDOW_SIZE) {
|
||||||
windowsCount++;
|
windowsCount++;
|
||||||
int windowWidth = computeWindowSize(currentWindowX, width);
|
int windowWidth = computeWindowSize(currentWindowX, width);
|
||||||
int start = getGlobalCoordinate(currentWindowX, currentWindowY, stride, offset);
|
int bufferIndexOffset =
|
||||||
double meanX = getMean(reference, start, stride, windowWidth, windowHeight);
|
get1dIndex(currentWindowX, currentWindowY, /* stride= */ width, /* offset= */ 0);
|
||||||
double meanY = getMean(actual, start, stride, windowWidth, windowHeight);
|
double referenceMean =
|
||||||
|
getMean(
|
||||||
|
referenceBuffer,
|
||||||
|
bufferIndexOffset,
|
||||||
|
/* stride= */ width,
|
||||||
|
windowWidth,
|
||||||
|
windowHeight);
|
||||||
|
double distortedMean =
|
||||||
|
getMean(
|
||||||
|
distortedBuffer,
|
||||||
|
bufferIndexOffset,
|
||||||
|
/* stride= */ width,
|
||||||
|
windowWidth,
|
||||||
|
windowHeight);
|
||||||
|
|
||||||
double[] variances =
|
double[] variances =
|
||||||
getVariancesAndCovariance(
|
getVariancesAndCovariance(
|
||||||
reference, actual, meanX, meanY, start, stride, windowWidth, windowHeight);
|
referenceBuffer,
|
||||||
// varX is the variance of window X, covXY is the covariance between window X and Y.
|
distortedBuffer,
|
||||||
double varX = variances[0];
|
referenceMean,
|
||||||
double varY = variances[1];
|
distortedMean,
|
||||||
double covXY = variances[2];
|
bufferIndexOffset,
|
||||||
|
/* stride= */ width,
|
||||||
|
windowWidth,
|
||||||
|
windowHeight);
|
||||||
|
double referenceVariance = variances[0];
|
||||||
|
double distortedVariance = variances[1];
|
||||||
|
double referenceDistortedCovariance = variances[2];
|
||||||
|
|
||||||
totalSsim += getWindowSsim(meanX, meanY, varX, varY, covXY);
|
totalSsim +=
|
||||||
|
getWindowSsim(
|
||||||
|
referenceMean,
|
||||||
|
distortedMean,
|
||||||
|
referenceVariance,
|
||||||
|
distortedVariance,
|
||||||
|
referenceDistortedCovariance);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -402,59 +433,76 @@ public final class SsimHelper {
|
|||||||
|
|
||||||
/** Returns the SSIM of a window. */
|
/** Returns the SSIM of a window. */
|
||||||
private static double getWindowSsim(
|
private static double getWindowSsim(
|
||||||
double meanX, double meanY, double varX, double varY, double covXY) {
|
double referenceMean,
|
||||||
|
double distortedMean,
|
||||||
|
double referenceVariance,
|
||||||
|
double distortedVariance,
|
||||||
|
double referenceDistortedCovariance) {
|
||||||
|
|
||||||
// Uses equation 13 on page 6 from the linked paper.
|
// Uses equation 13 on page 6 from the linked paper.
|
||||||
double numerator = (((2 * meanX * meanY) + C1) * ((2 * covXY) + C2));
|
double numerator =
|
||||||
double denominator = ((meanX * meanX) + (meanY * meanY) + C1) * (varX + varY + C2);
|
(((2 * referenceMean * distortedMean) + C1) * ((2 * referenceDistortedCovariance) + C2));
|
||||||
|
double denominator =
|
||||||
|
((referenceMean * referenceMean) + (distortedMean * distortedMean) + C1)
|
||||||
|
* (referenceVariance + distortedVariance + C2);
|
||||||
return numerator / denominator;
|
return numerator / denominator;
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Returns the means of the pixels in the two provided windows, in order. */
|
/** Returns the mean of the pixels in the window. */
|
||||||
private static double getMean(
|
private static double getMean(
|
||||||
byte[] pixels, int start, int stride, int windowWidth, int windowHeight) {
|
byte[] pixelBuffer, int bufferIndexOffset, int stride, int windowWidth, int windowHeight) {
|
||||||
double total = 0;
|
double total = 0;
|
||||||
for (int y = 0; y < windowHeight; y++) {
|
for (int y = 0; y < windowHeight; y++) {
|
||||||
for (int x = 0; x < windowWidth; x++) {
|
for (int x = 0; x < windowWidth; x++) {
|
||||||
total += pixels[getGlobalCoordinate(x, y, stride, start)];
|
total += pixelBuffer[get1dIndex(x, y, stride, bufferIndexOffset)];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return total / windowWidth * windowHeight;
|
return total / windowWidth * windowHeight;
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Returns the two variances and the covariance of the two windows. */
|
/** Calculates the variances and covariance of the pixels in the window for both buffers. */
|
||||||
private static double[] getVariancesAndCovariance(
|
private static double[] getVariancesAndCovariance(
|
||||||
byte[] pixelsX,
|
byte[] referenceBuffer,
|
||||||
byte[] pixelsY,
|
byte[] distortedBuffer,
|
||||||
double meanX,
|
double referenceMean,
|
||||||
double meanY,
|
double distortedMean,
|
||||||
int start,
|
int bufferIndexOffset,
|
||||||
int stride,
|
int stride,
|
||||||
int windowWidth,
|
int windowWidth,
|
||||||
int windowHeight) {
|
int windowHeight) {
|
||||||
// The variances in X and Y.
|
double referenceVariance = 0;
|
||||||
double varX = 0;
|
double distortedVariance = 0;
|
||||||
double varY = 0;
|
double referenceDistortedCovariance = 0;
|
||||||
// The covariance between X and Y.
|
|
||||||
double covXY = 0;
|
|
||||||
for (int y = 0; y < windowHeight; y++) {
|
for (int y = 0; y < windowHeight; y++) {
|
||||||
for (int x = 0; x < windowWidth; x++) {
|
for (int x = 0; x < windowWidth; x++) {
|
||||||
int index = getGlobalCoordinate(x, y, stride, start);
|
int index = get1dIndex(x, y, stride, bufferIndexOffset);
|
||||||
double offsetX = pixelsX[index] - meanX;
|
double referencePixelDeviation = referenceBuffer[index] - referenceMean;
|
||||||
double offsetY = pixelsY[index] - meanY;
|
double distortedPixelDeviation = distortedBuffer[index] - distortedMean;
|
||||||
varX += pow(offsetX, 2);
|
referenceVariance += referencePixelDeviation * referencePixelDeviation;
|
||||||
varY += pow(offsetY, 2);
|
distortedVariance += distortedPixelDeviation * distortedPixelDeviation;
|
||||||
covXY += offsetX * offsetY;
|
referenceDistortedCovariance += referencePixelDeviation * distortedPixelDeviation;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int normalizationFactor = windowWidth * windowHeight - 1;
|
int normalizationFactor = windowWidth * windowHeight - 1;
|
||||||
|
|
||||||
return new double[] {
|
return new double[] {
|
||||||
varX / normalizationFactor, varY / normalizationFactor, covXY / normalizationFactor
|
referenceVariance / normalizationFactor,
|
||||||
|
distortedVariance / normalizationFactor,
|
||||||
|
referenceDistortedCovariance / normalizationFactor
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
private static int getGlobalCoordinate(int x, int y, int stride, int offset) {
|
/**
|
||||||
|
* Translates a 2D coordinate into an 1D index, based on the stride of the 2D space.
|
||||||
|
*
|
||||||
|
* @param x The width component of coordinate.
|
||||||
|
* @param y The height component of coordinate.
|
||||||
|
* @param stride The width of the 2D space.
|
||||||
|
* @param offset An offset to apply.
|
||||||
|
* @return The 1D index.
|
||||||
|
*/
|
||||||
|
private static int get1dIndex(int x, int y, int stride, int offset) {
|
||||||
return x + (y * stride) + offset;
|
return x + (y * stride) + offset;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -333,7 +333,7 @@ public class TransformerAndroidTestRunner {
|
|||||||
SsimHelper.calculate(
|
SsimHelper.calculate(
|
||||||
context,
|
context,
|
||||||
/* referenceVideoPath= */ checkNotNull(mediaItem.localConfiguration).uri.toString(),
|
/* referenceVideoPath= */ checkNotNull(mediaItem.localConfiguration).uri.toString(),
|
||||||
outputVideoFile.getPath());
|
/* distortedVideoPath= */ outputVideoFile.getPath());
|
||||||
resultBuilder.setSsim(ssim);
|
resultBuilder.setSsim(ssim);
|
||||||
} catch (InterruptedException interruptedException) {
|
} catch (InterruptedException interruptedException) {
|
||||||
// InterruptedException is a special unexpected case because it is not related to Ssim
|
// InterruptedException is a special unexpected case because it is not related to Ssim
|
||||||
|
Loading…
x
Reference in New Issue
Block a user