Update test with self.evaluate(), as was suggested in review
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
parent
8a4733dab9
commit
ab33e4f733
@ -286,6 +286,9 @@ def _tf_dataset_len(s):
|
|||||||
msg = gen_string_ops.string_join(
|
msg = gen_string_ops.string_join(
|
||||||
["len requires dataset with definitive cardinality, got ",
|
["len requires dataset with definitive cardinality, got ",
|
||||||
gen_string_ops.as_string(l)])
|
gen_string_ops.as_string(l)])
|
||||||
|
# TODO (yongtang): UNKNOWN is treated as an error.
|
||||||
|
# In case there are more UNKNOWN cases for dataset, we could
|
||||||
|
# use dataset.reduce() to find out the length (in an expensive way).
|
||||||
with ops.control_dependencies([control_flow_ops.Assert(
|
with ops.control_dependencies([control_flow_ops.Assert(
|
||||||
math_ops.logical_and(
|
math_ops.logical_and(
|
||||||
math_ops.not_equal(l, cardinality.INFINITE),
|
math_ops.not_equal(l, cardinality.INFINITE),
|
||||||
|
@ -125,19 +125,19 @@ class PyBuiltinsTest(test.TestCase):
|
|||||||
|
|
||||||
def test_len_dataset(self):
|
def test_len_dataset(self):
|
||||||
dataset = dataset_ops.DatasetV2.from_tensor_slices([3, 2, 1])
|
dataset = dataset_ops.DatasetV2.from_tensor_slices([3, 2, 1])
|
||||||
self.assertEqual(py_builtins.len_(dataset), 3)
|
self.assertEqual(self.evaluate(py_builtins.len_(dataset)), 3)
|
||||||
|
|
||||||
# graph mode
|
# graph mode
|
||||||
@def_function.function(autograph=False)
|
@def_function.function(autograph=False)
|
||||||
def test_fn():
|
def test_fn():
|
||||||
dataset = dataset_ops.DatasetV2.from_tensor_slices([3, 2, 1])
|
dataset = dataset_ops.DatasetV2.from_tensor_slices([3, 2, 1])
|
||||||
return py_builtins.len_(dataset)
|
return py_builtins.len_(dataset)
|
||||||
self.assertEqual(test_fn(), 3)
|
self.assertEqual(self.evaluate(test_fn()), 3)
|
||||||
|
|
||||||
def test_len_dataset_infinite(self):
|
def test_len_dataset_infinite(self):
|
||||||
dataset = dataset_ops.DatasetV2.range(5).repeat().batch(2)
|
dataset = dataset_ops.DatasetV2.range(5).repeat().batch(2)
|
||||||
with self.assertRaises(errors_impl.InvalidArgumentError):
|
with self.assertRaises(errors_impl.InvalidArgumentError):
|
||||||
_ = py_builtins.len_(dataset)
|
_ = self.evaluate(py_builtins.len_(dataset))
|
||||||
|
|
||||||
# graph mode
|
# graph mode
|
||||||
@def_function.function
|
@def_function.function
|
||||||
@ -145,12 +145,12 @@ class PyBuiltinsTest(test.TestCase):
|
|||||||
dataset = dataset_ops.DatasetV2.range(5).repeat().batch(2)
|
dataset = dataset_ops.DatasetV2.range(5).repeat().batch(2)
|
||||||
return py_builtins.len_(dataset)
|
return py_builtins.len_(dataset)
|
||||||
with self.assertRaises(errors_impl.InvalidArgumentError):
|
with self.assertRaises(errors_impl.InvalidArgumentError):
|
||||||
test_fn()
|
self.evaluate(test_fn())
|
||||||
|
|
||||||
def test_len_dataset_unknown(self):
|
def test_len_dataset_unknown(self):
|
||||||
dataset = dataset_ops.DatasetV2.range(5).filter(lambda _: True).batch(2)
|
dataset = dataset_ops.DatasetV2.range(5).filter(lambda _: True).batch(2)
|
||||||
with self.assertRaises(errors_impl.InvalidArgumentError):
|
with self.assertRaises(errors_impl.InvalidArgumentError):
|
||||||
_ = py_builtins.len_(dataset)
|
_ = self.evaluate(py_builtins.len_(dataset))
|
||||||
|
|
||||||
# graph mode
|
# graph mode
|
||||||
@def_function.function(autograph=False)
|
@def_function.function(autograph=False)
|
||||||
@ -158,7 +158,7 @@ class PyBuiltinsTest(test.TestCase):
|
|||||||
dataset = dataset_ops.DatasetV2.range(5).filter(lambda _: True).batch(2)
|
dataset = dataset_ops.DatasetV2.range(5).filter(lambda _: True).batch(2)
|
||||||
return py_builtins.len_(dataset)
|
return py_builtins.len_(dataset)
|
||||||
with self.assertRaises(errors_impl.InvalidArgumentError):
|
with self.assertRaises(errors_impl.InvalidArgumentError):
|
||||||
test_fn()
|
self.evaluate(test_fn())
|
||||||
|
|
||||||
def test_len_scalar(self):
|
def test_len_scalar(self):
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
|
Loading…
Reference in New Issue
Block a user