Move weights to be the second argument in tf.bincount.

This matches np.bincount(x, weights, minlength).
Change: 153476565
This commit is contained in:
Dan Ringwalt 2017-04-18 09:16:55 -08:00 committed by TensorFlower Gardener
parent d8b65d4035
commit 4b57da4ced
3 changed files with 7 additions and 7 deletions

View File

@ -73,13 +73,13 @@ class BincountTest(test_util.TensorFlowTestCase):
else: else:
weights = np.random.random(num_samples) weights = np.random.random(num_samples)
self.assertAllEqual( self.assertAllEqual(
math_ops.bincount(arr, weights=weights).eval(), math_ops.bincount(arr, weights).eval(),
np.bincount(arr, weights)) np.bincount(arr, weights))
def test_zero_weights(self): def test_zero_weights(self):
with self.test_session(): with self.test_session():
self.assertAllEqual( self.assertAllEqual(
math_ops.bincount(np.arange(1000), weights=np.zeros(1000)).eval(), math_ops.bincount(np.arange(1000), np.zeros(1000)).eval(),
np.zeros(1000)) np.zeros(1000))
def test_negative(self): def test_negative(self):

View File

@ -2026,9 +2026,9 @@ def tanh(x, name=None):
def bincount(arr, def bincount(arr,
weights=None,
minlength=None, minlength=None,
maxlength=None, maxlength=None,
weights=None,
dtype=dtypes.int32): dtype=dtypes.int32):
"""Counts the number of occurrences of each value in an integer array. """Counts the number of occurrences of each value in an integer array.
@ -2040,13 +2040,13 @@ def bincount(arr,
Args: Args:
arr: An int32 tensor of non-negative values. arr: An int32 tensor of non-negative values.
weights: If non-None, must be the same shape as arr. For each value in
`arr`, the bin will be incremented by the corresponding weight instead
of 1.
minlength: If given, ensures the output has length at least `minlength`, minlength: If given, ensures the output has length at least `minlength`,
padding with zeros at the end if necessary. padding with zeros at the end if necessary.
maxlength: If given, skips values in `arr` that are equal or greater than maxlength: If given, skips values in `arr` that are equal or greater than
`maxlength`, ensuring that the output has length at most `maxlength`. `maxlength`, ensuring that the output has length at most `maxlength`.
weights: If non-None, must be the same shape as arr. For each value in
`arr`, the bin will be incremented by the corresponding weight instead
of 1.
dtype: If `weights` is None, determines the type of the output bins. dtype: If `weights` is None, determines the type of the output bins.
Returns: Returns:

View File

@ -650,7 +650,7 @@ tf_module {
} }
member_method { member_method {
name: "bincount" name: "bincount"
argspec: "args=[\'arr\', \'minlength\', \'maxlength\', \'weights\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \"<dtype: \'int32\'>\"], " argspec: "args=[\'arr\', \'weights\', \'minlength\', \'maxlength\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \"<dtype: \'int32\'>\"], "
} }
member_method { member_method {
name: "bitcast" name: "bitcast"