Remove forward compatibility check for bincount.

PiperOrigin-RevId: 336704117
Change-Id: I41983da01dfedddcb4a3c5b794e8fccd59422b91
This commit is contained in:
A. Unique TensorFlower 2020-10-12 11:19:52 -07:00 committed by TensorFlower Gardener
parent 594177934c
commit 68e547e356

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
@ -119,6 +120,26 @@ def bincount(arr,
"""
name = "bincount" if name is None else name
with ops.name_scope(name):
# Somehow forward compatible needs to be False.
if not binary_output and axis is None:
arr = ops.convert_to_tensor(arr, name="arr", dtype=dtypes.int32)
array_is_nonempty = math_ops.reduce_prod(array_ops.shape(arr)) > 0
output_size = math_ops.cast(array_is_nonempty, dtypes.int32) * (
math_ops.reduce_max(arr) + 1)
if minlength is not None:
minlength = ops.convert_to_tensor(
minlength, name="minlength", dtype=dtypes.int32)
output_size = gen_math_ops.maximum(minlength, output_size)
if maxlength is not None:
maxlength = ops.convert_to_tensor(
maxlength, name="maxlength", dtype=dtypes.int32)
output_size = gen_math_ops.minimum(maxlength, output_size)
if weights is not None:
weights = ops.convert_to_tensor(weights, name="weights")
return gen_math_ops.unsorted_segment_sum(weights, arr, output_size)
weights = constant_op.constant([], dtype)
return gen_math_ops.bincount(arr, output_size, weights)
if not isinstance(arr, sparse_tensor.SparseTensor):
arr = ragged_tensor.convert_to_tensor_or_ragged_tensor(arr, name="arr")
if weights is not None: