Remove forward compatibility check for bincount.
PiperOrigin-RevId: 336704117 Change-Id: I41983da01dfedddcb4a3c5b794e8fccd59422b91
This commit is contained in:
parent
594177934c
commit
68e547e356
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import sparse_tensor
|
from tensorflow.python.framework import sparse_tensor
|
||||||
@ -119,6 +120,26 @@ def bincount(arr,
|
|||||||
"""
|
"""
|
||||||
name = "bincount" if name is None else name
|
name = "bincount" if name is None else name
|
||||||
with ops.name_scope(name):
|
with ops.name_scope(name):
|
||||||
|
# Somehow forward compatible needs to be False.
|
||||||
|
if not binary_output and axis is None:
|
||||||
|
arr = ops.convert_to_tensor(arr, name="arr", dtype=dtypes.int32)
|
||||||
|
array_is_nonempty = math_ops.reduce_prod(array_ops.shape(arr)) > 0
|
||||||
|
output_size = math_ops.cast(array_is_nonempty, dtypes.int32) * (
|
||||||
|
math_ops.reduce_max(arr) + 1)
|
||||||
|
if minlength is not None:
|
||||||
|
minlength = ops.convert_to_tensor(
|
||||||
|
minlength, name="minlength", dtype=dtypes.int32)
|
||||||
|
output_size = gen_math_ops.maximum(minlength, output_size)
|
||||||
|
if maxlength is not None:
|
||||||
|
maxlength = ops.convert_to_tensor(
|
||||||
|
maxlength, name="maxlength", dtype=dtypes.int32)
|
||||||
|
output_size = gen_math_ops.minimum(maxlength, output_size)
|
||||||
|
if weights is not None:
|
||||||
|
weights = ops.convert_to_tensor(weights, name="weights")
|
||||||
|
return gen_math_ops.unsorted_segment_sum(weights, arr, output_size)
|
||||||
|
weights = constant_op.constant([], dtype)
|
||||||
|
return gen_math_ops.bincount(arr, output_size, weights)
|
||||||
|
|
||||||
if not isinstance(arr, sparse_tensor.SparseTensor):
|
if not isinstance(arr, sparse_tensor.SparseTensor):
|
||||||
arr = ragged_tensor.convert_to_tensor_or_ragged_tensor(arr, name="arr")
|
arr = ragged_tensor.convert_to_tensor_or_ragged_tensor(arr, name="arr")
|
||||||
if weights is not None:
|
if weights is not None:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user