diff --git a/libraries/transformer/src/androidTest/java/androidx/media3/transformer/SsimHelper.java b/libraries/transformer/src/androidTest/java/androidx/media3/transformer/SsimHelper.java index b504e28ac5..369078cf36 100644 --- a/libraries/transformer/src/androidTest/java/androidx/media3/transformer/SsimHelper.java +++ b/libraries/transformer/src/androidTest/java/androidx/media3/transformer/SsimHelper.java @@ -63,66 +63,65 @@ public final class SsimHelper { private static final int DECODED_IMAGE_CHANNEL_COUNT = 3; /** - * Returns the mean SSIM score between the reference and the actual video. + * Returns the mean SSIM score between the reference and the distorted video. * *

The method compares every {@link #DEFAULT_COMPARISON_INTERVAL n-th} frame from both videos. * * @param context The {@link Context}. * @param referenceVideoPath The path to the reference video file, which must be in {@linkplain * Context#getAssets() Assets}. - * @param actualVideoPath The path to the actual video file. + * @param distortedVideoPath The path to the distorted video file. * @throws IOException When unable to open the provided video paths. */ - public static double calculate(Context context, String referenceVideoPath, String actualVideoPath) + public static double calculate( + Context context, String referenceVideoPath, String distortedVideoPath) throws IOException, InterruptedException { VideoDecodingWrapper referenceDecodingWrapper = new VideoDecodingWrapper(context, referenceVideoPath, DEFAULT_COMPARISON_INTERVAL); - VideoDecodingWrapper actualDecodingWrapper = - new VideoDecodingWrapper(context, actualVideoPath, DEFAULT_COMPARISON_INTERVAL); + VideoDecodingWrapper distortedDecodingWrapper = + new VideoDecodingWrapper(context, distortedVideoPath, DEFAULT_COMPARISON_INTERVAL); @Nullable byte[] referenceLumaBuffer = null; - @Nullable byte[] actualLumaBuffer = null; + @Nullable byte[] distortedLumaBuffer = null; double accumulatedSsim = 0.0; int comparedImagesCount = 0; try { while (true) { @Nullable Image referenceImage = referenceDecodingWrapper.runUntilComparisonFrameOrEnded(); - @Nullable Image actualImage = actualDecodingWrapper.runUntilComparisonFrameOrEnded(); + @Nullable Image distortedImage = distortedDecodingWrapper.runUntilComparisonFrameOrEnded(); if (referenceImage == null) { - assertThat(actualImage).isNull(); + assertThat(distortedImage).isNull(); break; } - checkNotNull(actualImage); + checkNotNull(distortedImage); int width = referenceImage.getWidth(); int height = referenceImage.getHeight(); - assertThat(actualImage.getWidth()).isEqualTo(width); - assertThat(actualImage.getHeight()).isEqualTo(height); + assertThat(distortedImage.getWidth()).isEqualTo(width); + assertThat(distortedImage.getHeight()).isEqualTo(height); if (referenceLumaBuffer == null || referenceLumaBuffer.length != width * height) { referenceLumaBuffer = new byte[width * height]; } - if (actualLumaBuffer == null || actualLumaBuffer.length != width * height) { - actualLumaBuffer = new byte[width * height]; + if (distortedLumaBuffer == null || distortedLumaBuffer.length != width * height) { + distortedLumaBuffer = new byte[width * height]; } try { accumulatedSsim += - SsimCalculator.calculate( + MssimCalculator.calculate( extractLumaChannelBuffer(referenceImage, referenceLumaBuffer), - extractLumaChannelBuffer(actualImage, actualLumaBuffer), - /* offset= */ 0, - /* stride= */ width, + extractLumaChannelBuffer(distortedImage, distortedLumaBuffer), width, height); } finally { referenceImage.close(); - actualImage.close(); + distortedImage.close(); } comparedImagesCount++; } } finally { referenceDecodingWrapper.close(); - actualDecodingWrapper.close(); + distortedDecodingWrapper.close(); } assertWithMessage("Input had no frames.").that(comparedImagesCount).isGreaterThan(0); return accumulatedSsim / comparedImagesCount; @@ -330,55 +329,87 @@ public final class SsimHelper { } /** - * Image comparison using the Structural Similarity Index, developed by Wang, Bovik, Sheikh, and - * Simoncelli. + * Image comparison using the Mean Structural Similarity (MSSIM), developed by Wang, Bovik, + * Sheikh, and Simoncelli. + * + *

MSSIM divides the image into windows, calculates SSIM of each, then returns the average. * * @see The SSIM paper. */ - private static final class SsimCalculator { - // These values were taken from the SSIM paper. Please see the linked paper for details. - private static final double IMAGE_DYNAMIC_RANGE = 255; + private static final class MssimCalculator { + // Referred to as 'L' in the SSIM paper, this constant defines the maximum pixel values. The + // range of pixel values is 0 to 255 (8 bit unsigned range). + private static final int PIXEL_MAX_VALUE = 255; + + // K1 and K2, as defined in the SSIM paper. private static final double K1 = 0.01; private static final double K2 = 0.03; - private static final double C1 = pow(IMAGE_DYNAMIC_RANGE * K1, 2); - private static final double C2 = pow(IMAGE_DYNAMIC_RANGE * K2, 2); + + // C1 and C2 stabilize the SSIM value when either (referenceMean^2 + distortedMean^2) or + // (referenceVariance + distortedVariance) is close to 0. See the SSIM formula in + // `getWindowSsim` for how these values impact each other in the calculation. + private static final double C1 = pow(PIXEL_MAX_VALUE * K1, 2); + private static final double C2 = pow(PIXEL_MAX_VALUE * K2, 2); + private static final int WINDOW_SIZE = 8; /** - * Calculates the Structural Similarity Index (SSIM) between two images. + * Calculates the Mean Structural Similarity (MSSIM) between two images. * - * @param reference The luma channel (Y) bitmap of the reference image. - * @param actual The luma channel (Y) bitmap of the actual image. - * @param offset The offset. - * @param stride The stride of the bitmap. + * @param referenceBuffer The luma channel (Y) buffer of the reference image. + * @param distortedBuffer The luma channel (Y) buffer of the distorted image. * @param width The image width in pixels. * @param height The image height in pixels. - * @return The SSIM score between the input images. + * @return The MSSIM score between the input images. */ public static double calculate( - byte[] reference, byte[] actual, int offset, int stride, int width, int height) { + byte[] referenceBuffer, byte[] distortedBuffer, int width, int height) { double totalSsim = 0; int windowsCount = 0; - // X refers to the reference image, while Y refers to the actual image. for (int currentWindowY = 0; currentWindowY < height; currentWindowY += WINDOW_SIZE) { int windowHeight = computeWindowSize(currentWindowY, height); for (int currentWindowX = 0; currentWindowX < width; currentWindowX += WINDOW_SIZE) { windowsCount++; int windowWidth = computeWindowSize(currentWindowX, width); - int start = getGlobalCoordinate(currentWindowX, currentWindowY, stride, offset); - double meanX = getMean(reference, start, stride, windowWidth, windowHeight); - double meanY = getMean(actual, start, stride, windowWidth, windowHeight); + int bufferIndexOffset = + get1dIndex(currentWindowX, currentWindowY, /* stride= */ width, /* offset= */ 0); + double referenceMean = + getMean( + referenceBuffer, + bufferIndexOffset, + /* stride= */ width, + windowWidth, + windowHeight); + double distortedMean = + getMean( + distortedBuffer, + bufferIndexOffset, + /* stride= */ width, + windowWidth, + windowHeight); double[] variances = getVariancesAndCovariance( - reference, actual, meanX, meanY, start, stride, windowWidth, windowHeight); - // varX is the variance of window X, covXY is the covariance between window X and Y. - double varX = variances[0]; - double varY = variances[1]; - double covXY = variances[2]; + referenceBuffer, + distortedBuffer, + referenceMean, + distortedMean, + bufferIndexOffset, + /* stride= */ width, + windowWidth, + windowHeight); + double referenceVariance = variances[0]; + double distortedVariance = variances[1]; + double referenceDistortedCovariance = variances[2]; - totalSsim += getWindowSsim(meanX, meanY, varX, varY, covXY); + totalSsim += + getWindowSsim( + referenceMean, + distortedMean, + referenceVariance, + distortedVariance, + referenceDistortedCovariance); } } @@ -402,59 +433,76 @@ public final class SsimHelper { /** Returns the SSIM of a window. */ private static double getWindowSsim( - double meanX, double meanY, double varX, double varY, double covXY) { + double referenceMean, + double distortedMean, + double referenceVariance, + double distortedVariance, + double referenceDistortedCovariance) { // Uses equation 13 on page 6 from the linked paper. - double numerator = (((2 * meanX * meanY) + C1) * ((2 * covXY) + C2)); - double denominator = ((meanX * meanX) + (meanY * meanY) + C1) * (varX + varY + C2); + double numerator = + (((2 * referenceMean * distortedMean) + C1) * ((2 * referenceDistortedCovariance) + C2)); + double denominator = + ((referenceMean * referenceMean) + (distortedMean * distortedMean) + C1) + * (referenceVariance + distortedVariance + C2); return numerator / denominator; } - /** Returns the means of the pixels in the two provided windows, in order. */ + /** Returns the mean of the pixels in the window. */ private static double getMean( - byte[] pixels, int start, int stride, int windowWidth, int windowHeight) { + byte[] pixelBuffer, int bufferIndexOffset, int stride, int windowWidth, int windowHeight) { double total = 0; for (int y = 0; y < windowHeight; y++) { for (int x = 0; x < windowWidth; x++) { - total += pixels[getGlobalCoordinate(x, y, stride, start)]; + total += pixelBuffer[get1dIndex(x, y, stride, bufferIndexOffset)]; } } return total / windowWidth * windowHeight; } - /** Returns the two variances and the covariance of the two windows. */ + /** Calculates the variances and covariance of the pixels in the window for both buffers. */ private static double[] getVariancesAndCovariance( - byte[] pixelsX, - byte[] pixelsY, - double meanX, - double meanY, - int start, + byte[] referenceBuffer, + byte[] distortedBuffer, + double referenceMean, + double distortedMean, + int bufferIndexOffset, int stride, int windowWidth, int windowHeight) { - // The variances in X and Y. - double varX = 0; - double varY = 0; - // The covariance between X and Y. - double covXY = 0; + double referenceVariance = 0; + double distortedVariance = 0; + double referenceDistortedCovariance = 0; for (int y = 0; y < windowHeight; y++) { for (int x = 0; x < windowWidth; x++) { - int index = getGlobalCoordinate(x, y, stride, start); - double offsetX = pixelsX[index] - meanX; - double offsetY = pixelsY[index] - meanY; - varX += pow(offsetX, 2); - varY += pow(offsetY, 2); - covXY += offsetX * offsetY; + int index = get1dIndex(x, y, stride, bufferIndexOffset); + double referencePixelDeviation = referenceBuffer[index] - referenceMean; + double distortedPixelDeviation = distortedBuffer[index] - distortedMean; + referenceVariance += referencePixelDeviation * referencePixelDeviation; + distortedVariance += distortedPixelDeviation * distortedPixelDeviation; + referenceDistortedCovariance += referencePixelDeviation * distortedPixelDeviation; } } int normalizationFactor = windowWidth * windowHeight - 1; + return new double[] { - varX / normalizationFactor, varY / normalizationFactor, covXY / normalizationFactor + referenceVariance / normalizationFactor, + distortedVariance / normalizationFactor, + referenceDistortedCovariance / normalizationFactor }; } - private static int getGlobalCoordinate(int x, int y, int stride, int offset) { + /** + * Translates a 2D coordinate into an 1D index, based on the stride of the 2D space. + * + * @param x The width component of coordinate. + * @param y The height component of coordinate. + * @param stride The width of the 2D space. + * @param offset An offset to apply. + * @return The 1D index. + */ + private static int get1dIndex(int x, int y, int stride, int offset) { return x + (y * stride) + offset; } } diff --git a/libraries/transformer/src/androidTest/java/androidx/media3/transformer/TransformerAndroidTestRunner.java b/libraries/transformer/src/androidTest/java/androidx/media3/transformer/TransformerAndroidTestRunner.java index 147c4fc9c9..1742c4edc9 100644 --- a/libraries/transformer/src/androidTest/java/androidx/media3/transformer/TransformerAndroidTestRunner.java +++ b/libraries/transformer/src/androidTest/java/androidx/media3/transformer/TransformerAndroidTestRunner.java @@ -333,7 +333,7 @@ public class TransformerAndroidTestRunner { SsimHelper.calculate( context, /* referenceVideoPath= */ checkNotNull(mediaItem.localConfiguration).uri.toString(), - outputVideoFile.getPath()); + /* distortedVideoPath= */ outputVideoFile.getPath()); resultBuilder.setSsim(ssim); } catch (InterruptedException interruptedException) { // InterruptedException is a special unexpected case because it is not related to Ssim