Use eager-friendly evaluation for bincount_op_test.
PiperOrigin-RevId: 236690668
This commit is contained in:
parent
6163623e0e
commit
70438aaa2b
@ -30,44 +30,48 @@ from tensorflow.python.platform import googletest
|
|||||||
|
|
||||||
class BincountTest(test_util.TensorFlowTestCase):
|
class BincountTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def test_empty(self):
|
def test_empty(self):
|
||||||
with self.session(use_gpu=True):
|
with self.session(use_gpu=True):
|
||||||
self.assertAllEqual(
|
self.assertAllEqual(self.evaluate(math_ops.bincount([], minlength=5)),
|
||||||
math_ops.bincount([], minlength=5).eval(), [0, 0, 0, 0, 0])
|
[0, 0, 0, 0, 0])
|
||||||
self.assertAllEqual(math_ops.bincount([], minlength=1).eval(), [0])
|
self.assertAllEqual(self.evaluate(math_ops.bincount([], minlength=1)),
|
||||||
self.assertAllEqual(math_ops.bincount([], minlength=0).eval(), [])
|
[0])
|
||||||
self.assertEqual(
|
self.assertAllEqual(self.evaluate(math_ops.bincount([], minlength=0)),
|
||||||
math_ops.bincount([], minlength=0, dtype=np.float32).eval().dtype,
|
[])
|
||||||
|
self.assertEqual(self.evaluate(math_ops.bincount([], minlength=0,
|
||||||
|
dtype=np.float32)).dtype,
|
||||||
np.float32)
|
np.float32)
|
||||||
self.assertEqual(
|
self.assertEqual(self.evaluate(math_ops.bincount([], minlength=3,
|
||||||
math_ops.bincount([], minlength=3, dtype=np.float64).eval().dtype,
|
dtype=np.float64)).dtype,
|
||||||
np.float64)
|
np.float64)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def test_values(self):
|
def test_values(self):
|
||||||
with self.session(use_gpu=True):
|
with self.session(use_gpu=True):
|
||||||
self.assertAllEqual(
|
self.assertAllEqual(self.evaluate(math_ops.bincount([1, 1, 1, 2, 2, 3])),
|
||||||
math_ops.bincount([1, 1, 1, 2, 2, 3]).eval(), [0, 3, 2, 1])
|
[0, 3, 2, 1])
|
||||||
arr = [1, 1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5]
|
arr = [1, 1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5]
|
||||||
self.assertAllEqual(math_ops.bincount(arr).eval(), [0, 5, 4, 3, 2, 1])
|
self.assertAllEqual(self.evaluate(math_ops.bincount(arr)),
|
||||||
|
[0, 5, 4, 3, 2, 1])
|
||||||
arr += [0, 0, 0, 0, 0, 0]
|
arr += [0, 0, 0, 0, 0, 0]
|
||||||
self.assertAllEqual(math_ops.bincount(arr).eval(), [6, 5, 4, 3, 2, 1])
|
self.assertAllEqual(self.evaluate(math_ops.bincount(arr)),
|
||||||
|
[6, 5, 4, 3, 2, 1])
|
||||||
|
|
||||||
self.assertAllEqual(math_ops.bincount([]).eval(), [])
|
self.assertAllEqual(self.evaluate(math_ops.bincount([])), [])
|
||||||
self.assertAllEqual(math_ops.bincount([0, 0, 0]).eval(), [3])
|
self.assertAllEqual(self.evaluate(math_ops.bincount([0, 0, 0])), [3])
|
||||||
self.assertAllEqual(math_ops.bincount([5]).eval(), [0, 0, 0, 0, 0, 1])
|
self.assertAllEqual(self.evaluate(math_ops.bincount([5])),
|
||||||
self.assertAllEqual(
|
[0, 0, 0, 0, 0, 1])
|
||||||
math_ops.bincount(np.arange(10000)).eval(), np.ones(10000))
|
self.assertAllEqual(self.evaluate(math_ops.bincount(np.arange(10000))),
|
||||||
|
np.ones(10000))
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def test_maxlength(self):
|
def test_maxlength(self):
|
||||||
with self.session(use_gpu=True):
|
with self.session(use_gpu=True):
|
||||||
self.assertAllEqual(math_ops.bincount([5], maxlength=3).eval(), [0, 0, 0])
|
self.assertAllEqual(self.evaluate(math_ops.bincount([5], maxlength=3)),
|
||||||
self.assertAllEqual(math_ops.bincount([1], maxlength=3).eval(), [0, 1])
|
[0, 0, 0])
|
||||||
self.assertAllEqual(math_ops.bincount([], maxlength=3).eval(), [])
|
self.assertAllEqual(self.evaluate(math_ops.bincount([1], maxlength=3)),
|
||||||
|
[0, 1])
|
||||||
|
self.assertAllEqual(self.evaluate(math_ops.bincount([], maxlength=3)),
|
||||||
|
[])
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def test_random_with_weights(self):
|
def test_random_with_weights(self):
|
||||||
num_samples = 10000
|
num_samples = 10000
|
||||||
with self.session(use_gpu=True):
|
with self.session(use_gpu=True):
|
||||||
@ -79,9 +83,9 @@ class BincountTest(test_util.TensorFlowTestCase):
|
|||||||
else:
|
else:
|
||||||
weights = np.random.random(num_samples)
|
weights = np.random.random(num_samples)
|
||||||
self.assertAllClose(
|
self.assertAllClose(
|
||||||
math_ops.bincount(arr, weights).eval(), np.bincount(arr, weights))
|
self.evaluate(math_ops.bincount(arr, weights)),
|
||||||
|
np.bincount(arr, weights))
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def test_random_without_weights(self):
|
def test_random_without_weights(self):
|
||||||
num_samples = 10000
|
num_samples = 10000
|
||||||
with self.session(use_gpu=True):
|
with self.session(use_gpu=True):
|
||||||
@ -90,20 +94,20 @@ class BincountTest(test_util.TensorFlowTestCase):
|
|||||||
arr = np.random.randint(0, 1000, num_samples)
|
arr = np.random.randint(0, 1000, num_samples)
|
||||||
weights = np.ones(num_samples).astype(dtype)
|
weights = np.ones(num_samples).astype(dtype)
|
||||||
self.assertAllClose(
|
self.assertAllClose(
|
||||||
math_ops.bincount(arr, None).eval(), np.bincount(arr, weights))
|
self.evaluate(math_ops.bincount(arr, None)),
|
||||||
|
np.bincount(arr, weights))
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def test_zero_weights(self):
|
def test_zero_weights(self):
|
||||||
with self.session(use_gpu=True):
|
with self.session(use_gpu=True):
|
||||||
self.assertAllEqual(
|
self.assertAllEqual(
|
||||||
math_ops.bincount(np.arange(1000), np.zeros(1000)).eval(),
|
self.evaluate(math_ops.bincount(np.arange(1000), np.zeros(1000))),
|
||||||
np.zeros(1000))
|
np.zeros(1000))
|
||||||
|
|
||||||
def test_negative(self):
|
def test_negative(self):
|
||||||
# unsorted_segment_sum will only report InvalidArgumentError on CPU
|
# unsorted_segment_sum will only report InvalidArgumentError on CPU
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
with self.assertRaises(errors.InvalidArgumentError):
|
with self.assertRaises(errors.InvalidArgumentError):
|
||||||
math_ops.bincount([1, 2, 3, -1, 6, 8]).eval()
|
self.evaluate(math_ops.bincount([1, 2, 3, -1, 6, 8]))
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def test_shape_function(self):
|
def test_shape_function(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user