Remove deprecated DecodeImageOp which is replaced by DecodeImageV2Op.
PiperOrigin-RevId: 331236403 Change-Id: I653fa60ea3301359611cbb25d37696f308bc5916
This commit is contained in:
parent
1b84cc94ef
commit
067e585de7
@ -67,296 +67,6 @@ FileFormat ClassifyFileFormat(StringPiece data) {
|
||||
return kUnknownFormat;
|
||||
}
|
||||
|
||||
string FileFormatString(FileFormat magic, StringPiece data) {
|
||||
switch (magic) {
|
||||
case kPngFormat:
|
||||
return "PNG";
|
||||
case kJpgFormat:
|
||||
return "JPEG";
|
||||
case kGifFormat:
|
||||
return "GIF";
|
||||
default: {
|
||||
if (data.empty()) return "empty file";
|
||||
return strings::StrCat("unknown format starting with '",
|
||||
absl::CEscape(data.substr(0, 16)), "'");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Decode an image (either jpeg, png, or gif). We use a single op so that
|
||||
// users don't have to care about which format they have.
|
||||
// TODO(b/141645641): Separate concerns here: constructors uses name to
|
||||
// determine type of parsing, compute uses file magic to parse and these might
|
||||
// not match.
|
||||
class DecodeImageOp : public OpKernel {
|
||||
public:
|
||||
explicit DecodeImageOp(OpKernelConstruction* context) : OpKernel(context) {
|
||||
// Determine which op we are: jpeg, png, gif, or any
|
||||
if (type_string() == "DecodeJpeg") {
|
||||
format_ = kJpgFormat;
|
||||
} else if (type_string() == "DecodeAndCropJpeg") {
|
||||
format_ = kJpgFormat;
|
||||
flags_.crop = true;
|
||||
} else if (type_string() == "DecodePng") {
|
||||
format_ = kPngFormat;
|
||||
} else if (type_string() == "DecodeGif") {
|
||||
format_ = kGifFormat;
|
||||
} else {
|
||||
OP_REQUIRES_OK(context,
|
||||
errors::InvalidArgument("Bad op type ", type_string()));
|
||||
}
|
||||
|
||||
if (format_ == kGifFormat) {
|
||||
channels_ = 3;
|
||||
} else {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("channels", &channels_));
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
channels_ == 0 || channels_ == 1 || channels_ == 3 || channels_ == 4,
|
||||
errors::InvalidArgument("channels must be 0, 1, 3, or 4, got ",
|
||||
channels_));
|
||||
}
|
||||
flags_.components = channels_;
|
||||
|
||||
// In the case of png, we support uint16 output
|
||||
if (format_ == kPngFormat) {
|
||||
DataType dt;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("dtype", &dt));
|
||||
OP_REQUIRES(
|
||||
context, dt == DataType::DT_UINT8 || dt == DataType::DT_UINT16,
|
||||
errors::InvalidArgument("Type must be uint8 or uint16, got ", dt));
|
||||
if (dt == DataType::DT_UINT8) {
|
||||
channel_bits_ = 8;
|
||||
} else {
|
||||
channel_bits_ = 16;
|
||||
}
|
||||
}
|
||||
|
||||
// The TensorFlow-chosen default for jpeg decoding is IFAST, sacrificing
|
||||
// image quality for speed.
|
||||
flags_.dct_method = JDCT_IFAST;
|
||||
|
||||
if (format_ == kJpgFormat) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("ratio", &flags_.ratio));
|
||||
OP_REQUIRES(context,
|
||||
flags_.ratio == 1 || flags_.ratio == 2 || flags_.ratio == 4 ||
|
||||
flags_.ratio == 8,
|
||||
errors::InvalidArgument("ratio must be 1, 2, 4, or 8, got ",
|
||||
flags_.ratio));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("fancy_upscaling",
|
||||
&flags_.fancy_upscaling));
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr("try_recover_truncated",
|
||||
&flags_.try_recover_truncated_jpeg));
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr("acceptable_fraction",
|
||||
&flags_.min_acceptable_fraction));
|
||||
|
||||
string dct_method;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("dct_method", &dct_method));
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
(dct_method.empty() || dct_method == "INTEGER_FAST" ||
|
||||
dct_method == "INTEGER_ACCURATE"),
|
||||
errors::InvalidArgument("dct_method must be one of "
|
||||
"{'', 'INTEGER_FAST', 'INTEGER_ACCURATE'}"));
|
||||
if (dct_method == "INTEGER_FAST") {
|
||||
flags_.dct_method = JDCT_IFAST;
|
||||
} else if (dct_method == "INTEGER_ACCURATE") {
|
||||
flags_.dct_method = JDCT_ISLOW;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const Tensor& contents = context->input(0);
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsScalar(contents.shape()),
|
||||
errors::InvalidArgument("contents must be scalar, got shape ",
|
||||
contents.shape().DebugString()));
|
||||
|
||||
// Determine format
|
||||
const StringPiece input = contents.scalar<tstring>()();
|
||||
const auto magic = ClassifyFileFormat(input);
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
magic == kJpgFormat || magic == kPngFormat || magic == kGifFormat,
|
||||
errors::InvalidArgument("Expected image (JPEG, PNG, or GIF), got ",
|
||||
FileFormatString(magic, input)));
|
||||
OP_REQUIRES(context, input.size() <= std::numeric_limits<int>::max(),
|
||||
errors::InvalidArgument(
|
||||
FileFormatString(magic, input),
|
||||
" contents are too large for int: ", input.size()));
|
||||
OP_REQUIRES(context, magic == kPngFormat || channel_bits_ == 8,
|
||||
errors::InvalidArgument(FileFormatString(magic, input),
|
||||
" does not support uint16 output"));
|
||||
|
||||
switch (magic) {
|
||||
case kJpgFormat:
|
||||
DecodeJpeg(context, input);
|
||||
break;
|
||||
case kPngFormat:
|
||||
DecodePng(context, input);
|
||||
break;
|
||||
case kGifFormat:
|
||||
DecodeGif(context, input);
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Should never get here after check above";
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void DecodeJpeg(OpKernelContext* context, StringPiece input) {
|
||||
OP_REQUIRES(context, channels_ == 0 || channels_ == 1 || channels_ == 3,
|
||||
errors::InvalidArgument(
|
||||
"channels must be 0, 1, or 3 for JPEG, got ", channels_));
|
||||
|
||||
// Use local copy of flags to avoid race condition as the class member is
|
||||
// shared among different invocations.
|
||||
jpeg::UncompressFlags flags = flags_;
|
||||
if (flags.crop) {
|
||||
// Update flags to include crop window.
|
||||
const Tensor& crop_window = context->input(1);
|
||||
OP_REQUIRES(context, crop_window.dims() == 1,
|
||||
errors::InvalidArgument("crop_window must be 1-D, got shape ",
|
||||
crop_window.shape().DebugString()));
|
||||
OP_REQUIRES(context, crop_window.dim_size(0) == 4,
|
||||
errors::InvalidArgument("crop_size must have four elements ",
|
||||
crop_window.shape().DebugString()));
|
||||
auto crop_window_vec = crop_window.vec<int32>();
|
||||
flags.crop_y = crop_window_vec(0);
|
||||
flags.crop_x = crop_window_vec(1);
|
||||
flags.crop_height = crop_window_vec(2);
|
||||
flags.crop_width = crop_window_vec(3);
|
||||
}
|
||||
|
||||
// Decode jpeg, allocating tensor once the size is known.
|
||||
Tensor* output = nullptr;
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
jpeg::Uncompress(
|
||||
input.data(), input.size(), flags, nullptr /* nwarn */,
|
||||
[=, &output](int width, int height, int channels) -> uint8* {
|
||||
Status status(context->allocate_output(
|
||||
0,
|
||||
format_ == kGifFormat
|
||||
? TensorShape({1, height, width, channels})
|
||||
: TensorShape({height, width, channels}),
|
||||
&output));
|
||||
if (!status.ok()) {
|
||||
VLOG(1) << status;
|
||||
context->SetStatus(status);
|
||||
return nullptr;
|
||||
}
|
||||
return output->flat<uint8>().data();
|
||||
}),
|
||||
errors::InvalidArgument("Invalid JPEG data or crop window, data size ",
|
||||
input.size()));
|
||||
}
|
||||
|
||||
void DecodePng(OpKernelContext* context, StringPiece input) {
|
||||
// Start decoding png to get shape details
|
||||
png::DecodeContext decode;
|
||||
OP_REQUIRES(context,
|
||||
png::CommonInitDecode(input, channels_, channel_bits_, &decode),
|
||||
errors::InvalidArgument("Invalid PNG header, data size ",
|
||||
input.size()));
|
||||
|
||||
// Verify that width and height are not too large:
|
||||
// - verify width and height don't overflow int.
|
||||
// - width can later be multiplied by channels_ and sizeof(uint16), so
|
||||
// verify single dimension is not too large.
|
||||
// - verify when width and height are multiplied together, there are a few
|
||||
// bits to spare as well.
|
||||
const int width = static_cast<int>(decode.width);
|
||||
const int height = static_cast<int>(decode.height);
|
||||
const int64 total_size =
|
||||
static_cast<int64>(width) * static_cast<int64>(height);
|
||||
if (width != static_cast<int64>(decode.width) || width <= 0 ||
|
||||
width >= (1LL << 27) || height != static_cast<int64>(decode.height) ||
|
||||
height <= 0 || height >= (1LL << 27) || total_size >= (1LL << 29)) {
|
||||
png::CommonFreeDecode(&decode);
|
||||
OP_REQUIRES(context, false,
|
||||
errors::InvalidArgument("PNG size too large for int: ",
|
||||
decode.width, " by ", decode.height));
|
||||
}
|
||||
|
||||
// Allocate tensor
|
||||
Tensor* output = nullptr;
|
||||
const auto status = context->allocate_output(
|
||||
0,
|
||||
format_ == kGifFormat ? TensorShape({1, height, width, decode.channels})
|
||||
: TensorShape({height, width, decode.channels}),
|
||||
&output);
|
||||
if (!status.ok()) png::CommonFreeDecode(&decode);
|
||||
OP_REQUIRES_OK(context, status);
|
||||
|
||||
if (channel_bits_ == 8) {
|
||||
// Finish decoding png
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
png::CommonFinishDecode(
|
||||
reinterpret_cast<png_bytep>(output->flat<uint8>().data()),
|
||||
decode.channels * width * sizeof(uint8), &decode),
|
||||
errors::InvalidArgument("Invalid PNG data, size ", input.size()));
|
||||
} else {
|
||||
// Finish decoding png
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
png::CommonFinishDecode(
|
||||
reinterpret_cast<png_bytep>(output->flat<uint16>().data()),
|
||||
decode.channels * width * sizeof(uint16), &decode),
|
||||
errors::InvalidArgument("Invalid PNG data, size ", input.size()));
|
||||
}
|
||||
}
|
||||
|
||||
void DecodeGif(OpKernelContext* context, StringPiece input) {
|
||||
OP_REQUIRES(context, channels_ == 0 || channels_ == 3,
|
||||
errors::InvalidArgument("channels must be 0 or 3 for GIF, got ",
|
||||
channels_));
|
||||
|
||||
// Decode GIF, allocating tensor once the size is known.
|
||||
Tensor* output = nullptr;
|
||||
string error_string;
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
gif::Decode(input.data(), input.size(),
|
||||
[=, &output](int num_frames, int width, int height,
|
||||
int channels) -> uint8* {
|
||||
Status status;
|
||||
if (format_ == kGifFormat) {
|
||||
status = context->allocate_output(
|
||||
0,
|
||||
TensorShape({num_frames, height, width, channels}),
|
||||
&output);
|
||||
} else if (num_frames == 1) {
|
||||
status = context->allocate_output(
|
||||
0, TensorShape({height, width, channels}), &output);
|
||||
} else {
|
||||
status = errors::InvalidArgument(
|
||||
"Got ", num_frames, " frames, but animated gifs ",
|
||||
"can only be decoded by tf.io.decode_gif or ",
|
||||
"tf.io.decode_image");
|
||||
}
|
||||
if (!status.ok()) {
|
||||
VLOG(1) << status;
|
||||
context->SetStatus(status);
|
||||
return nullptr;
|
||||
}
|
||||
return output->flat<uint8>().data();
|
||||
},
|
||||
&error_string),
|
||||
errors::InvalidArgument("Invalid GIF data (size ", input.size(), "), ",
|
||||
error_string));
|
||||
}
|
||||
|
||||
private:
|
||||
FileFormat format_;
|
||||
int channels_;
|
||||
int channel_bits_ = 8;
|
||||
jpeg::UncompressFlags flags_;
|
||||
};
|
||||
|
||||
// Decode an image. Supported image formats are JPEG, PNG, GIF and BMP. This is
|
||||
// a newer version of `DecodeImageOp` for enabling image data parsing to take
|
||||
// place in kernels only, reducing security vulnerabilities and redundancy.
|
||||
|
Loading…
Reference in New Issue
Block a user