Allow negative axes in tf.nn.sufficient_statistics even when shape of x is unknown.

PiperOrigin-RevId: 323065341
Change-Id: I38b4750077030b2243ef18c34a4ee76eeb2883c8
This commit is contained in:
A. Unique TensorFlower 2020-07-24 13:51:47 -07:00 committed by TensorFlower Gardener
parent 87ab16969f
commit 926c086248

View File

@ -1158,9 +1158,23 @@ def sufficient_statistics(x, axes, shift=None, keep_dims=None, name=None,
an input that's optionally shifted. See:
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Computing_shifted_data
For example:
>>> t = [[1, 2, 3], [4, 5, 6]]
>>> sufficient_statistics(t, [1])
(<tf.Tensor: shape=(), dtype=int32, numpy=3>, <tf.Tensor: shape=(2,),
dtype=int32, numpy=array([ 6, 15], dtype=int32)>, <tf.Tensor: shape=(2,),
dtype=int32, numpy=array([14, 77], dtype=int32)>, None)
>>> sufficient_statistics(t, [-1])
(<tf.Tensor: shape=(), dtype=int32, numpy=3>, <tf.Tensor: shape=(2,),
dtype=int32, numpy=array([ 6, 15], dtype=int32)>, <tf.Tensor: shape=(2,),
dtype=int32, numpy=array([14, 77], dtype=int32)>, None)
Args:
x: A `Tensor`.
axes: Array of ints. Axes along which to compute mean and variance.
axes: Array of ints. Axes along which to compute mean and variance. As in
Python, the axes can also be negative numbers. A negative axis is
interpreted as counting from the end of the rank, i.e., axis +
rank(values)-th dimension.
shift: A `Tensor` containing the value by which to shift the data for
numerical stability, or `None` if no shift is to be performed. A shift
close to the true mean provides the most numerically stable results.
@ -1191,8 +1205,11 @@ def sufficient_statistics(x, axes, shift=None, keep_dims=None, name=None,
counts *= x_shape.dims[d].value
counts = constant_op.constant(counts, dtype=x.dtype)
else: # shape needs to be inferred at runtime.
# Normalize axes to be positive. Required for gather.
rank = array_ops.rank(x)
positive_axes = [axis + rank if axis < 0 else axis for axis in axes]
x_dims = array_ops.gather(
math_ops.cast(array_ops.shape(x), x.dtype), axes)
math_ops.cast(array_ops.shape(x), x.dtype), positive_axes)
counts = math_ops.reduce_prod(x_dims, name="count")
if shift is not None:
shift = ops.convert_to_tensor(shift, name="shift")