Merge pull request #38716 from yongtang:audograph-len-dataset
PiperOrigin-RevId: 308067661 Change-Id: I8fc645d3099e67d287f37308cf298bb4e28d9507
This commit is contained in:
commit
a4f77b5adc
tensorflow/python/autograph/operators
@ -29,6 +29,7 @@ import six
|
||||
|
||||
from tensorflow.python.autograph.utils import py_func
|
||||
from tensorflow.python.autograph.utils import tensors
|
||||
from tensorflow.python.data.experimental.ops import cardinality
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import iterator_ops
|
||||
from tensorflow.python.framework import constant_op
|
||||
@ -234,6 +235,8 @@ def len_(s):
|
||||
return _tf_tensor_list_len(s)
|
||||
elif tensor_util.is_tensor(s):
|
||||
return _tf_tensor_len(s)
|
||||
if isinstance(s, dataset_ops.DatasetV2):
|
||||
return _tf_dataset_len(s)
|
||||
return _py_len(s)
|
||||
|
||||
|
||||
@ -278,6 +281,26 @@ def _tf_tensor_len(s):
|
||||
raise_zero_rank_error)
|
||||
|
||||
|
||||
def _tf_dataset_len(s):
|
||||
l = cardinality.cardinality(s)
|
||||
msg = gen_string_ops.string_join([
|
||||
'len requires dataset with definitive cardinality, got ',
|
||||
gen_string_ops.as_string(l)
|
||||
])
|
||||
# TODO (yongtang): UNKNOWN is treated as an error.
|
||||
# In case there are more UNKNOWN cases for dataset, we could
|
||||
# use dataset.reduce() to find out the length (in an expensive way).
|
||||
with ops.control_dependencies([
|
||||
control_flow_ops.Assert(
|
||||
math_ops.logical_and(
|
||||
math_ops.not_equal(l, cardinality.INFINITE),
|
||||
math_ops.not_equal(l, cardinality.UNKNOWN)), [msg])
|
||||
]):
|
||||
l = array_ops.identity(l)
|
||||
|
||||
return l
|
||||
|
||||
|
||||
def _py_len(s):
|
||||
return len(s)
|
||||
|
||||
|
@ -123,6 +123,46 @@ class PyBuiltinsTest(test.TestCase):
|
||||
tl = py_builtins.len_(data_structures.tf_tensor_list_new([3, 4, 5]))
|
||||
self.assertEqual(self.evaluate(tl), 3)
|
||||
|
||||
def test_len_dataset(self):
|
||||
dataset = dataset_ops.DatasetV2.from_tensor_slices([3, 2, 1])
|
||||
self.assertEqual(self.evaluate(py_builtins.len_(dataset)), 3)
|
||||
|
||||
# graph mode
|
||||
@def_function.function(autograph=False)
|
||||
def test_fn():
|
||||
dataset = dataset_ops.DatasetV2.from_tensor_slices([3, 2, 1])
|
||||
return py_builtins.len_(dataset)
|
||||
|
||||
self.assertEqual(self.evaluate(test_fn()), 3)
|
||||
|
||||
def test_len_dataset_infinite(self):
|
||||
dataset = dataset_ops.DatasetV2.range(5).repeat().batch(2)
|
||||
with self.assertRaises(errors_impl.InvalidArgumentError):
|
||||
_ = self.evaluate(py_builtins.len_(dataset))
|
||||
|
||||
# graph mode
|
||||
@def_function.function
|
||||
def test_fn():
|
||||
dataset = dataset_ops.DatasetV2.range(5).repeat().batch(2)
|
||||
return py_builtins.len_(dataset)
|
||||
|
||||
with self.assertRaises(errors_impl.InvalidArgumentError):
|
||||
self.evaluate(test_fn())
|
||||
|
||||
def test_len_dataset_unknown(self):
|
||||
dataset = dataset_ops.DatasetV2.range(5).filter(lambda _: True).batch(2)
|
||||
with self.assertRaises(errors_impl.InvalidArgumentError):
|
||||
_ = self.evaluate(py_builtins.len_(dataset))
|
||||
|
||||
# graph mode
|
||||
@def_function.function(autograph=False)
|
||||
def test_fn():
|
||||
dataset = dataset_ops.DatasetV2.range(5).filter(lambda _: True).batch(2)
|
||||
return py_builtins.len_(dataset)
|
||||
|
||||
with self.assertRaises(errors_impl.InvalidArgumentError):
|
||||
self.evaluate(test_fn())
|
||||
|
||||
def test_len_scalar(self):
|
||||
with self.assertRaises(ValueError):
|
||||
py_builtins.len_(constant_op.constant(1))
|
||||
|
Loading…
Reference in New Issue
Block a user