diff --git a/tensorflow/core/lib/gif/BUILD b/tensorflow/core/lib/gif/BUILD index 49ada18e31f..7a52fcb2411 100644 --- a/tensorflow/core/lib/gif/BUILD +++ b/tensorflow/core/lib/gif/BUILD @@ -5,6 +5,7 @@ load( "//tensorflow:tensorflow.bzl", "if_android", "if_mobile", + "tf_cc_test", "tf_copts", ) load( @@ -64,3 +65,18 @@ cc_library( "@com_google_absl//absl/strings", ], ) + +tf_cc_test( + name = "lib_gif_io_test", + srcs = ["gif_io_test.cc"], + data = ["//tensorflow/core/lib/gif/testdata:gif_testdata"], + deps = [ + ":gif_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/lib/png:png_io", + "@com_google_absl//absl/base", + ], +) diff --git a/tensorflow/core/lib/gif/gif_io.cc b/tensorflow/core/lib/gif/gif_io.cc index 5fb47043654..ba4aa1156db 100644 --- a/tensorflow/core/lib/gif/gif_io.cc +++ b/tensorflow/core/lib/gif/gif_io.cc @@ -111,11 +111,31 @@ uint8* Decode(const void* srcdata, int datasize, SavedImage* this_image = &gif_file->SavedImages[k]; GifImageDesc* img_desc = &this_image->ImageDesc; + // The Graphics Control Block tells us which index in the color map + // correspond to "transparent color", i.e. no need to update the pixel + // on the canvas. The "transparent color index" is specific to each + // sub-frame. + GraphicsControlBlock gcb; + DGifSavedExtensionToGCB(gif_file, k, &gcb); + int imgLeft = img_desc->Left; int imgTop = img_desc->Top; int imgRight = img_desc->Left + img_desc->Width; int imgBottom = img_desc->Top + img_desc->Height; + if (k > 0) { + uint8* last_dst = dstdata + (k - 1) * width * channel * height; + for (int i = 0; i < height; ++i) { + uint8* p_dst = this_dst + i * width * channel; + uint8* l_dst = last_dst + i * width * channel; + for (int j = 0; j < width; ++j) { + p_dst[j * channel + 0] = l_dst[j * channel + 0]; + p_dst[j * channel + 1] = l_dst[j * channel + 1]; + p_dst[j * channel + 2] = l_dst[j * channel + 2]; + } + } + } + if (img_desc->Left != 0 || img_desc->Top != 0 || img_desc->Width != width || img_desc->Height != height) { // If the first frame does not fill the entire canvas then fill the @@ -129,19 +149,6 @@ uint8* Decode(const void* srcdata, int datasize, p_dst[j * channel + 2] = 0; } } - } else { - // Otherwise previous frame will be reused to fill the unoccupied - // canvas. - uint8* last_dst = dstdata + (k - 1) * width * channel * height; - for (int i = 0; i < height; ++i) { - uint8* p_dst = this_dst + i * width * channel; - uint8* l_dst = last_dst + i * width * channel; - for (int j = 0; j < width; ++j) { - p_dst[j * channel + 0] = l_dst[j * channel + 0]; - p_dst[j * channel + 1] = l_dst[j * channel + 1]; - p_dst[j * channel + 2] = l_dst[j * channel + 2]; - } - } } imgLeft = std::max(imgLeft, 0); @@ -172,6 +179,12 @@ uint8* Decode(const void* srcdata, int datasize, return nullptr; } + if (color_index == gcb.TransparentColor) { + // Use the pixel from the previous frame. In other words, no need to + // update our canvas for this pixel. + continue; + } + const GifColorType& gif_color = color_map->Colors[color_index]; p_dst[j * channel + 0] = gif_color.Red; p_dst[j * channel + 1] = gif_color.Green; diff --git a/tensorflow/core/lib/gif/gif_io_test.cc b/tensorflow/core/lib/gif/gif_io_test.cc new file mode 100644 index 00000000000..38c18191169 --- /dev/null +++ b/tensorflow/core/lib/gif/gif_io_test.cc @@ -0,0 +1,192 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/lib/gif/gif_io.h" + +#include "tensorflow/core/lib/png/png_io.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace gif { +namespace { + +const char kTestData[] = "tensorflow/core/lib/gif/testdata/"; + +struct DecodeGifTestCase { + const string filepath; + const int num_frames; + const int width; + const int height; + const int channels; +}; + +void ReadFileToStringOrDie(Env* env, const string& filename, string* output) { + TF_CHECK_OK(ReadFileToString(env, filename, output)); +} + +void TestDecodeGif(Env* env, DecodeGifTestCase testcase) { + string gif; + ReadFileToStringOrDie(env, testcase.filepath, &gif); + + // Decode gif image data. + std::unique_ptr imgdata; + int nframes, w, h, c; + string error_string; + imgdata.reset(gif::Decode( + gif.data(), gif.size(), + [&](int frame_cnt, int width, int height, int channels) -> uint8* { + nframes = frame_cnt; + w = width; + h = height; + c = channels; + return new uint8[frame_cnt * height * width * channels]; + }, + &error_string)); + ASSERT_NE(imgdata, nullptr); + // Make sure the decoded information matches the ground-truth image info. + ASSERT_EQ(nframes, testcase.num_frames); + ASSERT_EQ(w, testcase.width); + ASSERT_EQ(h, testcase.height); + ASSERT_EQ(c, testcase.channels); +} + +TEST(GifTest, Gif) { + Env* env = Env::Default(); + const string testdata_path = kTestData; + std::vector testcases( + {// file_path, num_of_channels, width, height, channels + {testdata_path + "lena.gif", 1, 51, 26, 3}, + {testdata_path + "optimized.gif", 12, 20, 40, 3}, + {testdata_path + "red_black.gif", 1, 16, 16, 3}, + {testdata_path + "scan.gif", 12, 20, 40, 3}, + {testdata_path + "squares.gif", 2, 16, 16, 3}}); + + for (const auto& tc : testcases) { + TestDecodeGif(env, tc); + } +} + +void TestDecodeAnimatedGif(Env* env, const uint8* gif_data, + const string& png_filepath, int frame_idx) { + string png; // ground-truth + ReadFileToStringOrDie(env, png_filepath, &png); + + // Compare decoded gif to ground-truth image frames in png format. + png::DecodeContext decode; + png::CommonInitDecode(png, 3, 8, &decode); + const int width = static_cast(decode.width); + const int height = static_cast(decode.height); + std::unique_ptr png_imgdata( + new uint8[height * width * decode.channels]); + png::CommonFinishDecode(reinterpret_cast(png_imgdata.get()), + decode.channels * width * sizeof(uint8), &decode); + + int frame_len = width * height * decode.channels; + int gif_idx = frame_len * frame_idx; + for (int i = 0; i < frame_len; i++) { + ASSERT_EQ(gif_data[gif_idx + i], png_imgdata[i]); + } +} + +TEST(GifTest, AnimatedGif) { + Env* env = Env::Default(); + const string testdata_path = kTestData; + + // Read animated gif file once. + string gif; + ReadFileToStringOrDie(env, testdata_path + "pendulum_sm.gif", &gif); + + std::unique_ptr gif_imgdata; + int nframes, w, h, c; + string error_string; + gif_imgdata.reset(gif::Decode( + gif.data(), gif.size(), + [&](int num_frames, int width, int height, int channels) -> uint8* { + nframes = num_frames; + w = width; + h = height; + c = channels; + return new uint8[num_frames * height * width * channels]; + }, + &error_string)); + + TestDecodeAnimatedGif(env, gif_imgdata.get(), + testdata_path + "pendulum_sm_frame0.png", 0); + TestDecodeAnimatedGif(env, gif_imgdata.get(), + testdata_path + "pendulum_sm_frame1.png", 1); + TestDecodeAnimatedGif(env, gif_imgdata.get(), + testdata_path + "pendulum_sm_frame2.png", 2); +} + +void TestExpandAnimations(Env* env, const string& filepath) { + string gif; + ReadFileToStringOrDie(env, filepath, &gif); + + std::unique_ptr imgdata; + string error_string; + int nframes; + // `expand_animations` is set to true by default. Set to false. + bool expand_animations = false; + imgdata.reset(gif::Decode( + gif.data(), gif.size(), + [&](int frame_cnt, int width, int height, int channels) -> uint8* { + nframes = frame_cnt; + return new uint8[frame_cnt * height * width * channels]; + }, + &error_string, expand_animations)); + + // Check that only 1 frame is being decoded. + ASSERT_EQ(nframes, 1); +} + +TEST(GifTest, ExpandAnimations) { + Env* env = Env::Default(); + const string testdata_path = kTestData; + + // Test all animated gif test images. + TestExpandAnimations(env, testdata_path + "scan.gif"); + TestExpandAnimations(env, testdata_path + "pendulum_sm.gif"); + TestExpandAnimations(env, testdata_path + "squares.gif"); +} + +void TestInvalidGifFormat(const string& header_bytes) { + std::unique_ptr imgdata; + string error_string; + int nframes; + imgdata.reset(gif::Decode( + header_bytes.data(), header_bytes.size(), + [&](int frame_cnt, int width, int height, int channels) -> uint8* { + nframes = frame_cnt; + return new uint8[frame_cnt * height * width * channels]; + }, + &error_string)); + + // Check that decoding image formats other than gif throws an error. + string err_msg = "failed to open gif file"; + ASSERT_EQ(error_string.substr(0, 23), err_msg); +} + +TEST(GifTest, BadGif) { + // Input header bytes of other image formats to gif decoder. + TestInvalidGifFormat("\x89\x50\x4E\x47\x0D\x0A\x1A\x0A"); // png + TestInvalidGifFormat("\x42\x4d"); // bmp + TestInvalidGifFormat("\xff\xd8\xff"); // jpeg + TestInvalidGifFormat("\x49\x49\x2A\x00"); // tiff +} + +} // namespace +} // namespace gif +} // namespace tensorflow diff --git a/tensorflow/core/lib/gif/testdata/BUILD b/tensorflow/core/lib/gif/testdata/BUILD index ff7d9f7a58c..b7169510c9d 100644 --- a/tensorflow/core/lib/gif/testdata/BUILD +++ b/tensorflow/core/lib/gif/testdata/BUILD @@ -15,6 +15,12 @@ filegroup( "scan.gif", "red_black.gif", "squares.gif", + "pendulum_sm.gif", + # Add groundtruth frames for `pendulum_sm.gif`. + # PNG format because it's lossless. + "pendulum_sm_frame0.png", + "pendulum_sm_frame1.png", + "pendulum_sm_frame2.png", # GIF data with optimization "optimized.gif", ], diff --git a/tensorflow/core/lib/gif/testdata/pendulum_sm.gif b/tensorflow/core/lib/gif/testdata/pendulum_sm.gif new file mode 100644 index 00000000000..4edb67dc528 Binary files /dev/null and b/tensorflow/core/lib/gif/testdata/pendulum_sm.gif differ diff --git a/tensorflow/core/lib/gif/testdata/pendulum_sm_frame0.png b/tensorflow/core/lib/gif/testdata/pendulum_sm_frame0.png new file mode 100644 index 00000000000..43a0a15f253 Binary files /dev/null and b/tensorflow/core/lib/gif/testdata/pendulum_sm_frame0.png differ diff --git a/tensorflow/core/lib/gif/testdata/pendulum_sm_frame1.png b/tensorflow/core/lib/gif/testdata/pendulum_sm_frame1.png new file mode 100644 index 00000000000..47c42dc2db6 Binary files /dev/null and b/tensorflow/core/lib/gif/testdata/pendulum_sm_frame1.png differ diff --git a/tensorflow/core/lib/gif/testdata/pendulum_sm_frame2.png b/tensorflow/core/lib/gif/testdata/pendulum_sm_frame2.png new file mode 100644 index 00000000000..7f7607c2d17 Binary files /dev/null and b/tensorflow/core/lib/gif/testdata/pendulum_sm_frame2.png differ diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py index 0e871949dba..653272e7015 100644 --- a/tensorflow/python/ops/image_ops_test.py +++ b/tensorflow/python/ops/image_ops_test.py @@ -4533,6 +4533,26 @@ class GifTest(test_util.TensorFlowTestCase): image = image_ops.decode_gif(gif) self.assertEqual(image.get_shape().as_list(), [None, None, None, 3]) + def testAnimatedGif(self): + # Test if all frames in the animated GIF file is properly decoded. + with self.cached_session(use_gpu=True): + base = "tensorflow/core/lib/gif/testdata" + gif = io_ops.read_file(os.path.join(base, "pendulum_sm.gif")) + gt_frame0 = io_ops.read_file(os.path.join(base, "pendulum_sm_frame0.png")) + gt_frame1 = io_ops.read_file(os.path.join(base, "pendulum_sm_frame1.png")) + gt_frame2 = io_ops.read_file(os.path.join(base, "pendulum_sm_frame2.png")) + + image = image_ops.decode_gif(gif) + frame0 = image_ops.decode_png(gt_frame0) + frame1 = image_ops.decode_png(gt_frame1) + frame2 = image_ops.decode_png(gt_frame2) + image, frame0, frame1, frame2 = self.evaluate([image, frame0, frame1, + frame2]) + # Compare decoded gif frames with ground-truth data. + self.assertAllEqual(image[0], frame0) + self.assertAllEqual(image[1], frame1) + self.assertAllEqual(image[2], frame2) + class ConvertImageTest(test_util.TensorFlowTestCase):