`decode_gif` has never handled animated GIFs properly (ever since it was first written). Apply GIF transparency data while decoding so that all frame in an animated GIF are correctly decoded.

PiperOrigin-RevId: 349259384
Change-Id: I147390cdedf2a4c701acfe832f715355584c592b
This commit is contained in:
Hye Soo Yang 2020-12-28 02:49:10 -08:00 committed by TensorFlower Gardener
parent ed66f33d00
commit 3741a7bda7
9 changed files with 260 additions and 13 deletions

View File

@ -5,6 +5,7 @@ load(
"//tensorflow:tensorflow.bzl",
"if_android",
"if_mobile",
"tf_cc_test",
"tf_copts",
)
load(
@ -64,3 +65,18 @@ cc_library(
"@com_google_absl//absl/strings",
],
)
tf_cc_test(
name = "lib_gif_io_test",
srcs = ["gif_io_test.cc"],
data = ["//tensorflow/core/lib/gif/testdata:gif_testdata"],
deps = [
":gif_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/lib/png:png_io",
"@com_google_absl//absl/base",
],
)

View File

@ -111,11 +111,31 @@ uint8* Decode(const void* srcdata, int datasize,
SavedImage* this_image = &gif_file->SavedImages[k];
GifImageDesc* img_desc = &this_image->ImageDesc;
// The Graphics Control Block tells us which index in the color map
// correspond to "transparent color", i.e. no need to update the pixel
// on the canvas. The "transparent color index" is specific to each
// sub-frame.
GraphicsControlBlock gcb;
DGifSavedExtensionToGCB(gif_file, k, &gcb);
int imgLeft = img_desc->Left;
int imgTop = img_desc->Top;
int imgRight = img_desc->Left + img_desc->Width;
int imgBottom = img_desc->Top + img_desc->Height;
if (k > 0) {
uint8* last_dst = dstdata + (k - 1) * width * channel * height;
for (int i = 0; i < height; ++i) {
uint8* p_dst = this_dst + i * width * channel;
uint8* l_dst = last_dst + i * width * channel;
for (int j = 0; j < width; ++j) {
p_dst[j * channel + 0] = l_dst[j * channel + 0];
p_dst[j * channel + 1] = l_dst[j * channel + 1];
p_dst[j * channel + 2] = l_dst[j * channel + 2];
}
}
}
if (img_desc->Left != 0 || img_desc->Top != 0 || img_desc->Width != width ||
img_desc->Height != height) {
// If the first frame does not fill the entire canvas then fill the
@ -129,19 +149,6 @@ uint8* Decode(const void* srcdata, int datasize,
p_dst[j * channel + 2] = 0;
}
}
} else {
// Otherwise previous frame will be reused to fill the unoccupied
// canvas.
uint8* last_dst = dstdata + (k - 1) * width * channel * height;
for (int i = 0; i < height; ++i) {
uint8* p_dst = this_dst + i * width * channel;
uint8* l_dst = last_dst + i * width * channel;
for (int j = 0; j < width; ++j) {
p_dst[j * channel + 0] = l_dst[j * channel + 0];
p_dst[j * channel + 1] = l_dst[j * channel + 1];
p_dst[j * channel + 2] = l_dst[j * channel + 2];
}
}
}
imgLeft = std::max(imgLeft, 0);
@ -172,6 +179,12 @@ uint8* Decode(const void* srcdata, int datasize,
return nullptr;
}
if (color_index == gcb.TransparentColor) {
// Use the pixel from the previous frame. In other words, no need to
// update our canvas for this pixel.
continue;
}
const GifColorType& gif_color = color_map->Colors[color_index];
p_dst[j * channel + 0] = gif_color.Red;
p_dst[j * channel + 1] = gif_color.Green;

View File

@ -0,0 +1,192 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/lib/gif/gif_io.h"
#include "tensorflow/core/lib/png/png_io.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace gif {
namespace {
const char kTestData[] = "tensorflow/core/lib/gif/testdata/";
struct DecodeGifTestCase {
const string filepath;
const int num_frames;
const int width;
const int height;
const int channels;
};
void ReadFileToStringOrDie(Env* env, const string& filename, string* output) {
TF_CHECK_OK(ReadFileToString(env, filename, output));
}
void TestDecodeGif(Env* env, DecodeGifTestCase testcase) {
string gif;
ReadFileToStringOrDie(env, testcase.filepath, &gif);
// Decode gif image data.
std::unique_ptr<uint8[]> imgdata;
int nframes, w, h, c;
string error_string;
imgdata.reset(gif::Decode(
gif.data(), gif.size(),
[&](int frame_cnt, int width, int height, int channels) -> uint8* {
nframes = frame_cnt;
w = width;
h = height;
c = channels;
return new uint8[frame_cnt * height * width * channels];
},
&error_string));
ASSERT_NE(imgdata, nullptr);
// Make sure the decoded information matches the ground-truth image info.
ASSERT_EQ(nframes, testcase.num_frames);
ASSERT_EQ(w, testcase.width);
ASSERT_EQ(h, testcase.height);
ASSERT_EQ(c, testcase.channels);
}
TEST(GifTest, Gif) {
Env* env = Env::Default();
const string testdata_path = kTestData;
std::vector<DecodeGifTestCase> testcases(
{// file_path, num_of_channels, width, height, channels
{testdata_path + "lena.gif", 1, 51, 26, 3},
{testdata_path + "optimized.gif", 12, 20, 40, 3},
{testdata_path + "red_black.gif", 1, 16, 16, 3},
{testdata_path + "scan.gif", 12, 20, 40, 3},
{testdata_path + "squares.gif", 2, 16, 16, 3}});
for (const auto& tc : testcases) {
TestDecodeGif(env, tc);
}
}
void TestDecodeAnimatedGif(Env* env, const uint8* gif_data,
const string& png_filepath, int frame_idx) {
string png; // ground-truth
ReadFileToStringOrDie(env, png_filepath, &png);
// Compare decoded gif to ground-truth image frames in png format.
png::DecodeContext decode;
png::CommonInitDecode(png, 3, 8, &decode);
const int width = static_cast<int>(decode.width);
const int height = static_cast<int>(decode.height);
std::unique_ptr<uint8[]> png_imgdata(
new uint8[height * width * decode.channels]);
png::CommonFinishDecode(reinterpret_cast<png_bytep>(png_imgdata.get()),
decode.channels * width * sizeof(uint8), &decode);
int frame_len = width * height * decode.channels;
int gif_idx = frame_len * frame_idx;
for (int i = 0; i < frame_len; i++) {
ASSERT_EQ(gif_data[gif_idx + i], png_imgdata[i]);
}
}
TEST(GifTest, AnimatedGif) {
Env* env = Env::Default();
const string testdata_path = kTestData;
// Read animated gif file once.
string gif;
ReadFileToStringOrDie(env, testdata_path + "pendulum_sm.gif", &gif);
std::unique_ptr<uint8[]> gif_imgdata;
int nframes, w, h, c;
string error_string;
gif_imgdata.reset(gif::Decode(
gif.data(), gif.size(),
[&](int num_frames, int width, int height, int channels) -> uint8* {
nframes = num_frames;
w = width;
h = height;
c = channels;
return new uint8[num_frames * height * width * channels];
},
&error_string));
TestDecodeAnimatedGif(env, gif_imgdata.get(),
testdata_path + "pendulum_sm_frame0.png", 0);
TestDecodeAnimatedGif(env, gif_imgdata.get(),
testdata_path + "pendulum_sm_frame1.png", 1);
TestDecodeAnimatedGif(env, gif_imgdata.get(),
testdata_path + "pendulum_sm_frame2.png", 2);
}
void TestExpandAnimations(Env* env, const string& filepath) {
string gif;
ReadFileToStringOrDie(env, filepath, &gif);
std::unique_ptr<uint8[]> imgdata;
string error_string;
int nframes;
// `expand_animations` is set to true by default. Set to false.
bool expand_animations = false;
imgdata.reset(gif::Decode(
gif.data(), gif.size(),
[&](int frame_cnt, int width, int height, int channels) -> uint8* {
nframes = frame_cnt;
return new uint8[frame_cnt * height * width * channels];
},
&error_string, expand_animations));
// Check that only 1 frame is being decoded.
ASSERT_EQ(nframes, 1);
}
TEST(GifTest, ExpandAnimations) {
Env* env = Env::Default();
const string testdata_path = kTestData;
// Test all animated gif test images.
TestExpandAnimations(env, testdata_path + "scan.gif");
TestExpandAnimations(env, testdata_path + "pendulum_sm.gif");
TestExpandAnimations(env, testdata_path + "squares.gif");
}
void TestInvalidGifFormat(const string& header_bytes) {
std::unique_ptr<uint8[]> imgdata;
string error_string;
int nframes;
imgdata.reset(gif::Decode(
header_bytes.data(), header_bytes.size(),
[&](int frame_cnt, int width, int height, int channels) -> uint8* {
nframes = frame_cnt;
return new uint8[frame_cnt * height * width * channels];
},
&error_string));
// Check that decoding image formats other than gif throws an error.
string err_msg = "failed to open gif file";
ASSERT_EQ(error_string.substr(0, 23), err_msg);
}
TEST(GifTest, BadGif) {
// Input header bytes of other image formats to gif decoder.
TestInvalidGifFormat("\x89\x50\x4E\x47\x0D\x0A\x1A\x0A"); // png
TestInvalidGifFormat("\x42\x4d"); // bmp
TestInvalidGifFormat("\xff\xd8\xff"); // jpeg
TestInvalidGifFormat("\x49\x49\x2A\x00"); // tiff
}
} // namespace
} // namespace gif
} // namespace tensorflow

View File

@ -15,6 +15,12 @@ filegroup(
"scan.gif",
"red_black.gif",
"squares.gif",
"pendulum_sm.gif",
# Add groundtruth frames for `pendulum_sm.gif`.
# PNG format because it's lossless.
"pendulum_sm_frame0.png",
"pendulum_sm_frame1.png",
"pendulum_sm_frame2.png",
# GIF data with optimization
"optimized.gif",
],

Binary file not shown.

After

Width:  |  Height:  |  Size: 17 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.7 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.7 KiB

View File

@ -4533,6 +4533,26 @@ class GifTest(test_util.TensorFlowTestCase):
image = image_ops.decode_gif(gif)
self.assertEqual(image.get_shape().as_list(), [None, None, None, 3])
def testAnimatedGif(self):
# Test if all frames in the animated GIF file is properly decoded.
with self.cached_session(use_gpu=True):
base = "tensorflow/core/lib/gif/testdata"
gif = io_ops.read_file(os.path.join(base, "pendulum_sm.gif"))
gt_frame0 = io_ops.read_file(os.path.join(base, "pendulum_sm_frame0.png"))
gt_frame1 = io_ops.read_file(os.path.join(base, "pendulum_sm_frame1.png"))
gt_frame2 = io_ops.read_file(os.path.join(base, "pendulum_sm_frame2.png"))
image = image_ops.decode_gif(gif)
frame0 = image_ops.decode_png(gt_frame0)
frame1 = image_ops.decode_png(gt_frame1)
frame2 = image_ops.decode_png(gt_frame2)
image, frame0, frame1, frame2 = self.evaluate([image, frame0, frame1,
frame2])
# Compare decoded gif frames with ground-truth data.
self.assertAllEqual(image[0], frame0)
self.assertAllEqual(image[1], frame1)
self.assertAllEqual(image[2], frame2)
class ConvertImageTest(test_util.TensorFlowTestCase):