From e4d6335bcb7a73cd8967c2c12339380aa1ae284f Mon Sep 17 00:00:00 2001
From: Hye Soo Yang <hyey@google.com>
Date: Tue, 23 Jun 2020 10:21:31 -0700
Subject: [PATCH] Refactor DecodeImageOp for the purpose of removing redundant
 data parsing and format checks from python wrapper and having them take place
 only in kernels. Remove security concerns. This change:

- Creates new op kernel (`DecodeImageV2Op`) that can decode all four image formats (jpg, png, gif, bmp). `DecodeImage` is the op name. `DecodeBmpOp` is moved into `DecodeImageV2Op`. (Now we have `gen_image_ops.decode_image` as opposed to previous `decode_image` which was a pure python implementation.)
- Updates GIF decoder to take in `expand_animation` flag for decoding just one frame.
- Removes data parsing and format checking logic from python layer entirely.
- Updates magic bytes for detecting image formats.
- Replicates portions of `convert_image_dtype` functionality in kernel (for optionally converting uint8/uint16 -> float32).

PiperOrigin-RevId: 317891936
Change-Id: I84f18e053f6dad845d9f2a61e1119f4de131c85d
---
 .../base_api/api_def_DecodeImage.pbtxt        |  51 ++
 .../python_api/api_def_DecodeImage.pbtxt      |   4 +
 tensorflow/core/kernels/decode_image_op.cc    | 459 +++++++++++++++++-
 tensorflow/core/lib/gif/gif_io.cc             |  13 +-
 tensorflow/core/lib/gif/gif_io.h              |   2 +-
 tensorflow/core/ops/image_ops.cc              |  45 ++
 tensorflow/core/ops/image_ops_test.cc         |  33 +-
 tensorflow/python/ops/image_ops_impl.py       |  17 +
 tensorflow/python/ops/image_ops_test.py       | 201 ++++----
 .../api/golden/v1/tensorflow.raw_ops.pbtxt    |   4 +
 .../api/golden/v2/tensorflow.raw_ops.pbtxt    |   4 +
 11 files changed, 738 insertions(+), 95 deletions(-)
 create mode 100644 tensorflow/core/api_def/base_api/api_def_DecodeImage.pbtxt
 create mode 100644 tensorflow/core/api_def/python_api/api_def_DecodeImage.pbtxt

diff --git a/tensorflow/core/api_def/base_api/api_def_DecodeImage.pbtxt b/tensorflow/core/api_def/base_api/api_def_DecodeImage.pbtxt
new file mode 100644
index 00000000000..c534425eb24
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_DecodeImage.pbtxt
@@ -0,0 +1,51 @@
+op {
+  graph_op_name: "DecodeImage"
+  in_arg {
+    name: "contents"
+    description: <<END
+0-D. The encoded image bytes.
+END
+  }
+  out_arg {
+    name: "image"
+    description: <<END
+3-D with shape `[height, width, channels]` or 4-D with shape
+`[frame, height, width, channels]`..
+END
+  }
+  attr {
+    name: "channels"
+    description: <<END
+Number of color channels for the decoded image.
+END
+  }
+  attr {
+    name: "dtype"
+    description: <<END
+The desired DType of the returned Tensor.
+END
+  }
+  attr {
+    name: "expand_animations"
+    description: <<END
+Controls the output shape of the returned op. If True, the returned op will
+produce a 3-D tensor for PNG, JPEG, and BMP files; and a 4-D tensor for all
+GIFs, whether animated or not. If, False, the returned op will produce a 3-D
+tensor for all file types and will truncate animated GIFs to the first frame.
+END
+  }
+  summary: "Function for decode_bmp, decode_gif, decode_jpeg, and decode_png."
+  description: <<END
+Detects whether an image is a BMP, GIF, JPEG, or PNG, and performs the
+appropriate operation to convert the input bytes string into a Tensor of type
+dtype.
+
+*NOTE*: decode_gif returns a 4-D array [num_frames, height, width, 3], as
+opposed to decode_bmp, decode_jpeg and decode_png, which return 3-D arrays
+[height, width, num_channels]. Make sure to take this into account when
+constructing your graph if you are intermixing GIF files with BMP, JPEG, and/or
+PNG files. Alternately, set the expand_animations argument of this function to
+False, in which case the op will return 3-dimensional tensors and will truncate
+animated GIF files to the first frame.
+END
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_DecodeImage.pbtxt b/tensorflow/core/api_def/python_api/api_def_DecodeImage.pbtxt
new file mode 100644
index 00000000000..54c4f6eeeee
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_DecodeImage.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "DecodeImage"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/kernels/decode_image_op.cc b/tensorflow/core/kernels/decode_image_op.cc
index 3f878ac6b95..8d0c0d89d43 100644
--- a/tensorflow/core/kernels/decode_image_op.cc
+++ b/tensorflow/core/kernels/decode_image_op.cc
@@ -17,7 +17,10 @@ limitations under the License.
 
 #include <memory>
 
+#define EIGEN_USE_THREADS
+
 #include "absl/strings/escaping.h"
+#include "tensorflow/core/framework/bounds_check.h"
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/register_types.h"
 #include "tensorflow/core/framework/tensor.h"
@@ -33,19 +36,31 @@ limitations under the License.
 namespace tensorflow {
 namespace {
 
+// Magic bytes (hex) for each image format.
+// https://en.wikipedia.org/wiki/List_of_file_signatures
+// WARNING: Changing `static const` to `constexpr` requires first checking that
+// it works with supported MSVC version.
+// https://docs.microsoft.com/en-us/cpp/cpp/constexpr-cpp?redirectedfrom=MSDN&view=vs-2019
+static const char kPngMagicBytes[] = "\x89\x50\x4E\x47\x0D\x0A\x1A\x0A";
+static const char kGifMagicBytes[] = "\x47\x49\x46\x38";
+static const char kBmpMagicBytes[] = "\x42\x4d";
+// The 4th byte of JPEG is '\xe0' or '\xe1', so check just the first three.
+static const char kJpegMagicBytes[] = "\xff\xd8\xff";
+
 enum FileFormat {
   kUnknownFormat = 0,
   kPngFormat = 1,
   kJpgFormat = 2,
   kGifFormat = 3,
+  kBmpFormat = 4,
 };
 
 // Classify the contents of a file based on starting bytes (the magic number).
 FileFormat ClassifyFileFormat(StringPiece data) {
-  // The 4th byte of JPEG is '\xe0' or '\xe1', so check just the first three
-  if (absl::StartsWith(data, "\xff\xd8\xff")) return kJpgFormat;
-  if (absl::StartsWith(data, "\x89PNG\r\n\x1a\n")) return kPngFormat;
-  if (absl::StartsWith(data, "\x47\x49\x46\x38")) return kGifFormat;
+  if (absl::StartsWith(data, kJpegMagicBytes)) return kJpgFormat;
+  if (absl::StartsWith(data, kPngMagicBytes)) return kPngFormat;
+  if (absl::StartsWith(data, kGifMagicBytes)) return kGifFormat;
+  if (absl::StartsWith(data, kBmpMagicBytes)) return kBmpFormat;
   return kUnknownFormat;
 }
 
@@ -339,11 +354,447 @@ class DecodeImageOp : public OpKernel {
   jpeg::UncompressFlags flags_;
 };
 
+// Decode an image. Supported image formats are JPEG, PNG, GIF and BMP. This is
+// a newer version of `DecodeImageOp` for enabling image data parsing to take
+// place in kernels only, reducing security vulnerabilities and redundancy.
+class DecodeImageV2Op : public OpKernel {
+ public:
+  explicit DecodeImageV2Op(OpKernelConstruction* context) : OpKernel(context) {
+    OP_REQUIRES_OK(context, context->GetAttr("channels", &channels_));
+    OP_REQUIRES(
+        context,
+        channels_ == 0 || channels_ == 1 || channels_ == 3 || channels_ == 4,
+        errors::InvalidArgument("`channels` must be 0, 1, 3 or 4 but got ",
+                                channels_));
+    OP_REQUIRES_OK(context, context->GetAttr("dtype", &data_type_));
+    OP_REQUIRES(
+        context,
+        data_type_ == DataType::DT_UINT8 || data_type_ == DataType::DT_UINT16 ||
+            data_type_ == DataType::DT_FLOAT,
+        errors::InvalidArgument(
+            "`dtype` must be unit8, unit16, float but got: ", data_type_));
+    OP_REQUIRES_OK(context,
+                   context->GetAttr("expand_animations", &expand_animations_));
+  }
+
+  // Helper for decoding BMP.
+  inline int32 ByteSwapInt32ForBigEndian(int32 x) {
+#if (__BYTE_ORDER__ == __ORDER_BIG_ENDIAN__)
+    return le32toh(x);
+#else
+    return x;
+#endif
+  }
+
+  void Compute(OpKernelContext* context) override {
+    const Tensor& contents = context->input(0);
+    OP_REQUIRES(
+        context, TensorShapeUtils::IsScalar(contents.shape()),
+        errors::InvalidArgument("`contents` must be scalar but got shape",
+                                contents.shape().DebugString()));
+    const StringPiece input = contents.scalar<tstring>()();
+    OP_REQUIRES(context, !input.empty(),
+                errors::InvalidArgument("Input is empty."));
+    OP_REQUIRES(context, input.size() <= std::numeric_limits<int>::max(),
+                errors::InvalidArgument(
+                    "Input contents are too large for int: ", input.size()));
+
+    // Parse magic bytes to determine file format.
+    switch (ClassifyFileFormat(input)) {
+      case kJpgFormat:
+        DecodeJpegV2(context, input);
+        break;
+      case kPngFormat:
+        DecodePngV2(context, input);
+        break;
+      case kGifFormat:
+        DecodeGifV2(context, input);
+        break;
+      case kBmpFormat:
+        DecodeBmpV2(context, input);
+        break;
+      case kUnknownFormat:
+        OP_REQUIRES(context, false,
+                    errors::InvalidArgument("Unknown image file format. One of "
+                                            "JPEG, PNG, GIF, BMP required."));
+        break;
+    }
+  }
+
+  void DecodeJpegV2(OpKernelContext* context, StringPiece input) {
+    OP_REQUIRES(context, channels_ == 0 || channels_ == 1 || channels_ == 3,
+                errors::InvalidArgument("JPEG does not support 4 channels."));
+
+    // Use default settings for `DecodeImage` op. Use local copy of flags to
+    // avoid race condition as the class member is shared among different
+    // invocations.
+    jpeg::UncompressFlags flags = jpeg::UncompressFlags();
+    flags.components = channels_;
+    flags.dct_method = JDCT_IFAST;
+
+    // Output tensor and the image buffer size.
+    Tensor* output = nullptr;
+    int buffer_size = 0;
+
+    // Decode JPEG. Directly allocate to the output buffer if data type is
+    // uint8 (to save extra copying). Otherwise, allocate a new uint8 buffer
+    // with buffer size. `jpeg::Uncompress` support unit8 only.
+    uint8* buffer = jpeg::Uncompress(
+        input.data(), input.size(), flags, nullptr /* nwarn */,
+        [&](int width, int height, int channels) -> uint8* {
+          buffer_size = height * width * channels;
+          Status status = context->allocate_output(
+              0, TensorShape({height, width, channels}), &output);
+          if (!status.ok()) {
+            VLOG(1) << status;
+            context->SetStatus(status);
+            return nullptr;
+          }
+
+          if (data_type_ == DataType::DT_UINT8) {
+            return output->flat<uint8>().data();
+          } else {
+            return new uint8[buffer_size];
+          }
+        });
+
+    OP_REQUIRES(context, buffer,
+                errors::InvalidArgument("jpeg::Uncompress failed."));
+
+    // For when desired data type if unit8, the output buffer is already
+    // allocated during the `jpeg::Uncompress` call above; return.
+    if (data_type_ == DataType::DT_UINT8) {
+      return;
+    }
+    // Make sure we don't forget to deallocate `buffer`.
+    std::unique_ptr<uint8[]> buffer_unique_ptr(buffer);
+
+    // Convert uint8 image data to desired data type.
+    // Use eigen threadpooling to speed up the copy operation.
+    const auto& device = context->eigen_device<Eigen::ThreadPoolDevice>();
+    TTypes<uint8>::UnalignedConstFlat buffer_view(buffer, buffer_size);
+    if (data_type_ == DataType::DT_UINT16) {
+      uint16 scale = floor((std::numeric_limits<uint16>::max() + 1) /
+                           (std::numeric_limits<uint8>::max() + 1));
+      // Fill output tensor with desired dtype.
+      output->flat<uint16>().device(device) =
+          buffer_view.cast<uint16>() * scale;
+    } else if (data_type_ == DataType::DT_FLOAT) {
+      float scale = 1. / std::numeric_limits<uint8>::max();
+      // Fill output tensor with desired dtype.
+      output->flat<float>().device(device) = buffer_view.cast<float>() * scale;
+    }
+  }
+
+  void DecodePngV2(OpKernelContext* context, StringPiece input) {
+    int channel_bits;
+    channel_bits = (data_type_ == DataType::DT_UINT8) ? 8 : 16;
+    png::DecodeContext decode;
+    OP_REQUIRES(
+        context, png::CommonInitDecode(input, channels_, channel_bits, &decode),
+        errors::InvalidArgument("Invalid PNG. Failed to initialize decoder."));
+
+    // Verify that width and height are not too large:
+    // - verify width and height don't overflow int.
+    // - width can later be multiplied by channels_ and sizeof(uint16), so
+    //   verify single dimension is not too large.
+    // - verify when width and height are multiplied together, there are a few
+    //   bits to spare as well.
+    const int width = static_cast<int>(decode.width);
+    const int height = static_cast<int>(decode.height);
+    const int64 total_size =
+        static_cast<int64>(width) * static_cast<int64>(height);
+    if (width != static_cast<int64>(decode.width) || width <= 0 ||
+        width >= (1LL << 27) || height != static_cast<int64>(decode.height) ||
+        height <= 0 || height >= (1LL << 27) || total_size >= (1LL << 29)) {
+      png::CommonFreeDecode(&decode);
+      OP_REQUIRES(context, false,
+                  errors::InvalidArgument("PNG size too large for int: ",
+                                          decode.width, " by ", decode.height));
+    }
+
+    Tensor* output = nullptr;
+    const auto status = context->allocate_output(
+        0, TensorShape({height, width, decode.channels}), &output);
+    if (!status.ok()) png::CommonFreeDecode(&decode);
+    OP_REQUIRES_OK(context, status);
+
+    if (data_type_ == DataType::DT_UINT8) {
+      OP_REQUIRES(
+          context,
+          png::CommonFinishDecode(
+              reinterpret_cast<png_bytep>(output->flat<uint8>().data()),
+              decode.channels * width * sizeof(uint8), &decode),
+          errors::InvalidArgument("Invalid PNG data, size ", input.size()));
+    } else if (data_type_ == DataType::DT_UINT16) {
+      OP_REQUIRES(
+          context,
+          png::CommonFinishDecode(
+              reinterpret_cast<png_bytep>(output->flat<uint16>().data()),
+              decode.channels * width * sizeof(uint16), &decode),
+          errors::InvalidArgument("Invalid PNG data, size ", input.size()));
+    } else if (data_type_ == DataType::DT_FLOAT) {
+      // `png::CommonFinishDecode` does not support `float`. First allocate
+      // uint16 buffer for the image and decode in uint16 (lossless). Wrap the
+      // buffer in `unique_ptr` so that we don't forget to delete the buffer.
+      std::unique_ptr<uint16[]> buffer(
+          new uint16[height * width * decode.channels]);
+      OP_REQUIRES(
+          context,
+          png::CommonFinishDecode(reinterpret_cast<png_bytep>(buffer.get()),
+                                  decode.channels * width * sizeof(uint16),
+                                  &decode),
+          errors::InvalidArgument("Invalid PNG data, size ", input.size()));
+
+      // Convert uint16 image data to desired data type.
+      // Use eigen threadpooling to speed up the copy operation.
+      const auto& device = context->eigen_device<Eigen::ThreadPoolDevice>();
+      TTypes<uint16, 3>::UnalignedConstTensor buf(buffer.get(), height, width,
+                                                  decode.channels);
+      float scale = 1. / std::numeric_limits<uint16>::max();
+      // Fill output tensor with desired dtype.
+      output->tensor<float, 3>().device(device) = buf.cast<float>() * scale;
+    }
+  }
+
+  void DecodeGifV2(OpKernelContext* context, StringPiece input) {
+    // GIF has 3 channels.
+    OP_REQUIRES(context, channels_ == 0 || channels_ == 3,
+                errors::InvalidArgument("channels must be 0 or 3 for GIF, got ",
+                                        channels_));
+
+    // Decode GIF, allocating tensor if dtype is uint8, otherwise defer tensor
+    // allocation til after dtype conversion is done. `gif`::Decode` supports
+    // uint8 only.
+    Tensor* output = nullptr;
+    int buffer_size = 0;
+    string error_string;
+    uint8* buffer = gif::Decode(
+        input.data(), input.size(),
+        [&](int num_frames, int width, int height, int channels) -> uint8* {
+          buffer_size = num_frames * height * width * channels;
+
+          Status status;
+          if (expand_animations_) {
+            status = context->allocate_output(
+                0, TensorShape({num_frames, height, width, channels}), &output);
+          } else {
+            status = context->allocate_output(
+                0, TensorShape({height, width, channels}), &output);
+          }
+          if (!status.ok()) {
+            VLOG(1) << status;
+            context->SetStatus(status);
+            return nullptr;
+          }
+
+          if (data_type_ == DataType::DT_UINT8) {
+            return output->flat<uint8>().data();
+          } else {
+            return new uint8[buffer_size];
+          }
+        },
+        &error_string, expand_animations_);
+
+    OP_REQUIRES(context, buffer,
+                errors::InvalidArgument("Invalid GIF data (size ", input.size(),
+                                        "), ", error_string));
+
+    // For when desired data type is unit8, the output buffer is already
+    // allocated during the `gif::Decode` call above; return.
+    if (data_type_ == DataType::DT_UINT8) {
+      return;
+    }
+    // Make sure we don't forget to deallocate `buffer`.
+    std::unique_ptr<uint8[]> buffer_unique_ptr(buffer);
+
+    // Convert the raw uint8 buffer to desired dtype.
+    // Use eigen threadpooling to speed up the copy operation.
+    TTypes<uint8>::UnalignedConstFlat buffer_view(buffer, buffer_size);
+    const auto& device = context->eigen_device<Eigen::ThreadPoolDevice>();
+    if (data_type_ == DataType::DT_UINT16) {
+      uint16 scale = floor((std::numeric_limits<uint16>::max() + 1) /
+                           (std::numeric_limits<uint8>::max() + 1));
+      // Fill output tensor with desired dtype.
+      output->flat<uint16>().device(device) =
+          buffer_view.cast<uint16>() * scale;
+    } else if (data_type_ == DataType::DT_FLOAT) {
+      float scale = 1. / std::numeric_limits<uint8>::max();
+      // Fill output tensor with desired dtype.
+      output->flat<float>().device(device) = buffer_view.cast<float>() * scale;
+    }
+  }
+
+  void DecodeBmpV2(OpKernelContext* context, StringPiece input) {
+    OP_REQUIRES(context, channels_ == 0 || channels_ == 3,
+                errors::InvalidArgument(
+                    "`channels` must be 0 or 3 for BMP, but got ", channels_));
+
+    OP_REQUIRES(context, (32 <= input.size()),
+                errors::InvalidArgument("Incomplete bmp content, requires at "
+                                        "least 32 bytes to find the header "
+                                        "size, width, height, and bpp, got ",
+                                        input.size(), " bytes"));
+
+    const uint8* img_bytes = reinterpret_cast<const uint8*>(input.data());
+    int32 header_size_ = internal::SubtleMustCopy(
+        *(reinterpret_cast<const int32*>(img_bytes + 10)));
+    const int32 header_size = ByteSwapInt32ForBigEndian(header_size_);
+    int32 width_ = internal::SubtleMustCopy(
+        *(reinterpret_cast<const int32*>(img_bytes + 18)));
+    const int32 width = ByteSwapInt32ForBigEndian(width_);
+    int32 height_ = internal::SubtleMustCopy(
+        *(reinterpret_cast<const int32*>(img_bytes + 22)));
+    const int32 height = ByteSwapInt32ForBigEndian(height_);
+    int32 bpp_ = internal::SubtleMustCopy(
+        *(reinterpret_cast<const int32*>(img_bytes + 28)));
+    const int32 bpp = ByteSwapInt32ForBigEndian(bpp_);
+
+    if (channels_) {
+      OP_REQUIRES(context, (channels_ == bpp / 8),
+                  errors::InvalidArgument(
+                      "channels attribute ", channels_,
+                      " does not match bits per pixel from file ", bpp / 8));
+    } else {
+      channels_ = bpp / 8;
+    }
+
+    // Current implementation only supports 1, 3 or 4 channel
+    // bitmaps.
+    OP_REQUIRES(context, (channels_ == 1 || channels_ == 3 || channels_ == 4),
+                errors::InvalidArgument(
+                    "Number of channels must be 1, 3 or 4, was ", channels_));
+
+    OP_REQUIRES(context, width > 0,
+                errors::InvalidArgument("Width must be positive"));
+    OP_REQUIRES(context, height != 0,
+                errors::InvalidArgument("Height must be nonzero"));
+    OP_REQUIRES(context, header_size >= 0,
+                errors::InvalidArgument("header size must be nonnegative"));
+
+    // The real requirement is < 2^31 minus some headers and channel data,
+    // so rounding down to something that's still ridiculously big.
+    OP_REQUIRES(
+        context,
+        (static_cast<int64>(width) * std::abs(static_cast<int64>(height))) <
+            static_cast<int64>(std::numeric_limits<int32_t>::max() / 8),
+        errors::InvalidArgument(
+            "Total possible pixel bytes must be less than 2^30"));
+
+    const int32 abs_height = abs(height);
+
+    // there may be padding bytes when the width is not a multiple of 4 bytes
+    const int row_size = (channels_ * width + 3) / 4 * 4;
+
+    const int64 last_pixel_offset = static_cast<int64>(header_size) +
+                                    (abs_height - 1) * row_size +
+                                    (width - 1) * channels_;
+
+    // [expected file size] = [last pixel offset] + [last pixel size=channels]
+    const int64 expected_file_size = last_pixel_offset + channels_;
+
+    OP_REQUIRES(
+        context, (expected_file_size <= input.size()),
+        errors::InvalidArgument("Incomplete bmp content, requires at least ",
+                                expected_file_size, " bytes, got ",
+                                input.size(), " bytes"));
+
+    // if height is negative, data layout is top down
+    // otherwise, it's bottom up.
+    bool top_down = (height < 0);
+
+    // Decode image, allocating tensor once the image size is known.
+    Tensor* output = nullptr;
+    OP_REQUIRES_OK(
+        context, context->allocate_output(
+                     0, TensorShape({abs_height, width, channels_}), &output));
+
+    const uint8* bmp_pixels = &img_bytes[header_size];
+
+    if (data_type_ == DataType::DT_UINT8) {
+      DecodeBMP(bmp_pixels, row_size, output->flat<uint8>().data(), width,
+                abs_height, channels_, top_down);
+    } else {
+      std::unique_ptr<uint8[]> buffer(new uint8[height * width * channels_]);
+      DecodeBMP(bmp_pixels, row_size, buffer.get(), width, abs_height,
+                channels_, top_down);
+      TTypes<uint8, 3>::UnalignedConstTensor buf(buffer.get(), height, width,
+                                                 channels_);
+      // Convert the raw uint8 buffer to desired dtype.
+      // Use eigen threadpooling to speed up the copy operation.
+      const auto& device = context->eigen_device<Eigen::ThreadPoolDevice>();
+      if (data_type_ == DataType::DT_UINT16) {
+        uint16 scale = floor((std::numeric_limits<uint16>::max() + 1) /
+                             (std::numeric_limits<uint8>::max() + 1));
+        // Fill output tensor with desired dtype.
+        output->tensor<uint16, 3>().device(device) = buf.cast<uint16>() * scale;
+      } else if (data_type_ == DataType::DT_FLOAT) {
+        float scale = 1. / std::numeric_limits<uint8>::max();
+        // Fill output tensor with desired dtype.
+        output->tensor<float, 3>().device(device) = buf.cast<float>() * scale;
+      }
+    }
+  }
+
+  void DecodeBMP(const uint8* input, const int row_size, uint8* const output,
+                 const int width, const int height, const int channels,
+                 bool top_down);
+
+ private:
+  int channels_ = 0;
+  DataType data_type_;
+  bool expand_animations_;
+};
+
 REGISTER_KERNEL_BUILDER(Name("DecodeJpeg").Device(DEVICE_CPU), DecodeImageOp);
 REGISTER_KERNEL_BUILDER(Name("DecodePng").Device(DEVICE_CPU), DecodeImageOp);
 REGISTER_KERNEL_BUILDER(Name("DecodeGif").Device(DEVICE_CPU), DecodeImageOp);
 REGISTER_KERNEL_BUILDER(Name("DecodeAndCropJpeg").Device(DEVICE_CPU),
                         DecodeImageOp);
+REGISTER_KERNEL_BUILDER(Name("DecodeImage").Device(DEVICE_CPU),
+                        DecodeImageV2Op);
+
+void DecodeImageV2Op::DecodeBMP(const uint8* input, const int row_size,
+                                uint8* const output, const int width,
+                                const int height, const int channels,
+                                bool top_down) {
+  for (int i = 0; i < height; i++) {
+    int src_pos;
+    int dst_pos;
+
+    for (int j = 0; j < width; j++) {
+      if (!top_down) {
+        src_pos = ((height - 1 - i) * row_size) + j * channels;
+      } else {
+        src_pos = i * row_size + j * channels;
+      }
+
+      dst_pos = (i * width + j) * channels;
+
+      switch (channels) {
+        case 1:
+          output[dst_pos] = input[src_pos];
+          break;
+        case 3:
+          // BGR -> RGB
+          output[dst_pos] = input[src_pos + 2];
+          output[dst_pos + 1] = input[src_pos + 1];
+          output[dst_pos + 2] = input[src_pos];
+          break;
+        case 4:
+          // BGRA -> RGBA
+          output[dst_pos] = input[src_pos + 2];
+          output[dst_pos + 1] = input[src_pos + 1];
+          output[dst_pos + 2] = input[src_pos];
+          output[dst_pos + 3] = input[src_pos + 3];
+          break;
+        default:
+          LOG(FATAL) << "Unexpected number of channels: " << channels;
+          break;
+      }
+    }
+  }
+}
 
 }  // namespace
 }  // namespace tensorflow
diff --git a/tensorflow/core/lib/gif/gif_io.cc b/tensorflow/core/lib/gif/gif_io.cc
index dc5406920a4..32e2f6dfa52 100644
--- a/tensorflow/core/lib/gif/gif_io.cc
+++ b/tensorflow/core/lib/gif/gif_io.cc
@@ -55,7 +55,7 @@ static const char* GifErrorStringNonNull(int error_code) {
 
 uint8* Decode(const void* srcdata, int datasize,
               const std::function<uint8*(int, int, int, int)>& allocate_output,
-              string* error_string) {
+              string* error_string, bool expand_animations) {
   int error_code = D_GIF_SUCCEEDED;
   InputBufferInfo info = {reinterpret_cast<const uint8*>(srcdata), datasize};
   GifFileType* gif_file =
@@ -82,10 +82,13 @@ uint8* Decode(const void* srcdata, int datasize,
     return nullptr;
   }
 
+  int target_num_frames = gif_file->ImageCount;
+  if (!expand_animations) target_num_frames = 1;
+
   // Don't request more memory than needed for each frame, preventing OOM
   int max_frame_width = 0;
   int max_frame_height = 0;
-  for (int k = 0; k < gif_file->ImageCount; k++) {
+  for (int k = 0; k < target_num_frames; k++) {
     SavedImage* si = &gif_file->SavedImages[k];
     if (max_frame_height < si->ImageDesc.Height)
       max_frame_height = si->ImageDesc.Height;
@@ -93,14 +96,14 @@ uint8* Decode(const void* srcdata, int datasize,
       max_frame_width = si->ImageDesc.Width;
   }
 
-  const int num_frames = gif_file->ImageCount;
   const int width = max_frame_width;
   const int height = max_frame_height;
   const int channel = 3;
 
-  uint8* const dstdata = allocate_output(num_frames, width, height, channel);
+  uint8* const dstdata =
+      allocate_output(target_num_frames, width, height, channel);
   if (!dstdata) return nullptr;
-  for (int k = 0; k < num_frames; k++) {
+  for (int k = 0; k < target_num_frames; k++) {
     uint8* this_dst = dstdata + k * width * channel * height;
 
     SavedImage* this_image = &gif_file->SavedImages[k];
diff --git a/tensorflow/core/lib/gif/gif_io.h b/tensorflow/core/lib/gif/gif_io.h
index e46a7917398..ae7d5125bd7 100644
--- a/tensorflow/core/lib/gif/gif_io.h
+++ b/tensorflow/core/lib/gif/gif_io.h
@@ -44,7 +44,7 @@ namespace gif {
 
 uint8* Decode(const void* srcdata, int datasize,
               const std::function<uint8*(int, int, int, int)>& allocate_output,
-              string* error_string);
+              string* error_string, bool expand_animations = true);
 
 }  // namespace gif
 }  // namespace tensorflow
diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc
index e11f14b8538..43ee65c4ab4 100644
--- a/tensorflow/core/ops/image_ops.cc
+++ b/tensorflow/core/ops/image_ops.cc
@@ -87,6 +87,40 @@ Status DecodeImageShapeFn(InferenceContext* c) {
   return Status::OK();
 }
 
+Status DecodeImageV2ShapeFn(InferenceContext* c) {
+  ShapeHandle unused;
+  int32 channels;
+  bool expand_animations;
+  DimensionHandle channels_dim;
+
+  TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
+  TF_RETURN_IF_ERROR(c->GetAttr("channels", &channels));
+  TF_RETURN_IF_ERROR(c->GetAttr("expand_animations", &expand_animations));
+
+  if (channels == 0) {
+    channels_dim = c->UnknownDim();
+  } else {
+    if (channels < 0) {
+      return errors::InvalidArgument("channels must be non-negative, got ",
+                                     channels);
+    }
+    channels_dim = c->MakeDim(channels);
+  }
+
+  // `expand_animations` set to true will return 4-D shapes for GIF. 3-D shapes
+  // will be returned for jpg, png, and bmp. `expand_animations` set to false
+  // will always return 3-D shapes for all (jpg, png, bmp, gif).
+  if (expand_animations) {
+    c->set_output(0, c->UnknownShape());
+    return Status::OK();
+  } else {
+    c->set_output(0,
+                  c->MakeShape({InferenceContext::kUnknownDim,
+                                InferenceContext::kUnknownDim, channels_dim}));
+    return Status::OK();
+  }
+}
+
 Status EncodeImageShapeFn(InferenceContext* c) {
   ShapeHandle unused;
   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &unused));
@@ -412,6 +446,17 @@ REGISTER_OP("RandomCrop")
     });
 // TODO(shlens): Support variable rank in RandomCrop.
 
+// --------------------------------------------------------------------------
+REGISTER_OP("DecodeImage")
+    .Input("contents: string")
+    // Setting `channels` to 0 means using the inherent number of channels in
+    // the image.
+    .Attr("channels: int = 0")
+    .Attr("dtype: {uint8, uint16, float32} = DT_UINT8")
+    .Output("image: dtype")
+    .Attr("expand_animations: bool = true")
+    .SetShapeFn(DecodeImageV2ShapeFn);
+
 // --------------------------------------------------------------------------
 REGISTER_OP("DecodeJpeg")
     .Input("contents: string")
diff --git a/tensorflow/core/ops/image_ops_test.cc b/tensorflow/core/ops/image_ops_test.cc
index e517e750955..4d0c1fceb28 100644
--- a/tensorflow/core/ops/image_ops_test.cc
+++ b/tensorflow/core/ops/image_ops_test.cc
@@ -17,6 +17,7 @@ limitations under the License.
 #include "tensorflow/core/framework/op.h"
 #include "tensorflow/core/framework/shape_inference_testutil.h"
 #include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/framework/types.pb.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/test.h"
 
@@ -61,6 +62,34 @@ TEST(ImageOpsTest, DecodeGif) {
   INFER_OK(op, "[]", "[?,?,?,3]");
 }
 
+TEST(ImageOpTest, DecodeImage) {
+  ShapeInferenceTestOp op("DecodeImage");
+
+  // Rank check.
+  INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[1]");
+
+  // Set `expand_animations` to false. Output is always ?,?,?.
+  TF_ASSERT_OK(NodeDefBuilder("test", "DecodeImage")
+                   .Input({"img", 0, DT_STRING})
+                   .Attr("expand_animations", false)
+                   .Finalize(&op.node_def));
+  INFER_OK(op, "[]", "[?,?,?]");
+
+  // Set `expand_animations` to false. Output shape is not known (3D or 4D).
+  TF_ASSERT_OK(NodeDefBuilder("test", "DecodeImage")
+                   .Input({"img", 0, DT_STRING})
+                   .Attr("expand_animations", true)
+                   .Finalize(&op.node_def));
+  INFER_OK(op, "[]", "?");
+
+  // Negative channel value is rejected.
+  TF_ASSERT_OK(NodeDefBuilder("test", "DecodeImage")
+                   .Input({"img", 0, DT_STRING})
+                   .Attr("channels", -1)
+                   .Finalize(&op.node_def));
+  INFER_ERROR("channels must be non-negative, got -1", op, "[]");
+}
+
 TEST(ImageOpsTest, DecodeImage_ShapeFn) {
   for (const char* op_name : {"DecodeJpeg", "DecodePng"}) {
     ShapeInferenceTestOp op(op_name);
@@ -325,8 +354,8 @@ TEST(ImageOpsTest, DrawBoundingBoxes_ShapeFn) {
 
   // Check images.
   INFER_ERROR("must be rank 4", op, "[1,?,3];?");
-  INFER_ERROR("should be either 1 (GRY), 3 (RGB), or 4 (RGBA)",
-      op, "[1,?,?,5];?");
+  INFER_ERROR("should be either 1 (GRY), 3 (RGB), or 4 (RGBA)", op,
+              "[1,?,?,5];?");
 
   // Check boxes.
   INFER_ERROR("must be rank 3", op, "[1,?,?,4];[1,4]");
diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py
index 683681b5c98..bbce25724e7 100644
--- a/tensorflow/python/ops/image_ops_impl.py
+++ b/tensorflow/python/ops/image_ops_impl.py
@@ -2634,6 +2634,23 @@ def decode_image(contents,
     ValueError: On incorrect number of channels.
   """
   with ops.name_scope(name, 'decode_image'):
+    if compat.forward_compatible(2020, 7, 14):
+      channels = 0 if channels is None else channels
+      if dtype not in [dtypes.float32, dtypes.uint8, dtypes.uint16]:
+        dest_dtype = dtype
+        dtype = dtypes.uint16
+        return convert_image_dtype(gen_image_ops.decode_image(
+            contents=contents,
+            channels=channels,
+            expand_animations=expand_animations,
+            dtype=dtype), dest_dtype)
+      else:
+        return gen_image_ops.decode_image(
+            contents=contents,
+            channels=channels,
+            expand_animations=expand_animations,
+            dtype=dtype)
+
     if channels not in (None, 0, 1, 3, 4):
       raise ValueError('channels must be in (None, 0, 1, 3, 4)')
     substr = string_ops.substr(contents, 0, 3)
diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py
index 0206ccf9b33..a05209c2038 100644
--- a/tensorflow/python/ops/image_ops_test.py
+++ b/tensorflow/python/ops/image_ops_test.py
@@ -30,6 +30,7 @@ from six.moves import xrange  # pylint: disable=redefined-builtin
 
 from tensorflow.core.protobuf import config_pb2
 from tensorflow.python.client import session
+from tensorflow.python.compat import compat
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
@@ -5173,107 +5174,141 @@ class SobelEdgesTest(test_util.TensorFlowTestCase):
 @test_util.run_all_in_graph_and_eager_modes
 class DecodeImageTest(test_util.TensorFlowTestCase):
 
+  _FORWARD_COMPATIBILITY_HORIZONS = [
+      (2020, 6, 11),
+      (2020, 7, 11),
+      (2525, 1, 1),  # future behavior
+  ]
+
   def testJpegUint16(self):
-    with self.cached_session(use_gpu=True) as sess:
-      base = "tensorflow/core/lib/jpeg/testdata"
-      jpeg0 = io_ops.read_file(os.path.join(base, "jpeg_merge_test1.jpg"))
-      image0 = image_ops.decode_image(jpeg0, dtype=dtypes.uint16)
-      image1 = image_ops.convert_image_dtype(image_ops.decode_jpeg(jpeg0),
-                                             dtypes.uint16)
-      image0, image1 = self.evaluate([image0, image1])
-      self.assertAllEqual(image0, image1)
+    for horizon in self._FORWARD_COMPATIBILITY_HORIZONS:
+      with compat.forward_compatibility_horizon(*horizon):
+        with self.cached_session(use_gpu=True) as sess:
+          base = "tensorflow/core/lib/jpeg/testdata"
+          jpeg0 = io_ops.read_file(os.path.join(base, "jpeg_merge_test1.jpg"))
+          image0 = image_ops.decode_image(jpeg0, dtype=dtypes.uint16)
+          image1 = image_ops.convert_image_dtype(image_ops.decode_jpeg(jpeg0),
+                                                 dtypes.uint16)
+          image0, image1 = self.evaluate([image0, image1])
+          self.assertAllEqual(image0, image1)
 
   def testPngUint16(self):
-    with self.cached_session(use_gpu=True) as sess:
-      base = "tensorflow/core/lib/png/testdata"
-      png0 = io_ops.read_file(os.path.join(base, "lena_rgba.png"))
-      image0 = image_ops.decode_image(png0, dtype=dtypes.uint16)
-      image1 = image_ops.convert_image_dtype(
-          image_ops.decode_png(png0, dtype=dtypes.uint16), dtypes.uint16)
-      image0, image1 = self.evaluate([image0, image1])
-      self.assertAllEqual(image0, image1)
+    for horizon in self._FORWARD_COMPATIBILITY_HORIZONS:
+      with compat.forward_compatibility_horizon(*horizon):
+        with self.cached_session(use_gpu=True) as sess:
+          base = "tensorflow/core/lib/png/testdata"
+          png0 = io_ops.read_file(os.path.join(base, "lena_rgba.png"))
+          image0 = image_ops.decode_image(png0, dtype=dtypes.uint16)
+          image1 = image_ops.convert_image_dtype(
+              image_ops.decode_png(png0, dtype=dtypes.uint16), dtypes.uint16)
+          image0, image1 = self.evaluate([image0, image1])
+          self.assertAllEqual(image0, image1)
 
-      # NumPy conversions should happen before
-      x = np.random.randint(256, size=(4, 4, 3), dtype=np.uint16)
-      x_str = image_ops_impl.encode_png(x)
-      x_dec = image_ops_impl.decode_image(
-          x_str, channels=3, dtype=dtypes.uint16)
-      self.assertAllEqual(x, x_dec)
+          # NumPy conversions should happen before
+          x = np.random.randint(256, size=(4, 4, 3), dtype=np.uint16)
+          x_str = image_ops_impl.encode_png(x)
+          x_dec = image_ops_impl.decode_image(
+              x_str, channels=3, dtype=dtypes.uint16)
+          self.assertAllEqual(x, x_dec)
 
   def testGifUint16(self):
-    with self.cached_session(use_gpu=True) as sess:
-      base = "tensorflow/core/lib/gif/testdata"
-      gif0 = io_ops.read_file(os.path.join(base, "scan.gif"))
-      image0 = image_ops.decode_image(gif0, dtype=dtypes.uint16)
-      image1 = image_ops.convert_image_dtype(image_ops.decode_gif(gif0),
-                                             dtypes.uint16)
-      image0, image1 = self.evaluate([image0, image1])
-      self.assertAllEqual(image0, image1)
+    for horizon in self._FORWARD_COMPATIBILITY_HORIZONS:
+      with compat.forward_compatibility_horizon(*horizon):
+        with self.cached_session(use_gpu=True) as sess:
+          base = "tensorflow/core/lib/gif/testdata"
+          gif0 = io_ops.read_file(os.path.join(base, "scan.gif"))
+          image0 = image_ops.decode_image(gif0, dtype=dtypes.uint16)
+          image1 = image_ops.convert_image_dtype(image_ops.decode_gif(gif0),
+                                                 dtypes.uint16)
+          image0, image1 = self.evaluate([image0, image1])
+          self.assertAllEqual(image0, image1)
 
   def testBmpUint16(self):
-    with self.cached_session(use_gpu=True) as sess:
-      base = "tensorflow/core/lib/bmp/testdata"
-      bmp0 = io_ops.read_file(os.path.join(base, "lena.bmp"))
-      image0 = image_ops.decode_image(bmp0, dtype=dtypes.uint16)
-      image1 = image_ops.convert_image_dtype(image_ops.decode_bmp(bmp0),
-                                             dtypes.uint16)
-      image0, image1 = self.evaluate([image0, image1])
-      self.assertAllEqual(image0, image1)
+    for horizon in self._FORWARD_COMPATIBILITY_HORIZONS:
+      with compat.forward_compatibility_horizon(*horizon):
+        with self.cached_session(use_gpu=True) as sess:
+          base = "tensorflow/core/lib/bmp/testdata"
+          bmp0 = io_ops.read_file(os.path.join(base, "lena.bmp"))
+          image0 = image_ops.decode_image(bmp0, dtype=dtypes.uint16)
+          image1 = image_ops.convert_image_dtype(image_ops.decode_bmp(bmp0),
+                                                 dtypes.uint16)
+          image0, image1 = self.evaluate([image0, image1])
+          self.assertAllEqual(image0, image1)
 
   def testJpegFloat32(self):
-    with self.cached_session(use_gpu=True) as sess:
-      base = "tensorflow/core/lib/jpeg/testdata"
-      jpeg0 = io_ops.read_file(os.path.join(base, "jpeg_merge_test1.jpg"))
-      image0 = image_ops.decode_image(jpeg0, dtype=dtypes.float32)
-      image1 = image_ops.convert_image_dtype(image_ops.decode_jpeg(jpeg0),
-                                             dtypes.float32)
-      image0, image1 = self.evaluate([image0, image1])
-      self.assertAllEqual(image0, image1)
+    for horizon in self._FORWARD_COMPATIBILITY_HORIZONS:
+      with compat.forward_compatibility_horizon(*horizon):
+        with self.cached_session(use_gpu=True) as sess:
+          base = "tensorflow/core/lib/jpeg/testdata"
+          jpeg0 = io_ops.read_file(os.path.join(base, "jpeg_merge_test1.jpg"))
+          image0 = image_ops.decode_image(jpeg0, dtype=dtypes.float32)
+          image1 = image_ops.convert_image_dtype(image_ops.decode_jpeg(jpeg0),
+                                                 dtypes.float32)
+          image0, image1 = self.evaluate([image0, image1])
+          self.assertAllEqual(image0, image1)
 
   def testPngFloat32(self):
-    with self.cached_session(use_gpu=True) as sess:
-      base = "tensorflow/core/lib/png/testdata"
-      png0 = io_ops.read_file(os.path.join(base, "lena_rgba.png"))
-      image0 = image_ops.decode_image(png0, dtype=dtypes.float32)
-      image1 = image_ops.convert_image_dtype(
-          image_ops.decode_png(png0, dtype=dtypes.uint16), dtypes.float32)
-      image0, image1 = self.evaluate([image0, image1])
-      self.assertAllEqual(image0, image1)
+    for horizon in self._FORWARD_COMPATIBILITY_HORIZONS:
+      with compat.forward_compatibility_horizon(*horizon):
+        with self.cached_session(use_gpu=True) as sess:
+          base = "tensorflow/core/lib/png/testdata"
+          png0 = io_ops.read_file(os.path.join(base, "lena_rgba.png"))
+          image0 = image_ops.decode_image(png0, dtype=dtypes.float32)
+          image1 = image_ops.convert_image_dtype(
+              image_ops.decode_png(png0, dtype=dtypes.uint16), dtypes.float32)
+          image0, image1 = self.evaluate([image0, image1])
+          self.assertAllEqual(image0, image1)
 
   def testGifFloat32(self):
-    with self.cached_session(use_gpu=True) as sess:
-      base = "tensorflow/core/lib/gif/testdata"
-      gif0 = io_ops.read_file(os.path.join(base, "scan.gif"))
-      image0 = image_ops.decode_image(gif0, dtype=dtypes.float32)
-      image1 = image_ops.convert_image_dtype(image_ops.decode_gif(gif0),
-                                             dtypes.float32)
-      image0, image1 = self.evaluate([image0, image1])
-      self.assertAllEqual(image0, image1)
+    for horizon in self._FORWARD_COMPATIBILITY_HORIZONS:
+      with compat.forward_compatibility_horizon(*horizon):
+        with self.cached_session(use_gpu=True) as sess:
+          base = "tensorflow/core/lib/gif/testdata"
+          gif0 = io_ops.read_file(os.path.join(base, "scan.gif"))
+          image0 = image_ops.decode_image(gif0, dtype=dtypes.float32)
+          image1 = image_ops.convert_image_dtype(image_ops.decode_gif(gif0),
+                                                 dtypes.float32)
+          image0, image1 = self.evaluate([image0, image1])
+          self.assertAllEqual(image0, image1)
 
   def testBmpFloat32(self):
-    with self.cached_session(use_gpu=True) as sess:
-      base = "tensorflow/core/lib/bmp/testdata"
-      bmp0 = io_ops.read_file(os.path.join(base, "lena.bmp"))
-      image0 = image_ops.decode_image(bmp0, dtype=dtypes.float32)
-      image1 = image_ops.convert_image_dtype(image_ops.decode_bmp(bmp0),
-                                             dtypes.float32)
-      image0, image1 = self.evaluate([image0, image1])
-      self.assertAllEqual(image0, image1)
+    for horizon in self._FORWARD_COMPATIBILITY_HORIZONS:
+      with compat.forward_compatibility_horizon(*horizon):
+        with self.cached_session(use_gpu=True) as sess:
+          base = "tensorflow/core/lib/bmp/testdata"
+          bmp0 = io_ops.read_file(os.path.join(base, "lena.bmp"))
+          image0 = image_ops.decode_image(bmp0, dtype=dtypes.float32)
+          image1 = image_ops.convert_image_dtype(image_ops.decode_bmp(bmp0),
+                                                 dtypes.float32)
+          image0, image1 = self.evaluate([image0, image1])
+          self.assertAllEqual(image0, image1)
 
   def testExpandAnimations(self):
-    with self.cached_session(use_gpu=True) as sess:
-      base = "tensorflow/core/lib/gif/testdata"
-      gif0 = io_ops.read_file(os.path.join(base, "scan.gif"))
-      image0 = image_ops.decode_image(
-          gif0, dtype=dtypes.float32, expand_animations=False)
-      # image_ops.decode_png() handles GIFs and returns 3D tensors
-      animation = image_ops.decode_gif(gif0)
-      first_frame = array_ops.gather(animation, 0)
-      image1 = image_ops.convert_image_dtype(first_frame, dtypes.float32)
-      image0, image1 = self.evaluate([image0, image1])
-      self.assertEqual(len(image0.shape), 3)
-      self.assertAllEqual(list(image0.shape), [40, 20, 3])
-      self.assertAllEqual(image0, image1)
+    for horizon in self._FORWARD_COMPATIBILITY_HORIZONS:
+      with compat.forward_compatibility_horizon(*horizon):
+        with self.cached_session(use_gpu=True) as sess:
+          base = "tensorflow/core/lib/gif/testdata"
+          gif0 = io_ops.read_file(os.path.join(base, "scan.gif"))
+
+          # Test `expand_animations=False` case.
+          image0 = image_ops.decode_image(
+              gif0, dtype=dtypes.float32, expand_animations=False)
+          # image_ops.decode_png() handles GIFs and returns 3D tensors
+          animation = image_ops.decode_gif(gif0)
+          first_frame = array_ops.gather(animation, 0)
+          image1 = image_ops.convert_image_dtype(first_frame, dtypes.float32)
+          image0, image1 = self.evaluate([image0, image1])
+          self.assertEqual(len(image0.shape), 3)
+          self.assertAllEqual(list(image0.shape), [40, 20, 3])
+          self.assertAllEqual(image0, image1)
+
+          # Test `expand_animations=True` case.
+          image2 = image_ops.decode_image(gif0, dtype=dtypes.float32)
+          image3 = image_ops.convert_image_dtype(animation, dtypes.float32)
+          image2, image3 = self.evaluate([image2, image3])
+          self.assertEqual(len(image2.shape), 4)
+          self.assertAllEqual(list(image2.shape), [12, 40, 20, 3])
+          self.assertAllEqual(image2, image3)
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
index 8e5303cbea4..1d27408735a 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
@@ -1064,6 +1064,10 @@ tf_module {
     name: "DecodeGif"
     argspec: "args=[\'contents\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
+  member_method {
+    name: "DecodeImage"
+    argspec: "args=[\'contents\', \'channels\', \'dtype\', \'expand_animations\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \"<dtype: \'uint8\'>\", \'True\', \'None\'], "
+  }
   member_method {
     name: "DecodeJSONExample"
     argspec: "args=[\'json_examples\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
index 8e5303cbea4..1d27408735a 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
@@ -1064,6 +1064,10 @@ tf_module {
     name: "DecodeGif"
     argspec: "args=[\'contents\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
+  member_method {
+    name: "DecodeImage"
+    argspec: "args=[\'contents\', \'channels\', \'dtype\', \'expand_animations\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \"<dtype: \'uint8\'>\", \'True\', \'None\'], "
+  }
   member_method {
     name: "DecodeJSONExample"
     argspec: "args=[\'json_examples\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "