[tf.data] Move cardinality to core API.

To migrate, replace `tf.data.experimental.cardinality(dataset)` with `dataset.cardinality()`.

PiperOrigin-RevId: 314246452
Change-Id: I6839b1519adda790b91b5197ec30bac339ca5149
This commit is contained in:
Andrew Audibert 2020-06-01 18:48:02 -07:00 committed by TensorFlower Gardener
parent 4061712bfe
commit 9aca1f01b9
22 changed files with 396 additions and 0 deletions

View File

@ -24,8 +24,10 @@ from __future__ import print_function
# pylint: disable=unused-import
from tensorflow.python.data import experimental
from tensorflow.python.data.ops.dataset_ops import Dataset
from tensorflow.python.data.ops.dataset_ops import INFINITE as INFINITE_CARDINALITY
from tensorflow.python.data.ops.dataset_ops import make_initializable_iterator
from tensorflow.python.data.ops.dataset_ops import make_one_shot_iterator
from tensorflow.python.data.ops.dataset_ops import UNKNOWN as UNKNOWN_CARDINALITY
from tensorflow.python.data.ops.iterator_ops import Iterator
from tensorflow.python.data.ops.readers import FixedLengthRecordDataset
from tensorflow.python.data.ops.readers import TextLineDataset

View File

@ -32,6 +32,7 @@ tf_export("data.experimental.UNKNOWN_CARDINALITY").export_constant(
__name__, "UNKNOWN")
# TODO(b/157691652): Deprecate this method after migrating users to the new API.
@tf_export("data.experimental.cardinality")
def cardinality(dataset):
"""Returns the cardinality of `dataset`, if known.

View File

@ -46,6 +46,17 @@ tf_py_test(
],
)
tf_py_test(
name = "cardinality_test",
srcs = ["cardinality_test.py"],
deps = [
"//tensorflow/python/data/experimental/ops:cardinality",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"@absl_py//absl/testing:parameterized",
],
)
tf_py_test(
name = "checkpoint_test",
size = "medium",
@ -385,6 +396,19 @@ cuda_py_test(
],
)
tf_py_test(
name = "len_test",
size = "small",
srcs = ["len_test.py"],
deps = [
":test_base",
"//tensorflow/python:constant_op",
"//tensorflow/python:framework_combinations",
"//tensorflow/python/data/ops:dataset_ops",
"@absl_py//absl/testing:parameterized",
],
)
tf_py_test(
name = "list_files_test",
size = "small",

View File

@ -0,0 +1,174 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for `tf.data.Dataset.cardinality()`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
from absl.testing import parameterized
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.platform import test
def _test_combinations():
# pylint: disable=g-long-lambda
cases = [
("Batch1",
lambda: dataset_ops.Dataset.range(5).batch(2, drop_remainder=True), 2),
("Batch2",
lambda: dataset_ops.Dataset.range(5).batch(2, drop_remainder=False), 3),
("Batch3",
lambda: dataset_ops.Dataset.range(5).filter(lambda _: True).batch(2),
dataset_ops.UNKNOWN),
("Batch4", lambda: dataset_ops.Dataset.range(5).repeat().batch(2),
dataset_ops.INFINITE),
("Cache1", lambda: dataset_ops.Dataset.range(5).cache(), 5),
("Cache2", lambda: dataset_ops.Dataset.range(5).cache("foo"), 5),
("Concatenate1", lambda: dataset_ops.Dataset.range(5).concatenate(
dataset_ops.Dataset.range(5)), 10),
("Concatenate2",
lambda: dataset_ops.Dataset.range(5).filter(lambda _: True).concatenate(
dataset_ops.Dataset.range(5)), dataset_ops.UNKNOWN),
("Concatenate3", lambda: dataset_ops.Dataset.range(5).repeat().
concatenate(dataset_ops.Dataset.range(5)), dataset_ops.INFINITE),
("Concatenate4", lambda: dataset_ops.Dataset.range(5).concatenate(
dataset_ops.Dataset.range(5).filter(lambda _: True)),
dataset_ops.UNKNOWN),
("Concatenate5",
lambda: dataset_ops.Dataset.range(5).filter(lambda _: True).concatenate(
dataset_ops.Dataset.range(5).filter(lambda _: True)),
dataset_ops.UNKNOWN),
("Concatenate6", lambda: dataset_ops.Dataset.range(5).repeat().
concatenate(dataset_ops.Dataset.range(5).filter(lambda _: True)),
dataset_ops.INFINITE),
("Concatenate7", lambda: dataset_ops.Dataset.range(5).concatenate(
dataset_ops.Dataset.range(5).repeat()), dataset_ops.INFINITE),
("Concatenate8",
lambda: dataset_ops.Dataset.range(5).filter(lambda _: True).concatenate(
dataset_ops.Dataset.range(5).repeat()), dataset_ops.INFINITE),
("Concatenate9",
lambda: dataset_ops.Dataset.range(5).repeat().concatenate(
dataset_ops.Dataset.range(5).repeat()), dataset_ops.INFINITE),
("FlatMap", lambda: dataset_ops.Dataset.range(5).flat_map(
lambda _: dataset_ops.Dataset.from_tensors(0)), dataset_ops.UNKNOWN),
("Filter", lambda: dataset_ops.Dataset.range(5).filter(lambda _: True),
dataset_ops.UNKNOWN),
("FromTensors1", lambda: dataset_ops.Dataset.from_tensors(0), 1),
("FromTensors2", lambda: dataset_ops.Dataset.from_tensors((0, 1)), 1),
("FromTensorSlices1",
lambda: dataset_ops.Dataset.from_tensor_slices([0, 0, 0]), 3),
("FromTensorSlices2", lambda: dataset_ops.Dataset.from_tensor_slices(
([0, 0, 0], [1, 1, 1])), 3),
("Interleave1", lambda: dataset_ops.Dataset.range(5).interleave(
lambda _: dataset_ops.Dataset.from_tensors(0), cycle_length=1),
dataset_ops.UNKNOWN),
("Interleave2", lambda: dataset_ops.Dataset.range(5).interleave(
lambda _: dataset_ops.Dataset.from_tensors(0),
cycle_length=1,
num_parallel_calls=1), dataset_ops.UNKNOWN),
("Map1", lambda: dataset_ops.Dataset.range(5).map(lambda x: x), 5),
("Map2", lambda: dataset_ops.Dataset.range(5).map(
lambda x: x, num_parallel_calls=1), 5),
("PaddedBatch1", lambda: dataset_ops.Dataset.range(5).padded_batch(
2, [], drop_remainder=True), 2),
("PaddedBatch2", lambda: dataset_ops.Dataset.range(5).padded_batch(
2, [], drop_remainder=False), 3),
("PaddedBatch3", lambda: dataset_ops.Dataset.range(5).filter(
lambda _: True).padded_batch(2, []), dataset_ops.UNKNOWN),
("PaddedBatch4",
lambda: dataset_ops.Dataset.range(5).repeat().padded_batch(2, []),
dataset_ops.INFINITE),
("Prefetch", lambda: dataset_ops.Dataset.range(5).prefetch(buffer_size=1),
5),
("Range1", lambda: dataset_ops.Dataset.range(0), 0),
("Range2", lambda: dataset_ops.Dataset.range(5), 5),
("Range3", lambda: dataset_ops.Dataset.range(5, 10), 5),
("Range4", lambda: dataset_ops.Dataset.range(10, 5), 0),
("Range5", lambda: dataset_ops.Dataset.range(5, 10, 2), 3),
("Range6", lambda: dataset_ops.Dataset.range(10, 5, -2), 3),
("Repeat1", lambda: dataset_ops.Dataset.range(0).repeat(0), 0),
("Repeat2", lambda: dataset_ops.Dataset.range(1).repeat(0), 0),
("Repeat3", lambda: dataset_ops.Dataset.range(0).repeat(5), 0),
("Repeat4", lambda: dataset_ops.Dataset.range(1).repeat(5), 5),
("Repeat5", lambda: dataset_ops.Dataset.range(0).repeat(), 0),
("Repeat6", lambda: dataset_ops.Dataset.range(1).repeat(),
dataset_ops.INFINITE),
("Shuffle", lambda: dataset_ops.Dataset.range(5).shuffle(buffer_size=1),
5),
("Shard1", lambda: dataset_ops.Dataset.range(5).shard(2, 0), 3),
("Shard2", lambda: dataset_ops.Dataset.range(5).shard(8, 7), 0),
("Shard3",
lambda: dataset_ops.Dataset.range(5).filter(lambda _: True).shard(2, 0),
dataset_ops.UNKNOWN),
("Shard4", lambda: dataset_ops.Dataset.range(5).repeat().shard(2, 0),
dataset_ops.INFINITE),
("Skip1", lambda: dataset_ops.Dataset.range(5).skip(2), 3),
("Skip2", lambda: dataset_ops.Dataset.range(5).skip(8), 0),
("Skip3",
lambda: dataset_ops.Dataset.range(5).filter(lambda _: True).skip(2),
dataset_ops.UNKNOWN),
("Skip4", lambda: dataset_ops.Dataset.range(5).repeat().skip(2),
dataset_ops.INFINITE),
("Take1", lambda: dataset_ops.Dataset.range(5).take(2), 2),
("Take2", lambda: dataset_ops.Dataset.range(5).take(8), 5),
("Take3",
lambda: dataset_ops.Dataset.range(5).filter(lambda _: True).take(2),
dataset_ops.UNKNOWN),
("Take4", lambda: dataset_ops.Dataset.range(5).repeat().take(2), 2),
("Window1", lambda: dataset_ops.Dataset.range(5).window(
size=2, shift=2, drop_remainder=True), 2),
("Window2", lambda: dataset_ops.Dataset.range(5).window(
size=2, shift=2, drop_remainder=False), 3),
("Zip1", lambda: dataset_ops.Dataset.zip(dataset_ops.Dataset.range(5)),
5),
("Zip2", lambda: dataset_ops.Dataset.zip(
(dataset_ops.Dataset.range(5), dataset_ops.Dataset.range(3))), 3),
("Zip3", lambda: dataset_ops.Dataset.zip((dataset_ops.Dataset.range(
5), dataset_ops.Dataset.range(3).repeat())), 5),
("Zip4", lambda: dataset_ops.Dataset.zip(
(dataset_ops.Dataset.range(5).repeat(), dataset_ops.Dataset.range(3).
repeat())), dataset_ops.INFINITE),
("Zip5", lambda: dataset_ops.Dataset.zip(
(dataset_ops.Dataset.range(5), dataset_ops.Dataset.range(3).filter(
lambda _: True))), dataset_ops.UNKNOWN),
]
def reduce_fn(x, y):
name, dataset_fn, expected_result = y
return x + combinations.combine(
dataset_fn=combinations.NamedObject(name, dataset_fn),
expected_result=expected_result)
return functools.reduce(reduce_fn, cases, [])
class CardinalityTest(test_base.DatasetTestBase, parameterized.TestCase):
"""Tests for `tf.data.Dataset.cardinality()`."""
@combinations.generate(
combinations.times(test_base.default_test_combinations(),
_test_combinations()))
def testCardinality(self, dataset_fn, expected_result):
dataset = dataset_fn()
self.assertEqual(self.evaluate(dataset.cardinality()), expected_result)
if __name__ == "__main__":
test.main()

View File

@ -0,0 +1,59 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for `tf.data.Dataset.__len__()`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.platform import test
class LenTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.eager_only_combinations())
def testKnown(self):
num_elements = 10
ds = dataset_ops.Dataset.range(num_elements)
self.assertLen(ds, 10)
@combinations.generate(test_base.eager_only_combinations())
def testInfinite(self):
num_elements = 10
ds = dataset_ops.Dataset.range(num_elements).repeat()
with self.assertRaisesRegex(TypeError, 'infinite'):
len(ds)
@combinations.generate(test_base.eager_only_combinations())
def testUnknown(self):
num_elements = 10
ds = dataset_ops.Dataset.range(num_elements).filter(lambda x: True)
with self.assertRaisesRegex(TypeError, 'unknown'):
len(ds)
@combinations.generate(test_base.graph_only_combinations())
def testGraphMode(self):
num_elements = 10
ds = dataset_ops.Dataset.range(num_elements)
with self.assertRaisesRegex(TypeError, 'not supported while tracing'):
len(ds)
if __name__ == '__main__':
test.main()

View File

@ -94,6 +94,12 @@ ops.NotDifferentiable("ReduceDataset")
AUTOTUNE = -1
tf_export("data.experimental.AUTOTUNE").export_constant(__name__, "AUTOTUNE")
# Constants representing infinite and unknown cardinalities.
INFINITE = -1
UNKNOWN = -2
tf_export("data.INFINITE_CARDINALITY").export_constant(__name__, "INFINITE")
tf_export("data.UNKNOWN_CARDINALITY").export_constant(__name__, "UNKNOWN")
@tf_export("data.Dataset", v1=[])
@six.add_metaclass(abc.ABCMeta)
@ -410,6 +416,36 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
raise RuntimeError("__iter__() is only supported inside of tf.function "
"or when eager execution is enabled.")
def __bool__(self):
return True # Required as __len__ is defined
__nonzero__ = __bool__ # Python 2 backward compatibility
def __len__(self):
"""Returns the length of the dataset if it is known and finite.
This method requires that you are running in eager mode, and that the
length of the dataset is known and non-infinite. When the length may be
unknown or infinite, or if you are running in graph mode, use
`tf.data.Dataset.cardinality` instead.
Returns:
An integer representing the length of the dataset.
Raises:
RuntimeError: If the dataset length is unknown or infinite, or if eager
execution is not enabled.
"""
if not context.executing_eagerly():
raise TypeError("__len__() is not supported while tracing functions. "
"Use `tf.data.Dataset.cardinality` instead.")
length = self.cardinality()
if length.numpy() == INFINITE:
raise TypeError("dataset length is infinite.")
if length.numpy() == UNKNOWN:
raise TypeError("dataset length is unknown.")
return length
@abc.abstractproperty
def element_spec(self):
"""The type specification of an element of this dataset.
@ -2095,6 +2131,34 @@ name=None))
"""
return _OptionsDataset(self, options)
def cardinality(self):
"""Returns the cardinality of the dataset, if known.
`cardinality` may return `tf.data.INFINITE_CARDINALITY` if the dataset
contains an infinite number of elements or `tf.data.UNKNOWN_CARDINALITY` if
the analysis fails to determine the number of elements in the dataset
(e.g. when the dataset source is a file).
>>> dataset = tf.data.Dataset.range(42)
>>> print(dataset.cardinality().numpy())
42
>>> dataset = dataset.repeat()
>>> cardinality = dataset.cardinality()
>>> print((cardinality == tf.data.INFINITE_CARDINALITY).numpy())
True
>>> dataset = dataset.filter(lambda x: True)
>>> cardinality = dataset.cardinality()
>>> print((cardinality == tf.data.UNKNOWN_CARDINALITY).numpy())
True
Returns:
A scalar `tf.int64` `Tensor` representing the cardinality of the dataset.
If the cardinality is infinite or unknown, `cardinality` returns the
named constants `tf.data.INFINITE_CARDINALITY` and
`tf.data.UNKNOWN_CARDINALITY` respectively.
"""
return ged_ops.dataset_cardinality(self._variant_tensor)
@tf_export(v1=["data.Dataset"])
class DatasetV1(DatasetV2):

View File

@ -41,6 +41,10 @@ tf_class {
name: "cache"
argspec: "args=[\'self\', \'filename\'], varargs=None, keywords=None, defaults=[\'\'], "
}
member_method {
name: "cardinality"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "concatenate"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"

View File

@ -43,6 +43,10 @@ tf_class {
name: "cache"
argspec: "args=[\'self\', \'filename\'], varargs=None, keywords=None, defaults=[\'\'], "
}
member_method {
name: "cardinality"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "concatenate"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"

View File

@ -43,6 +43,10 @@ tf_class {
name: "cache"
argspec: "args=[\'self\', \'filename\'], varargs=None, keywords=None, defaults=[\'\'], "
}
member_method {
name: "cardinality"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "concatenate"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"

View File

@ -43,6 +43,10 @@ tf_class {
name: "cache"
argspec: "args=[\'self\', \'filename\'], varargs=None, keywords=None, defaults=[\'\'], "
}
member_method {
name: "cardinality"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "concatenate"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"

View File

@ -43,6 +43,10 @@ tf_class {
name: "cache"
argspec: "args=[\'self\', \'filename\'], varargs=None, keywords=None, defaults=[\'\'], "
}
member_method {
name: "cardinality"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "concatenate"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"

View File

@ -43,6 +43,10 @@ tf_class {
name: "cache"
argspec: "args=[\'self\', \'filename\'], varargs=None, keywords=None, defaults=[\'\'], "
}
member_method {
name: "cardinality"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "concatenate"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"

View File

@ -43,6 +43,10 @@ tf_class {
name: "cache"
argspec: "args=[\'self\', \'filename\'], varargs=None, keywords=None, defaults=[\'\'], "
}
member_method {
name: "cardinality"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "concatenate"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"

View File

@ -12,6 +12,10 @@ tf_module {
name: "FixedLengthRecordDataset"
mtype: "<type \'type\'>"
}
member {
name: "INFINITE_CARDINALITY"
mtype: "<type \'int\'>"
}
member {
name: "Iterator"
mtype: "<type \'type\'>"
@ -28,6 +32,10 @@ tf_module {
name: "TextLineDataset"
mtype: "<type \'type\'>"
}
member {
name: "UNKNOWN_CARDINALITY"
mtype: "<type \'int\'>"
}
member {
name: "experimental"
mtype: "<type \'module\'>"

View File

@ -28,6 +28,10 @@ tf_class {
name: "cache"
argspec: "args=[\'self\', \'filename\'], varargs=None, keywords=None, defaults=[\'\'], "
}
member_method {
name: "cardinality"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "concatenate"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"

View File

@ -30,6 +30,10 @@ tf_class {
name: "cache"
argspec: "args=[\'self\', \'filename\'], varargs=None, keywords=None, defaults=[\'\'], "
}
member_method {
name: "cardinality"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "concatenate"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"

View File

@ -29,6 +29,10 @@ tf_class {
name: "cache"
argspec: "args=[\'self\', \'filename\'], varargs=None, keywords=None, defaults=[\'\'], "
}
member_method {
name: "cardinality"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "concatenate"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"

View File

@ -30,6 +30,10 @@ tf_class {
name: "cache"
argspec: "args=[\'self\', \'filename\'], varargs=None, keywords=None, defaults=[\'\'], "
}
member_method {
name: "cardinality"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "concatenate"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"

View File

@ -30,6 +30,10 @@ tf_class {
name: "cache"
argspec: "args=[\'self\', \'filename\'], varargs=None, keywords=None, defaults=[\'\'], "
}
member_method {
name: "cardinality"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "concatenate"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"

View File

@ -30,6 +30,10 @@ tf_class {
name: "cache"
argspec: "args=[\'self\', \'filename\'], varargs=None, keywords=None, defaults=[\'\'], "
}
member_method {
name: "cardinality"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "concatenate"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"

View File

@ -30,6 +30,10 @@ tf_class {
name: "cache"
argspec: "args=[\'self\', \'filename\'], varargs=None, keywords=None, defaults=[\'\'], "
}
member_method {
name: "cardinality"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "concatenate"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"

View File

@ -12,6 +12,10 @@ tf_module {
name: "FixedLengthRecordDataset"
mtype: "<type \'type\'>"
}
member {
name: "INFINITE_CARDINALITY"
mtype: "<type \'int\'>"
}
member {
name: "Options"
mtype: "<type \'type\'>"
@ -24,6 +28,10 @@ tf_module {
name: "TextLineDataset"
mtype: "<type \'type\'>"
}
member {
name: "UNKNOWN_CARDINALITY"
mtype: "<type \'int\'>"
}
member {
name: "experimental"
mtype: "<type \'module\'>"