Merge pull request #46741 from yongtang:46700-tf.math.reduce_prod-keepdims
PiperOrigin-RevId: 354991067 Change-Id: Ibf9c353042b49b5ffc8113aed99fb9f3cbf6eea2
This commit is contained in:
commit
808fc11053
@ -116,6 +116,24 @@ class ReductionUnknownShape(test.TestCase):
|
||||
self.assertEqual(y.shape, ())
|
||||
|
||||
|
||||
class ReductionInvalidKeepdims(test.TestCase):
|
||||
|
||||
def testBasic(self):
|
||||
# Test case for GitHub issue 46700.
|
||||
for dtype, reductions in [
|
||||
(dtypes.float32, (math_ops.reduce_sum, math_ops.reduce_mean,
|
||||
math_ops.reduce_prod, math_ops.reduce_max,
|
||||
math_ops.reduce_min, math_ops.reduce_euclidean_norm)),
|
||||
(dtypes.bool, (math_ops.reduce_all, math_ops.reduce_any))
|
||||
]:
|
||||
for reduction in reductions:
|
||||
with self.assertRaisesRegex(ValueError, "The truth value"):
|
||||
x = True if dtype == dtypes.bool else 1
|
||||
y = reduction(
|
||||
input_tensor=x, keepdims=np.array([63600, 1], dtype=np.float16))
|
||||
self.evaluate(y)
|
||||
|
||||
|
||||
class BaseReductionTest(test.TestCase):
|
||||
|
||||
def _tf_reduce(self, x, reduction_axes, keepdims):
|
||||
|
@ -2016,7 +2016,7 @@ def reduce_sum_with_dims(input_tensor,
|
||||
keepdims=False,
|
||||
name=None,
|
||||
dims=None):
|
||||
keepdims = False if keepdims is None else keepdims
|
||||
keepdims = False if keepdims is None else bool(keepdims)
|
||||
return _may_reduce_to_scalar(
|
||||
keepdims, axis,
|
||||
gen_math_ops._sum(input_tensor, dims, keepdims, name=name))
|
||||
@ -2059,6 +2059,7 @@ def reduce_euclidean_norm(input_tensor, axis=None, keepdims=False, name=None):
|
||||
Returns:
|
||||
The reduced tensor, of the same dtype as the input_tensor.
|
||||
"""
|
||||
keepdims = bool(keepdims)
|
||||
return _may_reduce_to_scalar(
|
||||
keepdims, axis,
|
||||
gen_math_ops.euclidean_norm(
|
||||
@ -2331,7 +2332,7 @@ def reduce_mean(input_tensor, axis=None, keepdims=False, name=None):
|
||||
|
||||
@end_compatibility
|
||||
"""
|
||||
keepdims = False if keepdims is None else keepdims
|
||||
keepdims = False if keepdims is None else bool(keepdims)
|
||||
return _may_reduce_to_scalar(
|
||||
keepdims, axis,
|
||||
gen_math_ops.mean(
|
||||
@ -2491,7 +2492,7 @@ def reduce_prod(input_tensor, axis=None, keepdims=False, name=None):
|
||||
Equivalent to np.prod
|
||||
@end_compatibility
|
||||
"""
|
||||
keepdims = False if keepdims is None else keepdims
|
||||
keepdims = False if keepdims is None else bool(keepdims)
|
||||
return _may_reduce_to_scalar(
|
||||
keepdims, axis,
|
||||
gen_math_ops.prod(
|
||||
@ -2678,7 +2679,7 @@ def reduce_min(input_tensor, axis=None, keepdims=False, name=None):
|
||||
Equivalent to np.min
|
||||
@end_compatibility
|
||||
"""
|
||||
keepdims = False if keepdims is None else keepdims
|
||||
keepdims = False if keepdims is None else bool(keepdims)
|
||||
return _may_reduce_to_scalar(
|
||||
keepdims, axis,
|
||||
gen_math_ops._min(
|
||||
@ -2805,7 +2806,7 @@ def reduce_max_with_dims(input_tensor,
|
||||
keepdims=False,
|
||||
name=None,
|
||||
dims=None):
|
||||
keepdims = False if keepdims is None else keepdims
|
||||
keepdims = False if keepdims is None else bool(keepdims)
|
||||
return _may_reduce_to_scalar(
|
||||
keepdims, axis,
|
||||
gen_math_ops._max(input_tensor, dims, keepdims, name=name))
|
||||
@ -2909,7 +2910,7 @@ def reduce_all(input_tensor, axis=None, keepdims=False, name=None):
|
||||
Equivalent to np.all
|
||||
@end_compatibility
|
||||
"""
|
||||
keepdims = False if keepdims is None else keepdims
|
||||
keepdims = False if keepdims is None else bool(keepdims)
|
||||
return _may_reduce_to_scalar(
|
||||
keepdims, axis,
|
||||
gen_math_ops._all(
|
||||
@ -3015,7 +3016,7 @@ def reduce_any(input_tensor, axis=None, keepdims=False, name=None):
|
||||
Equivalent to np.any
|
||||
@end_compatibility
|
||||
"""
|
||||
keepdims = False if keepdims is None else keepdims
|
||||
keepdims = False if keepdims is None else bool(keepdims)
|
||||
return _may_reduce_to_scalar(
|
||||
keepdims, axis,
|
||||
gen_math_ops._any(
|
||||
|
Loading…
Reference in New Issue
Block a user