Eager execution coverage for image_ops_test.py. Removed run_deprecated_v1
decorators.
Part 14 (class FormatTest, NonMaxSuppression*Test) PiperOrigin-RevId: 339106641 Change-Id: I504ff77c3713bc294124e3e5b449cd5bc5807786
This commit is contained in:
parent
2eda043cf4
commit
237c3268e9
tensorflow
@ -43,10 +43,14 @@ static inline void CheckScoreSizes(OpKernelContext* context, int num_boxes,
|
||||
const Tensor& scores) {
|
||||
// The shape of 'scores' is [num_boxes]
|
||||
OP_REQUIRES(context, scores.dims() == 1,
|
||||
errors::InvalidArgument("scores must be 1-D",
|
||||
scores.shape().DebugString()));
|
||||
OP_REQUIRES(context, scores.dim_size(0) == num_boxes,
|
||||
errors::InvalidArgument("scores has incompatible shape"));
|
||||
errors::InvalidArgument(
|
||||
"scores must be 1-D", scores.shape().DebugString(),
|
||||
" (Shape must be rank 1 but is rank ", scores.dims(), ")"));
|
||||
OP_REQUIRES(
|
||||
context, scores.dim_size(0) == num_boxes,
|
||||
errors::InvalidArgument("scores has incompatible shape (Dimensions must "
|
||||
"be equal, but are ",
|
||||
num_boxes, " and ", scores.dim_size(0), ")"));
|
||||
}
|
||||
|
||||
static inline void ParseAndCheckOverlapSizes(OpKernelContext* context,
|
||||
@ -67,11 +71,14 @@ static inline void ParseAndCheckBoxSizes(OpKernelContext* context,
|
||||
const Tensor& boxes, int* num_boxes) {
|
||||
// The shape of 'boxes' is [num_boxes, 4]
|
||||
OP_REQUIRES(context, boxes.dims() == 2,
|
||||
errors::InvalidArgument("boxes must be 2-D",
|
||||
boxes.shape().DebugString()));
|
||||
errors::InvalidArgument(
|
||||
"boxes must be 2-D", boxes.shape().DebugString(),
|
||||
" (Shape must be rank 2 but is rank ", boxes.dims(), ")"));
|
||||
*num_boxes = boxes.dim_size(0);
|
||||
OP_REQUIRES(context, boxes.dim_size(1) == 4,
|
||||
errors::InvalidArgument("boxes must have 4 columns"));
|
||||
errors::InvalidArgument("boxes must have 4 columns (Dimension "
|
||||
"must be 4 but is ",
|
||||
boxes.dim_size(1), ")"));
|
||||
}
|
||||
|
||||
static inline void CheckCombinedNMSScoreSizes(OpKernelContext* context,
|
||||
@ -670,12 +677,16 @@ class NonMaxSuppressionV3Op : public OpKernel {
|
||||
OP_REQUIRES(
|
||||
context, TensorShapeUtils::IsScalar(max_output_size.shape()),
|
||||
errors::InvalidArgument("max_output_size must be 0-D, got shape ",
|
||||
max_output_size.shape().DebugString()));
|
||||
max_output_size.shape().DebugString(),
|
||||
" (Shape must be rank 0 but is ", "rank ",
|
||||
max_output_size.dims(), ")"));
|
||||
// iou_threshold: scalar
|
||||
const Tensor& iou_threshold = context->input(3);
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsScalar(iou_threshold.shape()),
|
||||
errors::InvalidArgument("iou_threshold must be 0-D, got shape ",
|
||||
iou_threshold.shape().DebugString()));
|
||||
iou_threshold.shape().DebugString(),
|
||||
" (Shape must be rank 0 but is rank ",
|
||||
iou_threshold.dims(), ")"));
|
||||
const T iou_threshold_val = iou_threshold.scalar<T>()();
|
||||
OP_REQUIRES(context,
|
||||
iou_threshold_val >= static_cast<T>(0.0) &&
|
||||
|
@ -15,6 +15,8 @@ limitations under the License.
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#define EIGEN_USE_GPU
|
||||
#include "tensorflow/core/kernels/image/non_max_suppression_op.h"
|
||||
|
||||
#include <limits>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
@ -23,7 +25,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/kernels/gpu_prim.h"
|
||||
#include "tensorflow/core/kernels/image/non_max_suppression_op.h"
|
||||
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
||||
#include "tensorflow/core/util/gpu_launch_config.h"
|
||||
#include "tensorflow/stream_executor/stream_executor.h"
|
||||
@ -550,30 +551,44 @@ Status CheckValidInputs(const Tensor& boxes, const Tensor& scores,
|
||||
const Tensor& iou_threshold) {
|
||||
if (!TensorShapeUtils::IsScalar(max_output_size.shape())) {
|
||||
return errors::InvalidArgument("max_output_size must be 0-D, got shape ",
|
||||
max_output_size.shape().DebugString());
|
||||
max_output_size.shape().DebugString(),
|
||||
" (Shape must be rank 0 but is ", "rank ",
|
||||
max_output_size.dims(), ")");
|
||||
}
|
||||
if (!TensorShapeUtils::IsScalar(iou_threshold.shape())) {
|
||||
return errors::InvalidArgument("iou_threshold must be 0-D, got shape ",
|
||||
iou_threshold.shape().DebugString());
|
||||
iou_threshold.shape().DebugString(),
|
||||
" (Shape must be rank 0 but is rank ",
|
||||
iou_threshold.dims(), ")");
|
||||
}
|
||||
const float iou_threshold_val = iou_threshold.scalar<float>()();
|
||||
if (iou_threshold_val < 0 || iou_threshold_val > 1) {
|
||||
return errors::InvalidArgument("iou_threshold must be in [0, 1]");
|
||||
}
|
||||
if (boxes.dims() != 2) {
|
||||
return errors::InvalidArgument("boxes must be a rank 2 tensor!");
|
||||
return errors::InvalidArgument(
|
||||
"boxes must be a rank 2 tensor! (Shape must "
|
||||
"be rank 2 but is rank ",
|
||||
boxes.dims(), ")");
|
||||
}
|
||||
int num_boxes = boxes.dim_size(0);
|
||||
if (boxes.dim_size(1) != 4) {
|
||||
return errors::InvalidArgument("boxes must be Nx4");
|
||||
return errors::InvalidArgument(
|
||||
"boxes must be Nx4 (Dimension must be 4 but"
|
||||
" is ",
|
||||
boxes.dim_size(1), ")");
|
||||
}
|
||||
if (scores.dims() != 1) {
|
||||
return errors::InvalidArgument("scores must be a vector!");
|
||||
return errors::InvalidArgument(
|
||||
"scores must be a vector! (Shape must be "
|
||||
"rank 1 but is rank ",
|
||||
scores.dims(), ")");
|
||||
}
|
||||
if (scores.dim_size(0) != num_boxes) {
|
||||
return errors::InvalidArgument(
|
||||
"scores has incompatible shape"); // message must be exactly this
|
||||
// otherwise tests fail!
|
||||
"scores has incompatible shape " // message must be exactly this
|
||||
"(Dimensions must be equal, but are ", // otherwise tests fail!
|
||||
num_boxes, " and ", scores.dim_size(0), ")");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -39,6 +39,7 @@ from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
@ -4784,7 +4785,6 @@ class TotalVariationTest(test_util.TensorFlowTestCase):
|
||||
|
||||
class FormatTest(test_util.TensorFlowTestCase):
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testFormats(self):
|
||||
prefix = "tensorflow/core/lib"
|
||||
paths = ("png/testdata/lena_gray.png", "jpeg/testdata/jpeg_merge_test1.jpg",
|
||||
@ -4796,10 +4796,10 @@ class FormatTest(test_util.TensorFlowTestCase):
|
||||
}
|
||||
with self.cached_session():
|
||||
for path in paths:
|
||||
contents = io_ops.read_file(os.path.join(prefix, path)).eval()
|
||||
contents = self.evaluate(io_ops.read_file(os.path.join(prefix, path)))
|
||||
images = {}
|
||||
for name, decode in decoders.items():
|
||||
image = decode(contents).eval()
|
||||
image = self.evaluate(decode(contents))
|
||||
self.assertEqual(image.ndim, 3)
|
||||
for prev_name, prev in images.items():
|
||||
print("path %s, names %s %s, shapes %s %s" %
|
||||
@ -4817,7 +4817,6 @@ class FormatTest(test_util.TensorFlowTestCase):
|
||||
|
||||
class NonMaxSuppressionTest(test_util.TensorFlowTestCase):
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def NonMaxSuppressionTest(self):
|
||||
boxes_np = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9],
|
||||
[0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]]
|
||||
@ -4833,50 +4832,56 @@ class NonMaxSuppressionTest(test_util.TensorFlowTestCase):
|
||||
boxes, scores, max_output_size, iou_threshold)
|
||||
self.assertAllClose(selected_indices, [3, 0, 5])
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testInvalidShape(self):
|
||||
|
||||
def nms_func(box, score, iou_thres, score_thres):
|
||||
return image_ops.non_max_suppression(box, score, iou_thres, score_thres)
|
||||
|
||||
iou_thres = 3
|
||||
score_thres = 0.5
|
||||
|
||||
# The boxes should be 2D of shape [num_boxes, 4].
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
with self.assertRaisesRegex((ValueError, errors_impl.InvalidArgumentError),
|
||||
"Shape must be rank 2 but is rank 1"):
|
||||
boxes = constant_op.constant([0.0, 0.0, 1.0, 1.0])
|
||||
scores = constant_op.constant([0.9])
|
||||
image_ops.non_max_suppression(boxes, scores, 3, 0.5)
|
||||
nms_func(boxes, scores, iou_thres, score_thres)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "Dimension must be 4 but is 3"):
|
||||
with self.assertRaisesRegex((ValueError, errors_impl.InvalidArgumentError),
|
||||
"Dimension must be 4 but is 3"):
|
||||
boxes = constant_op.constant([[0.0, 0.0, 1.0]])
|
||||
scores = constant_op.constant([0.9])
|
||||
image_ops.non_max_suppression(boxes, scores, 3, 0.5)
|
||||
nms_func(boxes, scores, iou_thres, score_thres)
|
||||
|
||||
# The boxes is of shape [num_boxes, 4], and the scores is
|
||||
# of shape [num_boxes]. So an error will be thrown.
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
with self.assertRaisesRegex((ValueError, errors_impl.InvalidArgumentError),
|
||||
"Dimensions must be equal, but are 1 and 2"):
|
||||
boxes = constant_op.constant([[0.0, 0.0, 1.0, 1.0]])
|
||||
scores = constant_op.constant([0.9, 0.75])
|
||||
image_ops.non_max_suppression(boxes, scores, 3, 0.5)
|
||||
nms_func(boxes, scores, iou_thres, score_thres)
|
||||
|
||||
# The scores should be 1D of shape [num_boxes].
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
with self.assertRaisesRegex((ValueError, errors_impl.InvalidArgumentError),
|
||||
"Shape must be rank 1 but is rank 2"):
|
||||
boxes = constant_op.constant([[0.0, 0.0, 1.0, 1.0]])
|
||||
scores = constant_op.constant([[0.9]])
|
||||
image_ops.non_max_suppression(boxes, scores, 3, 0.5)
|
||||
nms_func(boxes, scores, iou_thres, score_thres)
|
||||
|
||||
# The max_output_size should be a scalar (0-D).
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
with self.assertRaisesRegex((ValueError, errors_impl.InvalidArgumentError),
|
||||
"Shape must be rank 0 but is rank 1"):
|
||||
boxes = constant_op.constant([[0.0, 0.0, 1.0, 1.0]])
|
||||
scores = constant_op.constant([0.9])
|
||||
image_ops.non_max_suppression(boxes, scores, [3], 0.5)
|
||||
nms_func(boxes, scores, [iou_thres], score_thres)
|
||||
|
||||
# The iou_threshold should be a scalar (0-D).
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
with self.assertRaisesRegex((ValueError, errors_impl.InvalidArgumentError),
|
||||
"Shape must be rank 0 but is rank 2"):
|
||||
boxes = constant_op.constant([[0.0, 0.0, 1.0, 1.0]])
|
||||
scores = constant_op.constant([0.9])
|
||||
image_ops.non_max_suppression(boxes, scores, 3, [[0.5]])
|
||||
nms_func(boxes, scores, iou_thres, [[score_thres]])
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.xla_allow_fallback(
|
||||
"non_max_suppression with dynamic output shape unsupported.")
|
||||
def testDataTypes(self):
|
||||
@ -4896,7 +4901,8 @@ class NonMaxSuppressionTest(test_util.TensorFlowTestCase):
|
||||
max_output_size = constant_op.constant(max_output_size_np)
|
||||
iou_threshold = constant_op.constant(iou_threshold_np, dtype=dtype)
|
||||
selected_indices = gen_image_ops.non_max_suppression_v2(
|
||||
boxes, scores, max_output_size, iou_threshold).eval()
|
||||
boxes, scores, max_output_size, iou_threshold)
|
||||
selected_indices = self.evaluate(selected_indices)
|
||||
self.assertAllClose(selected_indices, [3, 0, 5])
|
||||
# gen_image_ops.non_max_suppression_v3
|
||||
for dtype in [np.float16, np.float32]:
|
||||
@ -4941,7 +4947,6 @@ class NonMaxSuppressionTest(test_util.TensorFlowTestCase):
|
||||
|
||||
class NonMaxSuppressionWithScoresTest(test_util.TensorFlowTestCase):
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.xla_allow_fallback(
|
||||
"non_max_suppression with dynamic output shape unsupported.")
|
||||
def testSelectFromThreeClustersWithSoftNMS(self):
|
||||
@ -4974,75 +4979,167 @@ class NonMaxSuppressionWithScoresTest(test_util.TensorFlowTestCase):
|
||||
rtol=1e-2, atol=1e-2)
|
||||
|
||||
|
||||
class NonMaxSuppressionPaddedTest(test_util.TensorFlowTestCase):
|
||||
class NonMaxSuppressionPaddedTest(test_util.TensorFlowTestCase,
|
||||
parameterized.TestCase):
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla(
|
||||
"b/141236442: "
|
||||
"non_max_suppression with dynamic output shape unsupported.")
|
||||
def testSelectFromThreeClusters(self):
|
||||
boxes_np = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9],
|
||||
[0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]]
|
||||
scores_np = [0.9, 0.75, 0.6, 0.95, 0.5, 0.3]
|
||||
max_output_size_np = 5
|
||||
iou_threshold_np = 0.5
|
||||
boxes = constant_op.constant(boxes_np)
|
||||
scores = constant_op.constant(scores_np)
|
||||
max_output_size = constant_op.constant(max_output_size_np)
|
||||
iou_threshold = constant_op.constant(iou_threshold_np)
|
||||
selected_indices_padded, num_valid_padded = \
|
||||
image_ops.non_max_suppression_padded(
|
||||
def testSelectFromThreeClustersV1(self):
|
||||
with ops.Graph().as_default():
|
||||
boxes_np = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9],
|
||||
[0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]]
|
||||
scores_np = [0.9, 0.75, 0.6, 0.95, 0.5, 0.3]
|
||||
max_output_size_np = 5
|
||||
iou_threshold_np = 0.5
|
||||
boxes = constant_op.constant(boxes_np)
|
||||
scores = constant_op.constant(scores_np)
|
||||
max_output_size = constant_op.constant(max_output_size_np)
|
||||
iou_threshold = constant_op.constant(iou_threshold_np)
|
||||
selected_indices_padded, num_valid_padded = \
|
||||
image_ops.non_max_suppression_padded(
|
||||
boxes,
|
||||
scores,
|
||||
max_output_size,
|
||||
iou_threshold,
|
||||
pad_to_max_output_size=True)
|
||||
selected_indices, num_valid = image_ops.non_max_suppression_padded(
|
||||
boxes,
|
||||
scores,
|
||||
max_output_size,
|
||||
iou_threshold,
|
||||
pad_to_max_output_size=False)
|
||||
# The output shape of the padded operation must be fully defined.
|
||||
self.assertEqual(selected_indices_padded.shape.is_fully_defined(), True)
|
||||
self.assertEqual(selected_indices.shape.is_fully_defined(), False)
|
||||
with self.cached_session():
|
||||
self.assertAllClose(selected_indices_padded, [3, 0, 5, 0, 0])
|
||||
self.assertEqual(num_valid_padded.eval(), 3)
|
||||
self.assertAllClose(selected_indices, [3, 0, 5])
|
||||
self.assertEqual(num_valid.eval(), 3)
|
||||
|
||||
@parameterized.named_parameters([("_RunEagerly", True), ("_RunGraph", False)])
|
||||
@test_util.disable_xla(
|
||||
"b/141236442: "
|
||||
"non_max_suppression with dynamic output shape unsupported.")
|
||||
def testSelectFromThreeClustersV2(self, run_func_eagerly):
|
||||
if not context.executing_eagerly() and run_func_eagerly:
|
||||
# Skip running tf.function eagerly in V1 mode.
|
||||
self.skipTest("Skip test that runs tf.function eagerly in V1 mode.")
|
||||
else:
|
||||
|
||||
@def_function.function
|
||||
def func(boxes, scores, max_output_size, iou_threshold):
|
||||
boxes = constant_op.constant(boxes_np)
|
||||
scores = constant_op.constant(scores_np)
|
||||
max_output_size = constant_op.constant(max_output_size_np)
|
||||
iou_threshold = constant_op.constant(iou_threshold_np)
|
||||
|
||||
yp, nvp = image_ops.non_max_suppression_padded(
|
||||
boxes,
|
||||
scores,
|
||||
max_output_size,
|
||||
iou_threshold,
|
||||
pad_to_max_output_size=True)
|
||||
selected_indices, num_valid = image_ops.non_max_suppression_padded(
|
||||
boxes,
|
||||
scores,
|
||||
max_output_size,
|
||||
iou_threshold,
|
||||
pad_to_max_output_size=False)
|
||||
# The output shape of the padded operation must be fully defined.
|
||||
self.assertEqual(selected_indices_padded.shape.is_fully_defined(), True)
|
||||
self.assertEqual(selected_indices.shape.is_fully_defined(), False)
|
||||
with self.cached_session():
|
||||
self.assertAllClose(selected_indices_padded, [3, 0, 5, 0, 0])
|
||||
self.assertEqual(num_valid_padded.eval(), 3)
|
||||
self.assertAllClose(selected_indices, [3, 0, 5])
|
||||
self.assertEqual(num_valid.eval(), 3)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
y, n = image_ops.non_max_suppression_padded(
|
||||
boxes,
|
||||
scores,
|
||||
max_output_size,
|
||||
iou_threshold,
|
||||
pad_to_max_output_size=False)
|
||||
|
||||
# The output shape of the padded operation must be fully defined.
|
||||
self.assertEqual(yp.shape.is_fully_defined(), True)
|
||||
self.assertEqual(y.shape.is_fully_defined(), False)
|
||||
|
||||
return yp, nvp, y, n
|
||||
|
||||
boxes_np = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9],
|
||||
[0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]]
|
||||
scores_np = [0.9, 0.75, 0.6, 0.95, 0.5, 0.3]
|
||||
max_output_size_np = 5
|
||||
iou_threshold_np = 0.5
|
||||
|
||||
selected_indices_padded, num_valid_padded, selected_indices, num_valid = \
|
||||
func(boxes_np, scores_np, max_output_size_np, iou_threshold_np)
|
||||
|
||||
with self.cached_session():
|
||||
with test_util.run_functions_eagerly(run_func_eagerly):
|
||||
self.assertAllClose(selected_indices_padded, [3, 0, 5, 0, 0])
|
||||
self.assertEqual(self.evaluate(num_valid_padded), 3)
|
||||
self.assertAllClose(selected_indices, [3, 0, 5])
|
||||
self.assertEqual(self.evaluate(num_valid), 3)
|
||||
|
||||
@test_util.xla_allow_fallback(
|
||||
"non_max_suppression with dynamic output shape unsupported.")
|
||||
def testSelectFromContinuousOverLap(self):
|
||||
boxes_np = [[0, 0, 1, 1], [0, 0.2, 1, 1.2], [0, 0.4, 1, 1.4],
|
||||
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]]
|
||||
scores_np = [0.9, 0.75, 0.6, 0.5, 0.4, 0.3]
|
||||
max_output_size_np = 3
|
||||
iou_threshold_np = 0.5
|
||||
score_threshold_np = 0.1
|
||||
boxes = constant_op.constant(boxes_np)
|
||||
scores = constant_op.constant(scores_np)
|
||||
max_output_size = constant_op.constant(max_output_size_np)
|
||||
iou_threshold = constant_op.constant(iou_threshold_np)
|
||||
score_threshold = constant_op.constant(score_threshold_np)
|
||||
selected_indices, num_valid = image_ops.non_max_suppression_padded(
|
||||
boxes,
|
||||
scores,
|
||||
max_output_size,
|
||||
iou_threshold,
|
||||
score_threshold)
|
||||
# The output shape of the padded operation must be fully defined.
|
||||
self.assertEqual(selected_indices.shape.is_fully_defined(), False)
|
||||
with self.cached_session():
|
||||
self.assertAllClose(selected_indices, [0, 2, 4])
|
||||
self.assertEqual(num_valid.eval(), 3)
|
||||
def testSelectFromContinuousOverLapV1(self):
|
||||
with ops.Graph().as_default():
|
||||
boxes_np = [[0, 0, 1, 1], [0, 0.2, 1, 1.2], [0, 0.4, 1, 1.4],
|
||||
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]]
|
||||
scores_np = [0.9, 0.75, 0.6, 0.5, 0.4, 0.3]
|
||||
max_output_size_np = 3
|
||||
iou_threshold_np = 0.5
|
||||
score_threshold_np = 0.1
|
||||
boxes = constant_op.constant(boxes_np)
|
||||
scores = constant_op.constant(scores_np)
|
||||
max_output_size = constant_op.constant(max_output_size_np)
|
||||
iou_threshold = constant_op.constant(iou_threshold_np)
|
||||
score_threshold = constant_op.constant(score_threshold_np)
|
||||
selected_indices, num_valid = image_ops.non_max_suppression_padded(
|
||||
boxes,
|
||||
scores,
|
||||
max_output_size,
|
||||
iou_threshold,
|
||||
score_threshold)
|
||||
# The output shape of the padded operation must be fully defined.
|
||||
self.assertEqual(selected_indices.shape.is_fully_defined(), False)
|
||||
with self.cached_session():
|
||||
self.assertAllClose(selected_indices, [0, 2, 4])
|
||||
self.assertEqual(num_valid.eval(), 3)
|
||||
|
||||
@parameterized.named_parameters([("_RunEagerly", True), ("_RunGraph", False)])
|
||||
@test_util.xla_allow_fallback(
|
||||
"non_max_suppression with dynamic output shape unsupported.")
|
||||
def testSelectFromContinuousOverLapV2(self, run_func_eagerly):
|
||||
if not context.executing_eagerly() and run_func_eagerly:
|
||||
# Skip running tf.function eagerly in V1 mode.
|
||||
self.skipTest("Skip test that runs tf.function eagerly in V1 mode.")
|
||||
else:
|
||||
|
||||
@def_function.function
|
||||
def func(boxes, scores, max_output_size, iou_threshold, score_threshold):
|
||||
boxes = constant_op.constant(boxes)
|
||||
scores = constant_op.constant(scores)
|
||||
max_output_size = constant_op.constant(max_output_size)
|
||||
iou_threshold = constant_op.constant(iou_threshold)
|
||||
score_threshold = constant_op.constant(score_threshold)
|
||||
|
||||
y, nv = image_ops.non_max_suppression_padded(
|
||||
boxes, scores, max_output_size, iou_threshold, score_threshold)
|
||||
|
||||
# The output shape of the padded operation must be fully defined.
|
||||
self.assertEqual(y.shape.is_fully_defined(), False)
|
||||
|
||||
return y, nv
|
||||
|
||||
boxes_np = [[0, 0, 1, 1], [0, 0.2, 1, 1.2], [0, 0.4, 1, 1.4],
|
||||
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]]
|
||||
scores_np = [0.9, 0.75, 0.6, 0.5, 0.4, 0.3]
|
||||
max_output_size_np = 3
|
||||
iou_threshold_np = 0.5
|
||||
score_threshold_np = 0.1
|
||||
selected_indices, num_valid = func(boxes_np, scores_np,
|
||||
max_output_size_np, iou_threshold_np,
|
||||
score_threshold_np)
|
||||
with self.cached_session():
|
||||
with test_util.run_functions_eagerly(run_func_eagerly):
|
||||
self.assertAllClose(selected_indices, [0, 2, 4])
|
||||
self.assertEqual(self.evaluate(num_valid), 3)
|
||||
|
||||
|
||||
class NonMaxSuppressionWithOverlapsTest(test_util.TensorFlowTestCase):
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSelectOneFromThree(self):
|
||||
overlaps_np = [
|
||||
[1.0, 0.7, 0.2],
|
||||
@ -5068,28 +5165,31 @@ class NonMaxSuppressionWithOverlapsTest(test_util.TensorFlowTestCase):
|
||||
class VerifyCompatibleImageShapesTest(test_util.TensorFlowTestCase):
|
||||
"""Tests utility function used by ssim() and psnr()."""
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testWrongDims(self):
|
||||
img = array_ops.placeholder(dtype=dtypes.float32)
|
||||
img_np = np.array((2, 2))
|
||||
# Shape function requires placeholders and a graph.
|
||||
with ops.Graph().as_default():
|
||||
img = array_ops.placeholder(dtype=dtypes.float32)
|
||||
img_np = np.array((2, 2))
|
||||
|
||||
with self.cached_session(use_gpu=True) as sess:
|
||||
_, _, checks = image_ops_impl._verify_compatible_image_shapes(img, img)
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(checks, {img: img_np})
|
||||
with self.cached_session(use_gpu=True) as sess:
|
||||
_, _, checks = image_ops_impl._verify_compatible_image_shapes(img, img)
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(checks, {img: img_np})
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testShapeMismatch(self):
|
||||
img1 = array_ops.placeholder(dtype=dtypes.float32)
|
||||
img2 = array_ops.placeholder(dtype=dtypes.float32)
|
||||
# Shape function requires placeholders and a graph.
|
||||
with ops.Graph().as_default():
|
||||
img1 = array_ops.placeholder(dtype=dtypes.float32)
|
||||
img2 = array_ops.placeholder(dtype=dtypes.float32)
|
||||
|
||||
img1_np = np.array([1, 2, 2, 1])
|
||||
img2_np = np.array([1, 3, 3, 1])
|
||||
img1_np = np.array([1, 2, 2, 1])
|
||||
img2_np = np.array([1, 3, 3, 1])
|
||||
|
||||
with self.cached_session(use_gpu=True) as sess:
|
||||
_, _, checks = image_ops_impl._verify_compatible_image_shapes(img1, img2)
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(checks, {img1: img1_np, img2: img2_np})
|
||||
with self.cached_session(use_gpu=True) as sess:
|
||||
_, _, checks = image_ops_impl._verify_compatible_image_shapes(
|
||||
img1, img2)
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(checks, {img1: img1_np, img2: img2_np})
|
||||
|
||||
|
||||
class PSNRTest(test_util.TensorFlowTestCase):
|
||||
|
Loading…
Reference in New Issue
Block a user