use bincount_v2 in tf.math.bincount, support axis and binary_output, and support int64.
PiperOrigin-RevId: 312570490 Change-Id: I1d11cd0f294f6899920a547fbe0f8f9c54140be6
This commit is contained in:
parent
6e4fdec80e
commit
4be466a87e
@ -137,7 +137,7 @@ py_library(
|
||||
":_pywrap_utils",
|
||||
":array_ops",
|
||||
":audio_ops_gen",
|
||||
":bincount",
|
||||
":bincount_ops",
|
||||
":bitwise_ops",
|
||||
":boosted_trees_ops",
|
||||
":check_ops",
|
||||
@ -3476,23 +3476,24 @@ py_library(
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "bincount",
|
||||
srcs = ["ops/bincount.py"],
|
||||
name = "bincount_ops",
|
||||
srcs = ["ops/bincount_ops.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":count_ops_gen",
|
||||
":framework",
|
||||
":framework_for_generated_wrappers",
|
||||
"//tensorflow/python/compat",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "bincount_test",
|
||||
name = "bincount_ops_test",
|
||||
size = "small",
|
||||
srcs = ["ops/bincount_test.py"],
|
||||
srcs = ["ops/bincount_ops_test.py"],
|
||||
python_version = "PY3",
|
||||
deps = [
|
||||
":bincount",
|
||||
":bincount_ops",
|
||||
":platform_test",
|
||||
],
|
||||
)
|
||||
|
@ -85,7 +85,7 @@ from tensorflow.python import keras
|
||||
from tensorflow.python.feature_column import feature_column_lib as feature_column
|
||||
from tensorflow.python.layers import layers
|
||||
from tensorflow.python.module import module
|
||||
from tensorflow.python.ops import bincount
|
||||
from tensorflow.python.ops import bincount_ops
|
||||
from tensorflow.python.ops import bitwise_ops as bitwise
|
||||
from tensorflow.python.ops import gradient_checker_v2
|
||||
from tensorflow.python.ops import image_ops as image
|
||||
|
@ -178,9 +178,9 @@ cuda_py_test(
|
||||
srcs = ["bincount_op_test.py"],
|
||||
tags = ["no_windows_gpu"],
|
||||
deps = [
|
||||
"//tensorflow/python:bincount_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:math_ops",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for math_ops.bincount."""
|
||||
"""Tests for bincount_ops.bincount."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
@ -25,8 +25,8 @@ from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import bincount_ops
|
||||
from tensorflow.python.ops import gen_math_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
@ -37,45 +37,50 @@ class BincountTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def test_empty(self):
|
||||
with self.session(use_gpu=True):
|
||||
self.assertAllEqual(self.evaluate(math_ops.bincount([], minlength=5)),
|
||||
[0, 0, 0, 0, 0])
|
||||
self.assertAllEqual(self.evaluate(math_ops.bincount([], minlength=1)),
|
||||
[0])
|
||||
self.assertAllEqual(self.evaluate(math_ops.bincount([], minlength=0)),
|
||||
[])
|
||||
self.assertEqual(self.evaluate(math_ops.bincount([], minlength=0,
|
||||
dtype=np.float32)).dtype,
|
||||
np.float32)
|
||||
self.assertEqual(self.evaluate(math_ops.bincount([], minlength=3,
|
||||
dtype=np.float64)).dtype,
|
||||
np.float64)
|
||||
self.assertAllEqual(
|
||||
self.evaluate(bincount_ops.bincount([], minlength=5)),
|
||||
[0, 0, 0, 0, 0])
|
||||
self.assertAllEqual(
|
||||
self.evaluate(bincount_ops.bincount([], minlength=1)), [0])
|
||||
self.assertAllEqual(
|
||||
self.evaluate(bincount_ops.bincount([], minlength=0)), [])
|
||||
self.assertEqual(
|
||||
self.evaluate(
|
||||
bincount_ops.bincount([], minlength=0, dtype=np.float32)).dtype,
|
||||
np.float32)
|
||||
self.assertEqual(
|
||||
self.evaluate(
|
||||
bincount_ops.bincount([], minlength=3, dtype=np.float64)).dtype,
|
||||
np.float64)
|
||||
|
||||
def test_values(self):
|
||||
with self.session(use_gpu=True):
|
||||
self.assertAllEqual(self.evaluate(math_ops.bincount([1, 1, 1, 2, 2, 3])),
|
||||
[0, 3, 2, 1])
|
||||
self.assertAllEqual(
|
||||
self.evaluate(bincount_ops.bincount([1, 1, 1, 2, 2, 3])),
|
||||
[0, 3, 2, 1])
|
||||
arr = [1, 1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5]
|
||||
self.assertAllEqual(self.evaluate(math_ops.bincount(arr)),
|
||||
[0, 5, 4, 3, 2, 1])
|
||||
self.assertAllEqual(
|
||||
self.evaluate(bincount_ops.bincount(arr)), [0, 5, 4, 3, 2, 1])
|
||||
arr += [0, 0, 0, 0, 0, 0]
|
||||
self.assertAllEqual(self.evaluate(math_ops.bincount(arr)),
|
||||
[6, 5, 4, 3, 2, 1])
|
||||
self.assertAllEqual(
|
||||
self.evaluate(bincount_ops.bincount(arr)), [6, 5, 4, 3, 2, 1])
|
||||
|
||||
self.assertAllEqual(self.evaluate(math_ops.bincount([])), [])
|
||||
self.assertAllEqual(self.evaluate(math_ops.bincount([0, 0, 0])), [3])
|
||||
self.assertAllEqual(self.evaluate(math_ops.bincount([5])),
|
||||
[0, 0, 0, 0, 0, 1])
|
||||
self.assertAllEqual(self.evaluate(math_ops.bincount(np.arange(10000))),
|
||||
np.ones(10000))
|
||||
self.assertAllEqual(self.evaluate(bincount_ops.bincount([])), [])
|
||||
self.assertAllEqual(self.evaluate(bincount_ops.bincount([0, 0, 0])), [3])
|
||||
self.assertAllEqual(
|
||||
self.evaluate(bincount_ops.bincount([5])), [0, 0, 0, 0, 0, 1])
|
||||
self.assertAllEqual(
|
||||
self.evaluate(bincount_ops.bincount(np.arange(10000))),
|
||||
np.ones(10000))
|
||||
|
||||
def test_maxlength(self):
|
||||
with self.session(use_gpu=True):
|
||||
self.assertAllEqual(self.evaluate(math_ops.bincount([5], maxlength=3)),
|
||||
[0, 0, 0])
|
||||
self.assertAllEqual(self.evaluate(math_ops.bincount([1], maxlength=3)),
|
||||
[0, 1])
|
||||
self.assertAllEqual(self.evaluate(math_ops.bincount([], maxlength=3)),
|
||||
[])
|
||||
self.assertAllEqual(
|
||||
self.evaluate(bincount_ops.bincount([5], maxlength=3)), [0, 0, 0])
|
||||
self.assertAllEqual(
|
||||
self.evaluate(bincount_ops.bincount([1], maxlength=3)), [0, 1])
|
||||
self.assertAllEqual(
|
||||
self.evaluate(bincount_ops.bincount([], maxlength=3)), [])
|
||||
|
||||
def test_random_with_weights(self):
|
||||
num_samples = 10000
|
||||
@ -88,7 +93,7 @@ class BincountTest(test_util.TensorFlowTestCase):
|
||||
else:
|
||||
weights = np.random.random(num_samples)
|
||||
self.assertAllClose(
|
||||
self.evaluate(math_ops.bincount(arr, weights)),
|
||||
self.evaluate(bincount_ops.bincount(arr, weights)),
|
||||
np.bincount(arr, weights))
|
||||
|
||||
def test_random_without_weights(self):
|
||||
@ -99,20 +104,20 @@ class BincountTest(test_util.TensorFlowTestCase):
|
||||
arr = np.random.randint(0, 1000, num_samples)
|
||||
weights = np.ones(num_samples).astype(dtype)
|
||||
self.assertAllClose(
|
||||
self.evaluate(math_ops.bincount(arr, None)),
|
||||
self.evaluate(bincount_ops.bincount(arr, None)),
|
||||
np.bincount(arr, weights))
|
||||
|
||||
def test_zero_weights(self):
|
||||
with self.session(use_gpu=True):
|
||||
self.assertAllEqual(
|
||||
self.evaluate(math_ops.bincount(np.arange(1000), np.zeros(1000))),
|
||||
self.evaluate(bincount_ops.bincount(np.arange(1000), np.zeros(1000))),
|
||||
np.zeros(1000))
|
||||
|
||||
def test_negative(self):
|
||||
# unsorted_segment_sum will only report InvalidArgumentError on CPU
|
||||
with self.cached_session(), ops.device("/CPU:0"):
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
self.evaluate(math_ops.bincount([1, 2, 3, -1, 6, 8]))
|
||||
self.evaluate(bincount_ops.bincount([1, 2, 3, -1, 6, 8]))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_shape_function(self):
|
||||
|
@ -12,21 +12,245 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# maxlengthations under the License.
|
||||
# ==============================================================================
|
||||
"""tf.sparse.bincount ops."""
|
||||
"""bincount ops."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import gen_count_ops
|
||||
from tensorflow.python.ops import gen_math_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
@tf_export("math.bincount", v1=[])
|
||||
def bincount(arr,
|
||||
weights=None,
|
||||
minlength=None,
|
||||
maxlength=None,
|
||||
dtype=dtypes.int32,
|
||||
name=None,
|
||||
axis=None,
|
||||
binary_output=False):
|
||||
"""Counts the number of occurrences of each value in an integer array.
|
||||
|
||||
If `minlength` and `maxlength` are not given, returns a vector with length
|
||||
`tf.reduce_max(arr) + 1` if `arr` is non-empty, and length 0 otherwise.
|
||||
If `weights` are non-None, then index `i` of the output stores the sum of the
|
||||
value in `weights` at each index where the corresponding value in `arr` is
|
||||
`i`.
|
||||
|
||||
```python
|
||||
values = tf.constant([1,1,2,3,2,4,4,5])
|
||||
tf.math.bincount(values) #[0 2 2 1 2 1]
|
||||
```
|
||||
Vector length = Maximum element in vector `values` is 5. Adding 1, which is 6
|
||||
will be the vector length.
|
||||
|
||||
Each bin value in the output indicates number of occurrences of the particular
|
||||
index. Here, index 1 in output has a value 2. This indicates value 1 occurs
|
||||
two times in `values`.
|
||||
|
||||
```python
|
||||
values = tf.constant([1,1,2,3,2,4,4,5])
|
||||
weights = tf.constant([1,5,0,1,0,5,4,5])
|
||||
tf.math.bincount(values, weights=weights) #[0 6 0 1 9 5]
|
||||
```
|
||||
Bin will be incremented by the corresponding weight instead of 1.
|
||||
Here, index 1 in output has a value 6. This is the summation of weights
|
||||
corresponding to the value in `values`.
|
||||
|
||||
**Bin-counting on a certain axis**
|
||||
|
||||
This example takes a 2 dimensional input and returns a `Tensor` with
|
||||
bincounting on each sample.
|
||||
|
||||
>>> data = np.array([[1, 2, 3, 0], [0, 0, 1, 2]], dtype=np.int32)
|
||||
>>> tf.math.bincount(data, axis=-1)
|
||||
<tf.Tensor: shape=(2, 4), dtype=int32, numpy=
|
||||
array([[1, 1, 1, 1],
|
||||
[2, 1, 1, 0]], dtype=int32)>
|
||||
|
||||
|
||||
**Bin-counting with binary_output**
|
||||
|
||||
This example gives binary output instead of counting the occurrence.
|
||||
|
||||
>>> data = np.array([[1, 2, 3, 0], [0, 0, 1, 2]], dtype=np.int32)
|
||||
>>> tf.math.bincount(data, axis=-1, binary_output=True)
|
||||
<tf.Tensor: shape=(2, 4), dtype=int32, numpy=
|
||||
array([[1, 1, 1, 1],
|
||||
[1, 1, 1, 0]], dtype=int32)>
|
||||
|
||||
Args:
|
||||
arr: A Tensor, RaggedTensor, or SparseTensor whose values should be counted.
|
||||
These tensors must have a rank of 2 if `axis=-1`.
|
||||
weights: If non-None, must be the same shape as arr. For each value in
|
||||
`arr`, the bin will be incremented by the corresponding weight instead of
|
||||
1.
|
||||
minlength: If given, ensures the output has length at least `minlength`,
|
||||
padding with zeros at the end if necessary.
|
||||
maxlength: If given, skips values in `arr` that are equal or greater than
|
||||
`maxlength`, ensuring that the output has length at most `maxlength`.
|
||||
dtype: If `weights` is None, determines the type of the output bins.
|
||||
name: A name scope for the associated operations (optional).
|
||||
axis: The axis to slice over. Axes at and below `axis` will be flattened
|
||||
before bin counting. Currently, only `0`, and `-1` are supported. If None,
|
||||
all axes will be flattened (identical to passing `0`).
|
||||
binary_output: If True, this op will output 1 instead of the number of times
|
||||
a token appears (equivalent to one_hot + reduce_any instead of one_hot +
|
||||
reduce_add). Defaults to False.
|
||||
|
||||
Returns:
|
||||
A vector with the same dtype as `weights` or the given `dtype`. The bin
|
||||
values.
|
||||
|
||||
Raises:
|
||||
`InvalidArgumentError` if negative values are provided as an input.
|
||||
|
||||
"""
|
||||
name = "bincount" if name is None else name
|
||||
with ops.name_scope(name):
|
||||
# Somehow forward compatible needs to be False.
|
||||
if not binary_output and axis is None:
|
||||
arr = ops.convert_to_tensor(arr, name="arr", dtype=dtypes.int32)
|
||||
array_is_nonempty = math_ops.reduce_prod(array_ops.shape(arr)) > 0
|
||||
output_size = math_ops.cast(array_is_nonempty, dtypes.int32) * (
|
||||
math_ops.reduce_max(arr) + 1)
|
||||
if minlength is not None:
|
||||
minlength = ops.convert_to_tensor(
|
||||
minlength, name="minlength", dtype=dtypes.int32)
|
||||
output_size = gen_math_ops.maximum(minlength, output_size)
|
||||
if maxlength is not None:
|
||||
maxlength = ops.convert_to_tensor(
|
||||
maxlength, name="maxlength", dtype=dtypes.int32)
|
||||
output_size = gen_math_ops.minimum(maxlength, output_size)
|
||||
if weights is not None:
|
||||
weights = ops.convert_to_tensor(weights, name="weights")
|
||||
return gen_math_ops.unsorted_segment_sum(weights, arr, output_size)
|
||||
weights = constant_op.constant([], dtype)
|
||||
return gen_math_ops.bincount(arr, output_size, weights)
|
||||
|
||||
if not isinstance(arr, sparse_tensor.SparseTensor):
|
||||
arr = ragged_tensor.convert_to_tensor_or_ragged_tensor(arr, name="arr")
|
||||
if weights is not None:
|
||||
if not isinstance(weights, sparse_tensor.SparseTensor):
|
||||
weights = ragged_tensor.convert_to_tensor_or_ragged_tensor(
|
||||
weights, name="weights")
|
||||
|
||||
if weights is not None and binary_output:
|
||||
raise ValueError("binary_output and weights are mutually exclusive.")
|
||||
|
||||
if not arr.dtype.is_integer:
|
||||
arr = math_ops.cast(arr, dtypes.int32)
|
||||
if axis is None:
|
||||
axis = 0
|
||||
|
||||
if axis not in [0, -1]:
|
||||
raise ValueError("Unsupported axis value %s. Only 0 and -1 are currently "
|
||||
"supported." % axis)
|
||||
|
||||
if isinstance(arr, ragged_tensor.RaggedTensor):
|
||||
array_is_nonempty = math_ops.reduce_prod(array_ops.shape(arr.values)) > 0
|
||||
else:
|
||||
array_is_nonempty = math_ops.reduce_prod(array_ops.shape(arr)) > 0
|
||||
if isinstance(arr, sparse_tensor.SparseTensor):
|
||||
output_size = math_ops.cast(array_is_nonempty, arr.dtype) * (
|
||||
math_ops.reduce_max(arr.values) + 1)
|
||||
else:
|
||||
output_size = math_ops.cast(array_is_nonempty, arr.dtype) * (
|
||||
math_ops.reduce_max(arr) + 1)
|
||||
if minlength is not None:
|
||||
minlength = ops.convert_to_tensor(
|
||||
minlength, name="minlength", dtype=arr.dtype)
|
||||
output_size = gen_math_ops.maximum(minlength, output_size)
|
||||
if maxlength is not None:
|
||||
maxlength = ops.convert_to_tensor(
|
||||
maxlength, name="maxlength", dtype=arr.dtype)
|
||||
output_size = gen_math_ops.minimum(maxlength, output_size)
|
||||
|
||||
if axis == 0:
|
||||
if isinstance(arr, sparse_tensor.SparseTensor):
|
||||
if weights is not None:
|
||||
weights = validate_sparse_weights(arr, weights, dtype)
|
||||
arr = arr.values
|
||||
elif isinstance(arr, ragged_tensor.RaggedTensor):
|
||||
if weights is not None:
|
||||
weights = validate_ragged_weights(arr, weights, dtype)
|
||||
arr = arr.values
|
||||
else:
|
||||
if weights is not None:
|
||||
weights = array_ops.reshape(weights, [-1])
|
||||
arr = array_ops.reshape(arr, [-1])
|
||||
|
||||
if isinstance(arr, sparse_tensor.SparseTensor):
|
||||
weights = validate_sparse_weights(arr, weights, dtype)
|
||||
return gen_math_ops.sparse_bincount(
|
||||
indices=arr.indices,
|
||||
values=arr.values,
|
||||
dense_shape=arr.dense_shape,
|
||||
size=output_size,
|
||||
weights=weights,
|
||||
binary_output=binary_output)
|
||||
elif isinstance(arr, ragged_tensor.RaggedTensor):
|
||||
weights = validate_ragged_weights(arr, weights, dtype)
|
||||
return gen_math_ops.ragged_bincount(
|
||||
splits=arr.row_splits,
|
||||
values=arr.values,
|
||||
size=output_size,
|
||||
weights=weights,
|
||||
binary_output=binary_output)
|
||||
else:
|
||||
weights = validate_dense_weights(arr, weights, dtype)
|
||||
return gen_math_ops.dense_bincount(
|
||||
input=arr,
|
||||
size=output_size,
|
||||
weights=weights,
|
||||
binary_output=binary_output)
|
||||
|
||||
|
||||
@tf_export(v1=["math.bincount", "bincount"])
|
||||
@deprecation.deprecated_endpoints("bincount")
|
||||
def bincount_v1(arr,
|
||||
weights=None,
|
||||
minlength=None,
|
||||
maxlength=None,
|
||||
dtype=dtypes.int32):
|
||||
"""Counts the number of occurrences of each value in an integer array.
|
||||
|
||||
If `minlength` and `maxlength` are not given, returns a vector with length
|
||||
`tf.reduce_max(arr) + 1` if `arr` is non-empty, and length 0 otherwise.
|
||||
If `weights` are non-None, then index `i` of the output stores the sum of the
|
||||
value in `weights` at each index where the corresponding value in `arr` is
|
||||
`i`.
|
||||
|
||||
Args:
|
||||
arr: An int32 tensor of non-negative values.
|
||||
weights: If non-None, must be the same shape as arr. For each value in
|
||||
`arr`, the bin will be incremented by the corresponding weight instead of
|
||||
1.
|
||||
minlength: If given, ensures the output has length at least `minlength`,
|
||||
padding with zeros at the end if necessary.
|
||||
maxlength: If given, skips values in `arr` that are equal or greater than
|
||||
`maxlength`, ensuring that the output has length at most `maxlength`.
|
||||
dtype: If `weights` is None, determines the type of the output bins.
|
||||
|
||||
Returns:
|
||||
A vector with the same dtype as `weights` or the given `dtype`. The bin
|
||||
values.
|
||||
"""
|
||||
return bincount(arr, weights, minlength, maxlength, dtype)
|
||||
|
||||
|
||||
@tf_export("sparse.bincount")
|
||||
def sparse_bincount(values,
|
||||
weights=None,
|
||||
@ -45,19 +269,17 @@ def sparse_bincount(values,
|
||||
|
||||
Args:
|
||||
values: A Tensor, RaggedTensor, or SparseTensor whose values should be
|
||||
counted. These tensors must have a rank of 1 or 2.
|
||||
weights: A 1-dimensional Tensor of weights. If specified, the input array is
|
||||
weighted by the weight array, i.e. if a value `n` is found at position
|
||||
`i`, `out[n]` will be increased by `weight[i]` instead of 1.
|
||||
counted. These tensors must have a rank of 2 if `axis=-1`.
|
||||
weights: If non-None, must be the same shape as arr. For each value in
|
||||
`value`, the bin will be incremented by the corresponding weight instead
|
||||
of 1.
|
||||
axis: The axis to slice over. Axes at and below `axis` will be flattened
|
||||
before bin counting. Currently, only `0`, and `-1` are supported. If None,
|
||||
all axes will be flattened (identical to passing `0`).
|
||||
minlength: If given, skips `values` that are less than `minlength`, and
|
||||
ensures that the output has a `dense_shape` of at least `minlength` in the
|
||||
inner dimension.
|
||||
maxlength: If given, skips `values` that are greater than or equal to
|
||||
`maxlength`, and ensures that the output has a `dense_shape` of at most
|
||||
`maxlength` in the inner dimension.
|
||||
minlength: If given, ensures the output has length at least `minlength`,
|
||||
padding with zeros at the end if necessary.
|
||||
maxlength: If given, skips values in `values` that are equal or greater than
|
||||
`maxlength`, ensuring that the output has length at most `maxlength`.
|
||||
binary_output: If True, this op will output 1 instead of the number of times
|
||||
a token appears (equivalent to one_hot + reduce_any instead of one_hot +
|
||||
reduce_add). Defaults to False.
|
||||
@ -229,9 +451,11 @@ def sparse_bincount(values,
|
||||
return sparse_tensor.SparseTensor(c_ind, c_val, c_shape)
|
||||
|
||||
|
||||
def validate_dense_weights(values, weights):
|
||||
def validate_dense_weights(values, weights, dtype=None):
|
||||
"""Validates the passed weight tensor or creates an empty one."""
|
||||
if weights is None:
|
||||
if dtype:
|
||||
return array_ops.constant([], dtype=dtype)
|
||||
return array_ops.constant([], dtype=values.dtype)
|
||||
|
||||
if not isinstance(weights, ops.Tensor):
|
||||
@ -241,9 +465,11 @@ def validate_dense_weights(values, weights):
|
||||
return weights
|
||||
|
||||
|
||||
def validate_sparse_weights(values, weights):
|
||||
def validate_sparse_weights(values, weights, dtype=None):
|
||||
"""Validates the passed weight tensor or creates an empty one."""
|
||||
if weights is None:
|
||||
if dtype:
|
||||
return array_ops.constant([], dtype=dtype)
|
||||
return array_ops.constant([], dtype=values.values.dtype)
|
||||
|
||||
if not isinstance(weights, sparse_tensor.SparseTensor):
|
||||
@ -273,9 +499,11 @@ def validate_sparse_weights(values, weights):
|
||||
return weights
|
||||
|
||||
|
||||
def validate_ragged_weights(values, weights):
|
||||
def validate_ragged_weights(values, weights, dtype=None):
|
||||
"""Validates the passed weight tensor or creates an empty one."""
|
||||
if weights is None:
|
||||
if dtype:
|
||||
return array_ops.constant([], dtype=dtype)
|
||||
return array_ops.constant([], dtype=values.values.dtype)
|
||||
|
||||
if not isinstance(weights, ragged_tensor.RaggedTensor):
|
@ -23,9 +23,12 @@ import numpy as np
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.ops import bincount
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import bincount_ops
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -151,7 +154,7 @@ class TestSparseCount(test.TestCase, parameterized.TestCase):
|
||||
binary_output=False,
|
||||
weights=None,
|
||||
axis=-1):
|
||||
y = bincount.sparse_bincount(
|
||||
y = bincount_ops.sparse_bincount(
|
||||
x,
|
||||
weights=weights,
|
||||
minlength=minlength,
|
||||
@ -349,7 +352,7 @@ class TestSparseCount(test.TestCase, parameterized.TestCase):
|
||||
axis=-1):
|
||||
x_sparse = sparse_ops.from_dense(x)
|
||||
w_sparse = sparse_ops.from_dense(weights) if weights is not None else None
|
||||
y = bincount.sparse_bincount(
|
||||
y = bincount_ops.sparse_bincount(
|
||||
x_sparse,
|
||||
weights=w_sparse,
|
||||
minlength=minlength,
|
||||
@ -496,7 +499,7 @@ class TestSparseCount(test.TestCase, parameterized.TestCase):
|
||||
axis=-1):
|
||||
x_ragged = ragged_factory_ops.constant(x)
|
||||
w = ragged_factory_ops.constant(weights) if weights is not None else None
|
||||
y = bincount.sparse_bincount(
|
||||
y = bincount_ops.sparse_bincount(
|
||||
x_ragged,
|
||||
weights=w,
|
||||
minlength=minlength,
|
||||
@ -508,6 +511,237 @@ class TestSparseCount(test.TestCase, parameterized.TestCase):
|
||||
self.assertAllEqual(expected_shape, y.dense_shape)
|
||||
|
||||
|
||||
class TestDenseBincount(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters([{
|
||||
"dtype": np.int32,
|
||||
}, {
|
||||
"dtype": np.int64,
|
||||
}])
|
||||
def test_sparse_input_all_count(self, dtype):
|
||||
np.random.seed(42)
|
||||
num_rows = 128
|
||||
size = 1000
|
||||
n_elems = 4096
|
||||
inp_indices = np.random.randint(0, num_rows, (n_elems, 1))
|
||||
inp_indices = np.concatenate([inp_indices, np.zeros((n_elems, 1))], axis=1)
|
||||
inp_vals = np.random.randint(0, size, (n_elems,), dtype=dtype)
|
||||
sparse_inp = sparse_tensor.SparseTensor(inp_indices, inp_vals,
|
||||
[num_rows, 1])
|
||||
|
||||
np_out = np.bincount(inp_vals, minlength=size)
|
||||
self.assertAllEqual(
|
||||
np_out, self.evaluate(bincount_ops.bincount(sparse_inp, axis=0)))
|
||||
|
||||
@parameterized.parameters([{
|
||||
"dtype": np.int32,
|
||||
}, {
|
||||
"dtype": np.int64,
|
||||
}])
|
||||
def test_sparse_input_all_count_with_weights(self, dtype):
|
||||
np.random.seed(42)
|
||||
num_rows = 128
|
||||
size = 1000
|
||||
n_elems = 4096
|
||||
inp_indices = np.random.randint(0, num_rows, (n_elems, 1))
|
||||
inp_indices = np.concatenate([inp_indices, np.zeros((n_elems, 1))], axis=1)
|
||||
inp_vals = np.random.randint(0, size, (n_elems,), dtype=dtype)
|
||||
sparse_inp = sparse_tensor.SparseTensor(inp_indices, inp_vals,
|
||||
[num_rows, 1])
|
||||
weight_vals = np.random.random((n_elems,))
|
||||
sparse_weights = sparse_tensor.SparseTensor(inp_indices, weight_vals,
|
||||
[num_rows, 1])
|
||||
|
||||
np_out = np.bincount(inp_vals, minlength=size, weights=weight_vals)
|
||||
self.assertAllEqual(
|
||||
np_out,
|
||||
self.evaluate(bincount_ops.bincount(
|
||||
sparse_inp, sparse_weights, axis=0)))
|
||||
|
||||
@parameterized.parameters([{
|
||||
"dtype": np.int32,
|
||||
}, {
|
||||
"dtype": np.int64,
|
||||
}])
|
||||
def test_sparse_input_all_binary(self, dtype):
|
||||
np.random.seed(42)
|
||||
num_rows = 128
|
||||
size = 10
|
||||
n_elems = 4096
|
||||
inp_indices = np.random.randint(0, num_rows, (n_elems, 1))
|
||||
inp_indices = np.concatenate([inp_indices, np.zeros((n_elems, 1))], axis=1)
|
||||
inp_vals = np.random.randint(0, size, (n_elems,), dtype=dtype)
|
||||
sparse_inp = sparse_tensor.SparseTensor(inp_indices, inp_vals,
|
||||
[num_rows, 1])
|
||||
|
||||
np_out = np.ones((size,))
|
||||
self.assertAllEqual(
|
||||
np_out,
|
||||
self.evaluate(bincount_ops.bincount(sparse_inp, binary_output=True)))
|
||||
|
||||
@parameterized.parameters([{
|
||||
"dtype": np.int32,
|
||||
}, {
|
||||
"dtype": np.int64,
|
||||
}])
|
||||
def test_sparse_input_col_reduce_count(self, dtype):
|
||||
num_rows = 128
|
||||
num_cols = 27
|
||||
size = 100
|
||||
np.random.seed(42)
|
||||
inp = np.random.randint(0, size, (num_rows, num_cols), dtype=dtype)
|
||||
np_out = np.reshape(
|
||||
np.concatenate(
|
||||
[np.bincount(inp[j, :], minlength=size) for j in range(num_rows)],
|
||||
axis=0), (num_rows, size))
|
||||
# from_dense will filter out 0s.
|
||||
inp = inp + 1
|
||||
# from_dense will cause OOM in GPU.
|
||||
with ops.device("/CPU:0"):
|
||||
inp_sparse = sparse_ops.from_dense(inp)
|
||||
inp_sparse = sparse_tensor.SparseTensor(inp_sparse.indices,
|
||||
inp_sparse.values - 1,
|
||||
inp_sparse.dense_shape)
|
||||
self.assertAllEqual(
|
||||
np_out, self.evaluate(bincount_ops.bincount(arr=inp_sparse, axis=-1)))
|
||||
|
||||
@parameterized.parameters([{
|
||||
"dtype": np.int32,
|
||||
}, {
|
||||
"dtype": np.int64,
|
||||
}])
|
||||
def test_sparse_input_col_reduce_binary(self, dtype):
|
||||
num_rows = 128
|
||||
num_cols = 27
|
||||
size = 100
|
||||
np.random.seed(42)
|
||||
inp = np.random.randint(0, size, (num_rows, num_cols), dtype=dtype)
|
||||
np_out = np.reshape(
|
||||
np.concatenate([
|
||||
np.where(np.bincount(inp[j, :], minlength=size) > 0, 1, 0)
|
||||
for j in range(num_rows)
|
||||
],
|
||||
axis=0), (num_rows, size))
|
||||
# from_dense will filter out 0s.
|
||||
inp = inp + 1
|
||||
# from_dense will cause OOM in GPU.
|
||||
with ops.device("/CPU:0"):
|
||||
inp_sparse = sparse_ops.from_dense(inp)
|
||||
inp_sparse = sparse_tensor.SparseTensor(inp_sparse.indices,
|
||||
inp_sparse.values - 1,
|
||||
inp_sparse.dense_shape)
|
||||
self.assertAllEqual(
|
||||
np_out,
|
||||
self.evaluate(
|
||||
bincount_ops.bincount(arr=inp_sparse, axis=-1, binary_output=True)))
|
||||
|
||||
@parameterized.parameters([{
|
||||
"dtype": np.int32,
|
||||
}, {
|
||||
"dtype": np.int64,
|
||||
}])
|
||||
def test_ragged_input_count(self, dtype):
|
||||
x = ragged_factory_ops.constant([[], [], [3, 0, 1], [], [5, 0, 4, 4]],
|
||||
dtype)
|
||||
# pyformat: disable
|
||||
expected_output = [
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 0, 1, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[1, 0, 0, 0, 2, 1]]
|
||||
# pyformat: enable
|
||||
self.assertAllEqual(expected_output,
|
||||
self.evaluate(bincount_ops.bincount(arr=x, axis=-1)))
|
||||
|
||||
@parameterized.parameters([{
|
||||
"dtype": np.int32,
|
||||
}, {
|
||||
"dtype": np.int64,
|
||||
}])
|
||||
def test_ragged_input_binary(self, dtype):
|
||||
x = ragged_factory_ops.constant([[], [], [3, 0, 1], [], [5, 0, 4, 4]])
|
||||
# pyformat: disable
|
||||
expected_output = [
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 0, 1, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[1, 0, 0, 0, 1, 1]]
|
||||
# pyformat: enable
|
||||
self.assertAllEqual(
|
||||
expected_output,
|
||||
self.evaluate(
|
||||
bincount_ops.bincount(arr=x, axis=-1, binary_output=True)))
|
||||
|
||||
@parameterized.parameters([{
|
||||
"dtype": np.int32,
|
||||
}, {
|
||||
"dtype": np.int64,
|
||||
}])
|
||||
def test_ragged_input_count_with_weights(self, dtype):
|
||||
x = ragged_factory_ops.constant([[], [], [3, 0, 1], [], [5, 0, 4, 4]])
|
||||
weights = ragged_factory_ops.constant([[], [], [.1, .2, .3], [],
|
||||
[.2, .5, .6, .3]])
|
||||
# pyformat: disable
|
||||
expected_output = [
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[.2, .3, 0, .1, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[.5, 0, 0, 0, .9, .2]]
|
||||
# pyformat: enable
|
||||
self.assertAllClose(
|
||||
expected_output,
|
||||
self.evaluate(bincount_ops.bincount(arr=x, weights=weights, axis=-1)))
|
||||
|
||||
@parameterized.parameters([{
|
||||
"dtype": np.int32,
|
||||
}, {
|
||||
"dtype": np.int64,
|
||||
}])
|
||||
def test_ragged_input_count_np(self, dtype):
|
||||
np.random.seed(42)
|
||||
num_rows = 128
|
||||
num_cols = 27
|
||||
size = 1000
|
||||
inp = np.random.randint(0, size, (num_rows, num_cols), dtype=dtype)
|
||||
np_out = np.reshape(
|
||||
np.concatenate(
|
||||
[np.bincount(inp[j, :], minlength=size) for j in range(num_rows)],
|
||||
axis=0), (num_rows, size))
|
||||
x = ragged_tensor.RaggedTensor.from_tensor(inp)
|
||||
self.assertAllEqual(
|
||||
np_out,
|
||||
self.evaluate(bincount_ops.bincount(arr=x, minlength=size, axis=-1)))
|
||||
|
||||
@parameterized.parameters([{
|
||||
"dtype": np.int32,
|
||||
}, {
|
||||
"dtype": np.int64,
|
||||
}])
|
||||
def test_ragged_input_count_np_with_weights(self, dtype):
|
||||
np.random.seed(42)
|
||||
num_rows = 128
|
||||
num_cols = 27
|
||||
size = 1000
|
||||
inp = np.random.randint(0, size, (num_rows, num_cols), dtype=dtype)
|
||||
np_weight = np.random.random((num_rows, num_cols))
|
||||
np_out = np.reshape(
|
||||
np.concatenate([
|
||||
np.bincount(inp[j, :], weights=np_weight[j, :], minlength=size)
|
||||
for j in range(num_rows)
|
||||
],
|
||||
axis=0), (num_rows, size))
|
||||
x = ragged_tensor.RaggedTensor.from_tensor(inp)
|
||||
weights = ragged_tensor.RaggedTensor.from_tensor(np_weight)
|
||||
self.assertAllEqual(
|
||||
np_out,
|
||||
self.evaluate(
|
||||
bincount_ops.bincount(
|
||||
arr=x, weights=weights, minlength=size, axis=-1)))
|
||||
|
||||
|
||||
class TestSparseCountFailureModes(test.TestCase):
|
||||
|
||||
def test_dense_input_sparse_weights_fails(self):
|
||||
@ -515,13 +749,13 @@ class TestSparseCountFailureModes(test.TestCase):
|
||||
weights = sparse_ops.from_dense(
|
||||
np.array([[3, 0, 1, 0], [0, 0, 0, 0], [5, 0, 4, 4]], dtype=np.int32))
|
||||
with self.assertRaisesRegexp(ValueError, "must be a tf.Tensor"):
|
||||
self.evaluate(bincount.sparse_bincount(x, weights=weights, axis=-1))
|
||||
self.evaluate(bincount_ops.sparse_bincount(x, weights=weights, axis=-1))
|
||||
|
||||
def test_dense_input_ragged_weights_fails(self):
|
||||
x = np.array([[3, 2, 1], [5, 4, 4]], dtype=np.int32)
|
||||
weights = ragged_factory_ops.constant([[6, 0.5, 2], [14], [10, 0.25, 5, 3]])
|
||||
with self.assertRaisesRegexp(ValueError, "must be a tf.Tensor"):
|
||||
self.evaluate(bincount.sparse_bincount(x, weights=weights, axis=-1))
|
||||
self.evaluate(bincount_ops.sparse_bincount(x, weights=weights, axis=-1))
|
||||
|
||||
def test_dense_input_wrong_shape_fails(self):
|
||||
x = np.array([[3, 2, 1], [5, 4, 4]], dtype=np.int32)
|
||||
@ -532,24 +766,24 @@ class TestSparseCountFailureModes(test.TestCase):
|
||||
if context.executing_eagerly():
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||
"must have the same shape"):
|
||||
self.evaluate(bincount.sparse_bincount(x, weights=weights, axis=-1))
|
||||
self.evaluate(bincount_ops.sparse_bincount(x, weights=weights, axis=-1))
|
||||
else:
|
||||
with self.assertRaisesRegexp(ValueError, "both shapes must be equal"):
|
||||
self.evaluate(bincount.sparse_bincount(x, weights=weights, axis=-1))
|
||||
self.evaluate(bincount_ops.sparse_bincount(x, weights=weights, axis=-1))
|
||||
|
||||
def test_sparse_input_dense_weights_fails(self):
|
||||
x = sparse_ops.from_dense(
|
||||
np.array([[3, 0, 1, 0], [0, 0, 0, 0], [5, 0, 4, 4]], dtype=np.int32))
|
||||
weights = np.array([[3, 2, 1], [5, 4, 4]], dtype=np.int32)
|
||||
with self.assertRaisesRegexp(ValueError, "must be a SparseTensor"):
|
||||
self.evaluate(bincount.sparse_bincount(x, weights=weights, axis=-1))
|
||||
self.evaluate(bincount_ops.sparse_bincount(x, weights=weights, axis=-1))
|
||||
|
||||
def test_sparse_input_ragged_weights_fails(self):
|
||||
x = sparse_ops.from_dense(
|
||||
np.array([[3, 0, 1, 0], [0, 0, 0, 0], [5, 0, 4, 4]], dtype=np.int32))
|
||||
weights = ragged_factory_ops.constant([[6, 0.5, 2], [14], [10, 0.25, 5, 3]])
|
||||
with self.assertRaisesRegexp(ValueError, "must be a SparseTensor"):
|
||||
self.evaluate(bincount.sparse_bincount(x, weights=weights, axis=-1))
|
||||
self.evaluate(bincount_ops.sparse_bincount(x, weights=weights, axis=-1))
|
||||
|
||||
def test_sparse_input_wrong_indices_fails(self):
|
||||
x = sparse_ops.from_dense(
|
||||
@ -558,7 +792,7 @@ class TestSparseCountFailureModes(test.TestCase):
|
||||
np.array([[3, 1, 0, 0], [0, 0, 0, 0], [5, 0, 4, 4]], dtype=np.int32))
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||
"must have the same indices"):
|
||||
self.evaluate(bincount.sparse_bincount(x, weights=weights, axis=-1))
|
||||
self.evaluate(bincount_ops.sparse_bincount(x, weights=weights, axis=-1))
|
||||
|
||||
def test_sparse_input_too_many_indices_fails(self):
|
||||
x = sparse_ops.from_dense(
|
||||
@ -567,7 +801,7 @@ class TestSparseCountFailureModes(test.TestCase):
|
||||
np.array([[3, 1, 1, 0], [0, 0, 0, 0], [5, 0, 4, 4]], dtype=np.int32))
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||
"Incompatible shapes"):
|
||||
self.evaluate(bincount.sparse_bincount(x, weights=weights, axis=-1))
|
||||
self.evaluate(bincount_ops.sparse_bincount(x, weights=weights, axis=-1))
|
||||
|
||||
def test_sparse_input_wrong_shape_fails(self):
|
||||
x = sparse_ops.from_dense(
|
||||
@ -577,27 +811,27 @@ class TestSparseCountFailureModes(test.TestCase):
|
||||
dtype=np.int32))
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||
"must have the same dense shape"):
|
||||
self.evaluate(bincount.sparse_bincount(x, weights=weights, axis=-1))
|
||||
self.evaluate(bincount_ops.sparse_bincount(x, weights=weights, axis=-1))
|
||||
|
||||
def test_ragged_input_dense_weights_fails(self):
|
||||
x = ragged_factory_ops.constant([[6, 1, 2], [14], [10, 1, 5, 3]])
|
||||
weights = np.array([[3, 2, 1], [5, 4, 4]], dtype=np.int32)
|
||||
with self.assertRaisesRegexp(ValueError, "must be a RaggedTensor"):
|
||||
self.evaluate(bincount.sparse_bincount(x, weights=weights, axis=-1))
|
||||
self.evaluate(bincount_ops.sparse_bincount(x, weights=weights, axis=-1))
|
||||
|
||||
def test_ragged_input_sparse_weights_fails(self):
|
||||
x = ragged_factory_ops.constant([[6, 1, 2], [14], [10, 1, 5, 3]])
|
||||
weights = sparse_ops.from_dense(
|
||||
np.array([[3, 0, 1, 0], [0, 0, 0, 0], [5, 0, 4, 4]], dtype=np.int32))
|
||||
with self.assertRaisesRegexp(ValueError, "must be a RaggedTensor"):
|
||||
self.evaluate(bincount.sparse_bincount(x, weights=weights, axis=-1))
|
||||
self.evaluate(bincount_ops.sparse_bincount(x, weights=weights, axis=-1))
|
||||
|
||||
def test_ragged_input_different_shape_fails(self):
|
||||
x = ragged_factory_ops.constant([[6, 1, 2], [14], [10, 1, 5, 3]])
|
||||
weights = ragged_factory_ops.constant([[6, 0.5, 2], [], [10, 0.25, 5, 3]])
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||
"must have the same row splits"):
|
||||
self.evaluate(bincount.sparse_bincount(x, weights=weights, axis=-1))
|
||||
self.evaluate(bincount_ops.sparse_bincount(x, weights=weights, axis=-1))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
@ -3562,116 +3562,6 @@ def log_sigmoid(x, name=None):
|
||||
return gen_math_ops.neg(gen_nn_ops.softplus(-x), name=name)
|
||||
|
||||
|
||||
@tf_export("math.bincount", v1=[])
|
||||
@dispatch.add_dispatch_support
|
||||
def bincount(arr,
|
||||
weights=None,
|
||||
minlength=None,
|
||||
maxlength=None,
|
||||
dtype=dtypes.int32,
|
||||
name=None):
|
||||
"""Counts the number of occurrences of each value in an integer array.
|
||||
|
||||
If `minlength` and `maxlength` are not given, returns a vector with length
|
||||
`tf.reduce_max(arr) + 1` if `arr` is non-empty, and length 0 otherwise.
|
||||
If `weights` are non-None, then index `i` of the output stores the sum of the
|
||||
value in `weights` at each index where the corresponding value in `arr` is
|
||||
`i`.
|
||||
|
||||
```python
|
||||
values = tf.constant([1,1,2,3,2,4,4,5])
|
||||
tf.math.bincount(values) #[0 2 2 1 2 1]
|
||||
```
|
||||
Vector length = Maximum element in vector `values` is 5. Adding 1, which is 6
|
||||
will be the vector length.
|
||||
|
||||
Each bin value in the output indicates number of occurrences of the particular
|
||||
index. Here, index 1 in output has a value 2. This indicates value 1 occurs
|
||||
two times in `values`.
|
||||
|
||||
```python
|
||||
values = tf.constant([1,1,2,3,2,4,4,5])
|
||||
weights = tf.constant([1,5,0,1,0,5,4,5])
|
||||
tf.math.bincount(values, weights=weights) #[0 6 0 1 9 5]
|
||||
```
|
||||
Bin will be incremented by the corresponding weight instead of 1.
|
||||
Here, index 1 in output has a value 6. This is the summation of weights
|
||||
corresponding to the value in `values`.
|
||||
|
||||
Args:
|
||||
arr: An int32 tensor of non-negative values.
|
||||
weights: If non-None, must be the same shape as arr. For each value in
|
||||
`arr`, the bin will be incremented by the corresponding weight instead of
|
||||
1.
|
||||
minlength: If given, ensures the output has length at least `minlength`,
|
||||
padding with zeros at the end if necessary.
|
||||
maxlength: If given, skips values in `arr` that are equal or greater than
|
||||
`maxlength`, ensuring that the output has length at most `maxlength`.
|
||||
dtype: If `weights` is None, determines the type of the output bins.
|
||||
name: A name scope for the associated operations (optional).
|
||||
|
||||
Returns:
|
||||
A vector with the same dtype as `weights` or the given `dtype`. The bin
|
||||
values.
|
||||
|
||||
Raises:
|
||||
`InvalidArgumentError` if negative values are provided as an input.
|
||||
|
||||
"""
|
||||
name = "bincount" if name is None else name
|
||||
with ops.name_scope(name):
|
||||
arr = ops.convert_to_tensor(arr, name="arr", dtype=dtypes.int32)
|
||||
array_is_nonempty = reduce_prod(array_ops.shape(arr)) > 0
|
||||
output_size = cast(array_is_nonempty, dtypes.int32) * (reduce_max(arr) + 1)
|
||||
if minlength is not None:
|
||||
minlength = ops.convert_to_tensor(
|
||||
minlength, name="minlength", dtype=dtypes.int32)
|
||||
output_size = gen_math_ops.maximum(minlength, output_size)
|
||||
if maxlength is not None:
|
||||
maxlength = ops.convert_to_tensor(
|
||||
maxlength, name="maxlength", dtype=dtypes.int32)
|
||||
output_size = gen_math_ops.minimum(maxlength, output_size)
|
||||
if weights is not None:
|
||||
weights = ops.convert_to_tensor(weights, name="weights")
|
||||
return gen_math_ops.unsorted_segment_sum(weights, arr, output_size)
|
||||
weights = constant_op.constant([], dtype)
|
||||
return gen_math_ops.bincount(arr, output_size, weights)
|
||||
|
||||
|
||||
@tf_export(v1=["math.bincount", "bincount"])
|
||||
@dispatch.add_dispatch_support
|
||||
@deprecation.deprecated_endpoints("bincount")
|
||||
def bincount_v1(arr,
|
||||
weights=None,
|
||||
minlength=None,
|
||||
maxlength=None,
|
||||
dtype=dtypes.int32):
|
||||
"""Counts the number of occurrences of each value in an integer array.
|
||||
|
||||
If `minlength` and `maxlength` are not given, returns a vector with length
|
||||
`tf.reduce_max(arr) + 1` if `arr` is non-empty, and length 0 otherwise.
|
||||
If `weights` are non-None, then index `i` of the output stores the sum of the
|
||||
value in `weights` at each index where the corresponding value in `arr` is
|
||||
`i`.
|
||||
|
||||
Args:
|
||||
arr: An int32 tensor of non-negative values.
|
||||
weights: If non-None, must be the same shape as arr. For each value in
|
||||
`arr`, the bin will be incremented by the corresponding weight instead of
|
||||
1.
|
||||
minlength: If given, ensures the output has length at least `minlength`,
|
||||
padding with zeros at the end if necessary.
|
||||
maxlength: If given, skips values in `arr` that are equal or greater than
|
||||
`maxlength`, ensuring that the output has length at most `maxlength`.
|
||||
dtype: If `weights` is None, determines the type of the output bins.
|
||||
|
||||
Returns:
|
||||
A vector with the same dtype as `weights` or the given `dtype`. The bin
|
||||
values.
|
||||
"""
|
||||
return bincount(arr, weights, minlength, maxlength, dtype)
|
||||
|
||||
|
||||
@tf_export("math.cumsum", "cumsum")
|
||||
@dispatch.add_dispatch_support
|
||||
def cumsum(x, axis=0, exclusive=False, reverse=False, name=None):
|
||||
@ -4556,9 +4446,9 @@ def polyval(coeffs, x, name=None):
|
||||
|
||||
p(x) = coeffs[n-1] + x * (coeffs[n-2] + ... + x * (coeffs[1] +
|
||||
x * coeffs[0]))
|
||||
|
||||
|
||||
Usage Example:
|
||||
|
||||
|
||||
>>> coefficients = [1.0, 2.5, -4.2]
|
||||
>>> x = 5.0
|
||||
>>> y = tf.math.polyval(coefficients, x)
|
||||
|
@ -307,6 +307,7 @@ py_library(
|
||||
deps = [
|
||||
":segment_id_ops",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:bincount_ops",
|
||||
"//tensorflow/python:check_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
@ -417,6 +418,7 @@ py_library(
|
||||
deps = [
|
||||
":ragged_util",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:bincount_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
|
@ -228,6 +228,9 @@ class RowPartition(composite_tensor.CompositeTensor):
|
||||
... nrows=4))
|
||||
tf.RowPartition(row_splits=tf.Tensor([0 4 4 7 8], shape=(5,), dtype=int64))
|
||||
"""
|
||||
# Local import bincount_ops to avoid import-cycle since bincount_ops
|
||||
# imports ragged_tensor.
|
||||
from tensorflow.python.ops import bincount_ops # pylint: disable=g-import-not-at-top
|
||||
if not isinstance(validate, bool):
|
||||
raise TypeError("validate must have type bool")
|
||||
with ops.name_scope(None, "RowPartitionFromValueRowIds",
|
||||
@ -278,7 +281,7 @@ class RowPartition(composite_tensor.CompositeTensor):
|
||||
# cast.
|
||||
value_rowids_int32 = math_ops.cast(value_rowids, dtypes.int32)
|
||||
nrows_int32 = math_ops.cast(nrows, dtypes.int32)
|
||||
row_lengths = math_ops.bincount(
|
||||
row_lengths = bincount_ops.bincount(
|
||||
value_rowids_int32,
|
||||
minlength=nrows_int32,
|
||||
maxlength=nrows_int32,
|
||||
|
@ -98,6 +98,8 @@ def segment_ids_to_row_splits(segment_ids, num_segments=None,
|
||||
Returns:
|
||||
A sorted 1-D integer Tensor, with `shape=[num_segments + 1]`.
|
||||
"""
|
||||
# Local import bincount_ops to avoid import-cycle.
|
||||
from tensorflow.python.ops import bincount_ops # pylint: disable=g-import-not-at-top
|
||||
if out_type is None:
|
||||
if isinstance(segment_ids, ops.Tensor):
|
||||
out_type = segment_ids.dtype
|
||||
@ -119,7 +121,7 @@ def segment_ids_to_row_splits(segment_ids, num_segments=None,
|
||||
dtype=dtypes.int32)
|
||||
num_segments.shape.assert_has_rank(0)
|
||||
|
||||
row_lengths = math_ops.bincount(
|
||||
row_lengths = bincount_ops.bincount(
|
||||
segment_ids,
|
||||
minlength=num_segments,
|
||||
maxlength=num_segments,
|
||||
|
@ -82,7 +82,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "bincount"
|
||||
argspec: "args=[\'arr\', \'weights\', \'minlength\', \'maxlength\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \"<dtype: \'int32\'>\", \'None\'], "
|
||||
argspec: "args=[\'arr\', \'weights\', \'minlength\', \'maxlength\', \'dtype\', \'name\', \'axis\', \'binary_output\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \"<dtype: \'int32\'>\", \'None\', \'None\', \'False\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ceil"
|
||||
|
Loading…
Reference in New Issue
Block a user