Allow negative axes in tf.nn.sufficient_statistics even when shape of x is unknown.
PiperOrigin-RevId: 323065341 Change-Id: I38b4750077030b2243ef18c34a4ee76eeb2883c8
This commit is contained in:
parent
87ab16969f
commit
926c086248
@ -1158,9 +1158,23 @@ def sufficient_statistics(x, axes, shift=None, keep_dims=None, name=None,
|
||||
an input that's optionally shifted. See:
|
||||
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Computing_shifted_data
|
||||
|
||||
For example:
|
||||
>>> t = [[1, 2, 3], [4, 5, 6]]
|
||||
>>> sufficient_statistics(t, [1])
|
||||
(<tf.Tensor: shape=(), dtype=int32, numpy=3>, <tf.Tensor: shape=(2,),
|
||||
dtype=int32, numpy=array([ 6, 15], dtype=int32)>, <tf.Tensor: shape=(2,),
|
||||
dtype=int32, numpy=array([14, 77], dtype=int32)>, None)
|
||||
>>> sufficient_statistics(t, [-1])
|
||||
(<tf.Tensor: shape=(), dtype=int32, numpy=3>, <tf.Tensor: shape=(2,),
|
||||
dtype=int32, numpy=array([ 6, 15], dtype=int32)>, <tf.Tensor: shape=(2,),
|
||||
dtype=int32, numpy=array([14, 77], dtype=int32)>, None)
|
||||
|
||||
Args:
|
||||
x: A `Tensor`.
|
||||
axes: Array of ints. Axes along which to compute mean and variance.
|
||||
axes: Array of ints. Axes along which to compute mean and variance. As in
|
||||
Python, the axes can also be negative numbers. A negative axis is
|
||||
interpreted as counting from the end of the rank, i.e., axis +
|
||||
rank(values)-th dimension.
|
||||
shift: A `Tensor` containing the value by which to shift the data for
|
||||
numerical stability, or `None` if no shift is to be performed. A shift
|
||||
close to the true mean provides the most numerically stable results.
|
||||
@ -1191,8 +1205,11 @@ def sufficient_statistics(x, axes, shift=None, keep_dims=None, name=None,
|
||||
counts *= x_shape.dims[d].value
|
||||
counts = constant_op.constant(counts, dtype=x.dtype)
|
||||
else: # shape needs to be inferred at runtime.
|
||||
# Normalize axes to be positive. Required for gather.
|
||||
rank = array_ops.rank(x)
|
||||
positive_axes = [axis + rank if axis < 0 else axis for axis in axes]
|
||||
x_dims = array_ops.gather(
|
||||
math_ops.cast(array_ops.shape(x), x.dtype), axes)
|
||||
math_ops.cast(array_ops.shape(x), x.dtype), positive_axes)
|
||||
counts = math_ops.reduce_prod(x_dims, name="count")
|
||||
if shift is not None:
|
||||
shift = ops.convert_to_tensor(shift, name="shift")
|
||||
|
Loading…
Reference in New Issue
Block a user