From cc7e68b611392174c95e13fa8efa9d7bd21832df Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Mon, 22 Jun 2020 23:14:18 +0000 Subject: [PATCH] Update to have a better error message for tf.math.segment_[*] This PR tries to address the issue in 40653 where the error message of `tf.math.segment_[*]` does not match the error. This PR fixes 40653. Signed-off-by: Yong Tang --- .../core/kernels/segment_reduction_ops_impl_1.cc | 2 ++ .../kernel_tests/segment_reduction_ops_test.py | 12 ++++++++++++ 2 files changed, 14 insertions(+) diff --git a/tensorflow/core/kernels/segment_reduction_ops_impl_1.cc b/tensorflow/core/kernels/segment_reduction_ops_impl_1.cc index ae71ac31f2c..f71a8dac462 100644 --- a/tensorflow/core/kernels/segment_reduction_ops_impl_1.cc +++ b/tensorflow/core/kernels/segment_reduction_ops_impl_1.cc @@ -22,6 +22,8 @@ namespace internal { void SegmentReductionValidationHelper(OpKernelContext* context, const Tensor& input, const Tensor& segment_ids) { + OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(input.shape()), + errors::InvalidArgument("input must be at least rank 1")); OP_REQUIRES(context, TensorShapeUtils::IsVector(segment_ids.shape()), errors::InvalidArgument("segment_ids should be a vector.")); const int64 num_indices = segment_ids.NumElements(); diff --git a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py index 9c0e0e38b6a..03d31a59b47 100644 --- a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py +++ b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py @@ -25,6 +25,7 @@ import numpy as np from tensorflow.python.client import session from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes as dtypes_lib +from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import gradient_checker @@ -255,6 +256,17 @@ class SegmentReductionOpTest(SegmentReductionHelper): delta=1) self.assertAllClose(jacob_t, jacob_n) + def testDataInvalid(self): + # Test case for GitHub issue 40653. + for use_gpu in [True, False]: + with self.cached_session(use_gpu=use_gpu): + with self.assertRaisesRegex( + (ValueError, errors_impl.InvalidArgumentError), + "must be at least rank 1"): + s = math_ops.segment_mean( + data=np.uint16(10), segment_ids=np.array([]).astype('int64')) + self.evaluate(s) + class UnsortedSegmentTest(SegmentReductionHelper):