Merge pull request from yongtang:audograph-len-dataset

PiperOrigin-RevId: 308067661
Change-Id: I8fc645d3099e67d287f37308cf298bb4e28d9507
This commit is contained in:
TensorFlower Gardener 2020-04-23 09:29:05 -07:00
commit a4f77b5adc
2 changed files with 63 additions and 0 deletions
tensorflow/python/autograph/operators

View File

@ -29,6 +29,7 @@ import six
from tensorflow.python.autograph.utils import py_func
from tensorflow.python.autograph.utils import tensors
from tensorflow.python.data.experimental.ops import cardinality
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import constant_op
@ -234,6 +235,8 @@ def len_(s):
return _tf_tensor_list_len(s)
elif tensor_util.is_tensor(s):
return _tf_tensor_len(s)
if isinstance(s, dataset_ops.DatasetV2):
return _tf_dataset_len(s)
return _py_len(s)
@ -278,6 +281,26 @@ def _tf_tensor_len(s):
raise_zero_rank_error)
def _tf_dataset_len(s):
l = cardinality.cardinality(s)
msg = gen_string_ops.string_join([
'len requires dataset with definitive cardinality, got ',
gen_string_ops.as_string(l)
])
# TODO (yongtang): UNKNOWN is treated as an error.
# In case there are more UNKNOWN cases for dataset, we could
# use dataset.reduce() to find out the length (in an expensive way).
with ops.control_dependencies([
control_flow_ops.Assert(
math_ops.logical_and(
math_ops.not_equal(l, cardinality.INFINITE),
math_ops.not_equal(l, cardinality.UNKNOWN)), [msg])
]):
l = array_ops.identity(l)
return l
def _py_len(s):
return len(s)

View File

@ -123,6 +123,46 @@ class PyBuiltinsTest(test.TestCase):
tl = py_builtins.len_(data_structures.tf_tensor_list_new([3, 4, 5]))
self.assertEqual(self.evaluate(tl), 3)
def test_len_dataset(self):
dataset = dataset_ops.DatasetV2.from_tensor_slices([3, 2, 1])
self.assertEqual(self.evaluate(py_builtins.len_(dataset)), 3)
# graph mode
@def_function.function(autograph=False)
def test_fn():
dataset = dataset_ops.DatasetV2.from_tensor_slices([3, 2, 1])
return py_builtins.len_(dataset)
self.assertEqual(self.evaluate(test_fn()), 3)
def test_len_dataset_infinite(self):
dataset = dataset_ops.DatasetV2.range(5).repeat().batch(2)
with self.assertRaises(errors_impl.InvalidArgumentError):
_ = self.evaluate(py_builtins.len_(dataset))
# graph mode
@def_function.function
def test_fn():
dataset = dataset_ops.DatasetV2.range(5).repeat().batch(2)
return py_builtins.len_(dataset)
with self.assertRaises(errors_impl.InvalidArgumentError):
self.evaluate(test_fn())
def test_len_dataset_unknown(self):
dataset = dataset_ops.DatasetV2.range(5).filter(lambda _: True).batch(2)
with self.assertRaises(errors_impl.InvalidArgumentError):
_ = self.evaluate(py_builtins.len_(dataset))
# graph mode
@def_function.function(autograph=False)
def test_fn():
dataset = dataset_ops.DatasetV2.range(5).filter(lambda _: True).batch(2)
return py_builtins.len_(dataset)
with self.assertRaises(errors_impl.InvalidArgumentError):
self.evaluate(test_fn())
def test_len_scalar(self):
with self.assertRaises(ValueError):
py_builtins.len_(constant_op.constant(1))