Merge pull request #38899 from yongtang:29661-histogram_fixed_width_bins-exception
PiperOrigin-RevId: 310151116 Change-Id: If6f1eebf7517fd48a85dd6887154850422d111e5
This commit is contained in:
commit
5c9f025c55
@ -22,6 +22,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import tensor_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import clip_ops
|
from tensorflow.python.ops import clip_ops
|
||||||
from tensorflow.python.ops import gen_math_ops
|
from tensorflow.python.ops import gen_math_ops
|
||||||
@ -76,11 +77,20 @@ def histogram_fixed_width_bins(values,
|
|||||||
"""
|
"""
|
||||||
with ops.name_scope(name, 'histogram_fixed_width_bins',
|
with ops.name_scope(name, 'histogram_fixed_width_bins',
|
||||||
[values, value_range, nbins]):
|
[values, value_range, nbins]):
|
||||||
|
value_range_value = tensor_util.constant_value(value_range)
|
||||||
|
if value_range_value is not None:
|
||||||
|
if (value_range_value[0] >= value_range_value[1]):
|
||||||
|
raise ValueError(
|
||||||
|
'value_range should satisfy value_range[0] < value_range[1], ',
|
||||||
|
"but got '[{}, {}]".format(value_range_value[0],
|
||||||
|
value_range_value[1]))
|
||||||
|
|
||||||
values = ops.convert_to_tensor(values, name='values')
|
values = ops.convert_to_tensor(values, name='values')
|
||||||
shape = array_ops.shape(values)
|
shape = array_ops.shape(values)
|
||||||
|
|
||||||
values = array_ops.reshape(values, [-1])
|
values = array_ops.reshape(values, [-1])
|
||||||
value_range = ops.convert_to_tensor(value_range, name='value_range')
|
value_range = ops.convert_to_tensor(value_range, name='value_range')
|
||||||
|
|
||||||
nbins = ops.convert_to_tensor(nbins, dtype=dtypes.int32, name='nbins')
|
nbins = ops.convert_to_tensor(nbins, dtype=dtypes.int32, name='nbins')
|
||||||
nbins_float = math_ops.cast(nbins, values.dtype)
|
nbins_float = math_ops.cast(nbins, values.dtype)
|
||||||
|
|
||||||
|
@ -79,6 +79,16 @@ class BinValuesFixedWidth(test.TestCase):
|
|||||||
self.assertEqual(dtypes.int32, bins.dtype)
|
self.assertEqual(dtypes.int32, bins.dtype)
|
||||||
self.assertAllClose(expected_bins, self.evaluate(bins))
|
self.assertAllClose(expected_bins, self.evaluate(bins))
|
||||||
|
|
||||||
|
def test_range_overlap(self):
|
||||||
|
# GitHub issue 29661
|
||||||
|
value_range = np.float32([0.0, 0.0])
|
||||||
|
values = np.float32([-1.0, 0.0, 1.5, 2.0, 5.0, 15])
|
||||||
|
expected_bins = [0, 0, 4, 4, 4, 4]
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
with self.cached_session():
|
||||||
|
_ = histogram_ops.histogram_fixed_width_bins(
|
||||||
|
values, value_range, nbins=5)
|
||||||
|
|
||||||
|
|
||||||
class HistogramFixedWidthTest(test.TestCase):
|
class HistogramFixedWidthTest(test.TestCase):
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user