Handle int instead of byte in SSIM.

The value of pixels are converted to integers at the point of use,
move this logic to the initialisation step.

This is a prerequisite step for testing SSIM calculation, which
will lead on to some SSIM improvements being verifiable.

Tested manually and SSIM values match for the same video
before and after this change.

PiperOrigin-RevId: 473231779
This commit is contained in:
samrobinson 2022-09-09 13:07:14 +00:00 committed by Marc Baechinger
parent 4133bb6070
commit 3d5ddf0c42

View File

@ -80,8 +80,8 @@ public final class SsimHelper {
new VideoDecodingWrapper(context, referenceVideoPath, DEFAULT_COMPARISON_INTERVAL); new VideoDecodingWrapper(context, referenceVideoPath, DEFAULT_COMPARISON_INTERVAL);
VideoDecodingWrapper distortedDecodingWrapper = VideoDecodingWrapper distortedDecodingWrapper =
new VideoDecodingWrapper(context, distortedVideoPath, DEFAULT_COMPARISON_INTERVAL); new VideoDecodingWrapper(context, distortedVideoPath, DEFAULT_COMPARISON_INTERVAL);
@Nullable byte[] referenceLumaBuffer = null; @Nullable int[] referenceLumaBuffer = null;
@Nullable byte[] distortedLumaBuffer = null; @Nullable int[] distortedLumaBuffer = null;
double accumulatedSsim = 0.0; double accumulatedSsim = 0.0;
int comparedImagesCount = 0; int comparedImagesCount = 0;
try { try {
@ -101,10 +101,10 @@ public final class SsimHelper {
assertThat(distortedImage.getHeight()).isEqualTo(height); assertThat(distortedImage.getHeight()).isEqualTo(height);
if (referenceLumaBuffer == null || referenceLumaBuffer.length != width * height) { if (referenceLumaBuffer == null || referenceLumaBuffer.length != width * height) {
referenceLumaBuffer = new byte[width * height]; referenceLumaBuffer = new int[width * height];
} }
if (distortedLumaBuffer == null || distortedLumaBuffer.length != width * height) { if (distortedLumaBuffer == null || distortedLumaBuffer.length != width * height) {
distortedLumaBuffer = new byte[width * height]; distortedLumaBuffer = new int[width * height];
} }
try { try {
accumulatedSsim += accumulatedSsim +=
@ -134,7 +134,7 @@ public final class SsimHelper {
* @param lumaChannelBuffer The buffer where the extracted luma values are stored. * @param lumaChannelBuffer The buffer where the extracted luma values are stored.
* @return The {@code lumaChannelBuffer} for convenience. * @return The {@code lumaChannelBuffer} for convenience.
*/ */
private static byte[] extractLumaChannelBuffer(Image image, byte[] lumaChannelBuffer) { private static int[] extractLumaChannelBuffer(Image image, int[] lumaChannelBuffer) {
// This method is invoked on the main thread. // This method is invoked on the main thread.
// `image` should contain YUV channels. // `image` should contain YUV channels.
Image.Plane[] imagePlanes = image.getPlanes(); Image.Plane[] imagePlanes = image.getPlanes();
@ -147,7 +147,8 @@ public final class SsimHelper {
ByteBuffer lumaByteBuffer = lumaPlane.getBuffer(); ByteBuffer lumaByteBuffer = lumaPlane.getBuffer();
for (int y = 0; y < height; y++) { for (int y = 0; y < height; y++) {
for (int x = 0; x < width; x++) { for (int x = 0; x < width; x++) {
lumaChannelBuffer[y * width + x] = lumaByteBuffer.get(y * rowStride + x * pixelStride); lumaChannelBuffer[y * width + x] =
lumaByteBuffer.get(y * rowStride + x * pixelStride) & 0xFF;
} }
} }
return lumaChannelBuffer; return lumaChannelBuffer;
@ -363,7 +364,7 @@ public final class SsimHelper {
* @return The MSSIM score between the input images. * @return The MSSIM score between the input images.
*/ */
public static double calculate( public static double calculate(
byte[] referenceBuffer, byte[] distortedBuffer, int width, int height) { int[] referenceBuffer, int[] distortedBuffer, int width, int height) {
double totalSsim = 0; double totalSsim = 0;
int windowsCount = 0; int windowsCount = 0;
@ -450,11 +451,11 @@ public final class SsimHelper {
/** Returns the mean of the pixels in the window. */ /** Returns the mean of the pixels in the window. */
private static double getMean( private static double getMean(
byte[] pixelBuffer, int bufferIndexOffset, int stride, int windowWidth, int windowHeight) { int[] pixelBuffer, int bufferIndexOffset, int stride, int windowWidth, int windowHeight) {
double total = 0; double total = 0;
for (int y = 0; y < windowHeight; y++) { for (int y = 0; y < windowHeight; y++) {
for (int x = 0; x < windowWidth; x++) { for (int x = 0; x < windowWidth; x++) {
total += pixelBuffer[get1dIndex(x, y, stride, bufferIndexOffset)] & 0xFF; total += pixelBuffer[get1dIndex(x, y, stride, bufferIndexOffset)];
} }
} }
return total / (windowWidth * windowHeight); return total / (windowWidth * windowHeight);
@ -462,8 +463,8 @@ public final class SsimHelper {
/** Calculates the variances and covariance of the pixels in the window for both buffers. */ /** Calculates the variances and covariance of the pixels in the window for both buffers. */
private static double[] getVariancesAndCovariance( private static double[] getVariancesAndCovariance(
byte[] referenceBuffer, int[] referenceBuffer,
byte[] distortedBuffer, int[] distortedBuffer,
double referenceMean, double referenceMean,
double distortedMean, double distortedMean,
int bufferIndexOffset, int bufferIndexOffset,
@ -476,8 +477,8 @@ public final class SsimHelper {
for (int y = 0; y < windowHeight; y++) { for (int y = 0; y < windowHeight; y++) {
for (int x = 0; x < windowWidth; x++) { for (int x = 0; x < windowWidth; x++) {
int index = get1dIndex(x, y, stride, bufferIndexOffset); int index = get1dIndex(x, y, stride, bufferIndexOffset);
double referencePixelDeviation = (referenceBuffer[index] & 0xFF) - referenceMean; double referencePixelDeviation = referenceBuffer[index] - referenceMean;
double distortedPixelDeviation = (distortedBuffer[index] & 0xFF) - distortedMean; double distortedPixelDeviation = distortedBuffer[index] - distortedMean;
referenceVariance += referencePixelDeviation * referencePixelDeviation; referenceVariance += referencePixelDeviation * referencePixelDeviation;
distortedVariance += distortedPixelDeviation * distortedPixelDeviation; distortedVariance += distortedPixelDeviation * distortedPixelDeviation;
referenceDistortedCovariance += referencePixelDeviation * distortedPixelDeviation; referenceDistortedCovariance += referencePixelDeviation * distortedPixelDeviation;