Merge pull request #40693 from yongtang:40653-tf.math.segment_mean-error-message
PiperOrigin-RevId: 324550275 Change-Id: I0e6350d251bdff3603cc9f0ff3957edfce6e0a0f
This commit is contained in:
commit
1e9b9b1568
@ -22,6 +22,8 @@ namespace internal {
|
||||
void SegmentReductionValidationHelper(OpKernelContext* context,
|
||||
const Tensor& input,
|
||||
const Tensor& segment_ids) {
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(input.shape()),
|
||||
errors::InvalidArgument("input must be at least rank 1"));
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsVector(segment_ids.shape()),
|
||||
errors::InvalidArgument("segment_ids should be a vector."));
|
||||
const int64 num_indices = segment_ids.NumElements();
|
||||
|
@ -25,6 +25,7 @@ import numpy as np
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes as dtypes_lib
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import gradient_checker
|
||||
@ -255,6 +256,17 @@ class SegmentReductionOpTest(SegmentReductionHelper):
|
||||
delta=1)
|
||||
self.assertAllClose(jacob_t, jacob_n)
|
||||
|
||||
def testDataInvalid(self):
|
||||
# Test case for GitHub issue 40653.
|
||||
for use_gpu in [True, False]:
|
||||
with self.cached_session(use_gpu=use_gpu):
|
||||
with self.assertRaisesRegex(
|
||||
(ValueError, errors_impl.InvalidArgumentError),
|
||||
"must be at least rank 1"):
|
||||
s = math_ops.segment_mean(
|
||||
data=np.uint16(10), segment_ids=np.array([]).astype("int64"))
|
||||
self.evaluate(s)
|
||||
|
||||
|
||||
class UnsortedSegmentTest(SegmentReductionHelper):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user