Update test with self.evaluate(), as was suggested in review

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
Yong Tang 2020-04-21 23:17:39 +00:00
parent 8a4733dab9
commit ab33e4f733
2 changed files with 9 additions and 6 deletions

View File

@ -286,6 +286,9 @@ def _tf_dataset_len(s):
msg = gen_string_ops.string_join( msg = gen_string_ops.string_join(
["len requires dataset with definitive cardinality, got ", ["len requires dataset with definitive cardinality, got ",
gen_string_ops.as_string(l)]) gen_string_ops.as_string(l)])
# TODO (yongtang): UNKNOWN is treated as an error.
# In case there are more UNKNOWN cases for dataset, we could
# use dataset.reduce() to find out the length (in an expensive way).
with ops.control_dependencies([control_flow_ops.Assert( with ops.control_dependencies([control_flow_ops.Assert(
math_ops.logical_and( math_ops.logical_and(
math_ops.not_equal(l, cardinality.INFINITE), math_ops.not_equal(l, cardinality.INFINITE),

View File

@ -125,19 +125,19 @@ class PyBuiltinsTest(test.TestCase):
def test_len_dataset(self): def test_len_dataset(self):
dataset = dataset_ops.DatasetV2.from_tensor_slices([3, 2, 1]) dataset = dataset_ops.DatasetV2.from_tensor_slices([3, 2, 1])
self.assertEqual(py_builtins.len_(dataset), 3) self.assertEqual(self.evaluate(py_builtins.len_(dataset)), 3)
# graph mode # graph mode
@def_function.function(autograph=False) @def_function.function(autograph=False)
def test_fn(): def test_fn():
dataset = dataset_ops.DatasetV2.from_tensor_slices([3, 2, 1]) dataset = dataset_ops.DatasetV2.from_tensor_slices([3, 2, 1])
return py_builtins.len_(dataset) return py_builtins.len_(dataset)
self.assertEqual(test_fn(), 3) self.assertEqual(self.evaluate(test_fn()), 3)
def test_len_dataset_infinite(self): def test_len_dataset_infinite(self):
dataset = dataset_ops.DatasetV2.range(5).repeat().batch(2) dataset = dataset_ops.DatasetV2.range(5).repeat().batch(2)
with self.assertRaises(errors_impl.InvalidArgumentError): with self.assertRaises(errors_impl.InvalidArgumentError):
_ = py_builtins.len_(dataset) _ = self.evaluate(py_builtins.len_(dataset))
# graph mode # graph mode
@def_function.function @def_function.function
@ -145,12 +145,12 @@ class PyBuiltinsTest(test.TestCase):
dataset = dataset_ops.DatasetV2.range(5).repeat().batch(2) dataset = dataset_ops.DatasetV2.range(5).repeat().batch(2)
return py_builtins.len_(dataset) return py_builtins.len_(dataset)
with self.assertRaises(errors_impl.InvalidArgumentError): with self.assertRaises(errors_impl.InvalidArgumentError):
test_fn() self.evaluate(test_fn())
def test_len_dataset_unknown(self): def test_len_dataset_unknown(self):
dataset = dataset_ops.DatasetV2.range(5).filter(lambda _: True).batch(2) dataset = dataset_ops.DatasetV2.range(5).filter(lambda _: True).batch(2)
with self.assertRaises(errors_impl.InvalidArgumentError): with self.assertRaises(errors_impl.InvalidArgumentError):
_ = py_builtins.len_(dataset) _ = self.evaluate(py_builtins.len_(dataset))
# graph mode # graph mode
@def_function.function(autograph=False) @def_function.function(autograph=False)
@ -158,7 +158,7 @@ class PyBuiltinsTest(test.TestCase):
dataset = dataset_ops.DatasetV2.range(5).filter(lambda _: True).batch(2) dataset = dataset_ops.DatasetV2.range(5).filter(lambda _: True).batch(2)
return py_builtins.len_(dataset) return py_builtins.len_(dataset)
with self.assertRaises(errors_impl.InvalidArgumentError): with self.assertRaises(errors_impl.InvalidArgumentError):
test_fn() self.evaluate(test_fn())
def test_len_scalar(self): def test_len_scalar(self):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):