Remove forward compatibility check for bincount.
PiperOrigin-RevId: 336704117 Change-Id: I41983da01dfedddcb4a3c5b794e8fccd59422b91
This commit is contained in:
parent
594177934c
commit
68e547e356
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
@ -119,6 +120,26 @@ def bincount(arr,
|
||||
"""
|
||||
name = "bincount" if name is None else name
|
||||
with ops.name_scope(name):
|
||||
# Somehow forward compatible needs to be False.
|
||||
if not binary_output and axis is None:
|
||||
arr = ops.convert_to_tensor(arr, name="arr", dtype=dtypes.int32)
|
||||
array_is_nonempty = math_ops.reduce_prod(array_ops.shape(arr)) > 0
|
||||
output_size = math_ops.cast(array_is_nonempty, dtypes.int32) * (
|
||||
math_ops.reduce_max(arr) + 1)
|
||||
if minlength is not None:
|
||||
minlength = ops.convert_to_tensor(
|
||||
minlength, name="minlength", dtype=dtypes.int32)
|
||||
output_size = gen_math_ops.maximum(minlength, output_size)
|
||||
if maxlength is not None:
|
||||
maxlength = ops.convert_to_tensor(
|
||||
maxlength, name="maxlength", dtype=dtypes.int32)
|
||||
output_size = gen_math_ops.minimum(maxlength, output_size)
|
||||
if weights is not None:
|
||||
weights = ops.convert_to_tensor(weights, name="weights")
|
||||
return gen_math_ops.unsorted_segment_sum(weights, arr, output_size)
|
||||
weights = constant_op.constant([], dtype)
|
||||
return gen_math_ops.bincount(arr, output_size, weights)
|
||||
|
||||
if not isinstance(arr, sparse_tensor.SparseTensor):
|
||||
arr = ragged_tensor.convert_to_tensor_or_ragged_tensor(arr, name="arr")
|
||||
if weights is not None:
|
||||
|
Loading…
Reference in New Issue
Block a user