Update test with self.evaluate(), as was suggested in review

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
Yong Tang 2020-04-21 23:17:39 +00:00
parent 8a4733dab9
commit ab33e4f733
2 changed files with 9 additions and 6 deletions

View File

@ -286,6 +286,9 @@ def _tf_dataset_len(s):
msg = gen_string_ops.string_join(
["len requires dataset with definitive cardinality, got ",
gen_string_ops.as_string(l)])
# TODO (yongtang): UNKNOWN is treated as an error.
# In case there are more UNKNOWN cases for dataset, we could
# use dataset.reduce() to find out the length (in an expensive way).
with ops.control_dependencies([control_flow_ops.Assert(
math_ops.logical_and(
math_ops.not_equal(l, cardinality.INFINITE),

View File

@ -125,19 +125,19 @@ class PyBuiltinsTest(test.TestCase):
def test_len_dataset(self):
dataset = dataset_ops.DatasetV2.from_tensor_slices([3, 2, 1])
self.assertEqual(py_builtins.len_(dataset), 3)
self.assertEqual(self.evaluate(py_builtins.len_(dataset)), 3)
# graph mode
@def_function.function(autograph=False)
def test_fn():
dataset = dataset_ops.DatasetV2.from_tensor_slices([3, 2, 1])
return py_builtins.len_(dataset)
self.assertEqual(test_fn(), 3)
self.assertEqual(self.evaluate(test_fn()), 3)
def test_len_dataset_infinite(self):
dataset = dataset_ops.DatasetV2.range(5).repeat().batch(2)
with self.assertRaises(errors_impl.InvalidArgumentError):
_ = py_builtins.len_(dataset)
_ = self.evaluate(py_builtins.len_(dataset))
# graph mode
@def_function.function
@ -145,12 +145,12 @@ class PyBuiltinsTest(test.TestCase):
dataset = dataset_ops.DatasetV2.range(5).repeat().batch(2)
return py_builtins.len_(dataset)
with self.assertRaises(errors_impl.InvalidArgumentError):
test_fn()
self.evaluate(test_fn())
def test_len_dataset_unknown(self):
dataset = dataset_ops.DatasetV2.range(5).filter(lambda _: True).batch(2)
with self.assertRaises(errors_impl.InvalidArgumentError):
_ = py_builtins.len_(dataset)
_ = self.evaluate(py_builtins.len_(dataset))
# graph mode
@def_function.function(autograph=False)
@ -158,7 +158,7 @@ class PyBuiltinsTest(test.TestCase):
dataset = dataset_ops.DatasetV2.range(5).filter(lambda _: True).batch(2)
return py_builtins.len_(dataset)
with self.assertRaises(errors_impl.InvalidArgumentError):
test_fn()
self.evaluate(test_fn())
def test_len_scalar(self):
with self.assertRaises(ValueError):