diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 5f07c70ffeb..5ac588a6017 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -1801,6 +1801,24 @@ tf_cc_test( ], ) +tf_cc_test( + name = "encode_jpeg_op_test", + size = "small", + srcs = ["encode_jpeg_op_test.cc"], + deps = [ + ":encode_jpeg_op", + ":ops_testutil", + ":ops_util", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + tf_cc_test( name = "encode_wav_op_test", size = "small", diff --git a/tensorflow/core/kernels/encode_jpeg_op.cc b/tensorflow/core/kernels/encode_jpeg_op.cc index 923b379c2d6..a7b802bb9a2 100644 --- a/tensorflow/core/kernels/encode_jpeg_op.cc +++ b/tensorflow/core/kernels/encode_jpeg_op.cc @@ -164,11 +164,11 @@ class EncodeJpegVariableQualityOp : public OpKernel { OP_REQUIRES(context, TensorShapeUtils::IsScalar(quality.shape()), errors::InvalidArgument("quality must be scalar: ", quality.shape().DebugString())); + adjusted_flags.quality = quality.scalar()(); OP_REQUIRES(context, 0 <= adjusted_flags.quality && adjusted_flags.quality <= 100, errors::InvalidArgument("quality must be in [0,100], got ", adjusted_flags.quality)); - adjusted_flags.quality = quality.scalar()(); // Autodetect format. int channels; diff --git a/tensorflow/core/kernels/encode_jpeg_op_test.cc b/tensorflow/core/kernels/encode_jpeg_op_test.cc new file mode 100644 index 00000000000..0bac5cd2684 --- /dev/null +++ b/tensorflow/core/kernels/encode_jpeg_op_test.cc @@ -0,0 +1,47 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +using EncodeJpegWithVariableQualityTest = OpsTestBase; + +TEST_F(EncodeJpegWithVariableQualityTest, FailsForInvalidQuality) { + TF_ASSERT_OK(NodeDefBuilder("encode_op", "EncodeJpegVariableQuality") + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_INT32)) + .Finalize(node_def())); + TF_ASSERT_OK(InitOp()); + + AddInputFromArray(TensorShape({2, 2, 3}), + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}); + AddInputFromArray(TensorShape({}), {200}); + Status status = RunOpKernel(); + EXPECT_TRUE(errors::IsInvalidArgument(status)); + EXPECT_TRUE( + absl::StartsWith(status.error_message(), "quality must be in [0,100]")); +} + +} // namespace +} // namespace tensorflow