diff --git a/tensorflow/python/autograph/operators/py_builtins.py b/tensorflow/python/autograph/operators/py_builtins.py index 439d7d0c285..152a69013cd 100644 --- a/tensorflow/python/autograph/operators/py_builtins.py +++ b/tensorflow/python/autograph/operators/py_builtins.py @@ -286,6 +286,9 @@ def _tf_dataset_len(s): msg = gen_string_ops.string_join( ["len requires dataset with definitive cardinality, got ", gen_string_ops.as_string(l)]) + # TODO (yongtang): UNKNOWN is treated as an error. + # In case there are more UNKNOWN cases for dataset, we could + # use dataset.reduce() to find out the length (in an expensive way). with ops.control_dependencies([control_flow_ops.Assert( math_ops.logical_and( math_ops.not_equal(l, cardinality.INFINITE), diff --git a/tensorflow/python/autograph/operators/py_builtins_test.py b/tensorflow/python/autograph/operators/py_builtins_test.py index bfa185b52a2..83b618809db 100644 --- a/tensorflow/python/autograph/operators/py_builtins_test.py +++ b/tensorflow/python/autograph/operators/py_builtins_test.py @@ -125,19 +125,19 @@ class PyBuiltinsTest(test.TestCase): def test_len_dataset(self): dataset = dataset_ops.DatasetV2.from_tensor_slices([3, 2, 1]) - self.assertEqual(py_builtins.len_(dataset), 3) + self.assertEqual(self.evaluate(py_builtins.len_(dataset)), 3) # graph mode @def_function.function(autograph=False) def test_fn(): dataset = dataset_ops.DatasetV2.from_tensor_slices([3, 2, 1]) return py_builtins.len_(dataset) - self.assertEqual(test_fn(), 3) + self.assertEqual(self.evaluate(test_fn()), 3) def test_len_dataset_infinite(self): dataset = dataset_ops.DatasetV2.range(5).repeat().batch(2) with self.assertRaises(errors_impl.InvalidArgumentError): - _ = py_builtins.len_(dataset) + _ = self.evaluate(py_builtins.len_(dataset)) # graph mode @def_function.function @@ -145,12 +145,12 @@ class PyBuiltinsTest(test.TestCase): dataset = dataset_ops.DatasetV2.range(5).repeat().batch(2) return py_builtins.len_(dataset) with self.assertRaises(errors_impl.InvalidArgumentError): - test_fn() + self.evaluate(test_fn()) def test_len_dataset_unknown(self): dataset = dataset_ops.DatasetV2.range(5).filter(lambda _: True).batch(2) with self.assertRaises(errors_impl.InvalidArgumentError): - _ = py_builtins.len_(dataset) + _ = self.evaluate(py_builtins.len_(dataset)) # graph mode @def_function.function(autograph=False) @@ -158,7 +158,7 @@ class PyBuiltinsTest(test.TestCase): dataset = dataset_ops.DatasetV2.range(5).filter(lambda _: True).batch(2) return py_builtins.len_(dataset) with self.assertRaises(errors_impl.InvalidArgumentError): - test_fn() + self.evaluate(test_fn()) def test_len_scalar(self): with self.assertRaises(ValueError):