Merge pull request #40693 from yongtang:40653-tf.math.segment_mean-error-message

PiperOrigin-RevId: 324550275
Change-Id: I0e6350d251bdff3603cc9f0ff3957edfce6e0a0f
This commit is contained in:
TensorFlower Gardener 2020-08-02 23:53:04 -07:00
commit 1e9b9b1568
2 changed files with 14 additions and 0 deletions

View File

@ -22,6 +22,8 @@ namespace internal {
void SegmentReductionValidationHelper(OpKernelContext* context,
const Tensor& input,
const Tensor& segment_ids) {
OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(input.shape()),
errors::InvalidArgument("input must be at least rank 1"));
OP_REQUIRES(context, TensorShapeUtils::IsVector(segment_ids.shape()),
errors::InvalidArgument("segment_ids should be a vector."));
const int64 num_indices = segment_ids.NumElements();

View File

@ -25,6 +25,7 @@ import numpy as np
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes as dtypes_lib
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import gradient_checker
@ -255,6 +256,17 @@ class SegmentReductionOpTest(SegmentReductionHelper):
delta=1)
self.assertAllClose(jacob_t, jacob_n)
def testDataInvalid(self):
# Test case for GitHub issue 40653.
for use_gpu in [True, False]:
with self.cached_session(use_gpu=use_gpu):
with self.assertRaisesRegex(
(ValueError, errors_impl.InvalidArgumentError),
"must be at least rank 1"):
s = math_ops.segment_mean(
data=np.uint16(10), segment_ids=np.array([]).astype("int64"))
self.evaluate(s)
class UnsortedSegmentTest(SegmentReductionHelper):