Refactor DecodeImageOp for the purpose of removing redundant data parsing and format checks from python wrapper and having them take place only in kernels. Remove security concerns. This change:
- Creates new op kernel (`DecodeImageV2Op`) that can decode all four image formats (jpg, png, gif, bmp). `DecodeImage` is the op name. `DecodeBmpOp` is moved into `DecodeImageV2Op`. (Now we have `gen_image_ops.decode_image` as opposed to previous `decode_image` which was a pure python implementation.) - Updates GIF decoder to take in `expand_animation` flag for decoding just one frame. - Removes data parsing and format checking logic from python layer entirely. - Updates magic bytes for detecting image formats. - Replicates portions of `convert_image_dtype` functionality in kernel (for optionally converting uint8/uint16 -> float32). PiperOrigin-RevId: 317891936 Change-Id: I84f18e053f6dad845d9f2a61e1119f4de131c85d
This commit is contained in:
parent
b07691301f
commit
e4d6335bcb
tensorflow
core
api_def
kernels
lib/gif
ops
python/ops
tools/api/golden
51
tensorflow/core/api_def/base_api/api_def_DecodeImage.pbtxt
Normal file
51
tensorflow/core/api_def/base_api/api_def_DecodeImage.pbtxt
Normal file
@ -0,0 +1,51 @@
|
||||
op {
|
||||
graph_op_name: "DecodeImage"
|
||||
in_arg {
|
||||
name: "contents"
|
||||
description: <<END
|
||||
0-D. The encoded image bytes.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "image"
|
||||
description: <<END
|
||||
3-D with shape `[height, width, channels]` or 4-D with shape
|
||||
`[frame, height, width, channels]`..
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "channels"
|
||||
description: <<END
|
||||
Number of color channels for the decoded image.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "dtype"
|
||||
description: <<END
|
||||
The desired DType of the returned Tensor.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "expand_animations"
|
||||
description: <<END
|
||||
Controls the output shape of the returned op. If True, the returned op will
|
||||
produce a 3-D tensor for PNG, JPEG, and BMP files; and a 4-D tensor for all
|
||||
GIFs, whether animated or not. If, False, the returned op will produce a 3-D
|
||||
tensor for all file types and will truncate animated GIFs to the first frame.
|
||||
END
|
||||
}
|
||||
summary: "Function for decode_bmp, decode_gif, decode_jpeg, and decode_png."
|
||||
description: <<END
|
||||
Detects whether an image is a BMP, GIF, JPEG, or PNG, and performs the
|
||||
appropriate operation to convert the input bytes string into a Tensor of type
|
||||
dtype.
|
||||
|
||||
*NOTE*: decode_gif returns a 4-D array [num_frames, height, width, 3], as
|
||||
opposed to decode_bmp, decode_jpeg and decode_png, which return 3-D arrays
|
||||
[height, width, num_channels]. Make sure to take this into account when
|
||||
constructing your graph if you are intermixing GIF files with BMP, JPEG, and/or
|
||||
PNG files. Alternately, set the expand_animations argument of this function to
|
||||
False, in which case the op will return 3-dimensional tensors and will truncate
|
||||
animated GIF files to the first frame.
|
||||
END
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "DecodeImage"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -17,7 +17,10 @@ limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include "absl/strings/escaping.h"
|
||||
#include "tensorflow/core/framework/bounds_check.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
@ -33,19 +36,31 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// Magic bytes (hex) for each image format.
|
||||
// https://en.wikipedia.org/wiki/List_of_file_signatures
|
||||
// WARNING: Changing `static const` to `constexpr` requires first checking that
|
||||
// it works with supported MSVC version.
|
||||
// https://docs.microsoft.com/en-us/cpp/cpp/constexpr-cpp?redirectedfrom=MSDN&view=vs-2019
|
||||
static const char kPngMagicBytes[] = "\x89\x50\x4E\x47\x0D\x0A\x1A\x0A";
|
||||
static const char kGifMagicBytes[] = "\x47\x49\x46\x38";
|
||||
static const char kBmpMagicBytes[] = "\x42\x4d";
|
||||
// The 4th byte of JPEG is '\xe0' or '\xe1', so check just the first three.
|
||||
static const char kJpegMagicBytes[] = "\xff\xd8\xff";
|
||||
|
||||
enum FileFormat {
|
||||
kUnknownFormat = 0,
|
||||
kPngFormat = 1,
|
||||
kJpgFormat = 2,
|
||||
kGifFormat = 3,
|
||||
kBmpFormat = 4,
|
||||
};
|
||||
|
||||
// Classify the contents of a file based on starting bytes (the magic number).
|
||||
FileFormat ClassifyFileFormat(StringPiece data) {
|
||||
// The 4th byte of JPEG is '\xe0' or '\xe1', so check just the first three
|
||||
if (absl::StartsWith(data, "\xff\xd8\xff")) return kJpgFormat;
|
||||
if (absl::StartsWith(data, "\x89PNG\r\n\x1a\n")) return kPngFormat;
|
||||
if (absl::StartsWith(data, "\x47\x49\x46\x38")) return kGifFormat;
|
||||
if (absl::StartsWith(data, kJpegMagicBytes)) return kJpgFormat;
|
||||
if (absl::StartsWith(data, kPngMagicBytes)) return kPngFormat;
|
||||
if (absl::StartsWith(data, kGifMagicBytes)) return kGifFormat;
|
||||
if (absl::StartsWith(data, kBmpMagicBytes)) return kBmpFormat;
|
||||
return kUnknownFormat;
|
||||
}
|
||||
|
||||
@ -339,11 +354,447 @@ class DecodeImageOp : public OpKernel {
|
||||
jpeg::UncompressFlags flags_;
|
||||
};
|
||||
|
||||
// Decode an image. Supported image formats are JPEG, PNG, GIF and BMP. This is
|
||||
// a newer version of `DecodeImageOp` for enabling image data parsing to take
|
||||
// place in kernels only, reducing security vulnerabilities and redundancy.
|
||||
class DecodeImageV2Op : public OpKernel {
|
||||
public:
|
||||
explicit DecodeImageV2Op(OpKernelConstruction* context) : OpKernel(context) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("channels", &channels_));
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
channels_ == 0 || channels_ == 1 || channels_ == 3 || channels_ == 4,
|
||||
errors::InvalidArgument("`channels` must be 0, 1, 3 or 4 but got ",
|
||||
channels_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("dtype", &data_type_));
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
data_type_ == DataType::DT_UINT8 || data_type_ == DataType::DT_UINT16 ||
|
||||
data_type_ == DataType::DT_FLOAT,
|
||||
errors::InvalidArgument(
|
||||
"`dtype` must be unit8, unit16, float but got: ", data_type_));
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr("expand_animations", &expand_animations_));
|
||||
}
|
||||
|
||||
// Helper for decoding BMP.
|
||||
inline int32 ByteSwapInt32ForBigEndian(int32 x) {
|
||||
#if (__BYTE_ORDER__ == __ORDER_BIG_ENDIAN__)
|
||||
return le32toh(x);
|
||||
#else
|
||||
return x;
|
||||
#endif
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const Tensor& contents = context->input(0);
|
||||
OP_REQUIRES(
|
||||
context, TensorShapeUtils::IsScalar(contents.shape()),
|
||||
errors::InvalidArgument("`contents` must be scalar but got shape",
|
||||
contents.shape().DebugString()));
|
||||
const StringPiece input = contents.scalar<tstring>()();
|
||||
OP_REQUIRES(context, !input.empty(),
|
||||
errors::InvalidArgument("Input is empty."));
|
||||
OP_REQUIRES(context, input.size() <= std::numeric_limits<int>::max(),
|
||||
errors::InvalidArgument(
|
||||
"Input contents are too large for int: ", input.size()));
|
||||
|
||||
// Parse magic bytes to determine file format.
|
||||
switch (ClassifyFileFormat(input)) {
|
||||
case kJpgFormat:
|
||||
DecodeJpegV2(context, input);
|
||||
break;
|
||||
case kPngFormat:
|
||||
DecodePngV2(context, input);
|
||||
break;
|
||||
case kGifFormat:
|
||||
DecodeGifV2(context, input);
|
||||
break;
|
||||
case kBmpFormat:
|
||||
DecodeBmpV2(context, input);
|
||||
break;
|
||||
case kUnknownFormat:
|
||||
OP_REQUIRES(context, false,
|
||||
errors::InvalidArgument("Unknown image file format. One of "
|
||||
"JPEG, PNG, GIF, BMP required."));
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void DecodeJpegV2(OpKernelContext* context, StringPiece input) {
|
||||
OP_REQUIRES(context, channels_ == 0 || channels_ == 1 || channels_ == 3,
|
||||
errors::InvalidArgument("JPEG does not support 4 channels."));
|
||||
|
||||
// Use default settings for `DecodeImage` op. Use local copy of flags to
|
||||
// avoid race condition as the class member is shared among different
|
||||
// invocations.
|
||||
jpeg::UncompressFlags flags = jpeg::UncompressFlags();
|
||||
flags.components = channels_;
|
||||
flags.dct_method = JDCT_IFAST;
|
||||
|
||||
// Output tensor and the image buffer size.
|
||||
Tensor* output = nullptr;
|
||||
int buffer_size = 0;
|
||||
|
||||
// Decode JPEG. Directly allocate to the output buffer if data type is
|
||||
// uint8 (to save extra copying). Otherwise, allocate a new uint8 buffer
|
||||
// with buffer size. `jpeg::Uncompress` support unit8 only.
|
||||
uint8* buffer = jpeg::Uncompress(
|
||||
input.data(), input.size(), flags, nullptr /* nwarn */,
|
||||
[&](int width, int height, int channels) -> uint8* {
|
||||
buffer_size = height * width * channels;
|
||||
Status status = context->allocate_output(
|
||||
0, TensorShape({height, width, channels}), &output);
|
||||
if (!status.ok()) {
|
||||
VLOG(1) << status;
|
||||
context->SetStatus(status);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (data_type_ == DataType::DT_UINT8) {
|
||||
return output->flat<uint8>().data();
|
||||
} else {
|
||||
return new uint8[buffer_size];
|
||||
}
|
||||
});
|
||||
|
||||
OP_REQUIRES(context, buffer,
|
||||
errors::InvalidArgument("jpeg::Uncompress failed."));
|
||||
|
||||
// For when desired data type if unit8, the output buffer is already
|
||||
// allocated during the `jpeg::Uncompress` call above; return.
|
||||
if (data_type_ == DataType::DT_UINT8) {
|
||||
return;
|
||||
}
|
||||
// Make sure we don't forget to deallocate `buffer`.
|
||||
std::unique_ptr<uint8[]> buffer_unique_ptr(buffer);
|
||||
|
||||
// Convert uint8 image data to desired data type.
|
||||
// Use eigen threadpooling to speed up the copy operation.
|
||||
const auto& device = context->eigen_device<Eigen::ThreadPoolDevice>();
|
||||
TTypes<uint8>::UnalignedConstFlat buffer_view(buffer, buffer_size);
|
||||
if (data_type_ == DataType::DT_UINT16) {
|
||||
uint16 scale = floor((std::numeric_limits<uint16>::max() + 1) /
|
||||
(std::numeric_limits<uint8>::max() + 1));
|
||||
// Fill output tensor with desired dtype.
|
||||
output->flat<uint16>().device(device) =
|
||||
buffer_view.cast<uint16>() * scale;
|
||||
} else if (data_type_ == DataType::DT_FLOAT) {
|
||||
float scale = 1. / std::numeric_limits<uint8>::max();
|
||||
// Fill output tensor with desired dtype.
|
||||
output->flat<float>().device(device) = buffer_view.cast<float>() * scale;
|
||||
}
|
||||
}
|
||||
|
||||
void DecodePngV2(OpKernelContext* context, StringPiece input) {
|
||||
int channel_bits;
|
||||
channel_bits = (data_type_ == DataType::DT_UINT8) ? 8 : 16;
|
||||
png::DecodeContext decode;
|
||||
OP_REQUIRES(
|
||||
context, png::CommonInitDecode(input, channels_, channel_bits, &decode),
|
||||
errors::InvalidArgument("Invalid PNG. Failed to initialize decoder."));
|
||||
|
||||
// Verify that width and height are not too large:
|
||||
// - verify width and height don't overflow int.
|
||||
// - width can later be multiplied by channels_ and sizeof(uint16), so
|
||||
// verify single dimension is not too large.
|
||||
// - verify when width and height are multiplied together, there are a few
|
||||
// bits to spare as well.
|
||||
const int width = static_cast<int>(decode.width);
|
||||
const int height = static_cast<int>(decode.height);
|
||||
const int64 total_size =
|
||||
static_cast<int64>(width) * static_cast<int64>(height);
|
||||
if (width != static_cast<int64>(decode.width) || width <= 0 ||
|
||||
width >= (1LL << 27) || height != static_cast<int64>(decode.height) ||
|
||||
height <= 0 || height >= (1LL << 27) || total_size >= (1LL << 29)) {
|
||||
png::CommonFreeDecode(&decode);
|
||||
OP_REQUIRES(context, false,
|
||||
errors::InvalidArgument("PNG size too large for int: ",
|
||||
decode.width, " by ", decode.height));
|
||||
}
|
||||
|
||||
Tensor* output = nullptr;
|
||||
const auto status = context->allocate_output(
|
||||
0, TensorShape({height, width, decode.channels}), &output);
|
||||
if (!status.ok()) png::CommonFreeDecode(&decode);
|
||||
OP_REQUIRES_OK(context, status);
|
||||
|
||||
if (data_type_ == DataType::DT_UINT8) {
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
png::CommonFinishDecode(
|
||||
reinterpret_cast<png_bytep>(output->flat<uint8>().data()),
|
||||
decode.channels * width * sizeof(uint8), &decode),
|
||||
errors::InvalidArgument("Invalid PNG data, size ", input.size()));
|
||||
} else if (data_type_ == DataType::DT_UINT16) {
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
png::CommonFinishDecode(
|
||||
reinterpret_cast<png_bytep>(output->flat<uint16>().data()),
|
||||
decode.channels * width * sizeof(uint16), &decode),
|
||||
errors::InvalidArgument("Invalid PNG data, size ", input.size()));
|
||||
} else if (data_type_ == DataType::DT_FLOAT) {
|
||||
// `png::CommonFinishDecode` does not support `float`. First allocate
|
||||
// uint16 buffer for the image and decode in uint16 (lossless). Wrap the
|
||||
// buffer in `unique_ptr` so that we don't forget to delete the buffer.
|
||||
std::unique_ptr<uint16[]> buffer(
|
||||
new uint16[height * width * decode.channels]);
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
png::CommonFinishDecode(reinterpret_cast<png_bytep>(buffer.get()),
|
||||
decode.channels * width * sizeof(uint16),
|
||||
&decode),
|
||||
errors::InvalidArgument("Invalid PNG data, size ", input.size()));
|
||||
|
||||
// Convert uint16 image data to desired data type.
|
||||
// Use eigen threadpooling to speed up the copy operation.
|
||||
const auto& device = context->eigen_device<Eigen::ThreadPoolDevice>();
|
||||
TTypes<uint16, 3>::UnalignedConstTensor buf(buffer.get(), height, width,
|
||||
decode.channels);
|
||||
float scale = 1. / std::numeric_limits<uint16>::max();
|
||||
// Fill output tensor with desired dtype.
|
||||
output->tensor<float, 3>().device(device) = buf.cast<float>() * scale;
|
||||
}
|
||||
}
|
||||
|
||||
void DecodeGifV2(OpKernelContext* context, StringPiece input) {
|
||||
// GIF has 3 channels.
|
||||
OP_REQUIRES(context, channels_ == 0 || channels_ == 3,
|
||||
errors::InvalidArgument("channels must be 0 or 3 for GIF, got ",
|
||||
channels_));
|
||||
|
||||
// Decode GIF, allocating tensor if dtype is uint8, otherwise defer tensor
|
||||
// allocation til after dtype conversion is done. `gif`::Decode` supports
|
||||
// uint8 only.
|
||||
Tensor* output = nullptr;
|
||||
int buffer_size = 0;
|
||||
string error_string;
|
||||
uint8* buffer = gif::Decode(
|
||||
input.data(), input.size(),
|
||||
[&](int num_frames, int width, int height, int channels) -> uint8* {
|
||||
buffer_size = num_frames * height * width * channels;
|
||||
|
||||
Status status;
|
||||
if (expand_animations_) {
|
||||
status = context->allocate_output(
|
||||
0, TensorShape({num_frames, height, width, channels}), &output);
|
||||
} else {
|
||||
status = context->allocate_output(
|
||||
0, TensorShape({height, width, channels}), &output);
|
||||
}
|
||||
if (!status.ok()) {
|
||||
VLOG(1) << status;
|
||||
context->SetStatus(status);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (data_type_ == DataType::DT_UINT8) {
|
||||
return output->flat<uint8>().data();
|
||||
} else {
|
||||
return new uint8[buffer_size];
|
||||
}
|
||||
},
|
||||
&error_string, expand_animations_);
|
||||
|
||||
OP_REQUIRES(context, buffer,
|
||||
errors::InvalidArgument("Invalid GIF data (size ", input.size(),
|
||||
"), ", error_string));
|
||||
|
||||
// For when desired data type is unit8, the output buffer is already
|
||||
// allocated during the `gif::Decode` call above; return.
|
||||
if (data_type_ == DataType::DT_UINT8) {
|
||||
return;
|
||||
}
|
||||
// Make sure we don't forget to deallocate `buffer`.
|
||||
std::unique_ptr<uint8[]> buffer_unique_ptr(buffer);
|
||||
|
||||
// Convert the raw uint8 buffer to desired dtype.
|
||||
// Use eigen threadpooling to speed up the copy operation.
|
||||
TTypes<uint8>::UnalignedConstFlat buffer_view(buffer, buffer_size);
|
||||
const auto& device = context->eigen_device<Eigen::ThreadPoolDevice>();
|
||||
if (data_type_ == DataType::DT_UINT16) {
|
||||
uint16 scale = floor((std::numeric_limits<uint16>::max() + 1) /
|
||||
(std::numeric_limits<uint8>::max() + 1));
|
||||
// Fill output tensor with desired dtype.
|
||||
output->flat<uint16>().device(device) =
|
||||
buffer_view.cast<uint16>() * scale;
|
||||
} else if (data_type_ == DataType::DT_FLOAT) {
|
||||
float scale = 1. / std::numeric_limits<uint8>::max();
|
||||
// Fill output tensor with desired dtype.
|
||||
output->flat<float>().device(device) = buffer_view.cast<float>() * scale;
|
||||
}
|
||||
}
|
||||
|
||||
void DecodeBmpV2(OpKernelContext* context, StringPiece input) {
|
||||
OP_REQUIRES(context, channels_ == 0 || channels_ == 3,
|
||||
errors::InvalidArgument(
|
||||
"`channels` must be 0 or 3 for BMP, but got ", channels_));
|
||||
|
||||
OP_REQUIRES(context, (32 <= input.size()),
|
||||
errors::InvalidArgument("Incomplete bmp content, requires at "
|
||||
"least 32 bytes to find the header "
|
||||
"size, width, height, and bpp, got ",
|
||||
input.size(), " bytes"));
|
||||
|
||||
const uint8* img_bytes = reinterpret_cast<const uint8*>(input.data());
|
||||
int32 header_size_ = internal::SubtleMustCopy(
|
||||
*(reinterpret_cast<const int32*>(img_bytes + 10)));
|
||||
const int32 header_size = ByteSwapInt32ForBigEndian(header_size_);
|
||||
int32 width_ = internal::SubtleMustCopy(
|
||||
*(reinterpret_cast<const int32*>(img_bytes + 18)));
|
||||
const int32 width = ByteSwapInt32ForBigEndian(width_);
|
||||
int32 height_ = internal::SubtleMustCopy(
|
||||
*(reinterpret_cast<const int32*>(img_bytes + 22)));
|
||||
const int32 height = ByteSwapInt32ForBigEndian(height_);
|
||||
int32 bpp_ = internal::SubtleMustCopy(
|
||||
*(reinterpret_cast<const int32*>(img_bytes + 28)));
|
||||
const int32 bpp = ByteSwapInt32ForBigEndian(bpp_);
|
||||
|
||||
if (channels_) {
|
||||
OP_REQUIRES(context, (channels_ == bpp / 8),
|
||||
errors::InvalidArgument(
|
||||
"channels attribute ", channels_,
|
||||
" does not match bits per pixel from file ", bpp / 8));
|
||||
} else {
|
||||
channels_ = bpp / 8;
|
||||
}
|
||||
|
||||
// Current implementation only supports 1, 3 or 4 channel
|
||||
// bitmaps.
|
||||
OP_REQUIRES(context, (channels_ == 1 || channels_ == 3 || channels_ == 4),
|
||||
errors::InvalidArgument(
|
||||
"Number of channels must be 1, 3 or 4, was ", channels_));
|
||||
|
||||
OP_REQUIRES(context, width > 0,
|
||||
errors::InvalidArgument("Width must be positive"));
|
||||
OP_REQUIRES(context, height != 0,
|
||||
errors::InvalidArgument("Height must be nonzero"));
|
||||
OP_REQUIRES(context, header_size >= 0,
|
||||
errors::InvalidArgument("header size must be nonnegative"));
|
||||
|
||||
// The real requirement is < 2^31 minus some headers and channel data,
|
||||
// so rounding down to something that's still ridiculously big.
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
(static_cast<int64>(width) * std::abs(static_cast<int64>(height))) <
|
||||
static_cast<int64>(std::numeric_limits<int32_t>::max() / 8),
|
||||
errors::InvalidArgument(
|
||||
"Total possible pixel bytes must be less than 2^30"));
|
||||
|
||||
const int32 abs_height = abs(height);
|
||||
|
||||
// there may be padding bytes when the width is not a multiple of 4 bytes
|
||||
const int row_size = (channels_ * width + 3) / 4 * 4;
|
||||
|
||||
const int64 last_pixel_offset = static_cast<int64>(header_size) +
|
||||
(abs_height - 1) * row_size +
|
||||
(width - 1) * channels_;
|
||||
|
||||
// [expected file size] = [last pixel offset] + [last pixel size=channels]
|
||||
const int64 expected_file_size = last_pixel_offset + channels_;
|
||||
|
||||
OP_REQUIRES(
|
||||
context, (expected_file_size <= input.size()),
|
||||
errors::InvalidArgument("Incomplete bmp content, requires at least ",
|
||||
expected_file_size, " bytes, got ",
|
||||
input.size(), " bytes"));
|
||||
|
||||
// if height is negative, data layout is top down
|
||||
// otherwise, it's bottom up.
|
||||
bool top_down = (height < 0);
|
||||
|
||||
// Decode image, allocating tensor once the image size is known.
|
||||
Tensor* output = nullptr;
|
||||
OP_REQUIRES_OK(
|
||||
context, context->allocate_output(
|
||||
0, TensorShape({abs_height, width, channels_}), &output));
|
||||
|
||||
const uint8* bmp_pixels = &img_bytes[header_size];
|
||||
|
||||
if (data_type_ == DataType::DT_UINT8) {
|
||||
DecodeBMP(bmp_pixels, row_size, output->flat<uint8>().data(), width,
|
||||
abs_height, channels_, top_down);
|
||||
} else {
|
||||
std::unique_ptr<uint8[]> buffer(new uint8[height * width * channels_]);
|
||||
DecodeBMP(bmp_pixels, row_size, buffer.get(), width, abs_height,
|
||||
channels_, top_down);
|
||||
TTypes<uint8, 3>::UnalignedConstTensor buf(buffer.get(), height, width,
|
||||
channels_);
|
||||
// Convert the raw uint8 buffer to desired dtype.
|
||||
// Use eigen threadpooling to speed up the copy operation.
|
||||
const auto& device = context->eigen_device<Eigen::ThreadPoolDevice>();
|
||||
if (data_type_ == DataType::DT_UINT16) {
|
||||
uint16 scale = floor((std::numeric_limits<uint16>::max() + 1) /
|
||||
(std::numeric_limits<uint8>::max() + 1));
|
||||
// Fill output tensor with desired dtype.
|
||||
output->tensor<uint16, 3>().device(device) = buf.cast<uint16>() * scale;
|
||||
} else if (data_type_ == DataType::DT_FLOAT) {
|
||||
float scale = 1. / std::numeric_limits<uint8>::max();
|
||||
// Fill output tensor with desired dtype.
|
||||
output->tensor<float, 3>().device(device) = buf.cast<float>() * scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void DecodeBMP(const uint8* input, const int row_size, uint8* const output,
|
||||
const int width, const int height, const int channels,
|
||||
bool top_down);
|
||||
|
||||
private:
|
||||
int channels_ = 0;
|
||||
DataType data_type_;
|
||||
bool expand_animations_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("DecodeJpeg").Device(DEVICE_CPU), DecodeImageOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("DecodePng").Device(DEVICE_CPU), DecodeImageOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("DecodeGif").Device(DEVICE_CPU), DecodeImageOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("DecodeAndCropJpeg").Device(DEVICE_CPU),
|
||||
DecodeImageOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("DecodeImage").Device(DEVICE_CPU),
|
||||
DecodeImageV2Op);
|
||||
|
||||
void DecodeImageV2Op::DecodeBMP(const uint8* input, const int row_size,
|
||||
uint8* const output, const int width,
|
||||
const int height, const int channels,
|
||||
bool top_down) {
|
||||
for (int i = 0; i < height; i++) {
|
||||
int src_pos;
|
||||
int dst_pos;
|
||||
|
||||
for (int j = 0; j < width; j++) {
|
||||
if (!top_down) {
|
||||
src_pos = ((height - 1 - i) * row_size) + j * channels;
|
||||
} else {
|
||||
src_pos = i * row_size + j * channels;
|
||||
}
|
||||
|
||||
dst_pos = (i * width + j) * channels;
|
||||
|
||||
switch (channels) {
|
||||
case 1:
|
||||
output[dst_pos] = input[src_pos];
|
||||
break;
|
||||
case 3:
|
||||
// BGR -> RGB
|
||||
output[dst_pos] = input[src_pos + 2];
|
||||
output[dst_pos + 1] = input[src_pos + 1];
|
||||
output[dst_pos + 2] = input[src_pos];
|
||||
break;
|
||||
case 4:
|
||||
// BGRA -> RGBA
|
||||
output[dst_pos] = input[src_pos + 2];
|
||||
output[dst_pos + 1] = input[src_pos + 1];
|
||||
output[dst_pos + 2] = input[src_pos];
|
||||
output[dst_pos + 3] = input[src_pos + 3];
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Unexpected number of channels: " << channels;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -55,7 +55,7 @@ static const char* GifErrorStringNonNull(int error_code) {
|
||||
|
||||
uint8* Decode(const void* srcdata, int datasize,
|
||||
const std::function<uint8*(int, int, int, int)>& allocate_output,
|
||||
string* error_string) {
|
||||
string* error_string, bool expand_animations) {
|
||||
int error_code = D_GIF_SUCCEEDED;
|
||||
InputBufferInfo info = {reinterpret_cast<const uint8*>(srcdata), datasize};
|
||||
GifFileType* gif_file =
|
||||
@ -82,10 +82,13 @@ uint8* Decode(const void* srcdata, int datasize,
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
int target_num_frames = gif_file->ImageCount;
|
||||
if (!expand_animations) target_num_frames = 1;
|
||||
|
||||
// Don't request more memory than needed for each frame, preventing OOM
|
||||
int max_frame_width = 0;
|
||||
int max_frame_height = 0;
|
||||
for (int k = 0; k < gif_file->ImageCount; k++) {
|
||||
for (int k = 0; k < target_num_frames; k++) {
|
||||
SavedImage* si = &gif_file->SavedImages[k];
|
||||
if (max_frame_height < si->ImageDesc.Height)
|
||||
max_frame_height = si->ImageDesc.Height;
|
||||
@ -93,14 +96,14 @@ uint8* Decode(const void* srcdata, int datasize,
|
||||
max_frame_width = si->ImageDesc.Width;
|
||||
}
|
||||
|
||||
const int num_frames = gif_file->ImageCount;
|
||||
const int width = max_frame_width;
|
||||
const int height = max_frame_height;
|
||||
const int channel = 3;
|
||||
|
||||
uint8* const dstdata = allocate_output(num_frames, width, height, channel);
|
||||
uint8* const dstdata =
|
||||
allocate_output(target_num_frames, width, height, channel);
|
||||
if (!dstdata) return nullptr;
|
||||
for (int k = 0; k < num_frames; k++) {
|
||||
for (int k = 0; k < target_num_frames; k++) {
|
||||
uint8* this_dst = dstdata + k * width * channel * height;
|
||||
|
||||
SavedImage* this_image = &gif_file->SavedImages[k];
|
||||
|
@ -44,7 +44,7 @@ namespace gif {
|
||||
|
||||
uint8* Decode(const void* srcdata, int datasize,
|
||||
const std::function<uint8*(int, int, int, int)>& allocate_output,
|
||||
string* error_string);
|
||||
string* error_string, bool expand_animations = true);
|
||||
|
||||
} // namespace gif
|
||||
} // namespace tensorflow
|
||||
|
@ -87,6 +87,40 @@ Status DecodeImageShapeFn(InferenceContext* c) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DecodeImageV2ShapeFn(InferenceContext* c) {
|
||||
ShapeHandle unused;
|
||||
int32 channels;
|
||||
bool expand_animations;
|
||||
DimensionHandle channels_dim;
|
||||
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("channels", &channels));
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("expand_animations", &expand_animations));
|
||||
|
||||
if (channels == 0) {
|
||||
channels_dim = c->UnknownDim();
|
||||
} else {
|
||||
if (channels < 0) {
|
||||
return errors::InvalidArgument("channels must be non-negative, got ",
|
||||
channels);
|
||||
}
|
||||
channels_dim = c->MakeDim(channels);
|
||||
}
|
||||
|
||||
// `expand_animations` set to true will return 4-D shapes for GIF. 3-D shapes
|
||||
// will be returned for jpg, png, and bmp. `expand_animations` set to false
|
||||
// will always return 3-D shapes for all (jpg, png, bmp, gif).
|
||||
if (expand_animations) {
|
||||
c->set_output(0, c->UnknownShape());
|
||||
return Status::OK();
|
||||
} else {
|
||||
c->set_output(0,
|
||||
c->MakeShape({InferenceContext::kUnknownDim,
|
||||
InferenceContext::kUnknownDim, channels_dim}));
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
Status EncodeImageShapeFn(InferenceContext* c) {
|
||||
ShapeHandle unused;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &unused));
|
||||
@ -412,6 +446,17 @@ REGISTER_OP("RandomCrop")
|
||||
});
|
||||
// TODO(shlens): Support variable rank in RandomCrop.
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
REGISTER_OP("DecodeImage")
|
||||
.Input("contents: string")
|
||||
// Setting `channels` to 0 means using the inherent number of channels in
|
||||
// the image.
|
||||
.Attr("channels: int = 0")
|
||||
.Attr("dtype: {uint8, uint16, float32} = DT_UINT8")
|
||||
.Output("image: dtype")
|
||||
.Attr("expand_animations: bool = true")
|
||||
.SetShapeFn(DecodeImageV2ShapeFn);
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
REGISTER_OP("DecodeJpeg")
|
||||
.Input("contents: string")
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/shape_inference_testutil.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
@ -61,6 +62,34 @@ TEST(ImageOpsTest, DecodeGif) {
|
||||
INFER_OK(op, "[]", "[?,?,?,3]");
|
||||
}
|
||||
|
||||
TEST(ImageOpTest, DecodeImage) {
|
||||
ShapeInferenceTestOp op("DecodeImage");
|
||||
|
||||
// Rank check.
|
||||
INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[1]");
|
||||
|
||||
// Set `expand_animations` to false. Output is always ?,?,?.
|
||||
TF_ASSERT_OK(NodeDefBuilder("test", "DecodeImage")
|
||||
.Input({"img", 0, DT_STRING})
|
||||
.Attr("expand_animations", false)
|
||||
.Finalize(&op.node_def));
|
||||
INFER_OK(op, "[]", "[?,?,?]");
|
||||
|
||||
// Set `expand_animations` to false. Output shape is not known (3D or 4D).
|
||||
TF_ASSERT_OK(NodeDefBuilder("test", "DecodeImage")
|
||||
.Input({"img", 0, DT_STRING})
|
||||
.Attr("expand_animations", true)
|
||||
.Finalize(&op.node_def));
|
||||
INFER_OK(op, "[]", "?");
|
||||
|
||||
// Negative channel value is rejected.
|
||||
TF_ASSERT_OK(NodeDefBuilder("test", "DecodeImage")
|
||||
.Input({"img", 0, DT_STRING})
|
||||
.Attr("channels", -1)
|
||||
.Finalize(&op.node_def));
|
||||
INFER_ERROR("channels must be non-negative, got -1", op, "[]");
|
||||
}
|
||||
|
||||
TEST(ImageOpsTest, DecodeImage_ShapeFn) {
|
||||
for (const char* op_name : {"DecodeJpeg", "DecodePng"}) {
|
||||
ShapeInferenceTestOp op(op_name);
|
||||
@ -325,8 +354,8 @@ TEST(ImageOpsTest, DrawBoundingBoxes_ShapeFn) {
|
||||
|
||||
// Check images.
|
||||
INFER_ERROR("must be rank 4", op, "[1,?,3];?");
|
||||
INFER_ERROR("should be either 1 (GRY), 3 (RGB), or 4 (RGBA)",
|
||||
op, "[1,?,?,5];?");
|
||||
INFER_ERROR("should be either 1 (GRY), 3 (RGB), or 4 (RGBA)", op,
|
||||
"[1,?,?,5];?");
|
||||
|
||||
// Check boxes.
|
||||
INFER_ERROR("must be rank 3", op, "[1,?,?,4];[1,4]");
|
||||
|
@ -2634,6 +2634,23 @@ def decode_image(contents,
|
||||
ValueError: On incorrect number of channels.
|
||||
"""
|
||||
with ops.name_scope(name, 'decode_image'):
|
||||
if compat.forward_compatible(2020, 7, 14):
|
||||
channels = 0 if channels is None else channels
|
||||
if dtype not in [dtypes.float32, dtypes.uint8, dtypes.uint16]:
|
||||
dest_dtype = dtype
|
||||
dtype = dtypes.uint16
|
||||
return convert_image_dtype(gen_image_ops.decode_image(
|
||||
contents=contents,
|
||||
channels=channels,
|
||||
expand_animations=expand_animations,
|
||||
dtype=dtype), dest_dtype)
|
||||
else:
|
||||
return gen_image_ops.decode_image(
|
||||
contents=contents,
|
||||
channels=channels,
|
||||
expand_animations=expand_animations,
|
||||
dtype=dtype)
|
||||
|
||||
if channels not in (None, 0, 1, 3, 4):
|
||||
raise ValueError('channels must be in (None, 0, 1, 3, 4)')
|
||||
substr = string_ops.substr(contents, 0, 3)
|
||||
|
@ -30,6 +30,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
@ -5173,107 +5174,141 @@ class SobelEdgesTest(test_util.TensorFlowTestCase):
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class DecodeImageTest(test_util.TensorFlowTestCase):
|
||||
|
||||
_FORWARD_COMPATIBILITY_HORIZONS = [
|
||||
(2020, 6, 11),
|
||||
(2020, 7, 11),
|
||||
(2525, 1, 1), # future behavior
|
||||
]
|
||||
|
||||
def testJpegUint16(self):
|
||||
with self.cached_session(use_gpu=True) as sess:
|
||||
base = "tensorflow/core/lib/jpeg/testdata"
|
||||
jpeg0 = io_ops.read_file(os.path.join(base, "jpeg_merge_test1.jpg"))
|
||||
image0 = image_ops.decode_image(jpeg0, dtype=dtypes.uint16)
|
||||
image1 = image_ops.convert_image_dtype(image_ops.decode_jpeg(jpeg0),
|
||||
dtypes.uint16)
|
||||
image0, image1 = self.evaluate([image0, image1])
|
||||
self.assertAllEqual(image0, image1)
|
||||
for horizon in self._FORWARD_COMPATIBILITY_HORIZONS:
|
||||
with compat.forward_compatibility_horizon(*horizon):
|
||||
with self.cached_session(use_gpu=True) as sess:
|
||||
base = "tensorflow/core/lib/jpeg/testdata"
|
||||
jpeg0 = io_ops.read_file(os.path.join(base, "jpeg_merge_test1.jpg"))
|
||||
image0 = image_ops.decode_image(jpeg0, dtype=dtypes.uint16)
|
||||
image1 = image_ops.convert_image_dtype(image_ops.decode_jpeg(jpeg0),
|
||||
dtypes.uint16)
|
||||
image0, image1 = self.evaluate([image0, image1])
|
||||
self.assertAllEqual(image0, image1)
|
||||
|
||||
def testPngUint16(self):
|
||||
with self.cached_session(use_gpu=True) as sess:
|
||||
base = "tensorflow/core/lib/png/testdata"
|
||||
png0 = io_ops.read_file(os.path.join(base, "lena_rgba.png"))
|
||||
image0 = image_ops.decode_image(png0, dtype=dtypes.uint16)
|
||||
image1 = image_ops.convert_image_dtype(
|
||||
image_ops.decode_png(png0, dtype=dtypes.uint16), dtypes.uint16)
|
||||
image0, image1 = self.evaluate([image0, image1])
|
||||
self.assertAllEqual(image0, image1)
|
||||
for horizon in self._FORWARD_COMPATIBILITY_HORIZONS:
|
||||
with compat.forward_compatibility_horizon(*horizon):
|
||||
with self.cached_session(use_gpu=True) as sess:
|
||||
base = "tensorflow/core/lib/png/testdata"
|
||||
png0 = io_ops.read_file(os.path.join(base, "lena_rgba.png"))
|
||||
image0 = image_ops.decode_image(png0, dtype=dtypes.uint16)
|
||||
image1 = image_ops.convert_image_dtype(
|
||||
image_ops.decode_png(png0, dtype=dtypes.uint16), dtypes.uint16)
|
||||
image0, image1 = self.evaluate([image0, image1])
|
||||
self.assertAllEqual(image0, image1)
|
||||
|
||||
# NumPy conversions should happen before
|
||||
x = np.random.randint(256, size=(4, 4, 3), dtype=np.uint16)
|
||||
x_str = image_ops_impl.encode_png(x)
|
||||
x_dec = image_ops_impl.decode_image(
|
||||
x_str, channels=3, dtype=dtypes.uint16)
|
||||
self.assertAllEqual(x, x_dec)
|
||||
# NumPy conversions should happen before
|
||||
x = np.random.randint(256, size=(4, 4, 3), dtype=np.uint16)
|
||||
x_str = image_ops_impl.encode_png(x)
|
||||
x_dec = image_ops_impl.decode_image(
|
||||
x_str, channels=3, dtype=dtypes.uint16)
|
||||
self.assertAllEqual(x, x_dec)
|
||||
|
||||
def testGifUint16(self):
|
||||
with self.cached_session(use_gpu=True) as sess:
|
||||
base = "tensorflow/core/lib/gif/testdata"
|
||||
gif0 = io_ops.read_file(os.path.join(base, "scan.gif"))
|
||||
image0 = image_ops.decode_image(gif0, dtype=dtypes.uint16)
|
||||
image1 = image_ops.convert_image_dtype(image_ops.decode_gif(gif0),
|
||||
dtypes.uint16)
|
||||
image0, image1 = self.evaluate([image0, image1])
|
||||
self.assertAllEqual(image0, image1)
|
||||
for horizon in self._FORWARD_COMPATIBILITY_HORIZONS:
|
||||
with compat.forward_compatibility_horizon(*horizon):
|
||||
with self.cached_session(use_gpu=True) as sess:
|
||||
base = "tensorflow/core/lib/gif/testdata"
|
||||
gif0 = io_ops.read_file(os.path.join(base, "scan.gif"))
|
||||
image0 = image_ops.decode_image(gif0, dtype=dtypes.uint16)
|
||||
image1 = image_ops.convert_image_dtype(image_ops.decode_gif(gif0),
|
||||
dtypes.uint16)
|
||||
image0, image1 = self.evaluate([image0, image1])
|
||||
self.assertAllEqual(image0, image1)
|
||||
|
||||
def testBmpUint16(self):
|
||||
with self.cached_session(use_gpu=True) as sess:
|
||||
base = "tensorflow/core/lib/bmp/testdata"
|
||||
bmp0 = io_ops.read_file(os.path.join(base, "lena.bmp"))
|
||||
image0 = image_ops.decode_image(bmp0, dtype=dtypes.uint16)
|
||||
image1 = image_ops.convert_image_dtype(image_ops.decode_bmp(bmp0),
|
||||
dtypes.uint16)
|
||||
image0, image1 = self.evaluate([image0, image1])
|
||||
self.assertAllEqual(image0, image1)
|
||||
for horizon in self._FORWARD_COMPATIBILITY_HORIZONS:
|
||||
with compat.forward_compatibility_horizon(*horizon):
|
||||
with self.cached_session(use_gpu=True) as sess:
|
||||
base = "tensorflow/core/lib/bmp/testdata"
|
||||
bmp0 = io_ops.read_file(os.path.join(base, "lena.bmp"))
|
||||
image0 = image_ops.decode_image(bmp0, dtype=dtypes.uint16)
|
||||
image1 = image_ops.convert_image_dtype(image_ops.decode_bmp(bmp0),
|
||||
dtypes.uint16)
|
||||
image0, image1 = self.evaluate([image0, image1])
|
||||
self.assertAllEqual(image0, image1)
|
||||
|
||||
def testJpegFloat32(self):
|
||||
with self.cached_session(use_gpu=True) as sess:
|
||||
base = "tensorflow/core/lib/jpeg/testdata"
|
||||
jpeg0 = io_ops.read_file(os.path.join(base, "jpeg_merge_test1.jpg"))
|
||||
image0 = image_ops.decode_image(jpeg0, dtype=dtypes.float32)
|
||||
image1 = image_ops.convert_image_dtype(image_ops.decode_jpeg(jpeg0),
|
||||
dtypes.float32)
|
||||
image0, image1 = self.evaluate([image0, image1])
|
||||
self.assertAllEqual(image0, image1)
|
||||
for horizon in self._FORWARD_COMPATIBILITY_HORIZONS:
|
||||
with compat.forward_compatibility_horizon(*horizon):
|
||||
with self.cached_session(use_gpu=True) as sess:
|
||||
base = "tensorflow/core/lib/jpeg/testdata"
|
||||
jpeg0 = io_ops.read_file(os.path.join(base, "jpeg_merge_test1.jpg"))
|
||||
image0 = image_ops.decode_image(jpeg0, dtype=dtypes.float32)
|
||||
image1 = image_ops.convert_image_dtype(image_ops.decode_jpeg(jpeg0),
|
||||
dtypes.float32)
|
||||
image0, image1 = self.evaluate([image0, image1])
|
||||
self.assertAllEqual(image0, image1)
|
||||
|
||||
def testPngFloat32(self):
|
||||
with self.cached_session(use_gpu=True) as sess:
|
||||
base = "tensorflow/core/lib/png/testdata"
|
||||
png0 = io_ops.read_file(os.path.join(base, "lena_rgba.png"))
|
||||
image0 = image_ops.decode_image(png0, dtype=dtypes.float32)
|
||||
image1 = image_ops.convert_image_dtype(
|
||||
image_ops.decode_png(png0, dtype=dtypes.uint16), dtypes.float32)
|
||||
image0, image1 = self.evaluate([image0, image1])
|
||||
self.assertAllEqual(image0, image1)
|
||||
for horizon in self._FORWARD_COMPATIBILITY_HORIZONS:
|
||||
with compat.forward_compatibility_horizon(*horizon):
|
||||
with self.cached_session(use_gpu=True) as sess:
|
||||
base = "tensorflow/core/lib/png/testdata"
|
||||
png0 = io_ops.read_file(os.path.join(base, "lena_rgba.png"))
|
||||
image0 = image_ops.decode_image(png0, dtype=dtypes.float32)
|
||||
image1 = image_ops.convert_image_dtype(
|
||||
image_ops.decode_png(png0, dtype=dtypes.uint16), dtypes.float32)
|
||||
image0, image1 = self.evaluate([image0, image1])
|
||||
self.assertAllEqual(image0, image1)
|
||||
|
||||
def testGifFloat32(self):
|
||||
with self.cached_session(use_gpu=True) as sess:
|
||||
base = "tensorflow/core/lib/gif/testdata"
|
||||
gif0 = io_ops.read_file(os.path.join(base, "scan.gif"))
|
||||
image0 = image_ops.decode_image(gif0, dtype=dtypes.float32)
|
||||
image1 = image_ops.convert_image_dtype(image_ops.decode_gif(gif0),
|
||||
dtypes.float32)
|
||||
image0, image1 = self.evaluate([image0, image1])
|
||||
self.assertAllEqual(image0, image1)
|
||||
for horizon in self._FORWARD_COMPATIBILITY_HORIZONS:
|
||||
with compat.forward_compatibility_horizon(*horizon):
|
||||
with self.cached_session(use_gpu=True) as sess:
|
||||
base = "tensorflow/core/lib/gif/testdata"
|
||||
gif0 = io_ops.read_file(os.path.join(base, "scan.gif"))
|
||||
image0 = image_ops.decode_image(gif0, dtype=dtypes.float32)
|
||||
image1 = image_ops.convert_image_dtype(image_ops.decode_gif(gif0),
|
||||
dtypes.float32)
|
||||
image0, image1 = self.evaluate([image0, image1])
|
||||
self.assertAllEqual(image0, image1)
|
||||
|
||||
def testBmpFloat32(self):
|
||||
with self.cached_session(use_gpu=True) as sess:
|
||||
base = "tensorflow/core/lib/bmp/testdata"
|
||||
bmp0 = io_ops.read_file(os.path.join(base, "lena.bmp"))
|
||||
image0 = image_ops.decode_image(bmp0, dtype=dtypes.float32)
|
||||
image1 = image_ops.convert_image_dtype(image_ops.decode_bmp(bmp0),
|
||||
dtypes.float32)
|
||||
image0, image1 = self.evaluate([image0, image1])
|
||||
self.assertAllEqual(image0, image1)
|
||||
for horizon in self._FORWARD_COMPATIBILITY_HORIZONS:
|
||||
with compat.forward_compatibility_horizon(*horizon):
|
||||
with self.cached_session(use_gpu=True) as sess:
|
||||
base = "tensorflow/core/lib/bmp/testdata"
|
||||
bmp0 = io_ops.read_file(os.path.join(base, "lena.bmp"))
|
||||
image0 = image_ops.decode_image(bmp0, dtype=dtypes.float32)
|
||||
image1 = image_ops.convert_image_dtype(image_ops.decode_bmp(bmp0),
|
||||
dtypes.float32)
|
||||
image0, image1 = self.evaluate([image0, image1])
|
||||
self.assertAllEqual(image0, image1)
|
||||
|
||||
def testExpandAnimations(self):
|
||||
with self.cached_session(use_gpu=True) as sess:
|
||||
base = "tensorflow/core/lib/gif/testdata"
|
||||
gif0 = io_ops.read_file(os.path.join(base, "scan.gif"))
|
||||
image0 = image_ops.decode_image(
|
||||
gif0, dtype=dtypes.float32, expand_animations=False)
|
||||
# image_ops.decode_png() handles GIFs and returns 3D tensors
|
||||
animation = image_ops.decode_gif(gif0)
|
||||
first_frame = array_ops.gather(animation, 0)
|
||||
image1 = image_ops.convert_image_dtype(first_frame, dtypes.float32)
|
||||
image0, image1 = self.evaluate([image0, image1])
|
||||
self.assertEqual(len(image0.shape), 3)
|
||||
self.assertAllEqual(list(image0.shape), [40, 20, 3])
|
||||
self.assertAllEqual(image0, image1)
|
||||
for horizon in self._FORWARD_COMPATIBILITY_HORIZONS:
|
||||
with compat.forward_compatibility_horizon(*horizon):
|
||||
with self.cached_session(use_gpu=True) as sess:
|
||||
base = "tensorflow/core/lib/gif/testdata"
|
||||
gif0 = io_ops.read_file(os.path.join(base, "scan.gif"))
|
||||
|
||||
# Test `expand_animations=False` case.
|
||||
image0 = image_ops.decode_image(
|
||||
gif0, dtype=dtypes.float32, expand_animations=False)
|
||||
# image_ops.decode_png() handles GIFs and returns 3D tensors
|
||||
animation = image_ops.decode_gif(gif0)
|
||||
first_frame = array_ops.gather(animation, 0)
|
||||
image1 = image_ops.convert_image_dtype(first_frame, dtypes.float32)
|
||||
image0, image1 = self.evaluate([image0, image1])
|
||||
self.assertEqual(len(image0.shape), 3)
|
||||
self.assertAllEqual(list(image0.shape), [40, 20, 3])
|
||||
self.assertAllEqual(image0, image1)
|
||||
|
||||
# Test `expand_animations=True` case.
|
||||
image2 = image_ops.decode_image(gif0, dtype=dtypes.float32)
|
||||
image3 = image_ops.convert_image_dtype(animation, dtypes.float32)
|
||||
image2, image3 = self.evaluate([image2, image3])
|
||||
self.assertEqual(len(image2.shape), 4)
|
||||
self.assertAllEqual(list(image2.shape), [12, 40, 20, 3])
|
||||
self.assertAllEqual(image2, image3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1064,6 +1064,10 @@ tf_module {
|
||||
name: "DecodeGif"
|
||||
argspec: "args=[\'contents\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "DecodeImage"
|
||||
argspec: "args=[\'contents\', \'channels\', \'dtype\', \'expand_animations\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \"<dtype: \'uint8\'>\", \'True\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "DecodeJSONExample"
|
||||
argspec: "args=[\'json_examples\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -1064,6 +1064,10 @@ tf_module {
|
||||
name: "DecodeGif"
|
||||
argspec: "args=[\'contents\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "DecodeImage"
|
||||
argspec: "args=[\'contents\', \'channels\', \'dtype\', \'expand_animations\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \"<dtype: \'uint8\'>\", \'True\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "DecodeJSONExample"
|
||||
argspec: "args=[\'json_examples\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user