Use eager-friendly evaluation for bincount_op_test.
PiperOrigin-RevId: 236690668
This commit is contained in:
parent
6163623e0e
commit
70438aaa2b
@ -30,44 +30,48 @@ from tensorflow.python.platform import googletest
|
||||
|
||||
class BincountTest(test_util.TensorFlowTestCase):
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_empty(self):
|
||||
with self.session(use_gpu=True):
|
||||
self.assertAllEqual(
|
||||
math_ops.bincount([], minlength=5).eval(), [0, 0, 0, 0, 0])
|
||||
self.assertAllEqual(math_ops.bincount([], minlength=1).eval(), [0])
|
||||
self.assertAllEqual(math_ops.bincount([], minlength=0).eval(), [])
|
||||
self.assertEqual(
|
||||
math_ops.bincount([], minlength=0, dtype=np.float32).eval().dtype,
|
||||
np.float32)
|
||||
self.assertEqual(
|
||||
math_ops.bincount([], minlength=3, dtype=np.float64).eval().dtype,
|
||||
np.float64)
|
||||
self.assertAllEqual(self.evaluate(math_ops.bincount([], minlength=5)),
|
||||
[0, 0, 0, 0, 0])
|
||||
self.assertAllEqual(self.evaluate(math_ops.bincount([], minlength=1)),
|
||||
[0])
|
||||
self.assertAllEqual(self.evaluate(math_ops.bincount([], minlength=0)),
|
||||
[])
|
||||
self.assertEqual(self.evaluate(math_ops.bincount([], minlength=0,
|
||||
dtype=np.float32)).dtype,
|
||||
np.float32)
|
||||
self.assertEqual(self.evaluate(math_ops.bincount([], minlength=3,
|
||||
dtype=np.float64)).dtype,
|
||||
np.float64)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_values(self):
|
||||
with self.session(use_gpu=True):
|
||||
self.assertAllEqual(
|
||||
math_ops.bincount([1, 1, 1, 2, 2, 3]).eval(), [0, 3, 2, 1])
|
||||
self.assertAllEqual(self.evaluate(math_ops.bincount([1, 1, 1, 2, 2, 3])),
|
||||
[0, 3, 2, 1])
|
||||
arr = [1, 1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5]
|
||||
self.assertAllEqual(math_ops.bincount(arr).eval(), [0, 5, 4, 3, 2, 1])
|
||||
self.assertAllEqual(self.evaluate(math_ops.bincount(arr)),
|
||||
[0, 5, 4, 3, 2, 1])
|
||||
arr += [0, 0, 0, 0, 0, 0]
|
||||
self.assertAllEqual(math_ops.bincount(arr).eval(), [6, 5, 4, 3, 2, 1])
|
||||
self.assertAllEqual(self.evaluate(math_ops.bincount(arr)),
|
||||
[6, 5, 4, 3, 2, 1])
|
||||
|
||||
self.assertAllEqual(math_ops.bincount([]).eval(), [])
|
||||
self.assertAllEqual(math_ops.bincount([0, 0, 0]).eval(), [3])
|
||||
self.assertAllEqual(math_ops.bincount([5]).eval(), [0, 0, 0, 0, 0, 1])
|
||||
self.assertAllEqual(
|
||||
math_ops.bincount(np.arange(10000)).eval(), np.ones(10000))
|
||||
self.assertAllEqual(self.evaluate(math_ops.bincount([])), [])
|
||||
self.assertAllEqual(self.evaluate(math_ops.bincount([0, 0, 0])), [3])
|
||||
self.assertAllEqual(self.evaluate(math_ops.bincount([5])),
|
||||
[0, 0, 0, 0, 0, 1])
|
||||
self.assertAllEqual(self.evaluate(math_ops.bincount(np.arange(10000))),
|
||||
np.ones(10000))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_maxlength(self):
|
||||
with self.session(use_gpu=True):
|
||||
self.assertAllEqual(math_ops.bincount([5], maxlength=3).eval(), [0, 0, 0])
|
||||
self.assertAllEqual(math_ops.bincount([1], maxlength=3).eval(), [0, 1])
|
||||
self.assertAllEqual(math_ops.bincount([], maxlength=3).eval(), [])
|
||||
self.assertAllEqual(self.evaluate(math_ops.bincount([5], maxlength=3)),
|
||||
[0, 0, 0])
|
||||
self.assertAllEqual(self.evaluate(math_ops.bincount([1], maxlength=3)),
|
||||
[0, 1])
|
||||
self.assertAllEqual(self.evaluate(math_ops.bincount([], maxlength=3)),
|
||||
[])
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_random_with_weights(self):
|
||||
num_samples = 10000
|
||||
with self.session(use_gpu=True):
|
||||
@ -79,9 +83,9 @@ class BincountTest(test_util.TensorFlowTestCase):
|
||||
else:
|
||||
weights = np.random.random(num_samples)
|
||||
self.assertAllClose(
|
||||
math_ops.bincount(arr, weights).eval(), np.bincount(arr, weights))
|
||||
self.evaluate(math_ops.bincount(arr, weights)),
|
||||
np.bincount(arr, weights))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_random_without_weights(self):
|
||||
num_samples = 10000
|
||||
with self.session(use_gpu=True):
|
||||
@ -90,20 +94,20 @@ class BincountTest(test_util.TensorFlowTestCase):
|
||||
arr = np.random.randint(0, 1000, num_samples)
|
||||
weights = np.ones(num_samples).astype(dtype)
|
||||
self.assertAllClose(
|
||||
math_ops.bincount(arr, None).eval(), np.bincount(arr, weights))
|
||||
self.evaluate(math_ops.bincount(arr, None)),
|
||||
np.bincount(arr, weights))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_zero_weights(self):
|
||||
with self.session(use_gpu=True):
|
||||
self.assertAllEqual(
|
||||
math_ops.bincount(np.arange(1000), np.zeros(1000)).eval(),
|
||||
self.evaluate(math_ops.bincount(np.arange(1000), np.zeros(1000))),
|
||||
np.zeros(1000))
|
||||
|
||||
def test_negative(self):
|
||||
# unsorted_segment_sum will only report InvalidArgumentError on CPU
|
||||
with self.cached_session():
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
math_ops.bincount([1, 2, 3, -1, 6, 8]).eval()
|
||||
self.evaluate(math_ops.bincount([1, 2, 3, -1, 6, 8]))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_shape_function(self):
|
||||
|
Loading…
Reference in New Issue
Block a user