Eager execution coverage for image_ops_test.py. Removed run_deprecated_v1 decorators.

Part 14 (class FormatTest, NonMaxSuppression*Test)

PiperOrigin-RevId: 339106641
Change-Id: I504ff77c3713bc294124e3e5b449cd5bc5807786
This commit is contained in:
Hye Soo Yang 2020-10-26 13:33:10 -07:00 committed by TensorFlower Gardener
parent 2eda043cf4
commit 237c3268e9
3 changed files with 232 additions and 106 deletions

View File

@ -43,10 +43,14 @@ static inline void CheckScoreSizes(OpKernelContext* context, int num_boxes,
const Tensor& scores) {
// The shape of 'scores' is [num_boxes]
OP_REQUIRES(context, scores.dims() == 1,
errors::InvalidArgument("scores must be 1-D",
scores.shape().DebugString()));
OP_REQUIRES(context, scores.dim_size(0) == num_boxes,
errors::InvalidArgument("scores has incompatible shape"));
errors::InvalidArgument(
"scores must be 1-D", scores.shape().DebugString(),
" (Shape must be rank 1 but is rank ", scores.dims(), ")"));
OP_REQUIRES(
context, scores.dim_size(0) == num_boxes,
errors::InvalidArgument("scores has incompatible shape (Dimensions must "
"be equal, but are ",
num_boxes, " and ", scores.dim_size(0), ")"));
}
static inline void ParseAndCheckOverlapSizes(OpKernelContext* context,
@ -67,11 +71,14 @@ static inline void ParseAndCheckBoxSizes(OpKernelContext* context,
const Tensor& boxes, int* num_boxes) {
// The shape of 'boxes' is [num_boxes, 4]
OP_REQUIRES(context, boxes.dims() == 2,
errors::InvalidArgument("boxes must be 2-D",
boxes.shape().DebugString()));
errors::InvalidArgument(
"boxes must be 2-D", boxes.shape().DebugString(),
" (Shape must be rank 2 but is rank ", boxes.dims(), ")"));
*num_boxes = boxes.dim_size(0);
OP_REQUIRES(context, boxes.dim_size(1) == 4,
errors::InvalidArgument("boxes must have 4 columns"));
errors::InvalidArgument("boxes must have 4 columns (Dimension "
"must be 4 but is ",
boxes.dim_size(1), ")"));
}
static inline void CheckCombinedNMSScoreSizes(OpKernelContext* context,
@ -670,12 +677,16 @@ class NonMaxSuppressionV3Op : public OpKernel {
OP_REQUIRES(
context, TensorShapeUtils::IsScalar(max_output_size.shape()),
errors::InvalidArgument("max_output_size must be 0-D, got shape ",
max_output_size.shape().DebugString()));
max_output_size.shape().DebugString(),
" (Shape must be rank 0 but is ", "rank ",
max_output_size.dims(), ")"));
// iou_threshold: scalar
const Tensor& iou_threshold = context->input(3);
OP_REQUIRES(context, TensorShapeUtils::IsScalar(iou_threshold.shape()),
errors::InvalidArgument("iou_threshold must be 0-D, got shape ",
iou_threshold.shape().DebugString()));
iou_threshold.shape().DebugString(),
" (Shape must be rank 0 but is rank ",
iou_threshold.dims(), ")"));
const T iou_threshold_val = iou_threshold.scalar<T>()();
OP_REQUIRES(context,
iou_threshold_val >= static_cast<T>(0.0) &&

View File

@ -15,6 +15,8 @@ limitations under the License.
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
#include "tensorflow/core/kernels/image/non_max_suppression_op.h"
#include <limits>
#include "absl/strings/str_cat.h"
@ -23,7 +25,6 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/gpu_prim.h"
#include "tensorflow/core/kernels/image/non_max_suppression_op.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"
#include "tensorflow/core/util/gpu_launch_config.h"
#include "tensorflow/stream_executor/stream_executor.h"
@ -550,30 +551,44 @@ Status CheckValidInputs(const Tensor& boxes, const Tensor& scores,
const Tensor& iou_threshold) {
if (!TensorShapeUtils::IsScalar(max_output_size.shape())) {
return errors::InvalidArgument("max_output_size must be 0-D, got shape ",
max_output_size.shape().DebugString());
max_output_size.shape().DebugString(),
" (Shape must be rank 0 but is ", "rank ",
max_output_size.dims(), ")");
}
if (!TensorShapeUtils::IsScalar(iou_threshold.shape())) {
return errors::InvalidArgument("iou_threshold must be 0-D, got shape ",
iou_threshold.shape().DebugString());
iou_threshold.shape().DebugString(),
" (Shape must be rank 0 but is rank ",
iou_threshold.dims(), ")");
}
const float iou_threshold_val = iou_threshold.scalar<float>()();
if (iou_threshold_val < 0 || iou_threshold_val > 1) {
return errors::InvalidArgument("iou_threshold must be in [0, 1]");
}
if (boxes.dims() != 2) {
return errors::InvalidArgument("boxes must be a rank 2 tensor!");
return errors::InvalidArgument(
"boxes must be a rank 2 tensor! (Shape must "
"be rank 2 but is rank ",
boxes.dims(), ")");
}
int num_boxes = boxes.dim_size(0);
if (boxes.dim_size(1) != 4) {
return errors::InvalidArgument("boxes must be Nx4");
return errors::InvalidArgument(
"boxes must be Nx4 (Dimension must be 4 but"
" is ",
boxes.dim_size(1), ")");
}
if (scores.dims() != 1) {
return errors::InvalidArgument("scores must be a vector!");
return errors::InvalidArgument(
"scores must be a vector! (Shape must be "
"rank 1 but is rank ",
scores.dims(), ")");
}
if (scores.dim_size(0) != num_boxes) {
return errors::InvalidArgument(
"scores has incompatible shape"); // message must be exactly this
// otherwise tests fail!
"scores has incompatible shape " // message must be exactly this
"(Dimensions must be equal, but are ", // otherwise tests fail!
num_boxes, " and ", scores.dim_size(0), ")");
}
return Status::OK();
}

View File

@ -39,6 +39,7 @@ from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
@ -4784,7 +4785,6 @@ class TotalVariationTest(test_util.TensorFlowTestCase):
class FormatTest(test_util.TensorFlowTestCase):
@test_util.run_deprecated_v1
def testFormats(self):
prefix = "tensorflow/core/lib"
paths = ("png/testdata/lena_gray.png", "jpeg/testdata/jpeg_merge_test1.jpg",
@ -4796,10 +4796,10 @@ class FormatTest(test_util.TensorFlowTestCase):
}
with self.cached_session():
for path in paths:
contents = io_ops.read_file(os.path.join(prefix, path)).eval()
contents = self.evaluate(io_ops.read_file(os.path.join(prefix, path)))
images = {}
for name, decode in decoders.items():
image = decode(contents).eval()
image = self.evaluate(decode(contents))
self.assertEqual(image.ndim, 3)
for prev_name, prev in images.items():
print("path %s, names %s %s, shapes %s %s" %
@ -4817,7 +4817,6 @@ class FormatTest(test_util.TensorFlowTestCase):
class NonMaxSuppressionTest(test_util.TensorFlowTestCase):
@test_util.run_deprecated_v1
def NonMaxSuppressionTest(self):
boxes_np = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9],
[0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]]
@ -4833,50 +4832,56 @@ class NonMaxSuppressionTest(test_util.TensorFlowTestCase):
boxes, scores, max_output_size, iou_threshold)
self.assertAllClose(selected_indices, [3, 0, 5])
@test_util.run_deprecated_v1
def testInvalidShape(self):
def nms_func(box, score, iou_thres, score_thres):
return image_ops.non_max_suppression(box, score, iou_thres, score_thres)
iou_thres = 3
score_thres = 0.5
# The boxes should be 2D of shape [num_boxes, 4].
with self.assertRaisesRegex(ValueError,
with self.assertRaisesRegex((ValueError, errors_impl.InvalidArgumentError),
"Shape must be rank 2 but is rank 1"):
boxes = constant_op.constant([0.0, 0.0, 1.0, 1.0])
scores = constant_op.constant([0.9])
image_ops.non_max_suppression(boxes, scores, 3, 0.5)
nms_func(boxes, scores, iou_thres, score_thres)
with self.assertRaisesRegex(ValueError, "Dimension must be 4 but is 3"):
with self.assertRaisesRegex((ValueError, errors_impl.InvalidArgumentError),
"Dimension must be 4 but is 3"):
boxes = constant_op.constant([[0.0, 0.0, 1.0]])
scores = constant_op.constant([0.9])
image_ops.non_max_suppression(boxes, scores, 3, 0.5)
nms_func(boxes, scores, iou_thres, score_thres)
# The boxes is of shape [num_boxes, 4], and the scores is
# of shape [num_boxes]. So an error will be thrown.
with self.assertRaisesRegex(ValueError,
with self.assertRaisesRegex((ValueError, errors_impl.InvalidArgumentError),
"Dimensions must be equal, but are 1 and 2"):
boxes = constant_op.constant([[0.0, 0.0, 1.0, 1.0]])
scores = constant_op.constant([0.9, 0.75])
image_ops.non_max_suppression(boxes, scores, 3, 0.5)
nms_func(boxes, scores, iou_thres, score_thres)
# The scores should be 1D of shape [num_boxes].
with self.assertRaisesRegex(ValueError,
with self.assertRaisesRegex((ValueError, errors_impl.InvalidArgumentError),
"Shape must be rank 1 but is rank 2"):
boxes = constant_op.constant([[0.0, 0.0, 1.0, 1.0]])
scores = constant_op.constant([[0.9]])
image_ops.non_max_suppression(boxes, scores, 3, 0.5)
nms_func(boxes, scores, iou_thres, score_thres)
# The max_output_size should be a scalar (0-D).
with self.assertRaisesRegex(ValueError,
with self.assertRaisesRegex((ValueError, errors_impl.InvalidArgumentError),
"Shape must be rank 0 but is rank 1"):
boxes = constant_op.constant([[0.0, 0.0, 1.0, 1.0]])
scores = constant_op.constant([0.9])
image_ops.non_max_suppression(boxes, scores, [3], 0.5)
nms_func(boxes, scores, [iou_thres], score_thres)
# The iou_threshold should be a scalar (0-D).
with self.assertRaisesRegex(ValueError,
with self.assertRaisesRegex((ValueError, errors_impl.InvalidArgumentError),
"Shape must be rank 0 but is rank 2"):
boxes = constant_op.constant([[0.0, 0.0, 1.0, 1.0]])
scores = constant_op.constant([0.9])
image_ops.non_max_suppression(boxes, scores, 3, [[0.5]])
nms_func(boxes, scores, iou_thres, [[score_thres]])
@test_util.run_deprecated_v1
@test_util.xla_allow_fallback(
"non_max_suppression with dynamic output shape unsupported.")
def testDataTypes(self):
@ -4896,7 +4901,8 @@ class NonMaxSuppressionTest(test_util.TensorFlowTestCase):
max_output_size = constant_op.constant(max_output_size_np)
iou_threshold = constant_op.constant(iou_threshold_np, dtype=dtype)
selected_indices = gen_image_ops.non_max_suppression_v2(
boxes, scores, max_output_size, iou_threshold).eval()
boxes, scores, max_output_size, iou_threshold)
selected_indices = self.evaluate(selected_indices)
self.assertAllClose(selected_indices, [3, 0, 5])
# gen_image_ops.non_max_suppression_v3
for dtype in [np.float16, np.float32]:
@ -4941,7 +4947,6 @@ class NonMaxSuppressionTest(test_util.TensorFlowTestCase):
class NonMaxSuppressionWithScoresTest(test_util.TensorFlowTestCase):
@test_util.run_deprecated_v1
@test_util.xla_allow_fallback(
"non_max_suppression with dynamic output shape unsupported.")
def testSelectFromThreeClustersWithSoftNMS(self):
@ -4974,75 +4979,167 @@ class NonMaxSuppressionWithScoresTest(test_util.TensorFlowTestCase):
rtol=1e-2, atol=1e-2)
class NonMaxSuppressionPaddedTest(test_util.TensorFlowTestCase):
class NonMaxSuppressionPaddedTest(test_util.TensorFlowTestCase,
parameterized.TestCase):
@test_util.run_deprecated_v1
@test_util.disable_xla(
"b/141236442: "
"non_max_suppression with dynamic output shape unsupported.")
def testSelectFromThreeClusters(self):
boxes_np = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9],
[0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]]
scores_np = [0.9, 0.75, 0.6, 0.95, 0.5, 0.3]
max_output_size_np = 5
iou_threshold_np = 0.5
boxes = constant_op.constant(boxes_np)
scores = constant_op.constant(scores_np)
max_output_size = constant_op.constant(max_output_size_np)
iou_threshold = constant_op.constant(iou_threshold_np)
selected_indices_padded, num_valid_padded = \
image_ops.non_max_suppression_padded(
def testSelectFromThreeClustersV1(self):
with ops.Graph().as_default():
boxes_np = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9],
[0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]]
scores_np = [0.9, 0.75, 0.6, 0.95, 0.5, 0.3]
max_output_size_np = 5
iou_threshold_np = 0.5
boxes = constant_op.constant(boxes_np)
scores = constant_op.constant(scores_np)
max_output_size = constant_op.constant(max_output_size_np)
iou_threshold = constant_op.constant(iou_threshold_np)
selected_indices_padded, num_valid_padded = \
image_ops.non_max_suppression_padded(
boxes,
scores,
max_output_size,
iou_threshold,
pad_to_max_output_size=True)
selected_indices, num_valid = image_ops.non_max_suppression_padded(
boxes,
scores,
max_output_size,
iou_threshold,
pad_to_max_output_size=False)
# The output shape of the padded operation must be fully defined.
self.assertEqual(selected_indices_padded.shape.is_fully_defined(), True)
self.assertEqual(selected_indices.shape.is_fully_defined(), False)
with self.cached_session():
self.assertAllClose(selected_indices_padded, [3, 0, 5, 0, 0])
self.assertEqual(num_valid_padded.eval(), 3)
self.assertAllClose(selected_indices, [3, 0, 5])
self.assertEqual(num_valid.eval(), 3)
@parameterized.named_parameters([("_RunEagerly", True), ("_RunGraph", False)])
@test_util.disable_xla(
"b/141236442: "
"non_max_suppression with dynamic output shape unsupported.")
def testSelectFromThreeClustersV2(self, run_func_eagerly):
if not context.executing_eagerly() and run_func_eagerly:
# Skip running tf.function eagerly in V1 mode.
self.skipTest("Skip test that runs tf.function eagerly in V1 mode.")
else:
@def_function.function
def func(boxes, scores, max_output_size, iou_threshold):
boxes = constant_op.constant(boxes_np)
scores = constant_op.constant(scores_np)
max_output_size = constant_op.constant(max_output_size_np)
iou_threshold = constant_op.constant(iou_threshold_np)
yp, nvp = image_ops.non_max_suppression_padded(
boxes,
scores,
max_output_size,
iou_threshold,
pad_to_max_output_size=True)
selected_indices, num_valid = image_ops.non_max_suppression_padded(
boxes,
scores,
max_output_size,
iou_threshold,
pad_to_max_output_size=False)
# The output shape of the padded operation must be fully defined.
self.assertEqual(selected_indices_padded.shape.is_fully_defined(), True)
self.assertEqual(selected_indices.shape.is_fully_defined(), False)
with self.cached_session():
self.assertAllClose(selected_indices_padded, [3, 0, 5, 0, 0])
self.assertEqual(num_valid_padded.eval(), 3)
self.assertAllClose(selected_indices, [3, 0, 5])
self.assertEqual(num_valid.eval(), 3)
@test_util.run_deprecated_v1
y, n = image_ops.non_max_suppression_padded(
boxes,
scores,
max_output_size,
iou_threshold,
pad_to_max_output_size=False)
# The output shape of the padded operation must be fully defined.
self.assertEqual(yp.shape.is_fully_defined(), True)
self.assertEqual(y.shape.is_fully_defined(), False)
return yp, nvp, y, n
boxes_np = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9],
[0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]]
scores_np = [0.9, 0.75, 0.6, 0.95, 0.5, 0.3]
max_output_size_np = 5
iou_threshold_np = 0.5
selected_indices_padded, num_valid_padded, selected_indices, num_valid = \
func(boxes_np, scores_np, max_output_size_np, iou_threshold_np)
with self.cached_session():
with test_util.run_functions_eagerly(run_func_eagerly):
self.assertAllClose(selected_indices_padded, [3, 0, 5, 0, 0])
self.assertEqual(self.evaluate(num_valid_padded), 3)
self.assertAllClose(selected_indices, [3, 0, 5])
self.assertEqual(self.evaluate(num_valid), 3)
@test_util.xla_allow_fallback(
"non_max_suppression with dynamic output shape unsupported.")
def testSelectFromContinuousOverLap(self):
boxes_np = [[0, 0, 1, 1], [0, 0.2, 1, 1.2], [0, 0.4, 1, 1.4],
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]]
scores_np = [0.9, 0.75, 0.6, 0.5, 0.4, 0.3]
max_output_size_np = 3
iou_threshold_np = 0.5
score_threshold_np = 0.1
boxes = constant_op.constant(boxes_np)
scores = constant_op.constant(scores_np)
max_output_size = constant_op.constant(max_output_size_np)
iou_threshold = constant_op.constant(iou_threshold_np)
score_threshold = constant_op.constant(score_threshold_np)
selected_indices, num_valid = image_ops.non_max_suppression_padded(
boxes,
scores,
max_output_size,
iou_threshold,
score_threshold)
# The output shape of the padded operation must be fully defined.
self.assertEqual(selected_indices.shape.is_fully_defined(), False)
with self.cached_session():
self.assertAllClose(selected_indices, [0, 2, 4])
self.assertEqual(num_valid.eval(), 3)
def testSelectFromContinuousOverLapV1(self):
with ops.Graph().as_default():
boxes_np = [[0, 0, 1, 1], [0, 0.2, 1, 1.2], [0, 0.4, 1, 1.4],
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]]
scores_np = [0.9, 0.75, 0.6, 0.5, 0.4, 0.3]
max_output_size_np = 3
iou_threshold_np = 0.5
score_threshold_np = 0.1
boxes = constant_op.constant(boxes_np)
scores = constant_op.constant(scores_np)
max_output_size = constant_op.constant(max_output_size_np)
iou_threshold = constant_op.constant(iou_threshold_np)
score_threshold = constant_op.constant(score_threshold_np)
selected_indices, num_valid = image_ops.non_max_suppression_padded(
boxes,
scores,
max_output_size,
iou_threshold,
score_threshold)
# The output shape of the padded operation must be fully defined.
self.assertEqual(selected_indices.shape.is_fully_defined(), False)
with self.cached_session():
self.assertAllClose(selected_indices, [0, 2, 4])
self.assertEqual(num_valid.eval(), 3)
@parameterized.named_parameters([("_RunEagerly", True), ("_RunGraph", False)])
@test_util.xla_allow_fallback(
"non_max_suppression with dynamic output shape unsupported.")
def testSelectFromContinuousOverLapV2(self, run_func_eagerly):
if not context.executing_eagerly() and run_func_eagerly:
# Skip running tf.function eagerly in V1 mode.
self.skipTest("Skip test that runs tf.function eagerly in V1 mode.")
else:
@def_function.function
def func(boxes, scores, max_output_size, iou_threshold, score_threshold):
boxes = constant_op.constant(boxes)
scores = constant_op.constant(scores)
max_output_size = constant_op.constant(max_output_size)
iou_threshold = constant_op.constant(iou_threshold)
score_threshold = constant_op.constant(score_threshold)
y, nv = image_ops.non_max_suppression_padded(
boxes, scores, max_output_size, iou_threshold, score_threshold)
# The output shape of the padded operation must be fully defined.
self.assertEqual(y.shape.is_fully_defined(), False)
return y, nv
boxes_np = [[0, 0, 1, 1], [0, 0.2, 1, 1.2], [0, 0.4, 1, 1.4],
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]]
scores_np = [0.9, 0.75, 0.6, 0.5, 0.4, 0.3]
max_output_size_np = 3
iou_threshold_np = 0.5
score_threshold_np = 0.1
selected_indices, num_valid = func(boxes_np, scores_np,
max_output_size_np, iou_threshold_np,
score_threshold_np)
with self.cached_session():
with test_util.run_functions_eagerly(run_func_eagerly):
self.assertAllClose(selected_indices, [0, 2, 4])
self.assertEqual(self.evaluate(num_valid), 3)
class NonMaxSuppressionWithOverlapsTest(test_util.TensorFlowTestCase):
@test_util.run_deprecated_v1
def testSelectOneFromThree(self):
overlaps_np = [
[1.0, 0.7, 0.2],
@ -5068,28 +5165,31 @@ class NonMaxSuppressionWithOverlapsTest(test_util.TensorFlowTestCase):
class VerifyCompatibleImageShapesTest(test_util.TensorFlowTestCase):
"""Tests utility function used by ssim() and psnr()."""
@test_util.run_deprecated_v1
def testWrongDims(self):
img = array_ops.placeholder(dtype=dtypes.float32)
img_np = np.array((2, 2))
# Shape function requires placeholders and a graph.
with ops.Graph().as_default():
img = array_ops.placeholder(dtype=dtypes.float32)
img_np = np.array((2, 2))
with self.cached_session(use_gpu=True) as sess:
_, _, checks = image_ops_impl._verify_compatible_image_shapes(img, img)
with self.assertRaises(errors.InvalidArgumentError):
sess.run(checks, {img: img_np})
with self.cached_session(use_gpu=True) as sess:
_, _, checks = image_ops_impl._verify_compatible_image_shapes(img, img)
with self.assertRaises(errors.InvalidArgumentError):
sess.run(checks, {img: img_np})
@test_util.run_deprecated_v1
def testShapeMismatch(self):
img1 = array_ops.placeholder(dtype=dtypes.float32)
img2 = array_ops.placeholder(dtype=dtypes.float32)
# Shape function requires placeholders and a graph.
with ops.Graph().as_default():
img1 = array_ops.placeholder(dtype=dtypes.float32)
img2 = array_ops.placeholder(dtype=dtypes.float32)
img1_np = np.array([1, 2, 2, 1])
img2_np = np.array([1, 3, 3, 1])
img1_np = np.array([1, 2, 2, 1])
img2_np = np.array([1, 3, 3, 1])
with self.cached_session(use_gpu=True) as sess:
_, _, checks = image_ops_impl._verify_compatible_image_shapes(img1, img2)
with self.assertRaises(errors.InvalidArgumentError):
sess.run(checks, {img1: img1_np, img2: img2_np})
with self.cached_session(use_gpu=True) as sess:
_, _, checks = image_ops_impl._verify_compatible_image_shapes(
img1, img2)
with self.assertRaises(errors.InvalidArgumentError):
sess.run(checks, {img1: img1_np, img2: img2_np})
class PSNRTest(test_util.TensorFlowTestCase):