Preallocate buffer and use byte for Luminance in SSIM.

PiperOrigin-RevId: 439855702
This commit is contained in:
claincly 2022-04-06 17:08:40 +01:00 committed by Ian Baker
parent 7bd650b315
commit b5eba24e1f

View File

@ -50,6 +50,9 @@ import java.nio.ByteBuffer;
* images like MSE (mean squared error), but rather outputs the human perceptual difference. A * images like MSE (mean squared error), but rather outputs the human perceptual difference. A
* higher SSIM score signifies higher similarity, while a SSIM score of 1 means the two images are * higher SSIM score signifies higher similarity, while a SSIM score of 1 means the two images are
* exactly the same. * exactly the same.
*
* <p>SSIM is traditionally computed with the luminance channel (Y), this class uses the luma
* channel (Y') because the {@linkplain MediaCodec decoder} decodes to luma.
*/ */
public final class SsimHelper { public final class SsimHelper {
@ -76,6 +79,8 @@ public final class SsimHelper {
new VideoDecodingWrapper(context, expectedVideoPath, DEFAULT_COMPARISON_INTERVAL); new VideoDecodingWrapper(context, expectedVideoPath, DEFAULT_COMPARISON_INTERVAL);
VideoDecodingWrapper actualDecodingWrapper = VideoDecodingWrapper actualDecodingWrapper =
new VideoDecodingWrapper(context, actualVideoPath, DEFAULT_COMPARISON_INTERVAL); new VideoDecodingWrapper(context, actualVideoPath, DEFAULT_COMPARISON_INTERVAL);
@Nullable byte[] expectedLumaBuffer = null;
@Nullable byte[] actualLumaBuffer = null;
double accumulatedSsim = 0.0; double accumulatedSsim = 0.0;
int comparedImagesCount = 0; int comparedImagesCount = 0;
try { try {
@ -90,13 +95,21 @@ public final class SsimHelper {
int width = expectedImage.getWidth(); int width = expectedImage.getWidth();
int height = expectedImage.getHeight(); int height = expectedImage.getHeight();
assertThat(actualImage.getWidth()).isEqualTo(width); assertThat(actualImage.getWidth()).isEqualTo(width);
assertThat(actualImage.getHeight()).isEqualTo(height); assertThat(actualImage.getHeight()).isEqualTo(height);
if (expectedLumaBuffer == null || expectedLumaBuffer.length != width * height) {
expectedLumaBuffer = new byte[width * height];
}
if (actualLumaBuffer == null || actualLumaBuffer.length != width * height) {
actualLumaBuffer = new byte[width * height];
}
try { try {
accumulatedSsim += accumulatedSsim +=
SsimCalculator.calculate( SsimCalculator.calculate(
extractLumaChannelBuffer(expectedImage), extractLumaChannelBuffer(expectedImage, expectedLumaBuffer),
extractLumaChannelBuffer(actualImage), extractLumaChannelBuffer(actualImage, actualLumaBuffer),
/* offset= */ 0, /* offset= */ 0,
/* stride= */ width, /* stride= */ width,
width, width,
@ -116,11 +129,13 @@ public final class SsimHelper {
} }
/** /**
* Returns the buffer of the luma (Y) channel of the image. * Extracts, sets and returns the buffer of the luma (Y') channel of the image.
* *
* @param image The {@link Image} in YUV format. * @param image The {@link Image} in YUV format.
* @param lumaChannelBuffer The buffer where the extracted luma values are stored.
* @return The {@code lumaChannelBuffer} for convenience.
*/ */
private static int[] extractLumaChannelBuffer(Image image) { private static byte[] extractLumaChannelBuffer(Image image, byte[] lumaChannelBuffer) {
// This method is invoked on the main thread. // This method is invoked on the main thread.
// `image` should contain YUV channels. // `image` should contain YUV channels.
Image.Plane[] imagePlanes = image.getPlanes(); Image.Plane[] imagePlanes = image.getPlanes();
@ -131,7 +146,6 @@ public final class SsimHelper {
int width = image.getWidth(); int width = image.getWidth();
int height = image.getHeight(); int height = image.getHeight();
ByteBuffer lumaByteBuffer = lumaPlane.getBuffer(); ByteBuffer lumaByteBuffer = lumaPlane.getBuffer();
int[] lumaChannelBuffer = new int[width * height];
for (int y = 0; y < height; y++) { for (int y = 0; y < height; y++) {
for (int x = 0; x < width; x++) { for (int x = 0; x < width; x++) {
lumaChannelBuffer[y * width + x] = lumaByteBuffer.get(y * rowStride + x * pixelStride); lumaChannelBuffer[y * width + x] = lumaByteBuffer.get(y * rowStride + x * pixelStride);
@ -148,7 +162,7 @@ public final class SsimHelper {
// Use ExoPlayer's 10ms timeout setting. In practise, the test durations from using timeouts of // Use ExoPlayer's 10ms timeout setting. In practise, the test durations from using timeouts of
// 1/10/100ms don't differ significantly. // 1/10/100ms don't differ significantly.
private static final long DEQUEUE_TIMEOUT_US = 10_000; private static final long DEQUEUE_TIMEOUT_US = 10_000;
// SSIM should be calculated using the luma (Y) channel, thus using the YUV color space. // SSIM should be calculated using the luma (Y') channel, thus using the YUV color space.
private static final int IMAGE_READER_COLOR_SPACE = ImageFormat.YUV_420_888; private static final int IMAGE_READER_COLOR_SPACE = ImageFormat.YUV_420_888;
private static final int MEDIA_CODEC_COLOR_SPACE = private static final int MEDIA_CODEC_COLOR_SPACE =
MediaCodecInfo.CodecCapabilities.COLOR_FormatYUV420Flexible; MediaCodecInfo.CodecCapabilities.COLOR_FormatYUV420Flexible;
@ -333,8 +347,8 @@ public final class SsimHelper {
/** /**
* Calculates the Structural Similarity Index (SSIM) between two images. * Calculates the Structural Similarity Index (SSIM) between two images.
* *
* @param expected The luminance channel (Y) bitmap of the expected image. * @param expected The luma channel (Y) bitmap of the expected image.
* @param actual The luminance channel (Y) bitmap of the actual image. * @param actual The luma channel (Y) bitmap of the actual image.
* @param offset The offset. * @param offset The offset.
* @param stride The stride of the bitmap. * @param stride The stride of the bitmap.
* @param width The image width in pixels. * @param width The image width in pixels.
@ -342,7 +356,7 @@ public final class SsimHelper {
* @return The SSIM score between the input images. * @return The SSIM score between the input images.
*/ */
public static double calculate( public static double calculate(
int[] expected, int[] actual, int offset, int stride, int width, int height) { byte[] expected, byte[] actual, int offset, int stride, int width, int height) {
double totalSsim = 0; double totalSsim = 0;
int windowsCount = 0; int windowsCount = 0;
@ -398,7 +412,7 @@ public final class SsimHelper {
/** Returns the means of the pixels in the two provided windows, in order. */ /** Returns the means of the pixels in the two provided windows, in order. */
private static double getMean( private static double getMean(
int[] pixels, int start, int stride, int windowWidth, int windowHeight) { byte[] pixels, int start, int stride, int windowWidth, int windowHeight) {
double total = 0; double total = 0;
for (int y = 0; y < windowHeight; y++) { for (int y = 0; y < windowHeight; y++) {
for (int x = 0; x < windowWidth; x++) { for (int x = 0; x < windowWidth; x++) {
@ -410,8 +424,8 @@ public final class SsimHelper {
/** Returns the two variances and the covariance of the two windows. */ /** Returns the two variances and the covariance of the two windows. */
private static double[] getVariancesAndCovariance( private static double[] getVariancesAndCovariance(
int[] pixelsX, byte[] pixelsX,
int[] pixelsY, byte[] pixelsY,
double meanX, double meanX,
double meanY, double meanY,
int start, int start,