Merge pull request #38899 from yongtang:29661-histogram_fixed_width_bins-exception

PiperOrigin-RevId: 310151116
Change-Id: If6f1eebf7517fd48a85dd6887154850422d111e5
This commit is contained in:
TensorFlower Gardener 2020-05-06 08:04:54 -07:00
commit 5c9f025c55
2 changed files with 20 additions and 0 deletions

View File

@ -22,6 +22,7 @@ from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import gen_math_ops
@ -76,11 +77,20 @@ def histogram_fixed_width_bins(values,
"""
with ops.name_scope(name, 'histogram_fixed_width_bins',
[values, value_range, nbins]):
value_range_value = tensor_util.constant_value(value_range)
if value_range_value is not None:
if (value_range_value[0] >= value_range_value[1]):
raise ValueError(
'value_range should satisfy value_range[0] < value_range[1], ',
"but got '[{}, {}]".format(value_range_value[0],
value_range_value[1]))
values = ops.convert_to_tensor(values, name='values')
shape = array_ops.shape(values)
values = array_ops.reshape(values, [-1])
value_range = ops.convert_to_tensor(value_range, name='value_range')
nbins = ops.convert_to_tensor(nbins, dtype=dtypes.int32, name='nbins')
nbins_float = math_ops.cast(nbins, values.dtype)

View File

@ -79,6 +79,16 @@ class BinValuesFixedWidth(test.TestCase):
self.assertEqual(dtypes.int32, bins.dtype)
self.assertAllClose(expected_bins, self.evaluate(bins))
def test_range_overlap(self):
# GitHub issue 29661
value_range = np.float32([0.0, 0.0])
values = np.float32([-1.0, 0.0, 1.5, 2.0, 5.0, 15])
expected_bins = [0, 0, 4, 4, 4, 4]
with self.assertRaises(ValueError):
with self.cached_session():
_ = histogram_ops.histogram_fixed_width_bins(
values, value_range, nbins=5)
class HistogramFixedWidthTest(test.TestCase):