Merge pull request #40693 from yongtang:40653-tf.math.segment_mean-error-message
PiperOrigin-RevId: 324550275 Change-Id: I0e6350d251bdff3603cc9f0ff3957edfce6e0a0f
This commit is contained in:
commit
1e9b9b1568
@ -22,6 +22,8 @@ namespace internal {
|
|||||||
void SegmentReductionValidationHelper(OpKernelContext* context,
|
void SegmentReductionValidationHelper(OpKernelContext* context,
|
||||||
const Tensor& input,
|
const Tensor& input,
|
||||||
const Tensor& segment_ids) {
|
const Tensor& segment_ids) {
|
||||||
|
OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(input.shape()),
|
||||||
|
errors::InvalidArgument("input must be at least rank 1"));
|
||||||
OP_REQUIRES(context, TensorShapeUtils::IsVector(segment_ids.shape()),
|
OP_REQUIRES(context, TensorShapeUtils::IsVector(segment_ids.shape()),
|
||||||
errors::InvalidArgument("segment_ids should be a vector."));
|
errors::InvalidArgument("segment_ids should be a vector."));
|
||||||
const int64 num_indices = segment_ids.NumElements();
|
const int64 num_indices = segment_ids.NumElements();
|
||||||
|
@ -25,6 +25,7 @@ import numpy as np
|
|||||||
from tensorflow.python.client import session
|
from tensorflow.python.client import session
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes as dtypes_lib
|
from tensorflow.python.framework import dtypes as dtypes_lib
|
||||||
|
from tensorflow.python.framework import errors_impl
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import gradient_checker
|
from tensorflow.python.ops import gradient_checker
|
||||||
@ -255,6 +256,17 @@ class SegmentReductionOpTest(SegmentReductionHelper):
|
|||||||
delta=1)
|
delta=1)
|
||||||
self.assertAllClose(jacob_t, jacob_n)
|
self.assertAllClose(jacob_t, jacob_n)
|
||||||
|
|
||||||
|
def testDataInvalid(self):
|
||||||
|
# Test case for GitHub issue 40653.
|
||||||
|
for use_gpu in [True, False]:
|
||||||
|
with self.cached_session(use_gpu=use_gpu):
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
(ValueError, errors_impl.InvalidArgumentError),
|
||||||
|
"must be at least rank 1"):
|
||||||
|
s = math_ops.segment_mean(
|
||||||
|
data=np.uint16(10), segment_ids=np.array([]).astype("int64"))
|
||||||
|
self.evaluate(s)
|
||||||
|
|
||||||
|
|
||||||
class UnsortedSegmentTest(SegmentReductionHelper):
|
class UnsortedSegmentTest(SegmentReductionHelper):
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user