From e4d6335bcb7a73cd8967c2c12339380aa1ae284f Mon Sep 17 00:00:00 2001 From: Hye Soo Yang <hyey@google.com> Date: Tue, 23 Jun 2020 10:21:31 -0700 Subject: [PATCH] Refactor DecodeImageOp for the purpose of removing redundant data parsing and format checks from python wrapper and having them take place only in kernels. Remove security concerns. This change: - Creates new op kernel (`DecodeImageV2Op`) that can decode all four image formats (jpg, png, gif, bmp). `DecodeImage` is the op name. `DecodeBmpOp` is moved into `DecodeImageV2Op`. (Now we have `gen_image_ops.decode_image` as opposed to previous `decode_image` which was a pure python implementation.) - Updates GIF decoder to take in `expand_animation` flag for decoding just one frame. - Removes data parsing and format checking logic from python layer entirely. - Updates magic bytes for detecting image formats. - Replicates portions of `convert_image_dtype` functionality in kernel (for optionally converting uint8/uint16 -> float32). PiperOrigin-RevId: 317891936 Change-Id: I84f18e053f6dad845d9f2a61e1119f4de131c85d --- .../base_api/api_def_DecodeImage.pbtxt | 51 ++ .../python_api/api_def_DecodeImage.pbtxt | 4 + tensorflow/core/kernels/decode_image_op.cc | 459 +++++++++++++++++- tensorflow/core/lib/gif/gif_io.cc | 13 +- tensorflow/core/lib/gif/gif_io.h | 2 +- tensorflow/core/ops/image_ops.cc | 45 ++ tensorflow/core/ops/image_ops_test.cc | 33 +- tensorflow/python/ops/image_ops_impl.py | 17 + tensorflow/python/ops/image_ops_test.py | 201 ++++---- .../api/golden/v1/tensorflow.raw_ops.pbtxt | 4 + .../api/golden/v2/tensorflow.raw_ops.pbtxt | 4 + 11 files changed, 738 insertions(+), 95 deletions(-) create mode 100644 tensorflow/core/api_def/base_api/api_def_DecodeImage.pbtxt create mode 100644 tensorflow/core/api_def/python_api/api_def_DecodeImage.pbtxt diff --git a/tensorflow/core/api_def/base_api/api_def_DecodeImage.pbtxt b/tensorflow/core/api_def/base_api/api_def_DecodeImage.pbtxt new file mode 100644 index 00000000000..c534425eb24 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_DecodeImage.pbtxt @@ -0,0 +1,51 @@ +op { + graph_op_name: "DecodeImage" + in_arg { + name: "contents" + description: <<END +0-D. The encoded image bytes. +END + } + out_arg { + name: "image" + description: <<END +3-D with shape `[height, width, channels]` or 4-D with shape +`[frame, height, width, channels]`.. +END + } + attr { + name: "channels" + description: <<END +Number of color channels for the decoded image. +END + } + attr { + name: "dtype" + description: <<END +The desired DType of the returned Tensor. +END + } + attr { + name: "expand_animations" + description: <<END +Controls the output shape of the returned op. If True, the returned op will +produce a 3-D tensor for PNG, JPEG, and BMP files; and a 4-D tensor for all +GIFs, whether animated or not. If, False, the returned op will produce a 3-D +tensor for all file types and will truncate animated GIFs to the first frame. +END + } + summary: "Function for decode_bmp, decode_gif, decode_jpeg, and decode_png." + description: <<END +Detects whether an image is a BMP, GIF, JPEG, or PNG, and performs the +appropriate operation to convert the input bytes string into a Tensor of type +dtype. + +*NOTE*: decode_gif returns a 4-D array [num_frames, height, width, 3], as +opposed to decode_bmp, decode_jpeg and decode_png, which return 3-D arrays +[height, width, num_channels]. Make sure to take this into account when +constructing your graph if you are intermixing GIF files with BMP, JPEG, and/or +PNG files. Alternately, set the expand_animations argument of this function to +False, in which case the op will return 3-dimensional tensors and will truncate +animated GIF files to the first frame. +END +} diff --git a/tensorflow/core/api_def/python_api/api_def_DecodeImage.pbtxt b/tensorflow/core/api_def/python_api/api_def_DecodeImage.pbtxt new file mode 100644 index 00000000000..54c4f6eeeee --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_DecodeImage.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "DecodeImage" + visibility: HIDDEN +} diff --git a/tensorflow/core/kernels/decode_image_op.cc b/tensorflow/core/kernels/decode_image_op.cc index 3f878ac6b95..8d0c0d89d43 100644 --- a/tensorflow/core/kernels/decode_image_op.cc +++ b/tensorflow/core/kernels/decode_image_op.cc @@ -17,7 +17,10 @@ limitations under the License. #include <memory> +#define EIGEN_USE_THREADS + #include "absl/strings/escaping.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" @@ -33,19 +36,31 @@ limitations under the License. namespace tensorflow { namespace { +// Magic bytes (hex) for each image format. +// https://en.wikipedia.org/wiki/List_of_file_signatures +// WARNING: Changing `static const` to `constexpr` requires first checking that +// it works with supported MSVC version. +// https://docs.microsoft.com/en-us/cpp/cpp/constexpr-cpp?redirectedfrom=MSDN&view=vs-2019 +static const char kPngMagicBytes[] = "\x89\x50\x4E\x47\x0D\x0A\x1A\x0A"; +static const char kGifMagicBytes[] = "\x47\x49\x46\x38"; +static const char kBmpMagicBytes[] = "\x42\x4d"; +// The 4th byte of JPEG is '\xe0' or '\xe1', so check just the first three. +static const char kJpegMagicBytes[] = "\xff\xd8\xff"; + enum FileFormat { kUnknownFormat = 0, kPngFormat = 1, kJpgFormat = 2, kGifFormat = 3, + kBmpFormat = 4, }; // Classify the contents of a file based on starting bytes (the magic number). FileFormat ClassifyFileFormat(StringPiece data) { - // The 4th byte of JPEG is '\xe0' or '\xe1', so check just the first three - if (absl::StartsWith(data, "\xff\xd8\xff")) return kJpgFormat; - if (absl::StartsWith(data, "\x89PNG\r\n\x1a\n")) return kPngFormat; - if (absl::StartsWith(data, "\x47\x49\x46\x38")) return kGifFormat; + if (absl::StartsWith(data, kJpegMagicBytes)) return kJpgFormat; + if (absl::StartsWith(data, kPngMagicBytes)) return kPngFormat; + if (absl::StartsWith(data, kGifMagicBytes)) return kGifFormat; + if (absl::StartsWith(data, kBmpMagicBytes)) return kBmpFormat; return kUnknownFormat; } @@ -339,11 +354,447 @@ class DecodeImageOp : public OpKernel { jpeg::UncompressFlags flags_; }; +// Decode an image. Supported image formats are JPEG, PNG, GIF and BMP. This is +// a newer version of `DecodeImageOp` for enabling image data parsing to take +// place in kernels only, reducing security vulnerabilities and redundancy. +class DecodeImageV2Op : public OpKernel { + public: + explicit DecodeImageV2Op(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("channels", &channels_)); + OP_REQUIRES( + context, + channels_ == 0 || channels_ == 1 || channels_ == 3 || channels_ == 4, + errors::InvalidArgument("`channels` must be 0, 1, 3 or 4 but got ", + channels_)); + OP_REQUIRES_OK(context, context->GetAttr("dtype", &data_type_)); + OP_REQUIRES( + context, + data_type_ == DataType::DT_UINT8 || data_type_ == DataType::DT_UINT16 || + data_type_ == DataType::DT_FLOAT, + errors::InvalidArgument( + "`dtype` must be unit8, unit16, float but got: ", data_type_)); + OP_REQUIRES_OK(context, + context->GetAttr("expand_animations", &expand_animations_)); + } + + // Helper for decoding BMP. + inline int32 ByteSwapInt32ForBigEndian(int32 x) { +#if (__BYTE_ORDER__ == __ORDER_BIG_ENDIAN__) + return le32toh(x); +#else + return x; +#endif + } + + void Compute(OpKernelContext* context) override { + const Tensor& contents = context->input(0); + OP_REQUIRES( + context, TensorShapeUtils::IsScalar(contents.shape()), + errors::InvalidArgument("`contents` must be scalar but got shape", + contents.shape().DebugString())); + const StringPiece input = contents.scalar<tstring>()(); + OP_REQUIRES(context, !input.empty(), + errors::InvalidArgument("Input is empty.")); + OP_REQUIRES(context, input.size() <= std::numeric_limits<int>::max(), + errors::InvalidArgument( + "Input contents are too large for int: ", input.size())); + + // Parse magic bytes to determine file format. + switch (ClassifyFileFormat(input)) { + case kJpgFormat: + DecodeJpegV2(context, input); + break; + case kPngFormat: + DecodePngV2(context, input); + break; + case kGifFormat: + DecodeGifV2(context, input); + break; + case kBmpFormat: + DecodeBmpV2(context, input); + break; + case kUnknownFormat: + OP_REQUIRES(context, false, + errors::InvalidArgument("Unknown image file format. One of " + "JPEG, PNG, GIF, BMP required.")); + break; + } + } + + void DecodeJpegV2(OpKernelContext* context, StringPiece input) { + OP_REQUIRES(context, channels_ == 0 || channels_ == 1 || channels_ == 3, + errors::InvalidArgument("JPEG does not support 4 channels.")); + + // Use default settings for `DecodeImage` op. Use local copy of flags to + // avoid race condition as the class member is shared among different + // invocations. + jpeg::UncompressFlags flags = jpeg::UncompressFlags(); + flags.components = channels_; + flags.dct_method = JDCT_IFAST; + + // Output tensor and the image buffer size. + Tensor* output = nullptr; + int buffer_size = 0; + + // Decode JPEG. Directly allocate to the output buffer if data type is + // uint8 (to save extra copying). Otherwise, allocate a new uint8 buffer + // with buffer size. `jpeg::Uncompress` support unit8 only. + uint8* buffer = jpeg::Uncompress( + input.data(), input.size(), flags, nullptr /* nwarn */, + [&](int width, int height, int channels) -> uint8* { + buffer_size = height * width * channels; + Status status = context->allocate_output( + 0, TensorShape({height, width, channels}), &output); + if (!status.ok()) { + VLOG(1) << status; + context->SetStatus(status); + return nullptr; + } + + if (data_type_ == DataType::DT_UINT8) { + return output->flat<uint8>().data(); + } else { + return new uint8[buffer_size]; + } + }); + + OP_REQUIRES(context, buffer, + errors::InvalidArgument("jpeg::Uncompress failed.")); + + // For when desired data type if unit8, the output buffer is already + // allocated during the `jpeg::Uncompress` call above; return. + if (data_type_ == DataType::DT_UINT8) { + return; + } + // Make sure we don't forget to deallocate `buffer`. + std::unique_ptr<uint8[]> buffer_unique_ptr(buffer); + + // Convert uint8 image data to desired data type. + // Use eigen threadpooling to speed up the copy operation. + const auto& device = context->eigen_device<Eigen::ThreadPoolDevice>(); + TTypes<uint8>::UnalignedConstFlat buffer_view(buffer, buffer_size); + if (data_type_ == DataType::DT_UINT16) { + uint16 scale = floor((std::numeric_limits<uint16>::max() + 1) / + (std::numeric_limits<uint8>::max() + 1)); + // Fill output tensor with desired dtype. + output->flat<uint16>().device(device) = + buffer_view.cast<uint16>() * scale; + } else if (data_type_ == DataType::DT_FLOAT) { + float scale = 1. / std::numeric_limits<uint8>::max(); + // Fill output tensor with desired dtype. + output->flat<float>().device(device) = buffer_view.cast<float>() * scale; + } + } + + void DecodePngV2(OpKernelContext* context, StringPiece input) { + int channel_bits; + channel_bits = (data_type_ == DataType::DT_UINT8) ? 8 : 16; + png::DecodeContext decode; + OP_REQUIRES( + context, png::CommonInitDecode(input, channels_, channel_bits, &decode), + errors::InvalidArgument("Invalid PNG. Failed to initialize decoder.")); + + // Verify that width and height are not too large: + // - verify width and height don't overflow int. + // - width can later be multiplied by channels_ and sizeof(uint16), so + // verify single dimension is not too large. + // - verify when width and height are multiplied together, there are a few + // bits to spare as well. + const int width = static_cast<int>(decode.width); + const int height = static_cast<int>(decode.height); + const int64 total_size = + static_cast<int64>(width) * static_cast<int64>(height); + if (width != static_cast<int64>(decode.width) || width <= 0 || + width >= (1LL << 27) || height != static_cast<int64>(decode.height) || + height <= 0 || height >= (1LL << 27) || total_size >= (1LL << 29)) { + png::CommonFreeDecode(&decode); + OP_REQUIRES(context, false, + errors::InvalidArgument("PNG size too large for int: ", + decode.width, " by ", decode.height)); + } + + Tensor* output = nullptr; + const auto status = context->allocate_output( + 0, TensorShape({height, width, decode.channels}), &output); + if (!status.ok()) png::CommonFreeDecode(&decode); + OP_REQUIRES_OK(context, status); + + if (data_type_ == DataType::DT_UINT8) { + OP_REQUIRES( + context, + png::CommonFinishDecode( + reinterpret_cast<png_bytep>(output->flat<uint8>().data()), + decode.channels * width * sizeof(uint8), &decode), + errors::InvalidArgument("Invalid PNG data, size ", input.size())); + } else if (data_type_ == DataType::DT_UINT16) { + OP_REQUIRES( + context, + png::CommonFinishDecode( + reinterpret_cast<png_bytep>(output->flat<uint16>().data()), + decode.channels * width * sizeof(uint16), &decode), + errors::InvalidArgument("Invalid PNG data, size ", input.size())); + } else if (data_type_ == DataType::DT_FLOAT) { + // `png::CommonFinishDecode` does not support `float`. First allocate + // uint16 buffer for the image and decode in uint16 (lossless). Wrap the + // buffer in `unique_ptr` so that we don't forget to delete the buffer. + std::unique_ptr<uint16[]> buffer( + new uint16[height * width * decode.channels]); + OP_REQUIRES( + context, + png::CommonFinishDecode(reinterpret_cast<png_bytep>(buffer.get()), + decode.channels * width * sizeof(uint16), + &decode), + errors::InvalidArgument("Invalid PNG data, size ", input.size())); + + // Convert uint16 image data to desired data type. + // Use eigen threadpooling to speed up the copy operation. + const auto& device = context->eigen_device<Eigen::ThreadPoolDevice>(); + TTypes<uint16, 3>::UnalignedConstTensor buf(buffer.get(), height, width, + decode.channels); + float scale = 1. / std::numeric_limits<uint16>::max(); + // Fill output tensor with desired dtype. + output->tensor<float, 3>().device(device) = buf.cast<float>() * scale; + } + } + + void DecodeGifV2(OpKernelContext* context, StringPiece input) { + // GIF has 3 channels. + OP_REQUIRES(context, channels_ == 0 || channels_ == 3, + errors::InvalidArgument("channels must be 0 or 3 for GIF, got ", + channels_)); + + // Decode GIF, allocating tensor if dtype is uint8, otherwise defer tensor + // allocation til after dtype conversion is done. `gif`::Decode` supports + // uint8 only. + Tensor* output = nullptr; + int buffer_size = 0; + string error_string; + uint8* buffer = gif::Decode( + input.data(), input.size(), + [&](int num_frames, int width, int height, int channels) -> uint8* { + buffer_size = num_frames * height * width * channels; + + Status status; + if (expand_animations_) { + status = context->allocate_output( + 0, TensorShape({num_frames, height, width, channels}), &output); + } else { + status = context->allocate_output( + 0, TensorShape({height, width, channels}), &output); + } + if (!status.ok()) { + VLOG(1) << status; + context->SetStatus(status); + return nullptr; + } + + if (data_type_ == DataType::DT_UINT8) { + return output->flat<uint8>().data(); + } else { + return new uint8[buffer_size]; + } + }, + &error_string, expand_animations_); + + OP_REQUIRES(context, buffer, + errors::InvalidArgument("Invalid GIF data (size ", input.size(), + "), ", error_string)); + + // For when desired data type is unit8, the output buffer is already + // allocated during the `gif::Decode` call above; return. + if (data_type_ == DataType::DT_UINT8) { + return; + } + // Make sure we don't forget to deallocate `buffer`. + std::unique_ptr<uint8[]> buffer_unique_ptr(buffer); + + // Convert the raw uint8 buffer to desired dtype. + // Use eigen threadpooling to speed up the copy operation. + TTypes<uint8>::UnalignedConstFlat buffer_view(buffer, buffer_size); + const auto& device = context->eigen_device<Eigen::ThreadPoolDevice>(); + if (data_type_ == DataType::DT_UINT16) { + uint16 scale = floor((std::numeric_limits<uint16>::max() + 1) / + (std::numeric_limits<uint8>::max() + 1)); + // Fill output tensor with desired dtype. + output->flat<uint16>().device(device) = + buffer_view.cast<uint16>() * scale; + } else if (data_type_ == DataType::DT_FLOAT) { + float scale = 1. / std::numeric_limits<uint8>::max(); + // Fill output tensor with desired dtype. + output->flat<float>().device(device) = buffer_view.cast<float>() * scale; + } + } + + void DecodeBmpV2(OpKernelContext* context, StringPiece input) { + OP_REQUIRES(context, channels_ == 0 || channels_ == 3, + errors::InvalidArgument( + "`channels` must be 0 or 3 for BMP, but got ", channels_)); + + OP_REQUIRES(context, (32 <= input.size()), + errors::InvalidArgument("Incomplete bmp content, requires at " + "least 32 bytes to find the header " + "size, width, height, and bpp, got ", + input.size(), " bytes")); + + const uint8* img_bytes = reinterpret_cast<const uint8*>(input.data()); + int32 header_size_ = internal::SubtleMustCopy( + *(reinterpret_cast<const int32*>(img_bytes + 10))); + const int32 header_size = ByteSwapInt32ForBigEndian(header_size_); + int32 width_ = internal::SubtleMustCopy( + *(reinterpret_cast<const int32*>(img_bytes + 18))); + const int32 width = ByteSwapInt32ForBigEndian(width_); + int32 height_ = internal::SubtleMustCopy( + *(reinterpret_cast<const int32*>(img_bytes + 22))); + const int32 height = ByteSwapInt32ForBigEndian(height_); + int32 bpp_ = internal::SubtleMustCopy( + *(reinterpret_cast<const int32*>(img_bytes + 28))); + const int32 bpp = ByteSwapInt32ForBigEndian(bpp_); + + if (channels_) { + OP_REQUIRES(context, (channels_ == bpp / 8), + errors::InvalidArgument( + "channels attribute ", channels_, + " does not match bits per pixel from file ", bpp / 8)); + } else { + channels_ = bpp / 8; + } + + // Current implementation only supports 1, 3 or 4 channel + // bitmaps. + OP_REQUIRES(context, (channels_ == 1 || channels_ == 3 || channels_ == 4), + errors::InvalidArgument( + "Number of channels must be 1, 3 or 4, was ", channels_)); + + OP_REQUIRES(context, width > 0, + errors::InvalidArgument("Width must be positive")); + OP_REQUIRES(context, height != 0, + errors::InvalidArgument("Height must be nonzero")); + OP_REQUIRES(context, header_size >= 0, + errors::InvalidArgument("header size must be nonnegative")); + + // The real requirement is < 2^31 minus some headers and channel data, + // so rounding down to something that's still ridiculously big. + OP_REQUIRES( + context, + (static_cast<int64>(width) * std::abs(static_cast<int64>(height))) < + static_cast<int64>(std::numeric_limits<int32_t>::max() / 8), + errors::InvalidArgument( + "Total possible pixel bytes must be less than 2^30")); + + const int32 abs_height = abs(height); + + // there may be padding bytes when the width is not a multiple of 4 bytes + const int row_size = (channels_ * width + 3) / 4 * 4; + + const int64 last_pixel_offset = static_cast<int64>(header_size) + + (abs_height - 1) * row_size + + (width - 1) * channels_; + + // [expected file size] = [last pixel offset] + [last pixel size=channels] + const int64 expected_file_size = last_pixel_offset + channels_; + + OP_REQUIRES( + context, (expected_file_size <= input.size()), + errors::InvalidArgument("Incomplete bmp content, requires at least ", + expected_file_size, " bytes, got ", + input.size(), " bytes")); + + // if height is negative, data layout is top down + // otherwise, it's bottom up. + bool top_down = (height < 0); + + // Decode image, allocating tensor once the image size is known. + Tensor* output = nullptr; + OP_REQUIRES_OK( + context, context->allocate_output( + 0, TensorShape({abs_height, width, channels_}), &output)); + + const uint8* bmp_pixels = &img_bytes[header_size]; + + if (data_type_ == DataType::DT_UINT8) { + DecodeBMP(bmp_pixels, row_size, output->flat<uint8>().data(), width, + abs_height, channels_, top_down); + } else { + std::unique_ptr<uint8[]> buffer(new uint8[height * width * channels_]); + DecodeBMP(bmp_pixels, row_size, buffer.get(), width, abs_height, + channels_, top_down); + TTypes<uint8, 3>::UnalignedConstTensor buf(buffer.get(), height, width, + channels_); + // Convert the raw uint8 buffer to desired dtype. + // Use eigen threadpooling to speed up the copy operation. + const auto& device = context->eigen_device<Eigen::ThreadPoolDevice>(); + if (data_type_ == DataType::DT_UINT16) { + uint16 scale = floor((std::numeric_limits<uint16>::max() + 1) / + (std::numeric_limits<uint8>::max() + 1)); + // Fill output tensor with desired dtype. + output->tensor<uint16, 3>().device(device) = buf.cast<uint16>() * scale; + } else if (data_type_ == DataType::DT_FLOAT) { + float scale = 1. / std::numeric_limits<uint8>::max(); + // Fill output tensor with desired dtype. + output->tensor<float, 3>().device(device) = buf.cast<float>() * scale; + } + } + } + + void DecodeBMP(const uint8* input, const int row_size, uint8* const output, + const int width, const int height, const int channels, + bool top_down); + + private: + int channels_ = 0; + DataType data_type_; + bool expand_animations_; +}; + REGISTER_KERNEL_BUILDER(Name("DecodeJpeg").Device(DEVICE_CPU), DecodeImageOp); REGISTER_KERNEL_BUILDER(Name("DecodePng").Device(DEVICE_CPU), DecodeImageOp); REGISTER_KERNEL_BUILDER(Name("DecodeGif").Device(DEVICE_CPU), DecodeImageOp); REGISTER_KERNEL_BUILDER(Name("DecodeAndCropJpeg").Device(DEVICE_CPU), DecodeImageOp); +REGISTER_KERNEL_BUILDER(Name("DecodeImage").Device(DEVICE_CPU), + DecodeImageV2Op); + +void DecodeImageV2Op::DecodeBMP(const uint8* input, const int row_size, + uint8* const output, const int width, + const int height, const int channels, + bool top_down) { + for (int i = 0; i < height; i++) { + int src_pos; + int dst_pos; + + for (int j = 0; j < width; j++) { + if (!top_down) { + src_pos = ((height - 1 - i) * row_size) + j * channels; + } else { + src_pos = i * row_size + j * channels; + } + + dst_pos = (i * width + j) * channels; + + switch (channels) { + case 1: + output[dst_pos] = input[src_pos]; + break; + case 3: + // BGR -> RGB + output[dst_pos] = input[src_pos + 2]; + output[dst_pos + 1] = input[src_pos + 1]; + output[dst_pos + 2] = input[src_pos]; + break; + case 4: + // BGRA -> RGBA + output[dst_pos] = input[src_pos + 2]; + output[dst_pos + 1] = input[src_pos + 1]; + output[dst_pos + 2] = input[src_pos]; + output[dst_pos + 3] = input[src_pos + 3]; + break; + default: + LOG(FATAL) << "Unexpected number of channels: " << channels; + break; + } + } + } +} } // namespace } // namespace tensorflow diff --git a/tensorflow/core/lib/gif/gif_io.cc b/tensorflow/core/lib/gif/gif_io.cc index dc5406920a4..32e2f6dfa52 100644 --- a/tensorflow/core/lib/gif/gif_io.cc +++ b/tensorflow/core/lib/gif/gif_io.cc @@ -55,7 +55,7 @@ static const char* GifErrorStringNonNull(int error_code) { uint8* Decode(const void* srcdata, int datasize, const std::function<uint8*(int, int, int, int)>& allocate_output, - string* error_string) { + string* error_string, bool expand_animations) { int error_code = D_GIF_SUCCEEDED; InputBufferInfo info = {reinterpret_cast<const uint8*>(srcdata), datasize}; GifFileType* gif_file = @@ -82,10 +82,13 @@ uint8* Decode(const void* srcdata, int datasize, return nullptr; } + int target_num_frames = gif_file->ImageCount; + if (!expand_animations) target_num_frames = 1; + // Don't request more memory than needed for each frame, preventing OOM int max_frame_width = 0; int max_frame_height = 0; - for (int k = 0; k < gif_file->ImageCount; k++) { + for (int k = 0; k < target_num_frames; k++) { SavedImage* si = &gif_file->SavedImages[k]; if (max_frame_height < si->ImageDesc.Height) max_frame_height = si->ImageDesc.Height; @@ -93,14 +96,14 @@ uint8* Decode(const void* srcdata, int datasize, max_frame_width = si->ImageDesc.Width; } - const int num_frames = gif_file->ImageCount; const int width = max_frame_width; const int height = max_frame_height; const int channel = 3; - uint8* const dstdata = allocate_output(num_frames, width, height, channel); + uint8* const dstdata = + allocate_output(target_num_frames, width, height, channel); if (!dstdata) return nullptr; - for (int k = 0; k < num_frames; k++) { + for (int k = 0; k < target_num_frames; k++) { uint8* this_dst = dstdata + k * width * channel * height; SavedImage* this_image = &gif_file->SavedImages[k]; diff --git a/tensorflow/core/lib/gif/gif_io.h b/tensorflow/core/lib/gif/gif_io.h index e46a7917398..ae7d5125bd7 100644 --- a/tensorflow/core/lib/gif/gif_io.h +++ b/tensorflow/core/lib/gif/gif_io.h @@ -44,7 +44,7 @@ namespace gif { uint8* Decode(const void* srcdata, int datasize, const std::function<uint8*(int, int, int, int)>& allocate_output, - string* error_string); + string* error_string, bool expand_animations = true); } // namespace gif } // namespace tensorflow diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc index e11f14b8538..43ee65c4ab4 100644 --- a/tensorflow/core/ops/image_ops.cc +++ b/tensorflow/core/ops/image_ops.cc @@ -87,6 +87,40 @@ Status DecodeImageShapeFn(InferenceContext* c) { return Status::OK(); } +Status DecodeImageV2ShapeFn(InferenceContext* c) { + ShapeHandle unused; + int32 channels; + bool expand_animations; + DimensionHandle channels_dim; + + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); + TF_RETURN_IF_ERROR(c->GetAttr("channels", &channels)); + TF_RETURN_IF_ERROR(c->GetAttr("expand_animations", &expand_animations)); + + if (channels == 0) { + channels_dim = c->UnknownDim(); + } else { + if (channels < 0) { + return errors::InvalidArgument("channels must be non-negative, got ", + channels); + } + channels_dim = c->MakeDim(channels); + } + + // `expand_animations` set to true will return 4-D shapes for GIF. 3-D shapes + // will be returned for jpg, png, and bmp. `expand_animations` set to false + // will always return 3-D shapes for all (jpg, png, bmp, gif). + if (expand_animations) { + c->set_output(0, c->UnknownShape()); + return Status::OK(); + } else { + c->set_output(0, + c->MakeShape({InferenceContext::kUnknownDim, + InferenceContext::kUnknownDim, channels_dim})); + return Status::OK(); + } +} + Status EncodeImageShapeFn(InferenceContext* c) { ShapeHandle unused; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &unused)); @@ -412,6 +446,17 @@ REGISTER_OP("RandomCrop") }); // TODO(shlens): Support variable rank in RandomCrop. +// -------------------------------------------------------------------------- +REGISTER_OP("DecodeImage") + .Input("contents: string") + // Setting `channels` to 0 means using the inherent number of channels in + // the image. + .Attr("channels: int = 0") + .Attr("dtype: {uint8, uint16, float32} = DT_UINT8") + .Output("image: dtype") + .Attr("expand_animations: bool = true") + .SetShapeFn(DecodeImageV2ShapeFn); + // -------------------------------------------------------------------------- REGISTER_OP("DecodeJpeg") .Input("contents: string") diff --git a/tensorflow/core/ops/image_ops_test.cc b/tensorflow/core/ops/image_ops_test.cc index e517e750955..4d0c1fceb28 100644 --- a/tensorflow/core/ops/image_ops_test.cc +++ b/tensorflow/core/ops/image_ops_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference_testutil.h" #include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" @@ -61,6 +62,34 @@ TEST(ImageOpsTest, DecodeGif) { INFER_OK(op, "[]", "[?,?,?,3]"); } +TEST(ImageOpTest, DecodeImage) { + ShapeInferenceTestOp op("DecodeImage"); + + // Rank check. + INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[1]"); + + // Set `expand_animations` to false. Output is always ?,?,?. + TF_ASSERT_OK(NodeDefBuilder("test", "DecodeImage") + .Input({"img", 0, DT_STRING}) + .Attr("expand_animations", false) + .Finalize(&op.node_def)); + INFER_OK(op, "[]", "[?,?,?]"); + + // Set `expand_animations` to false. Output shape is not known (3D or 4D). + TF_ASSERT_OK(NodeDefBuilder("test", "DecodeImage") + .Input({"img", 0, DT_STRING}) + .Attr("expand_animations", true) + .Finalize(&op.node_def)); + INFER_OK(op, "[]", "?"); + + // Negative channel value is rejected. + TF_ASSERT_OK(NodeDefBuilder("test", "DecodeImage") + .Input({"img", 0, DT_STRING}) + .Attr("channels", -1) + .Finalize(&op.node_def)); + INFER_ERROR("channels must be non-negative, got -1", op, "[]"); +} + TEST(ImageOpsTest, DecodeImage_ShapeFn) { for (const char* op_name : {"DecodeJpeg", "DecodePng"}) { ShapeInferenceTestOp op(op_name); @@ -325,8 +354,8 @@ TEST(ImageOpsTest, DrawBoundingBoxes_ShapeFn) { // Check images. INFER_ERROR("must be rank 4", op, "[1,?,3];?"); - INFER_ERROR("should be either 1 (GRY), 3 (RGB), or 4 (RGBA)", - op, "[1,?,?,5];?"); + INFER_ERROR("should be either 1 (GRY), 3 (RGB), or 4 (RGBA)", op, + "[1,?,?,5];?"); // Check boxes. INFER_ERROR("must be rank 3", op, "[1,?,?,4];[1,4]"); diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py index 683681b5c98..bbce25724e7 100644 --- a/tensorflow/python/ops/image_ops_impl.py +++ b/tensorflow/python/ops/image_ops_impl.py @@ -2634,6 +2634,23 @@ def decode_image(contents, ValueError: On incorrect number of channels. """ with ops.name_scope(name, 'decode_image'): + if compat.forward_compatible(2020, 7, 14): + channels = 0 if channels is None else channels + if dtype not in [dtypes.float32, dtypes.uint8, dtypes.uint16]: + dest_dtype = dtype + dtype = dtypes.uint16 + return convert_image_dtype(gen_image_ops.decode_image( + contents=contents, + channels=channels, + expand_animations=expand_animations, + dtype=dtype), dest_dtype) + else: + return gen_image_ops.decode_image( + contents=contents, + channels=channels, + expand_animations=expand_animations, + dtype=dtype) + if channels not in (None, 0, 1, 3, 4): raise ValueError('channels must be in (None, 0, 1, 3, 4)') substr = string_ops.substr(contents, 0, 3) diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py index 0206ccf9b33..a05209c2038 100644 --- a/tensorflow/python/ops/image_ops_test.py +++ b/tensorflow/python/ops/image_ops_test.py @@ -30,6 +30,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session +from tensorflow.python.compat import compat from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -5173,107 +5174,141 @@ class SobelEdgesTest(test_util.TensorFlowTestCase): @test_util.run_all_in_graph_and_eager_modes class DecodeImageTest(test_util.TensorFlowTestCase): + _FORWARD_COMPATIBILITY_HORIZONS = [ + (2020, 6, 11), + (2020, 7, 11), + (2525, 1, 1), # future behavior + ] + def testJpegUint16(self): - with self.cached_session(use_gpu=True) as sess: - base = "tensorflow/core/lib/jpeg/testdata" - jpeg0 = io_ops.read_file(os.path.join(base, "jpeg_merge_test1.jpg")) - image0 = image_ops.decode_image(jpeg0, dtype=dtypes.uint16) - image1 = image_ops.convert_image_dtype(image_ops.decode_jpeg(jpeg0), - dtypes.uint16) - image0, image1 = self.evaluate([image0, image1]) - self.assertAllEqual(image0, image1) + for horizon in self._FORWARD_COMPATIBILITY_HORIZONS: + with compat.forward_compatibility_horizon(*horizon): + with self.cached_session(use_gpu=True) as sess: + base = "tensorflow/core/lib/jpeg/testdata" + jpeg0 = io_ops.read_file(os.path.join(base, "jpeg_merge_test1.jpg")) + image0 = image_ops.decode_image(jpeg0, dtype=dtypes.uint16) + image1 = image_ops.convert_image_dtype(image_ops.decode_jpeg(jpeg0), + dtypes.uint16) + image0, image1 = self.evaluate([image0, image1]) + self.assertAllEqual(image0, image1) def testPngUint16(self): - with self.cached_session(use_gpu=True) as sess: - base = "tensorflow/core/lib/png/testdata" - png0 = io_ops.read_file(os.path.join(base, "lena_rgba.png")) - image0 = image_ops.decode_image(png0, dtype=dtypes.uint16) - image1 = image_ops.convert_image_dtype( - image_ops.decode_png(png0, dtype=dtypes.uint16), dtypes.uint16) - image0, image1 = self.evaluate([image0, image1]) - self.assertAllEqual(image0, image1) + for horizon in self._FORWARD_COMPATIBILITY_HORIZONS: + with compat.forward_compatibility_horizon(*horizon): + with self.cached_session(use_gpu=True) as sess: + base = "tensorflow/core/lib/png/testdata" + png0 = io_ops.read_file(os.path.join(base, "lena_rgba.png")) + image0 = image_ops.decode_image(png0, dtype=dtypes.uint16) + image1 = image_ops.convert_image_dtype( + image_ops.decode_png(png0, dtype=dtypes.uint16), dtypes.uint16) + image0, image1 = self.evaluate([image0, image1]) + self.assertAllEqual(image0, image1) - # NumPy conversions should happen before - x = np.random.randint(256, size=(4, 4, 3), dtype=np.uint16) - x_str = image_ops_impl.encode_png(x) - x_dec = image_ops_impl.decode_image( - x_str, channels=3, dtype=dtypes.uint16) - self.assertAllEqual(x, x_dec) + # NumPy conversions should happen before + x = np.random.randint(256, size=(4, 4, 3), dtype=np.uint16) + x_str = image_ops_impl.encode_png(x) + x_dec = image_ops_impl.decode_image( + x_str, channels=3, dtype=dtypes.uint16) + self.assertAllEqual(x, x_dec) def testGifUint16(self): - with self.cached_session(use_gpu=True) as sess: - base = "tensorflow/core/lib/gif/testdata" - gif0 = io_ops.read_file(os.path.join(base, "scan.gif")) - image0 = image_ops.decode_image(gif0, dtype=dtypes.uint16) - image1 = image_ops.convert_image_dtype(image_ops.decode_gif(gif0), - dtypes.uint16) - image0, image1 = self.evaluate([image0, image1]) - self.assertAllEqual(image0, image1) + for horizon in self._FORWARD_COMPATIBILITY_HORIZONS: + with compat.forward_compatibility_horizon(*horizon): + with self.cached_session(use_gpu=True) as sess: + base = "tensorflow/core/lib/gif/testdata" + gif0 = io_ops.read_file(os.path.join(base, "scan.gif")) + image0 = image_ops.decode_image(gif0, dtype=dtypes.uint16) + image1 = image_ops.convert_image_dtype(image_ops.decode_gif(gif0), + dtypes.uint16) + image0, image1 = self.evaluate([image0, image1]) + self.assertAllEqual(image0, image1) def testBmpUint16(self): - with self.cached_session(use_gpu=True) as sess: - base = "tensorflow/core/lib/bmp/testdata" - bmp0 = io_ops.read_file(os.path.join(base, "lena.bmp")) - image0 = image_ops.decode_image(bmp0, dtype=dtypes.uint16) - image1 = image_ops.convert_image_dtype(image_ops.decode_bmp(bmp0), - dtypes.uint16) - image0, image1 = self.evaluate([image0, image1]) - self.assertAllEqual(image0, image1) + for horizon in self._FORWARD_COMPATIBILITY_HORIZONS: + with compat.forward_compatibility_horizon(*horizon): + with self.cached_session(use_gpu=True) as sess: + base = "tensorflow/core/lib/bmp/testdata" + bmp0 = io_ops.read_file(os.path.join(base, "lena.bmp")) + image0 = image_ops.decode_image(bmp0, dtype=dtypes.uint16) + image1 = image_ops.convert_image_dtype(image_ops.decode_bmp(bmp0), + dtypes.uint16) + image0, image1 = self.evaluate([image0, image1]) + self.assertAllEqual(image0, image1) def testJpegFloat32(self): - with self.cached_session(use_gpu=True) as sess: - base = "tensorflow/core/lib/jpeg/testdata" - jpeg0 = io_ops.read_file(os.path.join(base, "jpeg_merge_test1.jpg")) - image0 = image_ops.decode_image(jpeg0, dtype=dtypes.float32) - image1 = image_ops.convert_image_dtype(image_ops.decode_jpeg(jpeg0), - dtypes.float32) - image0, image1 = self.evaluate([image0, image1]) - self.assertAllEqual(image0, image1) + for horizon in self._FORWARD_COMPATIBILITY_HORIZONS: + with compat.forward_compatibility_horizon(*horizon): + with self.cached_session(use_gpu=True) as sess: + base = "tensorflow/core/lib/jpeg/testdata" + jpeg0 = io_ops.read_file(os.path.join(base, "jpeg_merge_test1.jpg")) + image0 = image_ops.decode_image(jpeg0, dtype=dtypes.float32) + image1 = image_ops.convert_image_dtype(image_ops.decode_jpeg(jpeg0), + dtypes.float32) + image0, image1 = self.evaluate([image0, image1]) + self.assertAllEqual(image0, image1) def testPngFloat32(self): - with self.cached_session(use_gpu=True) as sess: - base = "tensorflow/core/lib/png/testdata" - png0 = io_ops.read_file(os.path.join(base, "lena_rgba.png")) - image0 = image_ops.decode_image(png0, dtype=dtypes.float32) - image1 = image_ops.convert_image_dtype( - image_ops.decode_png(png0, dtype=dtypes.uint16), dtypes.float32) - image0, image1 = self.evaluate([image0, image1]) - self.assertAllEqual(image0, image1) + for horizon in self._FORWARD_COMPATIBILITY_HORIZONS: + with compat.forward_compatibility_horizon(*horizon): + with self.cached_session(use_gpu=True) as sess: + base = "tensorflow/core/lib/png/testdata" + png0 = io_ops.read_file(os.path.join(base, "lena_rgba.png")) + image0 = image_ops.decode_image(png0, dtype=dtypes.float32) + image1 = image_ops.convert_image_dtype( + image_ops.decode_png(png0, dtype=dtypes.uint16), dtypes.float32) + image0, image1 = self.evaluate([image0, image1]) + self.assertAllEqual(image0, image1) def testGifFloat32(self): - with self.cached_session(use_gpu=True) as sess: - base = "tensorflow/core/lib/gif/testdata" - gif0 = io_ops.read_file(os.path.join(base, "scan.gif")) - image0 = image_ops.decode_image(gif0, dtype=dtypes.float32) - image1 = image_ops.convert_image_dtype(image_ops.decode_gif(gif0), - dtypes.float32) - image0, image1 = self.evaluate([image0, image1]) - self.assertAllEqual(image0, image1) + for horizon in self._FORWARD_COMPATIBILITY_HORIZONS: + with compat.forward_compatibility_horizon(*horizon): + with self.cached_session(use_gpu=True) as sess: + base = "tensorflow/core/lib/gif/testdata" + gif0 = io_ops.read_file(os.path.join(base, "scan.gif")) + image0 = image_ops.decode_image(gif0, dtype=dtypes.float32) + image1 = image_ops.convert_image_dtype(image_ops.decode_gif(gif0), + dtypes.float32) + image0, image1 = self.evaluate([image0, image1]) + self.assertAllEqual(image0, image1) def testBmpFloat32(self): - with self.cached_session(use_gpu=True) as sess: - base = "tensorflow/core/lib/bmp/testdata" - bmp0 = io_ops.read_file(os.path.join(base, "lena.bmp")) - image0 = image_ops.decode_image(bmp0, dtype=dtypes.float32) - image1 = image_ops.convert_image_dtype(image_ops.decode_bmp(bmp0), - dtypes.float32) - image0, image1 = self.evaluate([image0, image1]) - self.assertAllEqual(image0, image1) + for horizon in self._FORWARD_COMPATIBILITY_HORIZONS: + with compat.forward_compatibility_horizon(*horizon): + with self.cached_session(use_gpu=True) as sess: + base = "tensorflow/core/lib/bmp/testdata" + bmp0 = io_ops.read_file(os.path.join(base, "lena.bmp")) + image0 = image_ops.decode_image(bmp0, dtype=dtypes.float32) + image1 = image_ops.convert_image_dtype(image_ops.decode_bmp(bmp0), + dtypes.float32) + image0, image1 = self.evaluate([image0, image1]) + self.assertAllEqual(image0, image1) def testExpandAnimations(self): - with self.cached_session(use_gpu=True) as sess: - base = "tensorflow/core/lib/gif/testdata" - gif0 = io_ops.read_file(os.path.join(base, "scan.gif")) - image0 = image_ops.decode_image( - gif0, dtype=dtypes.float32, expand_animations=False) - # image_ops.decode_png() handles GIFs and returns 3D tensors - animation = image_ops.decode_gif(gif0) - first_frame = array_ops.gather(animation, 0) - image1 = image_ops.convert_image_dtype(first_frame, dtypes.float32) - image0, image1 = self.evaluate([image0, image1]) - self.assertEqual(len(image0.shape), 3) - self.assertAllEqual(list(image0.shape), [40, 20, 3]) - self.assertAllEqual(image0, image1) + for horizon in self._FORWARD_COMPATIBILITY_HORIZONS: + with compat.forward_compatibility_horizon(*horizon): + with self.cached_session(use_gpu=True) as sess: + base = "tensorflow/core/lib/gif/testdata" + gif0 = io_ops.read_file(os.path.join(base, "scan.gif")) + + # Test `expand_animations=False` case. + image0 = image_ops.decode_image( + gif0, dtype=dtypes.float32, expand_animations=False) + # image_ops.decode_png() handles GIFs and returns 3D tensors + animation = image_ops.decode_gif(gif0) + first_frame = array_ops.gather(animation, 0) + image1 = image_ops.convert_image_dtype(first_frame, dtypes.float32) + image0, image1 = self.evaluate([image0, image1]) + self.assertEqual(len(image0.shape), 3) + self.assertAllEqual(list(image0.shape), [40, 20, 3]) + self.assertAllEqual(image0, image1) + + # Test `expand_animations=True` case. + image2 = image_ops.decode_image(gif0, dtype=dtypes.float32) + image3 = image_ops.convert_image_dtype(animation, dtypes.float32) + image2, image3 = self.evaluate([image2, image3]) + self.assertEqual(len(image2.shape), 4) + self.assertAllEqual(list(image2.shape), [12, 40, 20, 3]) + self.assertAllEqual(image2, image3) if __name__ == "__main__": diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index 8e5303cbea4..1d27408735a 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -1064,6 +1064,10 @@ tf_module { name: "DecodeGif" argspec: "args=[\'contents\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "DecodeImage" + argspec: "args=[\'contents\', \'channels\', \'dtype\', \'expand_animations\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \"<dtype: \'uint8\'>\", \'True\', \'None\'], " + } member_method { name: "DecodeJSONExample" argspec: "args=[\'json_examples\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index 8e5303cbea4..1d27408735a 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -1064,6 +1064,10 @@ tf_module { name: "DecodeGif" argspec: "args=[\'contents\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "DecodeImage" + argspec: "args=[\'contents\', \'channels\', \'dtype\', \'expand_animations\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \"<dtype: \'uint8\'>\", \'True\', \'None\'], " + } member_method { name: "DecodeJSONExample" argspec: "args=[\'json_examples\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "