Fix bug where EncodeJpegVariableQualityOp did not check the quality parameter range

PiperOrigin-RevId: 271168479
This commit is contained in:
A. Unique TensorFlower 2019-09-25 11:32:28 -07:00 committed by TensorFlower Gardener
parent 18c6ed23e7
commit 62bbb4d2da
3 changed files with 66 additions and 1 deletions

View File

@ -1801,6 +1801,24 @@ tf_cc_test(
],
)
tf_cc_test(
name = "encode_jpeg_op_test",
size = "small",
srcs = ["encode_jpeg_op_test.cc"],
deps = [
":encode_jpeg_op",
":ops_testutil",
":ops_util",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
tf_cc_test(
name = "encode_wav_op_test",
size = "small",

View File

@ -164,11 +164,11 @@ class EncodeJpegVariableQualityOp : public OpKernel {
OP_REQUIRES(context, TensorShapeUtils::IsScalar(quality.shape()),
errors::InvalidArgument("quality must be scalar: ",
quality.shape().DebugString()));
adjusted_flags.quality = quality.scalar<int>()();
OP_REQUIRES(context,
0 <= adjusted_flags.quality && adjusted_flags.quality <= 100,
errors::InvalidArgument("quality must be in [0,100], got ",
adjusted_flags.quality));
adjusted_flags.quality = quality.scalar<int>()();
// Autodetect format.
int channels;

View File

@ -0,0 +1,47 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
using EncodeJpegWithVariableQualityTest = OpsTestBase;
TEST_F(EncodeJpegWithVariableQualityTest, FailsForInvalidQuality) {
TF_ASSERT_OK(NodeDefBuilder("encode_op", "EncodeJpegVariableQuality")
.Input(FakeInput(DT_UINT8))
.Input(FakeInput(DT_INT32))
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
AddInputFromArray<uint8>(TensorShape({2, 2, 3}),
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11});
AddInputFromArray<int32>(TensorShape({}), {200});
Status status = RunOpKernel();
EXPECT_TRUE(errors::IsInvalidArgument(status));
EXPECT_TRUE(
absl::StartsWith(status.error_message(), "quality must be in [0,100]"));
}
} // namespace
} // namespace tensorflow