diff --git a/tensorflow/python/ops/bincount_ops.py b/tensorflow/python/ops/bincount_ops.py index bda86bf7461..758f0180a84 100644 --- a/tensorflow/python/ops/bincount_ops.py +++ b/tensorflow/python/ops/bincount_ops.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor @@ -119,6 +120,26 @@ def bincount(arr, """ name = "bincount" if name is None else name with ops.name_scope(name): + # Somehow forward compatible needs to be False. + if not binary_output and axis is None: + arr = ops.convert_to_tensor(arr, name="arr", dtype=dtypes.int32) + array_is_nonempty = math_ops.reduce_prod(array_ops.shape(arr)) > 0 + output_size = math_ops.cast(array_is_nonempty, dtypes.int32) * ( + math_ops.reduce_max(arr) + 1) + if minlength is not None: + minlength = ops.convert_to_tensor( + minlength, name="minlength", dtype=dtypes.int32) + output_size = gen_math_ops.maximum(minlength, output_size) + if maxlength is not None: + maxlength = ops.convert_to_tensor( + maxlength, name="maxlength", dtype=dtypes.int32) + output_size = gen_math_ops.minimum(maxlength, output_size) + if weights is not None: + weights = ops.convert_to_tensor(weights, name="weights") + return gen_math_ops.unsorted_segment_sum(weights, arr, output_size) + weights = constant_op.constant([], dtype) + return gen_math_ops.bincount(arr, output_size, weights) + if not isinstance(arr, sparse_tensor.SparseTensor): arr = ragged_tensor.convert_to_tensor_or_ragged_tensor(arr, name="arr") if weights is not None: