diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py index 13215ffabf3..8b6ed9f041b 100644 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py @@ -81,7 +81,7 @@ class ExpectationImportanceSampleTest(test.TestCase): # Compute E_p[X_1 * X_2 > 0], with X_i the ith component of X ~ p(x). # Should equal 1/2 because p is a spherical Gaussian centered at (0, 0). def indicator(x): - x1_times_x2 = math_ops.reduce_prod(x, reduction_indices=[-1]) + x1_times_x2 = math_ops.reduce_prod(x, axis=[-1]) return 0.5 * (math_ops.sign(x1_times_x2) + 1.0) prob = mc.expectation_importance_sampler( diff --git a/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py b/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py index 18d40fc1dff..e83a5485119 100644 --- a/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py +++ b/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py @@ -353,12 +353,12 @@ def expectation(f, samples, log_prob=None, use_reparametrization=True, def _sample_mean(values): """Mean over sample indices. In this module this is always [0].""" - return math_ops.reduce_mean(values, reduction_indices=[0]) + return math_ops.reduce_mean(values, axis=[0]) def _sample_max(values): """Max over sample indices. In this module this is always [0].""" - return math_ops.reduce_max(values, reduction_indices=[0]) + return math_ops.reduce_max(values, axis=[0]) def _get_samples(dist, z, n, seed): diff --git a/tensorflow/contrib/boosted_trees/python/utils/losses.py b/tensorflow/contrib/boosted_trees/python/utils/losses.py index 7a99dccdd10..220e981618b 100644 --- a/tensorflow/contrib/boosted_trees/python/utils/losses.py +++ b/tensorflow/contrib/boosted_trees/python/utils/losses.py @@ -119,8 +119,7 @@ def per_example_maxent_loss(labels, weights, logits, num_classes, eps=1e-15): labels = array_ops.expand_dims(labels, 1) # Labels are indices of classes, convert them to one hot encodings. target_one_hot = array_ops.one_hot(indices=labels, depth=num_classes) - labels = math_ops.reduce_sum( - input_tensor=target_one_hot, reduction_indices=[1]) + labels = math_ops.reduce_sum(input_tensor=target_one_hot, axis=[1]) labels = math_ops.to_float(labels) # Calculate softmax probabilities for each class. diff --git a/tensorflow/contrib/distributions/python/kernel_tests/normal_conjugate_posteriors_test.py b/tensorflow/contrib/distributions/python/kernel_tests/normal_conjugate_posteriors_test.py index 29eeaf43c51..ab3c07172a6 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/normal_conjugate_posteriors_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/normal_conjugate_posteriors_test.py @@ -82,7 +82,7 @@ class NormalTest(test.TestCase): x = constant_op.constant( [[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0], [2.5, -2.5, -4.0, 0.0, 1.0, -2.0]], dtype=dtypes.float32) - s = math_ops.reduce_sum(x, reduction_indices=[1]) + s = math_ops.reduce_sum(x, axis=[1]) x = array_ops.transpose(x) # Reshape to shape (6, 2) n = constant_op.constant([6] * 2) prior = distributions.Normal(loc=mu0, scale=sigma0) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py b/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py index a60056c444a..cdee30bbc42 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py @@ -147,14 +147,13 @@ class WishartCholeskyTest(test.TestCase): x = chol_w.sample(10000, seed=42) self.assertAllEqual((10000, 3, 3), x.get_shape()) - moment1_estimate = math_ops.reduce_mean(x, reduction_indices=[0]).eval() + moment1_estimate = math_ops.reduce_mean(x, axis=[0]).eval() self.assertAllClose(chol_w.mean().eval(), moment1_estimate, rtol=0.05) # The Variance estimate uses the squares rather than outer-products # because Wishart.Variance is the diagonal of the Wishart covariance # matrix. - variance_estimate = (math_ops.reduce_mean( - math_ops.square(x), reduction_indices=[0]) - + variance_estimate = (math_ops.reduce_mean(math_ops.square(x), axis=[0]) - math_ops.square(moment1_estimate)).eval() self.assertAllClose( chol_w.variance().eval(), variance_estimate, rtol=0.05) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py b/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py index 15c241d5d7a..74765f19e58 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py @@ -168,7 +168,7 @@ class SoftmaxCentered(bijector.Bijector): # log_normalization = 1 + reduce_sum(exp(logits)) # -log_normalization + reduce_sum(logits - log_normalization) log_normalization = nn_ops.softplus( - math_ops.reduce_logsumexp(x, axis=-1, keep_dims=True)) + math_ops.reduce_logsumexp(x, axis=-1, keepdims=True)) return array_ops.squeeze( (-log_normalization + math_ops.reduce_sum( x - log_normalization, axis=-1, keepdims=True)), axis=-1) diff --git a/tensorflow/contrib/eager/python/tfe_test.py b/tensorflow/contrib/eager/python/tfe_test.py index 4454abfb966..8c35dddb5a5 100644 --- a/tensorflow/contrib/eager/python/tfe_test.py +++ b/tensorflow/contrib/eager/python/tfe_test.py @@ -87,8 +87,8 @@ class TFETest(test_util.TensorFlowTestCase): x += 1. # Without a device context, heuristics are used to place ops. # In this case, ops.reduce_mean runs on the GPU. - reduction_indices = range(x.shape.ndims) - m = math_ops.reduce_mean(x, reduction_indices) + axis = range(x.shape.ndims) + m = math_ops.reduce_mean(x, axis) # m is on GPU, bring it back to CPU and compare. self.assertEqual(3.5, m.cpu().numpy()) diff --git a/tensorflow/contrib/layers/python/layers/encoders.py b/tensorflow/contrib/layers/python/layers/encoders.py index f42112206d0..3671633c8d7 100644 --- a/tensorflow/contrib/layers/python/layers/encoders.py +++ b/tensorflow/contrib/layers/python/layers/encoders.py @@ -84,8 +84,7 @@ def bow_encoder(ids, if isinstance(ids, sparse_tensor.SparseTensor): raise TypeError('ids are expected to be dense Tensor, got: %s', ids) return math_ops.reduce_mean( - embedding_ops.embedding_lookup(embeddings, ids), - reduction_indices=1) + embedding_ops.embedding_lookup(embeddings, ids), axis=1) def embed_sequence(ids, diff --git a/tensorflow/contrib/layers/python/layers/feature_column.py b/tensorflow/contrib/layers/python/layers/feature_column.py index 222404b19db..00d819ed0e9 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column.py +++ b/tensorflow/contrib/layers/python/layers/feature_column.py @@ -1015,8 +1015,7 @@ class _OneHotColumn( dense_id_tensor, depth=self.length, on_value=1.0, off_value=0.0) # Reduce to get a multi-hot per example. - return math_ops.reduce_sum( - one_hot_id_tensor, reduction_indices=[output_rank - 1]) + return math_ops.reduce_sum(one_hot_id_tensor, axis=[output_rank - 1]) @property def _variable_shape(self): diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index 8ead6336a08..0a4d2c6d4cb 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -3811,7 +3811,7 @@ class UnitNormTests(test.TestCase): image = random_ops.random_uniform((height, width, 3)) output = _layers.unit_norm(image, dim=dim, epsilon=1e-6) norms = math_ops.sqrt( - math_ops.reduce_sum(math_ops.square(output), reduction_indices=dim)) + math_ops.reduce_sum(math_ops.square(output), axis=dim)) shape = [height, width, 3] del shape[dim] @@ -3847,7 +3847,7 @@ class UnitNormTests(test.TestCase): image = array_ops.placeholder(dtypes.float32, (None, None, 3)) output = _layers.unit_norm(image, dim=dim, epsilon=1e-6) norms = math_ops.sqrt( - math_ops.reduce_sum(math_ops.square(output), reduction_indices=dim)) + math_ops.reduce_sum(math_ops.square(output), axis=dim)) with self.cached_session(): actual = norms.eval({image: placeholder_value}) diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py index 1d8a59281a4..28c4964527b 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py @@ -668,7 +668,7 @@ class DynamicRNNEstimatorLearningTest(test.TestCase): sequences = centers + noise inputs = array_ops.expand_dims(sequences, 2) - labels = math_ops.reduce_mean(sequences, reduction_indices=[1]) + labels = math_ops.reduce_mean(sequences, axis=[1]) return {'inputs': inputs}, labels return input_fn @@ -722,8 +722,8 @@ class DynamicRNNEstimatorLearningTest(test.TestCase): inputs = array_ops.expand_dims(math_ops.to_float(random_sequence), 2) labels = math_ops.to_int32( array_ops.squeeze( - math_ops.reduce_sum( - inputs, reduction_indices=[1]) > (sequence_length / 2.0))) + math_ops.reduce_sum(inputs, axis=[1]) > ( + sequence_length / 2.0))) return {'inputs': inputs}, labels return input_fn diff --git a/tensorflow/contrib/losses/python/losses/loss_ops.py b/tensorflow/contrib/losses/python/losses/loss_ops.py index d8ac4163b21..709a042bbce 100644 --- a/tensorflow/contrib/losses/python/losses/loss_ops.py +++ b/tensorflow/contrib/losses/python/losses/loss_ops.py @@ -59,9 +59,8 @@ def _scale_losses(losses, weights): """ # First, compute the sum of the losses over all elements: start_index = max(0, weights.get_shape().ndims) - reduction_indices = list(range(start_index, losses.get_shape().ndims)) - reduced_losses = math_ops.reduce_sum( - losses, reduction_indices=reduction_indices) + axis = list(range(start_index, losses.get_shape().ndims)) + reduced_losses = math_ops.reduce_sum(losses, axis=axis) reduced_losses = math_ops.multiply(reduced_losses, weights) return math_ops.reduce_sum(reduced_losses) @@ -158,10 +157,9 @@ def _num_present(losses, weights, per_batch=False): # First, count the number of nonzero weights: if weights.get_shape().ndims >= 1: - reduction_indices = list(range(1, weights.get_shape().ndims)) + axis = list(range(1, weights.get_shape().ndims)) num_nonzero_per_batch = math_ops.reduce_sum( - math_ops.to_float(math_ops.not_equal(weights, 0)), - reduction_indices=reduction_indices) + math_ops.to_float(math_ops.not_equal(weights, 0)), axis=axis) # Next, determine the number of elements that weights would broadcast to: broadcast_dims = array_ops.slice( @@ -577,16 +575,16 @@ def mean_pairwise_squared_error(predictions, if weights.get_shape().ndims is None: raise ValueError("weights.get_shape().ndims cannot be None") - reduction_indices = list(range(1, diffs.get_shape().ndims)) + axis = list(range(1, diffs.get_shape().ndims)) sum_squares_diff_per_batch = math_ops.reduce_sum( - math_ops.square(diffs), reduction_indices=reduction_indices) + math_ops.square(diffs), axis=axis) num_present_per_batch = _num_present(diffs, weights, per_batch=True) term1 = 2.0 * math_ops.div_no_nan( sum_squares_diff_per_batch, num_present_per_batch, name="value") - sum_diff = math_ops.reduce_sum(diffs, reduction_indices=reduction_indices) + sum_diff = math_ops.reduce_sum(diffs, axis=axis) term2 = 2.0 * math_ops.div_no_nan( math_ops.square(sum_diff), math_ops.square(num_present_per_batch), @@ -645,7 +643,7 @@ def cosine_distance(predictions, radial_diffs = math_ops.multiply(predictions, labels) losses = 1 - math_ops.reduce_sum( - radial_diffs, reduction_indices=[ + radial_diffs, axis=[ axis, ]) return compute_weighted_loss(losses, weights, scope=scope) diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py index 09fe65b73f8..7b432f8bd20 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py @@ -3416,7 +3416,7 @@ def streaming_mean_cosine_distance(predictions, predictions.get_shape().assert_is_compatible_with(labels.get_shape()) radial_diffs = math_ops.multiply(predictions, labels) radial_diffs = math_ops.reduce_sum( - radial_diffs, reduction_indices=[ + radial_diffs, axis=[ dim, ], keepdims=True) mean_distance, update_op = streaming_mean(radial_diffs, weights, None, None, diff --git a/tensorflow/contrib/quantize/python/quant_ops.py b/tensorflow/contrib/quantize/python/quant_ops.py index 6f659347fba..8619708cdae 100644 --- a/tensorflow/contrib/quantize/python/quant_ops.py +++ b/tensorflow/contrib/quantize/python/quant_ops.py @@ -138,7 +138,7 @@ def LastValueQuantize(inputs, if per_channel: if input_dim >= 2: batch_min = math_ops.reduce_min( - inputs, reduction_indices=reduce_dims, name='BatchMin') + inputs, axis=reduce_dims, name='BatchMin') else: batch_min = inputs else: @@ -147,7 +147,7 @@ def LastValueQuantize(inputs, if per_channel: if input_dim >= 2: batch_max = math_ops.reduce_max( - inputs, reduction_indices=reduce_dims, name='BatchMax') + inputs, axis=reduce_dims, name='BatchMax') else: batch_max = inputs else: @@ -263,7 +263,7 @@ def MovingAvgQuantize(inputs, if per_channel: if input_dim >= 2: batch_min = math_ops.reduce_min( - inputs, reduction_indices=reduce_dims, name='BatchMin') + inputs, axis=reduce_dims, name='BatchMin') else: batch_min = inputs else: @@ -272,7 +272,7 @@ def MovingAvgQuantize(inputs, if per_channel: if input_dim >= 2: batch_max = math_ops.reduce_max( - inputs, reduction_indices=reduce_dims, name='BatchMax') + inputs, axis=reduce_dims, name='BatchMax') else: batch_max = inputs else: diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py index 237b7f304e0..8b85548e5c9 100644 --- a/tensorflow/python/eager/backprop_test.py +++ b/tensorflow/python/eager/backprop_test.py @@ -74,7 +74,7 @@ class BackpropTest(test.TestCase): tf_g1 = embedding_ops.embedding_lookup(tf_var, tf_ind1) tf_g2 = embedding_ops.embedding_lookup(tf_var, tf_ind2) tf_g3 = embedding_ops.embedding_lookup(tf_var, tf_ind3) - tf_g4 = math_ops.reduce_sum(tf_var * 2.0, reduction_indices=(0, 1)) + tf_g4 = math_ops.reduce_sum(tf_var * 2.0, axis=(0, 1)) tf_y = tf_g1 * tf_g2 * tf_g3 * tf_g4 tf_grad = gradients.gradients(tf_y, [tf_var])[0] diff --git a/tensorflow/python/grappler/cost_analyzer_test.py b/tensorflow/python/grappler/cost_analyzer_test.py index b8225b81a52..de80df1879d 100644 --- a/tensorflow/python/grappler/cost_analyzer_test.py +++ b/tensorflow/python/grappler/cost_analyzer_test.py @@ -96,8 +96,8 @@ class CostAnalysisTest(test.TestCase): b_fc = variables.Variable(random_ops.truncated_normal([10], stddev=0.1)) y_conv = nn_ops.softmax(math_ops.matmul(h_conv_flat, w_fc) + b_fc) - cross_entropy = math_ops.reduce_mean(-math_ops.reduce_sum( - label * math_ops.log(y_conv), reduction_indices=[1])) + cross_entropy = math_ops.reduce_mean( + -math_ops.reduce_sum(label * math_ops.log(y_conv), axis=[1])) _ = adam.AdamOptimizer(1e-4).minimize(cross_entropy) mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph()) diff --git a/tensorflow/python/ops/linalg/linalg_impl.py b/tensorflow/python/ops/linalg/linalg_impl.py index 08d50ce622f..2c9476a9bd3 100644 --- a/tensorflow/python/ops/linalg/linalg_impl.py +++ b/tensorflow/python/ops/linalg/linalg_impl.py @@ -88,7 +88,7 @@ def logdet(matrix, name=None): chol = gen_linalg_ops.cholesky(matrix) return 2.0 * math_ops.reduce_sum( math_ops.log(math_ops.real(array_ops.matrix_diag_part(chol))), - reduction_indices=[-1]) + axis=[-1]) @tf_export('linalg.adjoint') diff --git a/tensorflow/python/ops/linalg/linear_operator.py b/tensorflow/python/ops/linalg/linear_operator.py index 6fb7a57e4d9..8efafda3a1e 100644 --- a/tensorflow/python/ops/linalg/linear_operator.py +++ b/tensorflow/python/ops/linalg/linear_operator.py @@ -690,7 +690,7 @@ class LinearOperator(object): " Requires conversion to a dense matrix and O(N^3) operations.") if self._can_use_cholesky(): diag = array_ops.matrix_diag_part(linalg_ops.cholesky(self.to_dense())) - return 2 * math_ops.reduce_sum(math_ops.log(diag), reduction_indices=[-1]) + return 2 * math_ops.reduce_sum(math_ops.log(diag), axis=[-1]) _, log_abs_det = linalg.slogdet(self.to_dense()) return log_abs_det diff --git a/tensorflow/python/ops/linalg/linear_operator_circulant.py b/tensorflow/python/ops/linalg/linear_operator_circulant.py index 09f0c518e7a..b74baa5dfdb 100644 --- a/tensorflow/python/ops/linalg/linear_operator_circulant.py +++ b/tensorflow/python/ops/linalg/linear_operator_circulant.py @@ -418,15 +418,13 @@ class _BaseLinearOperatorCirculant(linear_operator.LinearOperator): return math_ops.cast(y, self.dtype) def _determinant(self): - reduction_indices = [-(i + 1) for i in range(self.block_depth)] - det = math_ops.reduce_prod( - self.spectrum, reduction_indices=reduction_indices) + axis = [-(i + 1) for i in range(self.block_depth)] + det = math_ops.reduce_prod(self.spectrum, axis=axis) return math_ops.cast(det, self.dtype) def _log_abs_determinant(self): - reduction_indices = [-(i + 1) for i in range(self.block_depth)] - lad = math_ops.reduce_sum( - math_ops.log(self._abs_spectrum), reduction_indices=reduction_indices) + axis = [-(i + 1) for i in range(self.block_depth)] + lad = math_ops.reduce_sum(math_ops.log(self._abs_spectrum), axis=axis) return math_ops.cast(lad, self.dtype) def _solve(self, rhs, adjoint=False, adjoint_arg=False): diff --git a/tensorflow/python/ops/linalg/linear_operator_diag.py b/tensorflow/python/ops/linalg/linear_operator_diag.py index ed53decc00d..be893c705c9 100644 --- a/tensorflow/python/ops/linalg/linear_operator_diag.py +++ b/tensorflow/python/ops/linalg/linear_operator_diag.py @@ -228,11 +228,11 @@ class LinearOperatorDiag(linear_operator.LinearOperator): return diag_mat * x def _determinant(self): - return math_ops.reduce_prod(self._diag, reduction_indices=[-1]) + return math_ops.reduce_prod(self._diag, axis=[-1]) def _log_abs_determinant(self): log_det = math_ops.reduce_sum( - math_ops.log(math_ops.abs(self._diag)), reduction_indices=[-1]) + math_ops.log(math_ops.abs(self._diag)), axis=[-1]) if self.dtype.is_complex: log_det = math_ops.cast(log_det, dtype=self.dtype) return log_det diff --git a/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py b/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py index c4288ff8f87..aa0500aff06 100644 --- a/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py +++ b/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py @@ -391,7 +391,7 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator): if self._use_cholesky: chol_cap_diag = array_ops.matrix_diag_part(self._chol_capacitance) log_abs_det_c = 2 * math_ops.reduce_sum( - math_ops.log(chol_cap_diag), reduction_indices=[-1]) + math_ops.log(chol_cap_diag), axis=[-1]) else: det_c = linalg_ops.matrix_determinant(self._capacitance) log_abs_det_c = math_ops.log(math_ops.abs(det_c)) diff --git a/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py b/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py index ca6d3f54051..d33fe17e042 100644 --- a/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py +++ b/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py @@ -195,11 +195,11 @@ class LinearOperatorLowerTriangular(linear_operator.LinearOperator): self._tril, x, adjoint_a=adjoint, adjoint_b=adjoint_arg) def _determinant(self): - return math_ops.reduce_prod(self._diag, reduction_indices=[-1]) + return math_ops.reduce_prod(self._diag, axis=[-1]) def _log_abs_determinant(self): return math_ops.reduce_sum( - math_ops.log(math_ops.abs(self._diag)), reduction_indices=[-1]) + math_ops.log(math_ops.abs(self._diag)), axis=[-1]) def _solve(self, rhs, adjoint=False, adjoint_arg=False): rhs = linalg.adjoint(rhs) if adjoint_arg else rhs diff --git a/tensorflow/python/ops/losses/losses_impl.py b/tensorflow/python/ops/losses/losses_impl.py index e8cadf931bc..0a5b511f820 100644 --- a/tensorflow/python/ops/losses/losses_impl.py +++ b/tensorflow/python/ops/losses/losses_impl.py @@ -583,12 +583,10 @@ def mean_pairwise_squared_error( diffs = math_ops.subtract(predictions, labels) - reduction_indices = math_ops.range(1, array_ops.rank(diffs)) + axis = math_ops.range(1, array_ops.rank(diffs)) sum_squares_diff_per_batch = math_ops.reduce_sum( - math_ops.square(diffs), - reduction_indices=reduction_indices, - keepdims=True) + math_ops.square(diffs), axis=axis, keepdims=True) num_present_per_batch = _num_present(diffs, weights, per_batch=True) term1 = 2.0 * math_ops.div_no_nan( @@ -596,8 +594,7 @@ def mean_pairwise_squared_error( math_ops.maximum(num_present_per_batch - 1, 0), name="value") - sum_diff = math_ops.reduce_sum( - diffs, reduction_indices=reduction_indices, keepdims=True) + sum_diff = math_ops.reduce_sum(diffs, axis=axis, keepdims=True) term2 = 2.0 * math_ops.div_no_nan( math_ops.square(sum_diff), math_ops.maximum( diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 952a2a1e798..73ca3d527ab 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -1314,7 +1314,7 @@ def range(start, limit=None, delta=1, dtype=None, name="range"): # pylint: disa # Reduction operations -def _ReductionDims(x, axis, reduction_indices): +def _ReductionDims(x, axis, reduction_indices=None): # pylint: disable=invalid-name """Returns range(0, rank(x)) if reduction_indices is None.""" # TODO(aselle): Remove this after deprecation if reduction_indices is not None: @@ -1337,23 +1337,23 @@ def _ReductionDims(x, axis, reduction_indices): return range(0, array_ops.rank(x)) -def _may_reduce_to_scalar(keepdims, axis, reduction_indices, output): +def _may_reduce_to_scalar(keepdims, axis, output): """Set a reduction's output shape to be a scalar if we are certain.""" if not common_shapes.has_fully_defined_shape(output) and (not keepdims) and ( - axis is None) and (reduction_indices is None): + axis is None): output.set_shape(()) return output -@tf_export("math.reduce_sum", "reduce_sum") +@tf_export(v1=["math.reduce_sum", "reduce_sum"]) @deprecation.deprecated_args( None, "keep_dims is deprecated, use keepdims instead", "keep_dims") -def reduce_sum(input_tensor, - axis=None, - keepdims=None, - name=None, - reduction_indices=None, - keep_dims=None): +def reduce_sum_v1(input_tensor, + axis=None, + keepdims=None, + name=None, + reduction_indices=None, + keep_dims=None): """Computes the sum of elements across dimensions of a tensor. Reduces `input_tensor` along the dimensions given in `axis`. @@ -1393,18 +1393,58 @@ def reduce_sum(input_tensor, int64 while tensorflow returns the same dtype as the input. @end_compatibility """ + axis = deprecation.deprecated_argument_lookup( + "axis", axis, "reduction_indices", reduction_indices) keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims, "keep_dims", keep_dims) - if keepdims is None: - keepdims = False + return reduce_sum(input_tensor, axis, keepdims, name) - return _may_reduce_to_scalar(keepdims, axis, reduction_indices, - gen_math_ops._sum( - input_tensor, - _ReductionDims(input_tensor, axis, - reduction_indices), - keepdims, - name=name)) + +@tf_export("math.reduce_sum", "reduce_sum", v1=[]) +def reduce_sum(input_tensor, axis=None, keepdims=False, name=None): + """Computes the sum of elements across dimensions of a tensor. + + Reduces `input_tensor` along the dimensions given in `axis`. + Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each + entry in `axis`. If `keepdims` is true, the reduced dimensions + are retained with length 1. + + If `axis` is None, all dimensions are reduced, and a + tensor with a single element is returned. + + For example: + + ```python + x = tf.constant([[1, 1, 1], [1, 1, 1]]) + tf.reduce_sum(x) # 6 + tf.reduce_sum(x, 0) # [2, 2, 2] + tf.reduce_sum(x, 1) # [3, 3] + tf.reduce_sum(x, 1, keepdims=True) # [[3], [3]] + tf.reduce_sum(x, [0, 1]) # 6 + ``` + + Args: + input_tensor: The tensor to reduce. Should have numeric type. + axis: The dimensions to reduce. If `None` (the default), reduces all + dimensions. Must be in the range `[-rank(input_tensor), + rank(input_tensor))`. + keepdims: If true, retains reduced dimensions with length 1. + name: A name for the operation (optional). + + Returns: + The reduced tensor, of the same dtype as the input_tensor. + + @compatibility(numpy) + Equivalent to np.sum apart the fact that numpy upcast uint8 and int32 to + int64 while tensorflow returns the same dtype as the input. + @end_compatibility + """ + keepdims = False if keepdims is None else keepdims + return _may_reduce_to_scalar( + keepdims, axis, + gen_math_ops._sum( + input_tensor, _ReductionDims(input_tensor, axis), keepdims, + name=name)) @tf_export(v1=["math.count_nonzero", "count_nonzero"]) @@ -1544,15 +1584,13 @@ def count_nonzero_v2(input, # pylint: disable=redefined-builtin dtype=dtype) -@tf_export("math.reduce_mean", "reduce_mean") -@deprecation.deprecated_args( - None, "keep_dims is deprecated, use keepdims instead", "keep_dims") -def reduce_mean(input_tensor, - axis=None, - keepdims=None, - name=None, - reduction_indices=None, - keep_dims=None): +@tf_export(v1=["math.reduce_mean", "reduce_mean"]) +def reduce_mean_v1(input_tensor, + axis=None, + keepdims=None, + name=None, + reduction_indices=None, + keep_dims=None): """Computes the mean of elements across dimensions of a tensor. Reduces `input_tensor` along the dimensions given in `axis`. @@ -1602,22 +1640,72 @@ def reduce_mean(input_tensor, @end_compatibility """ + axis = deprecation.deprecated_argument_lookup( + "axis", axis, "reduction_indices", reduction_indices) keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims, "keep_dims", keep_dims) + return reduce_mean(input_tensor, axis, keepdims, name) - if keepdims is None: - keepdims = False - return _may_reduce_to_scalar(keepdims, axis, reduction_indices, - gen_math_ops.mean( - input_tensor, - _ReductionDims(input_tensor, axis, - reduction_indices), - keepdims, - name=name)) + +@tf_export("math.reduce_mean", "reduce_mean", v1=[]) +def reduce_mean(input_tensor, axis=None, keepdims=False, name=None): + """Computes the mean of elements across dimensions of a tensor. + + Reduces `input_tensor` along the dimensions given in `axis`. + Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each + entry in `axis`. If `keepdims` is true, the reduced dimensions + are retained with length 1. + + If `axis` is None, all dimensions are reduced, and a + tensor with a single element is returned. + + For example: + + ```python + x = tf.constant([[1., 1.], [2., 2.]]) + tf.reduce_mean(x) # 1.5 + tf.reduce_mean(x, 0) # [1.5, 1.5] + tf.reduce_mean(x, 1) # [1., 2.] + ``` + + Args: + input_tensor: The tensor to reduce. Should have numeric type. + axis: The dimensions to reduce. If `None` (the default), reduces all + dimensions. Must be in the range `[-rank(input_tensor), + rank(input_tensor))`. + keepdims: If true, retains reduced dimensions with length 1. + name: A name for the operation (optional). + + Returns: + The reduced tensor. + + @compatibility(numpy) + Equivalent to np.mean + + Please note that `np.mean` has a `dtype` parameter that could be used to + specify the output type. By default this is `dtype=float64`. On the other + hand, `tf.reduce_mean` has an aggressive type inference from `input_tensor`, + for example: + + ```python + x = tf.constant([1, 0, 1, 0]) + tf.reduce_mean(x) # 0 + y = tf.constant([1., 0., 1., 0.]) + tf.reduce_mean(y) # 0.5 + ``` + + @end_compatibility + """ + keepdims = False if keepdims is None else keepdims + return _may_reduce_to_scalar( + keepdims, axis, + gen_math_ops.mean( + input_tensor, _ReductionDims(input_tensor, axis), keepdims, + name=name)) @tf_export("math.reduce_variance") -def reduce_variance(input_tensor, axis=None, keepdims=None, name=None): +def reduce_variance(input_tensor, axis=None, keepdims=False, name=None): """Computes the variance of elements across dimensions of a tensor. Reduces `input_tensor` along the dimensions given in `axis`. @@ -1665,7 +1753,7 @@ def reduce_variance(input_tensor, axis=None, keepdims=None, name=None): @tf_export("math.reduce_std") -def reduce_std(input_tensor, axis=None, keepdims=None, name=None): +def reduce_std(input_tensor, axis=None, keepdims=False, name=None): """Computes the standard deviation of elements across dimensions of a tensor. Reduces `input_tensor` along the dimensions given in `axis`. @@ -1710,15 +1798,8 @@ def reduce_std(input_tensor, axis=None, keepdims=None, name=None): return sqrt(variance) -@tf_export("math.reduce_prod", "reduce_prod") -@deprecation.deprecated_args( - None, "keep_dims is deprecated, use keepdims instead", "keep_dims") -def reduce_prod(input_tensor, - axis=None, - keepdims=None, - name=None, - reduction_indices=None, - keep_dims=None): +@tf_export("math.reduce_prod", "reduce_prod", v1=[]) +def reduce_prod(input_tensor, axis=None, keepdims=False, name=None): """Computes the product of elements across dimensions of a tensor. Reduces `input_tensor` along the dimensions given in `axis`. @@ -1736,6 +1817,48 @@ def reduce_prod(input_tensor, `[-rank(input_tensor), rank(input_tensor))`. keepdims: If true, retains reduced dimensions with length 1. name: A name for the operation (optional). + + Returns: + The reduced tensor. + + @compatibility(numpy) + Equivalent to np.prod + @end_compatibility + """ + keepdims = False if keepdims is None else keepdims + return _may_reduce_to_scalar( + keepdims, axis, + gen_math_ops.prod( + input_tensor, _ReductionDims(input_tensor, axis), keepdims, + name=name)) + + +@tf_export(v1=["math.reduce_prod", "reduce_prod"]) +@deprecation.deprecated_args( + None, "keep_dims is deprecated, use keepdims instead", "keep_dims") +def reduce_prod_v1(input_tensor, + axis=None, + keepdims=None, + name=None, + reduction_indices=None, + keep_dims=None): + """Computes the product of elements across dimensions of a tensor. + + Reduces `input_tensor` along the dimensions given in `axis`. + Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each + entry in `axis`. If `keepdims` is true, the reduced dimensions + are retained with length 1. + + If `axis` is None, all dimensions are reduced, and a + tensor with a single element is returned. + + Args: + input_tensor: The tensor to reduce. Should have numeric type. + axis: The dimensions to reduce. If `None` (the default), reduces all + dimensions. Must be in the range `[-rank(input_tensor), + rank(input_tensor))`. + keepdims: If true, retains reduced dimensions with length 1. + name: A name for the operation (optional). reduction_indices: The old (deprecated) name for axis. keep_dims: Deprecated alias for `keepdims`. @@ -1746,29 +1869,22 @@ def reduce_prod(input_tensor, Equivalent to np.prod @end_compatibility """ + axis = deprecation.deprecated_argument_lookup( + "axis", axis, "reduction_indices", reduction_indices) keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims, "keep_dims", keep_dims) - - if keepdims is None: - keepdims = False - return _may_reduce_to_scalar(keepdims, axis, reduction_indices, - gen_math_ops.prod( - input_tensor, - _ReductionDims(input_tensor, axis, - reduction_indices), - keepdims, - name=name)) + return reduce_prod(input_tensor, axis, keepdims, name) -@tf_export("math.reduce_min", "reduce_min") +@tf_export(v1=["math.reduce_min", "reduce_min"]) @deprecation.deprecated_args( None, "keep_dims is deprecated, use keepdims instead", "keep_dims") -def reduce_min(input_tensor, - axis=None, - keepdims=None, - name=None, - reduction_indices=None, - keep_dims=None): +def reduce_min_v1(input_tensor, + axis=None, + keepdims=None, + name=None, + reduction_indices=None, + keep_dims=None): """Computes the minimum of elements across dimensions of a tensor. Reduces `input_tensor` along the dimensions given in `axis`. @@ -1781,9 +1897,9 @@ def reduce_min(input_tensor, Args: input_tensor: The tensor to reduce. Should have real numeric type. - axis: The dimensions to reduce. If `None` (the default), - reduces all dimensions. Must be in the range - `[-rank(input_tensor), rank(input_tensor))`. + axis: The dimensions to reduce. If `None` (the default), reduces all + dimensions. Must be in the range `[-rank(input_tensor), + rank(input_tensor))`. keepdims: If true, retains reduced dimensions with length 1. name: A name for the operation (optional). reduction_indices: The old (deprecated) name for axis. @@ -1796,28 +1912,57 @@ def reduce_min(input_tensor, Equivalent to np.min @end_compatibility """ + axis = deprecation.deprecated_argument_lookup( + "axis", axis, "reduction_indices", reduction_indices) keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims, "keep_dims", keep_dims) - if keepdims is None: - keepdims = False - return _may_reduce_to_scalar(keepdims, axis, reduction_indices, - gen_math_ops._min( - input_tensor, - _ReductionDims(input_tensor, axis, - reduction_indices), - keepdims, - name=name)) + return reduce_min(input_tensor, axis, keepdims, name) -@tf_export("math.reduce_max", "reduce_max") +@tf_export("math.reduce_min", "reduce_min", v1=[]) +def reduce_min(input_tensor, axis=None, keepdims=False, name=None): + """Computes the minimum of elements across dimensions of a tensor. + + Reduces `input_tensor` along the dimensions given in `axis`. + Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each + entry in `axis`. If `keepdims` is true, the reduced dimensions + are retained with length 1. + + If `axis` is None, all dimensions are reduced, and a + tensor with a single element is returned. + + Args: + input_tensor: The tensor to reduce. Should have real numeric type. + axis: The dimensions to reduce. If `None` (the default), reduces all + dimensions. Must be in the range `[-rank(input_tensor), + rank(input_tensor))`. + keepdims: If true, retains reduced dimensions with length 1. + name: A name for the operation (optional). + + Returns: + The reduced tensor. + + @compatibility(numpy) + Equivalent to np.min + @end_compatibility + """ + keepdims = False if keepdims is None else keepdims + return _may_reduce_to_scalar( + keepdims, axis, + gen_math_ops._min( + input_tensor, _ReductionDims(input_tensor, axis), keepdims, + name=name)) + + +@tf_export(v1=["math.reduce_max", "reduce_max"]) @deprecation.deprecated_args( None, "keep_dims is deprecated, use keepdims instead", "keep_dims") -def reduce_max(input_tensor, - axis=None, - keepdims=None, - name=None, - reduction_indices=None, - keep_dims=None): +def reduce_max_v1(input_tensor, + axis=None, + keepdims=None, + name=None, + reduction_indices=None, + keep_dims=None): """Computes the maximum of elements across dimensions of a tensor. Reduces `input_tensor` along the dimensions given in `axis`. @@ -1845,28 +1990,57 @@ def reduce_max(input_tensor, Equivalent to np.max @end_compatibility """ + axis = deprecation.deprecated_argument_lookup( + "axis", axis, "reduction_indices", reduction_indices) keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims, "keep_dims", keep_dims) - if keepdims is None: - keepdims = False - return _may_reduce_to_scalar(keepdims, axis, reduction_indices, - gen_math_ops._max( - input_tensor, - _ReductionDims(input_tensor, axis, - reduction_indices), - keepdims, - name=name)) + return reduce_max(input_tensor, axis, keepdims, name) -@tf_export("math.reduce_all", "reduce_all") +@tf_export("math.reduce_max", "reduce_max", v1=[]) +def reduce_max(input_tensor, axis=None, keepdims=False, name=None): + """Computes the maximum of elements across dimensions of a tensor. + + Reduces `input_tensor` along the dimensions given in `axis`. + Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each + entry in `axis`. If `keepdims` is true, the reduced dimensions + are retained with length 1. + + If `axis` is None, all dimensions are reduced, and a + tensor with a single element is returned. + + Args: + input_tensor: The tensor to reduce. Should have real numeric type. + axis: The dimensions to reduce. If `None` (the default), reduces all + dimensions. Must be in the range `[-rank(input_tensor), + rank(input_tensor))`. + keepdims: If true, retains reduced dimensions with length 1. + name: A name for the operation (optional). + + Returns: + The reduced tensor. + + @compatibility(numpy) + Equivalent to np.max + @end_compatibility + """ + keepdims = False if keepdims is None else keepdims + return _may_reduce_to_scalar( + keepdims, axis, + gen_math_ops._max( + input_tensor, _ReductionDims(input_tensor, axis), keepdims, + name=name)) + + +@tf_export(v1=["math.reduce_all", "reduce_all"]) @deprecation.deprecated_args( None, "keep_dims is deprecated, use keepdims instead", "keep_dims") -def reduce_all(input_tensor, - axis=None, - keepdims=None, - name=None, - reduction_indices=None, - keep_dims=None): +def reduce_all_v1(input_tensor, + axis=None, + keepdims=None, + name=None, + reduction_indices=None, + keep_dims=None): """Computes the "logical and" of elements across dimensions of a tensor. Reduces `input_tensor` along the dimensions given in `axis`. @@ -1888,9 +2062,9 @@ def reduce_all(input_tensor, Args: input_tensor: The boolean tensor to reduce. - axis: The dimensions to reduce. If `None` (the default), - reduces all dimensions. Must be in the range - `[-rank(input_tensor), rank(input_tensor))`. + axis: The dimensions to reduce. If `None` (the default), reduces all + dimensions. Must be in the range `[-rank(input_tensor), + rank(input_tensor))`. keepdims: If true, retains reduced dimensions with length 1. name: A name for the operation (optional). reduction_indices: The old (deprecated) name for axis. @@ -1903,28 +2077,66 @@ def reduce_all(input_tensor, Equivalent to np.all @end_compatibility """ + axis = deprecation.deprecated_argument_lookup( + "axis", axis, "reduction_indices", reduction_indices) keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims, "keep_dims", keep_dims) - if keepdims is None: - keepdims = False - return _may_reduce_to_scalar(keepdims, axis, reduction_indices, - gen_math_ops._all( - input_tensor, - _ReductionDims(input_tensor, axis, - reduction_indices), - keepdims, - name=name)) + return reduce_all(input_tensor, axis, keepdims, name) -@tf_export("math.reduce_any", "reduce_any") +@tf_export("reduce_all", "math.reduce_all", v1=[]) +def reduce_all(input_tensor, axis=None, keepdims=False, name=None): + """Computes the "logical and" of elements across dimensions of a tensor. + + Reduces `input_tensor` along the dimensions given in `axis`. + Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each + entry in `axis`. If `keepdims` is true, the reduced dimensions + are retained with length 1. + + If `axis` is None, all dimensions are reduced, and a + tensor with a single element is returned. + + For example: + + ```python + x = tf.constant([[True, True], [False, False]]) + tf.reduce_all(x) # False + tf.reduce_all(x, 0) # [False, False] + tf.reduce_all(x, 1) # [True, False] + ``` + + Args: + input_tensor: The boolean tensor to reduce. + axis: The dimensions to reduce. If `None` (the default), reduces all + dimensions. Must be in the range `[-rank(input_tensor), + rank(input_tensor))`. + keepdims: If true, retains reduced dimensions with length 1. + name: A name for the operation (optional). + + Returns: + The reduced tensor. + + @compatibility(numpy) + Equivalent to np.all + @end_compatibility + """ + keepdims = False if keepdims is None else keepdims + return _may_reduce_to_scalar( + keepdims, axis, + gen_math_ops._all( + input_tensor, _ReductionDims(input_tensor, axis), keepdims, + name=name)) + + +@tf_export(v1=["math.reduce_any", "reduce_any"]) @deprecation.deprecated_args( None, "keep_dims is deprecated, use keepdims instead", "keep_dims") -def reduce_any(input_tensor, - axis=None, - keepdims=None, - name=None, - reduction_indices=None, - keep_dims=None): +def reduce_any_v1(input_tensor, + axis=None, + keepdims=None, + name=None, + reduction_indices=None, + keep_dims=None): """Computes the "logical or" of elements across dimensions of a tensor. Reduces `input_tensor` along the dimensions given in `axis`. @@ -1946,9 +2158,9 @@ def reduce_any(input_tensor, Args: input_tensor: The boolean tensor to reduce. - axis: The dimensions to reduce. If `None` (the default), - reduces all dimensions. Must be in the range - `[-rank(input_tensor), rank(input_tensor))`. + axis: The dimensions to reduce. If `None` (the default), reduces all + dimensions. Must be in the range `[-rank(input_tensor), + rank(input_tensor))`. keepdims: If true, retains reduced dimensions with length 1. name: A name for the operation (optional). reduction_indices: The old (deprecated) name for axis. @@ -1961,28 +2173,66 @@ def reduce_any(input_tensor, Equivalent to np.any @end_compatibility """ + axis = deprecation.deprecated_argument_lookup( + "axis", axis, "reduction_indices", reduction_indices) keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims, "keep_dims", keep_dims) - if keepdims is None: - keepdims = False - return _may_reduce_to_scalar(keepdims, axis, reduction_indices, - gen_math_ops._any( - input_tensor, - _ReductionDims(input_tensor, axis, - reduction_indices), - keepdims, - name=name)) + return reduce_any(input_tensor, axis, keepdims, name) -@tf_export("math.reduce_logsumexp", "reduce_logsumexp") +@tf_export("math.reduce_any", "reduce_any", v1=[]) +def reduce_any(input_tensor, axis=None, keepdims=False, name=None): + """Computes the "logical or" of elements across dimensions of a tensor. + + Reduces `input_tensor` along the dimensions given in `axis`. + Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each + entry in `axis`. If `keepdims` is true, the reduced dimensions + are retained with length 1. + + If `axis` is None, all dimensions are reduced, and a + tensor with a single element is returned. + + For example: + + ```python + x = tf.constant([[True, True], [False, False]]) + tf.reduce_any(x) # True + tf.reduce_any(x, 0) # [True, True] + tf.reduce_any(x, 1) # [True, False] + ``` + + Args: + input_tensor: The boolean tensor to reduce. + axis: The dimensions to reduce. If `None` (the default), reduces all + dimensions. Must be in the range `[-rank(input_tensor), + rank(input_tensor))`. + keepdims: If true, retains reduced dimensions with length 1. + name: A name for the operation (optional). + + Returns: + The reduced tensor. + + @compatibility(numpy) + Equivalent to np.any + @end_compatibility + """ + keepdims = False if keepdims is None else keepdims + return _may_reduce_to_scalar( + keepdims, axis, + gen_math_ops._any( + input_tensor, _ReductionDims(input_tensor, axis), keepdims, + name=name)) + + +@tf_export(v1=["math.reduce_logsumexp", "reduce_logsumexp"]) @deprecation.deprecated_args( None, "keep_dims is deprecated, use keepdims instead", "keep_dims") -def reduce_logsumexp(input_tensor, - axis=None, - keepdims=None, - name=None, - reduction_indices=None, - keep_dims=None): +def reduce_logsumexp_v1(input_tensor, + axis=None, + keepdims=None, + name=None, + reduction_indices=None, + keep_dims=None): """Computes log(sum(exp(elements across dimensions of a tensor))). Reduces `input_tensor` along the dimensions given in `axis`. @@ -2010,9 +2260,9 @@ def reduce_logsumexp(input_tensor, Args: input_tensor: The tensor to reduce. Should have numeric type. - axis: The dimensions to reduce. If `None` (the default), - reduces all dimensions. Must be in the range - `[-rank(input_tensor), rank(input_tensor))`. + axis: The dimensions to reduce. If `None` (the default), reduces all + dimensions. Must be in the range `[-rank(input_tensor), + rank(input_tensor))`. keepdims: If true, retains reduced dimensions with length 1. name: A name for the operation (optional). reduction_indices: The old (deprecated) name for axis. @@ -2021,16 +2271,57 @@ def reduce_logsumexp(input_tensor, Returns: The reduced tensor. """ + axis = deprecation.deprecated_argument_lookup( + "axis", axis, "reduction_indices", reduction_indices) keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims, "keep_dims", keep_dims) - if keepdims is None: - keepdims = False + return reduce_logsumexp(input_tensor, axis, keepdims, name) + + +@tf_export("math.reduce_logsumexp", "reduce_logsumexp", v1=[]) +def reduce_logsumexp(input_tensor, axis=None, keepdims=False, name=None): + """Computes log(sum(exp(elements across dimensions of a tensor))). + + Reduces `input_tensor` along the dimensions given in `axis`. + Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each + entry in `axis`. If `keepdims` is true, the reduced dimensions + are retained with length 1. + + If `axis` has no entries, all dimensions are reduced, and a + tensor with a single element is returned. + + This function is more numerically stable than log(sum(exp(input))). It avoids + overflows caused by taking the exp of large inputs and underflows caused by + taking the log of small inputs. + + For example: + + ```python + x = tf.constant([[0., 0., 0.], [0., 0., 0.]]) + tf.reduce_logsumexp(x) # log(6) + tf.reduce_logsumexp(x, 0) # [log(2), log(2), log(2)] + tf.reduce_logsumexp(x, 1) # [log(3), log(3)] + tf.reduce_logsumexp(x, 1, keepdims=True) # [[log(3)], [log(3)]] + tf.reduce_logsumexp(x, [0, 1]) # log(6) + ``` + + Args: + input_tensor: The tensor to reduce. Should have numeric type. + axis: The dimensions to reduce. If `None` (the default), reduces all + dimensions. Must be in the range `[-rank(input_tensor), + rank(input_tensor))`. + keepdims: If true, retains reduced dimensions with length 1. + name: A name for the operation (optional). + + Returns: + The reduced tensor. + """ + keepdims = False if keepdims is None else keepdims input_tensor = ops.convert_to_tensor(input_tensor) with ops.name_scope(name, "ReduceLogSumExp", [input_tensor]) as name: raw_max = reduce_max( input_tensor, axis=axis, - reduction_indices=reduction_indices, keepdims=True) my_max = array_ops.stop_gradient( array_ops.where( @@ -2040,12 +2331,11 @@ def reduce_logsumexp(input_tensor, reduce_sum( gen_math_ops.exp(gen_math_ops.sub(input_tensor, my_max)), axis, - keepdims=keepdims, - reduction_indices=reduction_indices)) + keepdims=keepdims)) if not keepdims: my_max = array_ops.reshape(my_max, array_ops.shape(result)) result = gen_math_ops.add(result, my_max) - return _may_reduce_to_scalar(keepdims, axis, reduction_indices, result) + return _may_reduce_to_scalar(keepdims, axis, result) @tf_export("linalg.trace", v1=["linalg.trace", "trace"]) diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py index e0329f66ff3..cd45b6f1364 100644 --- a/tensorflow/python/ops/math_ops_test.py +++ b/tensorflow/python/ops/math_ops_test.py @@ -104,7 +104,7 @@ class LogSumExpTest(test_util.TensorFlowTestCase): for dtype in [np.float16, np.float32, np.double]: x_np = np.random.rand(5, 5).astype(dtype) with self.cached_session(use_gpu=True): - y_tf = math_ops.reduce_logsumexp(x_np, reduction_indices=[0]) + y_tf = math_ops.reduce_logsumexp(x_np, axis=[0]) y_np = log(np.sum(exp(x_np), axis=0)) self.assertShapeEqual(y_np, y_tf) y_tf_np = self.evaluate(y_tf) @@ -114,7 +114,7 @@ class LogSumExpTest(test_util.TensorFlowTestCase): for dtype in [np.float16, np.float32, np.double]: x_np = np.random.rand(5, 5).astype(dtype) with self.cached_session(use_gpu=True): - y_tf = math_ops.reduce_logsumexp(x_np, reduction_indices=0) + y_tf = math_ops.reduce_logsumexp(x_np, axis=0) y_np = log(np.sum(exp(x_np), axis=0)) self.assertShapeEqual(y_np, y_tf) y_tf_np = self.evaluate(y_tf) diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py index 0b91b8dde8e..27269c51c1b 100644 --- a/tensorflow/python/ops/metrics_impl.py +++ b/tensorflow/python/ops/metrics_impl.py @@ -948,7 +948,7 @@ def mean_cosine_distance(labels, predictions=predictions, labels=labels, weights=weights) radial_diffs = math_ops.multiply(predictions, labels) radial_diffs = math_ops.reduce_sum( - radial_diffs, reduction_indices=[ + radial_diffs, axis=[ dim, ], keepdims=True) mean_distance, update_op = mean(radial_diffs, weights, None, None, name or @@ -3045,7 +3045,7 @@ def _sparse_average_precision_at_top_k(labels, predictions_idx): # Reduce along k dimension to get the sum, yielding a [D1, ... DN] tensor. precision_sum = math_ops.reduce_sum( - relevant_precision_per_k, reduction_indices=(-1,), name='precision_sum') + relevant_precision_per_k, axis=(-1,), name='precision_sum') # Divide by number of relevant items to get average precision. These are # the "num_relevant_items" and "AveP" terms from the formula above. diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py index 4470c0b9580..72db0952b43 100644 --- a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py +++ b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py @@ -1324,13 +1324,12 @@ class ControlFlowTest(PForTest): pfor_out, pfor_out_grad = pfor_control_flow_ops.pfor(loop_fn, 4) # Note that tf.while_loop does not work in the setup above. So we manually # construct the equivalent computation of the above loops here. - real_out = math_ops.reduce_sum(inp, reduction_indices=[0]) - real_out = math_ops.reduce_prod(real_out, reduction_indices=[1]) + real_out = math_ops.reduce_sum(inp, axis=[0]) + real_out = math_ops.reduce_prod(real_out, axis=[1]) # Note that gradients of real_out will accumulate the gradients across the # output value. Hence we do the same aggregation on pfor_out_grad. real_out_grad = gradient_ops.gradients(real_out, inp)[0] - sum_pfor_out_grad = math_ops.reduce_sum( - pfor_out_grad, reduction_indices=[0]) + sum_pfor_out_grad = math_ops.reduce_sum(pfor_out_grad, axis=[0]) with session.Session() as sess: v1, v2, v1_grad, v2_grad = sess.run( diff --git a/tensorflow/python/ops/sparse_grad.py b/tensorflow/python/ops/sparse_grad.py index 1223b290ff6..2ca9c0c647d 100644 --- a/tensorflow/python/ops/sparse_grad.py +++ b/tensorflow/python/ops/sparse_grad.py @@ -195,7 +195,7 @@ def _SparseTensorDenseMatMulGrad(op, grad): parts_a = array_ops.gather(grad, rows if not adj_a else cols) parts_b = array_ops.gather(b if not adj_b else array_ops.transpose(b), cols if not adj_a else rows) - a_values_grad = math_ops.reduce_sum(parts_a * parts_b, reduction_indices=1) + a_values_grad = math_ops.reduce_sum(parts_a * parts_b, axis=1) # gradients w.r.t. (a_indices, a_values, a_shape, b) return (None, a_values_grad, None, b_grad) diff --git a/tensorflow/python/ops/special_math_ops.py b/tensorflow/python/ops/special_math_ops.py index f44f694109e..21f4996798e 100644 --- a/tensorflow/python/ops/special_math_ops.py +++ b/tensorflow/python/ops/special_math_ops.py @@ -70,8 +70,7 @@ def lbeta(x, name=None): x = ops.convert_to_tensor(x, name='x') # Note reduce_sum([]) = 0. - log_prod_gamma_x = math_ops.reduce_sum( - math_ops.lgamma(x), reduction_indices=[-1]) + log_prod_gamma_x = math_ops.reduce_sum(math_ops.lgamma(x), axis=[-1]) # Note lgamma(0) = infinity, so if x = [] # log_gamma_sum_x = lgamma(0) = infinity, and @@ -264,11 +263,11 @@ def einsum(equation, *inputs, **kwargs): missing_indices = set(temp_axis_labels) - set(output_axis_labels) if missing_indices: - reduction_indices = [ + axis = [ i for i, a in enumerate(temp_axis_labels) if a not in output_axis_labels ] - temp = math_ops.reduce_sum(temp, reduction_indices=reduction_indices) + temp = math_ops.reduce_sum(temp, axis=axis) temp_axis_labels = ''.join( a for a in temp_axis_labels if a in output_axis_labels) diff --git a/tensorflow/tools/api/golden/v1/tensorflow.math.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.math.pbtxt index 67f348be218..b7a99caeb7b 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.math.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.math.pbtxt @@ -318,7 +318,7 @@ tf_module { } member_method { name: "reduce_std" - argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " } member_method { name: "reduce_sum" @@ -326,7 +326,7 @@ tf_module { } member_method { name: "reduce_variance" - argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " } member_method { name: "rint" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.math.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.math.pbtxt index 86df9705144..5215cfbab0e 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.math.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.math.pbtxt @@ -290,43 +290,43 @@ tf_module { } member_method { name: "reduce_all" - argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " } member_method { name: "reduce_any" - argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " } member_method { name: "reduce_logsumexp" - argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " } member_method { name: "reduce_max" - argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " } member_method { name: "reduce_mean" - argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " } member_method { name: "reduce_min" - argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " } member_method { name: "reduce_prod" - argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " } member_method { name: "reduce_std" - argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " } member_method { name: "reduce_sum" - argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " } member_method { name: "reduce_variance" - argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " } member_method { name: "rint" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt index 0649ae267bb..873c41a390c 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt @@ -898,35 +898,35 @@ tf_module { } member_method { name: "reduce_all" - argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " } member_method { name: "reduce_any" - argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " } member_method { name: "reduce_logsumexp" - argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " } member_method { name: "reduce_max" - argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " } member_method { name: "reduce_mean" - argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " } member_method { name: "reduce_min" - argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " } member_method { name: "reduce_prod" - argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " } member_method { name: "reduce_sum" - argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " } member_method { name: "register_tensor_conversion_function" diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2.py b/tensorflow/tools/compatibility/tf_upgrade_v2.py index f85b2f08dc7..aab7aa8af5f 100644 --- a/tensorflow/tools/compatibility/tf_upgrade_v2.py +++ b/tensorflow/tools/compatibility/tf_upgrade_v2.py @@ -73,6 +73,7 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec): "t": "x", "msg": "message", }, + "tf.sparse.add": ["a", "b", "thresh"], "tf.sparse.split": { "split_dim": "axis", }, @@ -113,6 +114,73 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec): "tf.random.stateless_multinomial": { "output_dtype": "dtype", }, + "tf.sparse.concat": [ + "axis", "sp_inputs", "name", "expand_nonconcat_dim", "concat_dim" + ], + "tf.reduce_all": { + "reduction_indices": "axis", + "keep_dims": "keepdims", + }, + "tf.math.reduce_all": { + "reduction_indices": "axis", + "keep_dims": "keepdims", + }, + "tf.reduce_any": { + "reduction_indices": "axis", + "keep_dims": "keepdims", + }, + "tf.math.reduce_any": { + "reduction_indices": "axis", + "keep_dims": "keepdims", + }, + "tf.reduce_min": { + "reduction_indices": "axis", + "keep_dims": "keepdims", + }, + "tf.math.reduce_min": { + "reduction_indices": "axis", + "keep_dims": "keepdims", + }, + "tf.reduce_max": { + "reduction_indices": "axis", + "keep_dims": "keepdims", + }, + "tf.math.reduce_max": { + "reduction_indices": "axis", + "keep_dims": "keepdims", + }, + "tf.reduce_sum": { + "reduction_indices": "axis", + "keep_dims": "keepdims", + }, + "tf.math.reduce_sum": { + "reduction_indices": "axis", + "keep_dims": "keepdims", + }, + "tf.reduce_mean": { + "reduction_indices": "axis", + "keep_dims": "keepdims", + }, + "tf.math.reduce_mean": { + "reduction_indices": "axis", + "keep_dims": "keepdims", + }, + "tf.reduce_prod": { + "reduction_indices": "axis", + "keep_dims": "keepdims", + }, + "tf.math.reduce_prod": { + "reduction_indices": "axis", + "keep_dims": "keepdims", + }, + "tf.reduce_logsumexp": { + "reduction_indices": "axis", + "keep_dims": "keepdims", + }, + "tf.math.reduce_logsumexp": { + "reduction_indices": "axis", + "keep_dims": "keepdims", + }, } # Mapping from function to the new name of the function @@ -199,7 +267,8 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec): "tf.convert_to_tensor": ["value", "dtype", "name", "preferred_dtype"], "tf.nn.convolution": [ "input", "filter", "padding", "strides", "dilation_rate", "name", - "data_format"], + "data_format" + ], "tf.nn.crelu": ["features", "name", "axis"], "tf.nn.pool": [ "input", "window_shape", "pooling_type", "padding", "dilation_rate", @@ -218,19 +287,19 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec): ], "tf.pad": ["tensor", "paddings", "mode", "name", "constant_values"], "tf.quantize_v2": [ - "input", "min_range", "max_range", "T", "mode", "name", - "round_mode" + "input", "min_range", "max_range", "T", "mode", "name", "round_mode" ], "tf.feature_column.categorical_column_with_vocabulary_file": [ - "key", "vocabulary_file", "vocabulary_size", - "num_oov_buckets", "default_value", "dtype" + "key", "vocabulary_file", "vocabulary_size", "num_oov_buckets", + "default_value", "dtype" ], "tf.shape": ["input", "name", "out_type"], "tf.size": ["input", "name", "out_type"], + "tf.random.poisson": ["lam", "shape", "dtype", "seed", "name"], + "tf.sparse.add": ["a", "b", "thresh"], "tf.sparse.concat": [ "axis", "sp_inputs", "name", "expand_nonconcat_dim", "concat_dim" ], - "tf.random.poisson": ["lam", "shape", "dtype", "seed", "name"], "tf.sparse.segment_mean": [ "data", "indices", "segment_ids", "name", "num_segments" ], @@ -243,10 +312,75 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec): "tf.strings.length": ["input", "name", "unit"], "tf.transpose": ["a", "perm", "name", "conjugate"], "tf.tuple": ["tensors", "name", "control_inputs"], - "tf.while_loop": ["cond", "body", "loop_vars", "shape_invariants", - "parallel_iterations", "back_prop", "swap_memory", - "name", "maximum_iterations", - "return_same_structure"], + "tf.while_loop": [ + "cond", "body", "loop_vars", "shape_invariants", + "parallel_iterations", "back_prop", "swap_memory", "name", + "maximum_iterations", "return_same_structure" + ], + "tf.reduce_all": [ + "input_tensor", "axis", "keepdims", "name", "reduction_indices", + "keep_dims" + ], + "tf.math.reduce_all": [ + "input_tensor", "axis", "keepdims", "name", "reduction_indices", + "keep_dims" + ], + "tf.reduce_any": [ + "input_tensor", "axis", "keepdims", "name", "reduction_indices", + "keep_dims" + ], + "tf.math.reduce_any": [ + "input_tensor", "axis", "keepdims", "name", "reduction_indices", + "keep_dims" + ], + "tf.reduce_min": [ + "input_tensor", "axis", "keepdims", "name", "reduction_indices", + "keep_dims" + ], + "tf.math.reduce_min": [ + "input_tensor", "axis", "keepdims", "name", "reduction_indices", + "keep_dims" + ], + "tf.reduce_max": [ + "input_tensor", "axis", "keepdims", "name", "reduction_indices", + "keep_dims" + ], + "tf.math.reduce_max": [ + "input_tensor", "axis", "keepdims", "name", "reduction_indices", + "keep_dims" + ], + "tf.reduce_sum": [ + "input_tensor", "axis", "keepdims", "name", "reduction_indices", + "keep_dims" + ], + "tf.math.reduce_sum": [ + "input_tensor", "axis", "keepdims", "name", "reduction_indices", + "keep_dims" + ], + "tf.reduce_mean": [ + "input_tensor", "axis", "keepdims", "name", "reduction_indices", + "keep_dims" + ], + "tf.math.reduce_mean": [ + "input_tensor", "axis", "keepdims", "name", "reduction_indices", + "keep_dims" + ], + "tf.reduce_prod": [ + "input_tensor", "axis", "keepdims", "name", "reduction_indices", + "keep_dims" + ], + "tf.math.reduce_prod": [ + "input_tensor", "axis", "keepdims", "name", "reduction_indices", + "keep_dims" + ], + "tf.reduce_logsumexp": [ + "input_tensor", "axis", "keepdims", "name", "reduction_indices", + "keep_dims" + ], + "tf.math.reduce_logsumexp": [ + "input_tensor", "axis", "keepdims", "name", "reduction_indices", + "keep_dims" + ], } # Specially handled functions.