Refactor DecodeImageOp for the purpose of removing redundant data parsing and format checks from python wrapper and having them take place only in kernels. Remove security concerns. This change:

- Creates new op kernel (`DecodeImageV2Op`) that can decode all four image formats (jpg, png, gif, bmp). `DecodeImage` is the op name. `DecodeBmpOp` is moved into `DecodeImageV2Op`. (Now we have `gen_image_ops.decode_image` as opposed to previous `decode_image` which was a pure python implementation.)
- Updates GIF decoder to take in `expand_animation` flag for decoding just one frame.
- Removes data parsing and format checking logic from python layer entirely.
- Updates magic bytes for detecting image formats.
- Replicates portions of `convert_image_dtype` functionality in kernel (for optionally converting uint8/uint16 -> float32).

PiperOrigin-RevId: 317891936
Change-Id: I84f18e053f6dad845d9f2a61e1119f4de131c85d
This commit is contained in:
Hye Soo Yang 2020-06-23 10:21:31 -07:00 committed by TensorFlower Gardener
parent b07691301f
commit e4d6335bcb
11 changed files with 738 additions and 95 deletions

View File

@ -0,0 +1,51 @@
op {
graph_op_name: "DecodeImage"
in_arg {
name: "contents"
description: <<END
0-D. The encoded image bytes.
END
}
out_arg {
name: "image"
description: <<END
3-D with shape `[height, width, channels]` or 4-D with shape
`[frame, height, width, channels]`..
END
}
attr {
name: "channels"
description: <<END
Number of color channels for the decoded image.
END
}
attr {
name: "dtype"
description: <<END
The desired DType of the returned Tensor.
END
}
attr {
name: "expand_animations"
description: <<END
Controls the output shape of the returned op. If True, the returned op will
produce a 3-D tensor for PNG, JPEG, and BMP files; and a 4-D tensor for all
GIFs, whether animated or not. If, False, the returned op will produce a 3-D
tensor for all file types and will truncate animated GIFs to the first frame.
END
}
summary: "Function for decode_bmp, decode_gif, decode_jpeg, and decode_png."
description: <<END
Detects whether an image is a BMP, GIF, JPEG, or PNG, and performs the
appropriate operation to convert the input bytes string into a Tensor of type
dtype.
*NOTE*: decode_gif returns a 4-D array [num_frames, height, width, 3], as
opposed to decode_bmp, decode_jpeg and decode_png, which return 3-D arrays
[height, width, num_channels]. Make sure to take this into account when
constructing your graph if you are intermixing GIF files with BMP, JPEG, and/or
PNG files. Alternately, set the expand_animations argument of this function to
False, in which case the op will return 3-dimensional tensors and will truncate
animated GIF files to the first frame.
END
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "DecodeImage"
visibility: HIDDEN
}

View File

@ -17,7 +17,10 @@ limitations under the License.
#include <memory>
#define EIGEN_USE_THREADS
#include "absl/strings/escaping.h"
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
@ -33,19 +36,31 @@ limitations under the License.
namespace tensorflow {
namespace {
// Magic bytes (hex) for each image format.
// https://en.wikipedia.org/wiki/List_of_file_signatures
// WARNING: Changing `static const` to `constexpr` requires first checking that
// it works with supported MSVC version.
// https://docs.microsoft.com/en-us/cpp/cpp/constexpr-cpp?redirectedfrom=MSDN&view=vs-2019
static const char kPngMagicBytes[] = "\x89\x50\x4E\x47\x0D\x0A\x1A\x0A";
static const char kGifMagicBytes[] = "\x47\x49\x46\x38";
static const char kBmpMagicBytes[] = "\x42\x4d";
// The 4th byte of JPEG is '\xe0' or '\xe1', so check just the first three.
static const char kJpegMagicBytes[] = "\xff\xd8\xff";
enum FileFormat {
kUnknownFormat = 0,
kPngFormat = 1,
kJpgFormat = 2,
kGifFormat = 3,
kBmpFormat = 4,
};
// Classify the contents of a file based on starting bytes (the magic number).
FileFormat ClassifyFileFormat(StringPiece data) {
// The 4th byte of JPEG is '\xe0' or '\xe1', so check just the first three
if (absl::StartsWith(data, "\xff\xd8\xff")) return kJpgFormat;
if (absl::StartsWith(data, "\x89PNG\r\n\x1a\n")) return kPngFormat;
if (absl::StartsWith(data, "\x47\x49\x46\x38")) return kGifFormat;
if (absl::StartsWith(data, kJpegMagicBytes)) return kJpgFormat;
if (absl::StartsWith(data, kPngMagicBytes)) return kPngFormat;
if (absl::StartsWith(data, kGifMagicBytes)) return kGifFormat;
if (absl::StartsWith(data, kBmpMagicBytes)) return kBmpFormat;
return kUnknownFormat;
}
@ -339,11 +354,447 @@ class DecodeImageOp : public OpKernel {
jpeg::UncompressFlags flags_;
};
// Decode an image. Supported image formats are JPEG, PNG, GIF and BMP. This is
// a newer version of `DecodeImageOp` for enabling image data parsing to take
// place in kernels only, reducing security vulnerabilities and redundancy.
class DecodeImageV2Op : public OpKernel {
public:
explicit DecodeImageV2Op(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("channels", &channels_));
OP_REQUIRES(
context,
channels_ == 0 || channels_ == 1 || channels_ == 3 || channels_ == 4,
errors::InvalidArgument("`channels` must be 0, 1, 3 or 4 but got ",
channels_));
OP_REQUIRES_OK(context, context->GetAttr("dtype", &data_type_));
OP_REQUIRES(
context,
data_type_ == DataType::DT_UINT8 || data_type_ == DataType::DT_UINT16 ||
data_type_ == DataType::DT_FLOAT,
errors::InvalidArgument(
"`dtype` must be unit8, unit16, float but got: ", data_type_));
OP_REQUIRES_OK(context,
context->GetAttr("expand_animations", &expand_animations_));
}
// Helper for decoding BMP.
inline int32 ByteSwapInt32ForBigEndian(int32 x) {
#if (__BYTE_ORDER__ == __ORDER_BIG_ENDIAN__)
return le32toh(x);
#else
return x;
#endif
}
void Compute(OpKernelContext* context) override {
const Tensor& contents = context->input(0);
OP_REQUIRES(
context, TensorShapeUtils::IsScalar(contents.shape()),
errors::InvalidArgument("`contents` must be scalar but got shape",
contents.shape().DebugString()));
const StringPiece input = contents.scalar<tstring>()();
OP_REQUIRES(context, !input.empty(),
errors::InvalidArgument("Input is empty."));
OP_REQUIRES(context, input.size() <= std::numeric_limits<int>::max(),
errors::InvalidArgument(
"Input contents are too large for int: ", input.size()));
// Parse magic bytes to determine file format.
switch (ClassifyFileFormat(input)) {
case kJpgFormat:
DecodeJpegV2(context, input);
break;
case kPngFormat:
DecodePngV2(context, input);
break;
case kGifFormat:
DecodeGifV2(context, input);
break;
case kBmpFormat:
DecodeBmpV2(context, input);
break;
case kUnknownFormat:
OP_REQUIRES(context, false,
errors::InvalidArgument("Unknown image file format. One of "
"JPEG, PNG, GIF, BMP required."));
break;
}
}
void DecodeJpegV2(OpKernelContext* context, StringPiece input) {
OP_REQUIRES(context, channels_ == 0 || channels_ == 1 || channels_ == 3,
errors::InvalidArgument("JPEG does not support 4 channels."));
// Use default settings for `DecodeImage` op. Use local copy of flags to
// avoid race condition as the class member is shared among different
// invocations.
jpeg::UncompressFlags flags = jpeg::UncompressFlags();
flags.components = channels_;
flags.dct_method = JDCT_IFAST;
// Output tensor and the image buffer size.
Tensor* output = nullptr;
int buffer_size = 0;
// Decode JPEG. Directly allocate to the output buffer if data type is
// uint8 (to save extra copying). Otherwise, allocate a new uint8 buffer
// with buffer size. `jpeg::Uncompress` support unit8 only.
uint8* buffer = jpeg::Uncompress(
input.data(), input.size(), flags, nullptr /* nwarn */,
[&](int width, int height, int channels) -> uint8* {
buffer_size = height * width * channels;
Status status = context->allocate_output(
0, TensorShape({height, width, channels}), &output);
if (!status.ok()) {
VLOG(1) << status;
context->SetStatus(status);
return nullptr;
}
if (data_type_ == DataType::DT_UINT8) {
return output->flat<uint8>().data();
} else {
return new uint8[buffer_size];
}
});
OP_REQUIRES(context, buffer,
errors::InvalidArgument("jpeg::Uncompress failed."));
// For when desired data type if unit8, the output buffer is already
// allocated during the `jpeg::Uncompress` call above; return.
if (data_type_ == DataType::DT_UINT8) {
return;
}
// Make sure we don't forget to deallocate `buffer`.
std::unique_ptr<uint8[]> buffer_unique_ptr(buffer);
// Convert uint8 image data to desired data type.
// Use eigen threadpooling to speed up the copy operation.
const auto& device = context->eigen_device<Eigen::ThreadPoolDevice>();
TTypes<uint8>::UnalignedConstFlat buffer_view(buffer, buffer_size);
if (data_type_ == DataType::DT_UINT16) {
uint16 scale = floor((std::numeric_limits<uint16>::max() + 1) /
(std::numeric_limits<uint8>::max() + 1));
// Fill output tensor with desired dtype.
output->flat<uint16>().device(device) =
buffer_view.cast<uint16>() * scale;
} else if (data_type_ == DataType::DT_FLOAT) {
float scale = 1. / std::numeric_limits<uint8>::max();
// Fill output tensor with desired dtype.
output->flat<float>().device(device) = buffer_view.cast<float>() * scale;
}
}
void DecodePngV2(OpKernelContext* context, StringPiece input) {
int channel_bits;
channel_bits = (data_type_ == DataType::DT_UINT8) ? 8 : 16;
png::DecodeContext decode;
OP_REQUIRES(
context, png::CommonInitDecode(input, channels_, channel_bits, &decode),
errors::InvalidArgument("Invalid PNG. Failed to initialize decoder."));
// Verify that width and height are not too large:
// - verify width and height don't overflow int.
// - width can later be multiplied by channels_ and sizeof(uint16), so
// verify single dimension is not too large.
// - verify when width and height are multiplied together, there are a few
// bits to spare as well.
const int width = static_cast<int>(decode.width);
const int height = static_cast<int>(decode.height);
const int64 total_size =
static_cast<int64>(width) * static_cast<int64>(height);
if (width != static_cast<int64>(decode.width) || width <= 0 ||
width >= (1LL << 27) || height != static_cast<int64>(decode.height) ||
height <= 0 || height >= (1LL << 27) || total_size >= (1LL << 29)) {
png::CommonFreeDecode(&decode);
OP_REQUIRES(context, false,
errors::InvalidArgument("PNG size too large for int: ",
decode.width, " by ", decode.height));
}
Tensor* output = nullptr;
const auto status = context->allocate_output(
0, TensorShape({height, width, decode.channels}), &output);
if (!status.ok()) png::CommonFreeDecode(&decode);
OP_REQUIRES_OK(context, status);
if (data_type_ == DataType::DT_UINT8) {
OP_REQUIRES(
context,
png::CommonFinishDecode(
reinterpret_cast<png_bytep>(output->flat<uint8>().data()),
decode.channels * width * sizeof(uint8), &decode),
errors::InvalidArgument("Invalid PNG data, size ", input.size()));
} else if (data_type_ == DataType::DT_UINT16) {
OP_REQUIRES(
context,
png::CommonFinishDecode(
reinterpret_cast<png_bytep>(output->flat<uint16>().data()),
decode.channels * width * sizeof(uint16), &decode),
errors::InvalidArgument("Invalid PNG data, size ", input.size()));
} else if (data_type_ == DataType::DT_FLOAT) {
// `png::CommonFinishDecode` does not support `float`. First allocate
// uint16 buffer for the image and decode in uint16 (lossless). Wrap the
// buffer in `unique_ptr` so that we don't forget to delete the buffer.
std::unique_ptr<uint16[]> buffer(
new uint16[height * width * decode.channels]);
OP_REQUIRES(
context,
png::CommonFinishDecode(reinterpret_cast<png_bytep>(buffer.get()),
decode.channels * width * sizeof(uint16),
&decode),
errors::InvalidArgument("Invalid PNG data, size ", input.size()));
// Convert uint16 image data to desired data type.
// Use eigen threadpooling to speed up the copy operation.
const auto& device = context->eigen_device<Eigen::ThreadPoolDevice>();
TTypes<uint16, 3>::UnalignedConstTensor buf(buffer.get(), height, width,
decode.channels);
float scale = 1. / std::numeric_limits<uint16>::max();
// Fill output tensor with desired dtype.
output->tensor<float, 3>().device(device) = buf.cast<float>() * scale;
}
}
void DecodeGifV2(OpKernelContext* context, StringPiece input) {
// GIF has 3 channels.
OP_REQUIRES(context, channels_ == 0 || channels_ == 3,
errors::InvalidArgument("channels must be 0 or 3 for GIF, got ",
channels_));
// Decode GIF, allocating tensor if dtype is uint8, otherwise defer tensor
// allocation til after dtype conversion is done. `gif`::Decode` supports
// uint8 only.
Tensor* output = nullptr;
int buffer_size = 0;
string error_string;
uint8* buffer = gif::Decode(
input.data(), input.size(),
[&](int num_frames, int width, int height, int channels) -> uint8* {
buffer_size = num_frames * height * width * channels;
Status status;
if (expand_animations_) {
status = context->allocate_output(
0, TensorShape({num_frames, height, width, channels}), &output);
} else {
status = context->allocate_output(
0, TensorShape({height, width, channels}), &output);
}
if (!status.ok()) {
VLOG(1) << status;
context->SetStatus(status);
return nullptr;
}
if (data_type_ == DataType::DT_UINT8) {
return output->flat<uint8>().data();
} else {
return new uint8[buffer_size];
}
},
&error_string, expand_animations_);
OP_REQUIRES(context, buffer,
errors::InvalidArgument("Invalid GIF data (size ", input.size(),
"), ", error_string));
// For when desired data type is unit8, the output buffer is already
// allocated during the `gif::Decode` call above; return.
if (data_type_ == DataType::DT_UINT8) {
return;
}
// Make sure we don't forget to deallocate `buffer`.
std::unique_ptr<uint8[]> buffer_unique_ptr(buffer);
// Convert the raw uint8 buffer to desired dtype.
// Use eigen threadpooling to speed up the copy operation.
TTypes<uint8>::UnalignedConstFlat buffer_view(buffer, buffer_size);
const auto& device = context->eigen_device<Eigen::ThreadPoolDevice>();
if (data_type_ == DataType::DT_UINT16) {
uint16 scale = floor((std::numeric_limits<uint16>::max() + 1) /
(std::numeric_limits<uint8>::max() + 1));
// Fill output tensor with desired dtype.
output->flat<uint16>().device(device) =
buffer_view.cast<uint16>() * scale;
} else if (data_type_ == DataType::DT_FLOAT) {
float scale = 1. / std::numeric_limits<uint8>::max();
// Fill output tensor with desired dtype.
output->flat<float>().device(device) = buffer_view.cast<float>() * scale;
}
}
void DecodeBmpV2(OpKernelContext* context, StringPiece input) {
OP_REQUIRES(context, channels_ == 0 || channels_ == 3,
errors::InvalidArgument(
"`channels` must be 0 or 3 for BMP, but got ", channels_));
OP_REQUIRES(context, (32 <= input.size()),
errors::InvalidArgument("Incomplete bmp content, requires at "
"least 32 bytes to find the header "
"size, width, height, and bpp, got ",
input.size(), " bytes"));
const uint8* img_bytes = reinterpret_cast<const uint8*>(input.data());
int32 header_size_ = internal::SubtleMustCopy(
*(reinterpret_cast<const int32*>(img_bytes + 10)));
const int32 header_size = ByteSwapInt32ForBigEndian(header_size_);
int32 width_ = internal::SubtleMustCopy(
*(reinterpret_cast<const int32*>(img_bytes + 18)));
const int32 width = ByteSwapInt32ForBigEndian(width_);
int32 height_ = internal::SubtleMustCopy(
*(reinterpret_cast<const int32*>(img_bytes + 22)));
const int32 height = ByteSwapInt32ForBigEndian(height_);
int32 bpp_ = internal::SubtleMustCopy(
*(reinterpret_cast<const int32*>(img_bytes + 28)));
const int32 bpp = ByteSwapInt32ForBigEndian(bpp_);
if (channels_) {
OP_REQUIRES(context, (channels_ == bpp / 8),
errors::InvalidArgument(
"channels attribute ", channels_,
" does not match bits per pixel from file ", bpp / 8));
} else {
channels_ = bpp / 8;
}
// Current implementation only supports 1, 3 or 4 channel
// bitmaps.
OP_REQUIRES(context, (channels_ == 1 || channels_ == 3 || channels_ == 4),
errors::InvalidArgument(
"Number of channels must be 1, 3 or 4, was ", channels_));
OP_REQUIRES(context, width > 0,
errors::InvalidArgument("Width must be positive"));
OP_REQUIRES(context, height != 0,
errors::InvalidArgument("Height must be nonzero"));
OP_REQUIRES(context, header_size >= 0,
errors::InvalidArgument("header size must be nonnegative"));
// The real requirement is < 2^31 minus some headers and channel data,
// so rounding down to something that's still ridiculously big.
OP_REQUIRES(
context,
(static_cast<int64>(width) * std::abs(static_cast<int64>(height))) <
static_cast<int64>(std::numeric_limits<int32_t>::max() / 8),
errors::InvalidArgument(
"Total possible pixel bytes must be less than 2^30"));
const int32 abs_height = abs(height);
// there may be padding bytes when the width is not a multiple of 4 bytes
const int row_size = (channels_ * width + 3) / 4 * 4;
const int64 last_pixel_offset = static_cast<int64>(header_size) +
(abs_height - 1) * row_size +
(width - 1) * channels_;
// [expected file size] = [last pixel offset] + [last pixel size=channels]
const int64 expected_file_size = last_pixel_offset + channels_;
OP_REQUIRES(
context, (expected_file_size <= input.size()),
errors::InvalidArgument("Incomplete bmp content, requires at least ",
expected_file_size, " bytes, got ",
input.size(), " bytes"));
// if height is negative, data layout is top down
// otherwise, it's bottom up.
bool top_down = (height < 0);
// Decode image, allocating tensor once the image size is known.
Tensor* output = nullptr;
OP_REQUIRES_OK(
context, context->allocate_output(
0, TensorShape({abs_height, width, channels_}), &output));
const uint8* bmp_pixels = &img_bytes[header_size];
if (data_type_ == DataType::DT_UINT8) {
DecodeBMP(bmp_pixels, row_size, output->flat<uint8>().data(), width,
abs_height, channels_, top_down);
} else {
std::unique_ptr<uint8[]> buffer(new uint8[height * width * channels_]);
DecodeBMP(bmp_pixels, row_size, buffer.get(), width, abs_height,
channels_, top_down);
TTypes<uint8, 3>::UnalignedConstTensor buf(buffer.get(), height, width,
channels_);
// Convert the raw uint8 buffer to desired dtype.
// Use eigen threadpooling to speed up the copy operation.
const auto& device = context->eigen_device<Eigen::ThreadPoolDevice>();
if (data_type_ == DataType::DT_UINT16) {
uint16 scale = floor((std::numeric_limits<uint16>::max() + 1) /
(std::numeric_limits<uint8>::max() + 1));
// Fill output tensor with desired dtype.
output->tensor<uint16, 3>().device(device) = buf.cast<uint16>() * scale;
} else if (data_type_ == DataType::DT_FLOAT) {
float scale = 1. / std::numeric_limits<uint8>::max();
// Fill output tensor with desired dtype.
output->tensor<float, 3>().device(device) = buf.cast<float>() * scale;
}
}
}
void DecodeBMP(const uint8* input, const int row_size, uint8* const output,
const int width, const int height, const int channels,
bool top_down);
private:
int channels_ = 0;
DataType data_type_;
bool expand_animations_;
};
REGISTER_KERNEL_BUILDER(Name("DecodeJpeg").Device(DEVICE_CPU), DecodeImageOp);
REGISTER_KERNEL_BUILDER(Name("DecodePng").Device(DEVICE_CPU), DecodeImageOp);
REGISTER_KERNEL_BUILDER(Name("DecodeGif").Device(DEVICE_CPU), DecodeImageOp);
REGISTER_KERNEL_BUILDER(Name("DecodeAndCropJpeg").Device(DEVICE_CPU),
DecodeImageOp);
REGISTER_KERNEL_BUILDER(Name("DecodeImage").Device(DEVICE_CPU),
DecodeImageV2Op);
void DecodeImageV2Op::DecodeBMP(const uint8* input, const int row_size,
uint8* const output, const int width,
const int height, const int channels,
bool top_down) {
for (int i = 0; i < height; i++) {
int src_pos;
int dst_pos;
for (int j = 0; j < width; j++) {
if (!top_down) {
src_pos = ((height - 1 - i) * row_size) + j * channels;
} else {
src_pos = i * row_size + j * channels;
}
dst_pos = (i * width + j) * channels;
switch (channels) {
case 1:
output[dst_pos] = input[src_pos];
break;
case 3:
// BGR -> RGB
output[dst_pos] = input[src_pos + 2];
output[dst_pos + 1] = input[src_pos + 1];
output[dst_pos + 2] = input[src_pos];
break;
case 4:
// BGRA -> RGBA
output[dst_pos] = input[src_pos + 2];
output[dst_pos + 1] = input[src_pos + 1];
output[dst_pos + 2] = input[src_pos];
output[dst_pos + 3] = input[src_pos + 3];
break;
default:
LOG(FATAL) << "Unexpected number of channels: " << channels;
break;
}
}
}
}
} // namespace
} // namespace tensorflow

View File

@ -55,7 +55,7 @@ static const char* GifErrorStringNonNull(int error_code) {
uint8* Decode(const void* srcdata, int datasize,
const std::function<uint8*(int, int, int, int)>& allocate_output,
string* error_string) {
string* error_string, bool expand_animations) {
int error_code = D_GIF_SUCCEEDED;
InputBufferInfo info = {reinterpret_cast<const uint8*>(srcdata), datasize};
GifFileType* gif_file =
@ -82,10 +82,13 @@ uint8* Decode(const void* srcdata, int datasize,
return nullptr;
}
int target_num_frames = gif_file->ImageCount;
if (!expand_animations) target_num_frames = 1;
// Don't request more memory than needed for each frame, preventing OOM
int max_frame_width = 0;
int max_frame_height = 0;
for (int k = 0; k < gif_file->ImageCount; k++) {
for (int k = 0; k < target_num_frames; k++) {
SavedImage* si = &gif_file->SavedImages[k];
if (max_frame_height < si->ImageDesc.Height)
max_frame_height = si->ImageDesc.Height;
@ -93,14 +96,14 @@ uint8* Decode(const void* srcdata, int datasize,
max_frame_width = si->ImageDesc.Width;
}
const int num_frames = gif_file->ImageCount;
const int width = max_frame_width;
const int height = max_frame_height;
const int channel = 3;
uint8* const dstdata = allocate_output(num_frames, width, height, channel);
uint8* const dstdata =
allocate_output(target_num_frames, width, height, channel);
if (!dstdata) return nullptr;
for (int k = 0; k < num_frames; k++) {
for (int k = 0; k < target_num_frames; k++) {
uint8* this_dst = dstdata + k * width * channel * height;
SavedImage* this_image = &gif_file->SavedImages[k];

View File

@ -44,7 +44,7 @@ namespace gif {
uint8* Decode(const void* srcdata, int datasize,
const std::function<uint8*(int, int, int, int)>& allocate_output,
string* error_string);
string* error_string, bool expand_animations = true);
} // namespace gif
} // namespace tensorflow

View File

@ -87,6 +87,40 @@ Status DecodeImageShapeFn(InferenceContext* c) {
return Status::OK();
}
Status DecodeImageV2ShapeFn(InferenceContext* c) {
ShapeHandle unused;
int32 channels;
bool expand_animations;
DimensionHandle channels_dim;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
TF_RETURN_IF_ERROR(c->GetAttr("channels", &channels));
TF_RETURN_IF_ERROR(c->GetAttr("expand_animations", &expand_animations));
if (channels == 0) {
channels_dim = c->UnknownDim();
} else {
if (channels < 0) {
return errors::InvalidArgument("channels must be non-negative, got ",
channels);
}
channels_dim = c->MakeDim(channels);
}
// `expand_animations` set to true will return 4-D shapes for GIF. 3-D shapes
// will be returned for jpg, png, and bmp. `expand_animations` set to false
// will always return 3-D shapes for all (jpg, png, bmp, gif).
if (expand_animations) {
c->set_output(0, c->UnknownShape());
return Status::OK();
} else {
c->set_output(0,
c->MakeShape({InferenceContext::kUnknownDim,
InferenceContext::kUnknownDim, channels_dim}));
return Status::OK();
}
}
Status EncodeImageShapeFn(InferenceContext* c) {
ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &unused));
@ -412,6 +446,17 @@ REGISTER_OP("RandomCrop")
});
// TODO(shlens): Support variable rank in RandomCrop.
// --------------------------------------------------------------------------
REGISTER_OP("DecodeImage")
.Input("contents: string")
// Setting `channels` to 0 means using the inherent number of channels in
// the image.
.Attr("channels: int = 0")
.Attr("dtype: {uint8, uint16, float32} = DT_UINT8")
.Output("image: dtype")
.Attr("expand_animations: bool = true")
.SetShapeFn(DecodeImageV2ShapeFn);
// --------------------------------------------------------------------------
REGISTER_OP("DecodeJpeg")
.Input("contents: string")

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference_testutil.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
@ -61,6 +62,34 @@ TEST(ImageOpsTest, DecodeGif) {
INFER_OK(op, "[]", "[?,?,?,3]");
}
TEST(ImageOpTest, DecodeImage) {
ShapeInferenceTestOp op("DecodeImage");
// Rank check.
INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[1]");
// Set `expand_animations` to false. Output is always ?,?,?.
TF_ASSERT_OK(NodeDefBuilder("test", "DecodeImage")
.Input({"img", 0, DT_STRING})
.Attr("expand_animations", false)
.Finalize(&op.node_def));
INFER_OK(op, "[]", "[?,?,?]");
// Set `expand_animations` to false. Output shape is not known (3D or 4D).
TF_ASSERT_OK(NodeDefBuilder("test", "DecodeImage")
.Input({"img", 0, DT_STRING})
.Attr("expand_animations", true)
.Finalize(&op.node_def));
INFER_OK(op, "[]", "?");
// Negative channel value is rejected.
TF_ASSERT_OK(NodeDefBuilder("test", "DecodeImage")
.Input({"img", 0, DT_STRING})
.Attr("channels", -1)
.Finalize(&op.node_def));
INFER_ERROR("channels must be non-negative, got -1", op, "[]");
}
TEST(ImageOpsTest, DecodeImage_ShapeFn) {
for (const char* op_name : {"DecodeJpeg", "DecodePng"}) {
ShapeInferenceTestOp op(op_name);
@ -325,8 +354,8 @@ TEST(ImageOpsTest, DrawBoundingBoxes_ShapeFn) {
// Check images.
INFER_ERROR("must be rank 4", op, "[1,?,3];?");
INFER_ERROR("should be either 1 (GRY), 3 (RGB), or 4 (RGBA)",
op, "[1,?,?,5];?");
INFER_ERROR("should be either 1 (GRY), 3 (RGB), or 4 (RGBA)", op,
"[1,?,?,5];?");
// Check boxes.
INFER_ERROR("must be rank 3", op, "[1,?,?,4];[1,4]");

View File

@ -2634,6 +2634,23 @@ def decode_image(contents,
ValueError: On incorrect number of channels.
"""
with ops.name_scope(name, 'decode_image'):
if compat.forward_compatible(2020, 7, 14):
channels = 0 if channels is None else channels
if dtype not in [dtypes.float32, dtypes.uint8, dtypes.uint16]:
dest_dtype = dtype
dtype = dtypes.uint16
return convert_image_dtype(gen_image_ops.decode_image(
contents=contents,
channels=channels,
expand_animations=expand_animations,
dtype=dtype), dest_dtype)
else:
return gen_image_ops.decode_image(
contents=contents,
channels=channels,
expand_animations=expand_animations,
dtype=dtype)
if channels not in (None, 0, 1, 3, 4):
raise ValueError('channels must be in (None, 0, 1, 3, 4)')
substr = string_ops.substr(contents, 0, 3)

View File

@ -30,6 +30,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.compat import compat
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@ -5173,107 +5174,141 @@ class SobelEdgesTest(test_util.TensorFlowTestCase):
@test_util.run_all_in_graph_and_eager_modes
class DecodeImageTest(test_util.TensorFlowTestCase):
_FORWARD_COMPATIBILITY_HORIZONS = [
(2020, 6, 11),
(2020, 7, 11),
(2525, 1, 1), # future behavior
]
def testJpegUint16(self):
with self.cached_session(use_gpu=True) as sess:
base = "tensorflow/core/lib/jpeg/testdata"
jpeg0 = io_ops.read_file(os.path.join(base, "jpeg_merge_test1.jpg"))
image0 = image_ops.decode_image(jpeg0, dtype=dtypes.uint16)
image1 = image_ops.convert_image_dtype(image_ops.decode_jpeg(jpeg0),
dtypes.uint16)
image0, image1 = self.evaluate([image0, image1])
self.assertAllEqual(image0, image1)
for horizon in self._FORWARD_COMPATIBILITY_HORIZONS:
with compat.forward_compatibility_horizon(*horizon):
with self.cached_session(use_gpu=True) as sess:
base = "tensorflow/core/lib/jpeg/testdata"
jpeg0 = io_ops.read_file(os.path.join(base, "jpeg_merge_test1.jpg"))
image0 = image_ops.decode_image(jpeg0, dtype=dtypes.uint16)
image1 = image_ops.convert_image_dtype(image_ops.decode_jpeg(jpeg0),
dtypes.uint16)
image0, image1 = self.evaluate([image0, image1])
self.assertAllEqual(image0, image1)
def testPngUint16(self):
with self.cached_session(use_gpu=True) as sess:
base = "tensorflow/core/lib/png/testdata"
png0 = io_ops.read_file(os.path.join(base, "lena_rgba.png"))
image0 = image_ops.decode_image(png0, dtype=dtypes.uint16)
image1 = image_ops.convert_image_dtype(
image_ops.decode_png(png0, dtype=dtypes.uint16), dtypes.uint16)
image0, image1 = self.evaluate([image0, image1])
self.assertAllEqual(image0, image1)
for horizon in self._FORWARD_COMPATIBILITY_HORIZONS:
with compat.forward_compatibility_horizon(*horizon):
with self.cached_session(use_gpu=True) as sess:
base = "tensorflow/core/lib/png/testdata"
png0 = io_ops.read_file(os.path.join(base, "lena_rgba.png"))
image0 = image_ops.decode_image(png0, dtype=dtypes.uint16)
image1 = image_ops.convert_image_dtype(
image_ops.decode_png(png0, dtype=dtypes.uint16), dtypes.uint16)
image0, image1 = self.evaluate([image0, image1])
self.assertAllEqual(image0, image1)
# NumPy conversions should happen before
x = np.random.randint(256, size=(4, 4, 3), dtype=np.uint16)
x_str = image_ops_impl.encode_png(x)
x_dec = image_ops_impl.decode_image(
x_str, channels=3, dtype=dtypes.uint16)
self.assertAllEqual(x, x_dec)
# NumPy conversions should happen before
x = np.random.randint(256, size=(4, 4, 3), dtype=np.uint16)
x_str = image_ops_impl.encode_png(x)
x_dec = image_ops_impl.decode_image(
x_str, channels=3, dtype=dtypes.uint16)
self.assertAllEqual(x, x_dec)
def testGifUint16(self):
with self.cached_session(use_gpu=True) as sess:
base = "tensorflow/core/lib/gif/testdata"
gif0 = io_ops.read_file(os.path.join(base, "scan.gif"))
image0 = image_ops.decode_image(gif0, dtype=dtypes.uint16)
image1 = image_ops.convert_image_dtype(image_ops.decode_gif(gif0),
dtypes.uint16)
image0, image1 = self.evaluate([image0, image1])
self.assertAllEqual(image0, image1)
for horizon in self._FORWARD_COMPATIBILITY_HORIZONS:
with compat.forward_compatibility_horizon(*horizon):
with self.cached_session(use_gpu=True) as sess:
base = "tensorflow/core/lib/gif/testdata"
gif0 = io_ops.read_file(os.path.join(base, "scan.gif"))
image0 = image_ops.decode_image(gif0, dtype=dtypes.uint16)
image1 = image_ops.convert_image_dtype(image_ops.decode_gif(gif0),
dtypes.uint16)
image0, image1 = self.evaluate([image0, image1])
self.assertAllEqual(image0, image1)
def testBmpUint16(self):
with self.cached_session(use_gpu=True) as sess:
base = "tensorflow/core/lib/bmp/testdata"
bmp0 = io_ops.read_file(os.path.join(base, "lena.bmp"))
image0 = image_ops.decode_image(bmp0, dtype=dtypes.uint16)
image1 = image_ops.convert_image_dtype(image_ops.decode_bmp(bmp0),
dtypes.uint16)
image0, image1 = self.evaluate([image0, image1])
self.assertAllEqual(image0, image1)
for horizon in self._FORWARD_COMPATIBILITY_HORIZONS:
with compat.forward_compatibility_horizon(*horizon):
with self.cached_session(use_gpu=True) as sess:
base = "tensorflow/core/lib/bmp/testdata"
bmp0 = io_ops.read_file(os.path.join(base, "lena.bmp"))
image0 = image_ops.decode_image(bmp0, dtype=dtypes.uint16)
image1 = image_ops.convert_image_dtype(image_ops.decode_bmp(bmp0),
dtypes.uint16)
image0, image1 = self.evaluate([image0, image1])
self.assertAllEqual(image0, image1)
def testJpegFloat32(self):
with self.cached_session(use_gpu=True) as sess:
base = "tensorflow/core/lib/jpeg/testdata"
jpeg0 = io_ops.read_file(os.path.join(base, "jpeg_merge_test1.jpg"))
image0 = image_ops.decode_image(jpeg0, dtype=dtypes.float32)
image1 = image_ops.convert_image_dtype(image_ops.decode_jpeg(jpeg0),
dtypes.float32)
image0, image1 = self.evaluate([image0, image1])
self.assertAllEqual(image0, image1)
for horizon in self._FORWARD_COMPATIBILITY_HORIZONS:
with compat.forward_compatibility_horizon(*horizon):
with self.cached_session(use_gpu=True) as sess:
base = "tensorflow/core/lib/jpeg/testdata"
jpeg0 = io_ops.read_file(os.path.join(base, "jpeg_merge_test1.jpg"))
image0 = image_ops.decode_image(jpeg0, dtype=dtypes.float32)
image1 = image_ops.convert_image_dtype(image_ops.decode_jpeg(jpeg0),
dtypes.float32)
image0, image1 = self.evaluate([image0, image1])
self.assertAllEqual(image0, image1)
def testPngFloat32(self):
with self.cached_session(use_gpu=True) as sess:
base = "tensorflow/core/lib/png/testdata"
png0 = io_ops.read_file(os.path.join(base, "lena_rgba.png"))
image0 = image_ops.decode_image(png0, dtype=dtypes.float32)
image1 = image_ops.convert_image_dtype(
image_ops.decode_png(png0, dtype=dtypes.uint16), dtypes.float32)
image0, image1 = self.evaluate([image0, image1])
self.assertAllEqual(image0, image1)
for horizon in self._FORWARD_COMPATIBILITY_HORIZONS:
with compat.forward_compatibility_horizon(*horizon):
with self.cached_session(use_gpu=True) as sess:
base = "tensorflow/core/lib/png/testdata"
png0 = io_ops.read_file(os.path.join(base, "lena_rgba.png"))
image0 = image_ops.decode_image(png0, dtype=dtypes.float32)
image1 = image_ops.convert_image_dtype(
image_ops.decode_png(png0, dtype=dtypes.uint16), dtypes.float32)
image0, image1 = self.evaluate([image0, image1])
self.assertAllEqual(image0, image1)
def testGifFloat32(self):
with self.cached_session(use_gpu=True) as sess:
base = "tensorflow/core/lib/gif/testdata"
gif0 = io_ops.read_file(os.path.join(base, "scan.gif"))
image0 = image_ops.decode_image(gif0, dtype=dtypes.float32)
image1 = image_ops.convert_image_dtype(image_ops.decode_gif(gif0),
dtypes.float32)
image0, image1 = self.evaluate([image0, image1])
self.assertAllEqual(image0, image1)
for horizon in self._FORWARD_COMPATIBILITY_HORIZONS:
with compat.forward_compatibility_horizon(*horizon):
with self.cached_session(use_gpu=True) as sess:
base = "tensorflow/core/lib/gif/testdata"
gif0 = io_ops.read_file(os.path.join(base, "scan.gif"))
image0 = image_ops.decode_image(gif0, dtype=dtypes.float32)
image1 = image_ops.convert_image_dtype(image_ops.decode_gif(gif0),
dtypes.float32)
image0, image1 = self.evaluate([image0, image1])
self.assertAllEqual(image0, image1)
def testBmpFloat32(self):
with self.cached_session(use_gpu=True) as sess:
base = "tensorflow/core/lib/bmp/testdata"
bmp0 = io_ops.read_file(os.path.join(base, "lena.bmp"))
image0 = image_ops.decode_image(bmp0, dtype=dtypes.float32)
image1 = image_ops.convert_image_dtype(image_ops.decode_bmp(bmp0),
dtypes.float32)
image0, image1 = self.evaluate([image0, image1])
self.assertAllEqual(image0, image1)
for horizon in self._FORWARD_COMPATIBILITY_HORIZONS:
with compat.forward_compatibility_horizon(*horizon):
with self.cached_session(use_gpu=True) as sess:
base = "tensorflow/core/lib/bmp/testdata"
bmp0 = io_ops.read_file(os.path.join(base, "lena.bmp"))
image0 = image_ops.decode_image(bmp0, dtype=dtypes.float32)
image1 = image_ops.convert_image_dtype(image_ops.decode_bmp(bmp0),
dtypes.float32)
image0, image1 = self.evaluate([image0, image1])
self.assertAllEqual(image0, image1)
def testExpandAnimations(self):
with self.cached_session(use_gpu=True) as sess:
base = "tensorflow/core/lib/gif/testdata"
gif0 = io_ops.read_file(os.path.join(base, "scan.gif"))
image0 = image_ops.decode_image(
gif0, dtype=dtypes.float32, expand_animations=False)
# image_ops.decode_png() handles GIFs and returns 3D tensors
animation = image_ops.decode_gif(gif0)
first_frame = array_ops.gather(animation, 0)
image1 = image_ops.convert_image_dtype(first_frame, dtypes.float32)
image0, image1 = self.evaluate([image0, image1])
self.assertEqual(len(image0.shape), 3)
self.assertAllEqual(list(image0.shape), [40, 20, 3])
self.assertAllEqual(image0, image1)
for horizon in self._FORWARD_COMPATIBILITY_HORIZONS:
with compat.forward_compatibility_horizon(*horizon):
with self.cached_session(use_gpu=True) as sess:
base = "tensorflow/core/lib/gif/testdata"
gif0 = io_ops.read_file(os.path.join(base, "scan.gif"))
# Test `expand_animations=False` case.
image0 = image_ops.decode_image(
gif0, dtype=dtypes.float32, expand_animations=False)
# image_ops.decode_png() handles GIFs and returns 3D tensors
animation = image_ops.decode_gif(gif0)
first_frame = array_ops.gather(animation, 0)
image1 = image_ops.convert_image_dtype(first_frame, dtypes.float32)
image0, image1 = self.evaluate([image0, image1])
self.assertEqual(len(image0.shape), 3)
self.assertAllEqual(list(image0.shape), [40, 20, 3])
self.assertAllEqual(image0, image1)
# Test `expand_animations=True` case.
image2 = image_ops.decode_image(gif0, dtype=dtypes.float32)
image3 = image_ops.convert_image_dtype(animation, dtypes.float32)
image2, image3 = self.evaluate([image2, image3])
self.assertEqual(len(image2.shape), 4)
self.assertAllEqual(list(image2.shape), [12, 40, 20, 3])
self.assertAllEqual(image2, image3)
if __name__ == "__main__":

View File

@ -1064,6 +1064,10 @@ tf_module {
name: "DecodeGif"
argspec: "args=[\'contents\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "DecodeImage"
argspec: "args=[\'contents\', \'channels\', \'dtype\', \'expand_animations\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \"<dtype: \'uint8\'>\", \'True\', \'None\'], "
}
member_method {
name: "DecodeJSONExample"
argspec: "args=[\'json_examples\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -1064,6 +1064,10 @@ tf_module {
name: "DecodeGif"
argspec: "args=[\'contents\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "DecodeImage"
argspec: "args=[\'contents\', \'channels\', \'dtype\', \'expand_animations\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \"<dtype: \'uint8\'>\", \'True\', \'None\'], "
}
member_method {
name: "DecodeJSONExample"
argspec: "args=[\'json_examples\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "