[tf.data] Moving cardinality to core API.

PiperOrigin-RevId: 309266739
Change-Id: I9f9364a13e01d4cab32ddea06a1ec26f2f5ff307
This commit is contained in:
Jiri Simsa 2020-04-30 11:35:12 -07:00 committed by TensorFlower Gardener
parent 13ad6da577
commit ed5a72cb2a
24 changed files with 312 additions and 0 deletions

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "DatasetCardinalityV2"
visibility: HIDDEN
}

View File

@ -115,6 +115,21 @@ void DatasetCardinalityOp::Compute(OpKernelContext* ctx) {
result->scalar<int64>()() = dataset->Cardinality();
}
void DatasetCardinalityV2Op::Compute(OpKernelContext* ctx) {
DatasetBase* dataset;
OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset));
Tensor* result;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &result));
int64 cardinality = dataset->Cardinality();
if (cardinality == data::kUnknownCardinality) {
result->scalar<double>()() = std::nan("");
} else if (cardinality == data::kInfiniteCardinality) {
result->scalar<double>()() = std::numeric_limits<double>::infinity();
} else {
result->scalar<double>()() = cardinality;
}
}
void DatasetFromGraphOp::Compute(OpKernelContext* ctx) {
tstring graph_def_string;
OP_REQUIRES_OK(ctx,
@ -163,6 +178,9 @@ REGISTER_KERNEL_BUILDER(
Name("ExperimentalDatasetCardinality").Device(DEVICE_CPU),
DatasetCardinalityOp);
REGISTER_KERNEL_BUILDER(Name("DatasetCardinalityV2").Device(DEVICE_CPU),
DatasetCardinalityV2Op);
REGISTER_KERNEL_BUILDER(Name("DatasetFromGraph").Device(DEVICE_CPU),
DatasetFromGraphOp);

View File

@ -49,6 +49,13 @@ class DatasetCardinalityOp : public OpKernel {
void Compute(OpKernelContext* ctx) override;
};
class DatasetCardinalityV2Op : public OpKernel {
public:
explicit DatasetCardinalityV2Op(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override;
};
class DatasetFromGraphOp : public OpKernel {
public:
static constexpr const char* const kGraphDef = "graph_def";

View File

@ -221,6 +221,11 @@ REGISTER_OP("DatasetCardinality")
.Output("cardinality: int64")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("DatasetCardinalityV2")
.Input("input_dataset: variant")
.Output("cardinality: float64")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("ExperimentalDatasetCardinality")
.Input("input_dataset: variant")
.Output("cardinality: int64")

View File

@ -21,6 +21,7 @@ from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
@ -32,6 +33,7 @@ tf_export("data.experimental.UNKNOWN_CARDINALITY").export_constant(
__name__, "UNKNOWN")
@deprecation.deprecated(None, "Use `tf.data.Dataset.cardinality()")
@tf_export("data.experimental.cardinality")
def cardinality(dataset):
"""Returns the cardinality of `dataset`, if known.

View File

@ -46,6 +46,17 @@ tf_py_test(
],
)
tf_py_test(
name = "cardinality_test",
srcs = ["cardinality_test.py"],
deps = [
"//tensorflow/python/data/experimental/ops:cardinality",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"@absl_py//absl/testing:parameterized",
],
)
tf_py_test(
name = "checkpoint_test",
size = "medium",

View File

@ -0,0 +1,171 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for `tf.data.Dataset.cardinality()`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
from absl.testing import parameterized
import numpy as np
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.platform import test
def _test_combinations():
# pylint: disable=g-long-lambda
cases = [
("Batch1",
lambda: dataset_ops.Dataset.range(5).batch(2, drop_remainder=True), 2),
("Batch2",
lambda: dataset_ops.Dataset.range(5).batch(2, drop_remainder=False), 3),
("Batch3",
lambda: dataset_ops.Dataset.range(5).filter(lambda _: True).batch(2),
np.nan),
("Batch4", lambda: dataset_ops.Dataset.range(5).repeat().batch(2),
np.inf),
("Cache1", lambda: dataset_ops.Dataset.range(5).cache(), 5),
("Cache2", lambda: dataset_ops.Dataset.range(5).cache("foo"), 5),
("Concatenate1", lambda: dataset_ops.Dataset.range(5).concatenate(
dataset_ops.Dataset.range(5)), 10),
("Concatenate2", lambda: dataset_ops.Dataset.range(5).filter(
lambda _: True).concatenate(dataset_ops.Dataset.range(5)), np.nan),
("Concatenate3", lambda: dataset_ops.Dataset.range(5).repeat().
concatenate(dataset_ops.Dataset.range(5)), np.inf),
("Concatenate4", lambda: dataset_ops.Dataset.range(5).concatenate(
dataset_ops.Dataset.range(5).filter(lambda _: True)), np.nan),
("Concatenate5",
lambda: dataset_ops.Dataset.range(5).filter(lambda _: True).concatenate(
dataset_ops.Dataset.range(5).filter(lambda _: True)), np.nan),
("Concatenate6",
lambda: dataset_ops.Dataset.range(5).repeat().concatenate(
dataset_ops.Dataset.range(5).filter(lambda _: True)), np.inf),
("Concatenate7", lambda: dataset_ops.Dataset.range(5).concatenate(
dataset_ops.Dataset.range(5).repeat()), np.inf),
("Concatenate8",
lambda: dataset_ops.Dataset.range(5).filter(lambda _: True).concatenate(
dataset_ops.Dataset.range(5).repeat()), np.inf),
("Concatenate9", lambda: dataset_ops.Dataset.range(5).repeat().
concatenate(dataset_ops.Dataset.range(5).repeat()), np.inf),
("FlatMap", lambda: dataset_ops.Dataset.range(5).flat_map(
lambda _: dataset_ops.Dataset.from_tensors(0)), np.nan),
("Filter", lambda: dataset_ops.Dataset.range(5).filter(lambda _: True),
np.nan),
("FromTensors1", lambda: dataset_ops.Dataset.from_tensors(0), 1),
("FromTensors2", lambda: dataset_ops.Dataset.from_tensors((0, 1)), 1),
("FromTensorSlices1",
lambda: dataset_ops.Dataset.from_tensor_slices([0, 0, 0]), 3),
("FromTensorSlices2", lambda: dataset_ops.Dataset.from_tensor_slices(
([0, 0, 0], [1, 1, 1])), 3),
("Interleave1", lambda: dataset_ops.Dataset.range(5).interleave(
lambda _: dataset_ops.Dataset.from_tensors(0), cycle_length=1),
np.nan),
("Interleave2", lambda: dataset_ops.Dataset.range(5).interleave(
lambda _: dataset_ops.Dataset.from_tensors(0),
cycle_length=1,
num_parallel_calls=1), np.nan),
("Map1", lambda: dataset_ops.Dataset.range(5).map(lambda x: x), 5),
("Map2", lambda: dataset_ops.Dataset.range(5).map(
lambda x: x, num_parallel_calls=1), 5),
("PaddedBatch1", lambda: dataset_ops.Dataset.range(5).padded_batch(
2, [], drop_remainder=True), 2),
("PaddedBatch2", lambda: dataset_ops.Dataset.range(5).padded_batch(
2, [], drop_remainder=False), 3),
("PaddedBatch3", lambda: dataset_ops.Dataset.range(5).filter(
lambda _: True).padded_batch(2, []), np.nan),
("PaddedBatch4",
lambda: dataset_ops.Dataset.range(5).repeat().padded_batch(2, []),
np.inf),
("Prefetch", lambda: dataset_ops.Dataset.range(5).prefetch(buffer_size=1),
5),
("Range1", lambda: dataset_ops.Dataset.range(0), 0),
("Range2", lambda: dataset_ops.Dataset.range(5), 5),
("Range3", lambda: dataset_ops.Dataset.range(5, 10), 5),
("Range4", lambda: dataset_ops.Dataset.range(10, 5), 0),
("Range5", lambda: dataset_ops.Dataset.range(5, 10, 2), 3),
("Range6", lambda: dataset_ops.Dataset.range(10, 5, -2), 3),
("Repeat1", lambda: dataset_ops.Dataset.range(0).repeat(0), 0),
("Repeat2", lambda: dataset_ops.Dataset.range(1).repeat(0), 0),
("Repeat3", lambda: dataset_ops.Dataset.range(0).repeat(5), 0),
("Repeat4", lambda: dataset_ops.Dataset.range(1).repeat(5), 5),
("Repeat5", lambda: dataset_ops.Dataset.range(0).repeat(), 0),
("Repeat6", lambda: dataset_ops.Dataset.range(1).repeat(), np.inf),
("Shuffle", lambda: dataset_ops.Dataset.range(5).shuffle(buffer_size=1),
5),
("Shard1", lambda: dataset_ops.Dataset.range(5).shard(2, 0), 3),
("Shard2", lambda: dataset_ops.Dataset.range(5).shard(8, 7), 0),
("Shard3",
lambda: dataset_ops.Dataset.range(5).filter(lambda _: True).shard(2, 0),
np.nan),
("Shard4", lambda: dataset_ops.Dataset.range(5).repeat().shard(2, 0),
np.inf),
("Skip1", lambda: dataset_ops.Dataset.range(5).skip(2), 3),
("Skip2", lambda: dataset_ops.Dataset.range(5).skip(8), 0),
("Skip3",
lambda: dataset_ops.Dataset.range(5).filter(lambda _: True).skip(2),
np.nan),
("Skip4", lambda: dataset_ops.Dataset.range(5).repeat().skip(2), np.inf),
("Take1", lambda: dataset_ops.Dataset.range(5).take(2), 2),
("Take2", lambda: dataset_ops.Dataset.range(5).take(8), 5),
("Take3",
lambda: dataset_ops.Dataset.range(5).filter(lambda _: True).take(2),
np.nan),
("Take4", lambda: dataset_ops.Dataset.range(5).repeat().take(2), 2),
("Window1", lambda: dataset_ops.Dataset.range(5).window(
size=2, shift=2, drop_remainder=True), 2),
("Window2", lambda: dataset_ops.Dataset.range(5).window(
size=2, shift=2, drop_remainder=False), 3),
("Zip1", lambda: dataset_ops.Dataset.zip(dataset_ops.Dataset.range(5)),
5),
("Zip2", lambda: dataset_ops.Dataset.zip(
(dataset_ops.Dataset.range(5), dataset_ops.Dataset.range(3))), 3),
("Zip3", lambda: dataset_ops.Dataset.zip((dataset_ops.Dataset.range(
5), dataset_ops.Dataset.range(3).repeat())), 5),
("Zip4", lambda: dataset_ops.Dataset.zip((dataset_ops.Dataset.range(
5).repeat(), dataset_ops.Dataset.range(3).repeat())), np.inf),
("Zip5", lambda: dataset_ops.Dataset.zip((dataset_ops.Dataset.range(
5), dataset_ops.Dataset.range(3).filter(lambda _: True))), np.nan),
]
def reduce_fn(x, y):
name, dataset_fn, expected_result = y
return x + combinations.combine(
dataset_fn=combinations.NamedObject(name, dataset_fn),
expected_result=expected_result)
return functools.reduce(reduce_fn, cases, [])
class CardinalityTest(test_base.DatasetTestBase, parameterized.TestCase):
"""Tests for `tf.data.Dataset.cardinality()`."""
@combinations.generate(
combinations.times(test_base.default_test_combinations(),
_test_combinations()))
def testCardinality(self, dataset_fn, expected_result):
dataset = dataset_fn()
if np.isnan(expected_result):
self.assertTrue(np.isnan(self.evaluate(dataset.cardinality())))
else:
self.assertEqual(self.evaluate(dataset.cardinality()), expected_result)
if __name__ == "__main__":
test.main()

View File

@ -2080,6 +2080,36 @@ name=None))
"""
return _OptionsDataset(self, options)
def cardinality(self):
"""Returns the (statically known) cardinality of the dataset.
The returned cardinality may be infinite or unknown. the latter will be
returned if static analysis fails to determine the number of elements in
`dataset` (e.g. when the dataset source is a file).
Note: To provide an idiomatic representation for infinite and unknown
cardinality, this method returns a 64-bit floating point number. As a
consequence, the returned cardinality will be approximate for datasets
whose integer cardinality cannot be accurately represented by 64-bit
floating point number (i.e. cardinalities greater than 2^53).
>>> dataset = tf.data.Dataset.range(42)
>>> print(dataset.cardinality().numpy())
42.0
>>> dataset = dataset.repeat()
>>> print(dataset.cardinality().numpy() == np.inf)
True
>>> dataset = dataset.filter(lambda x: True)
>>> print(np.isnan(dataset.cardinality().numpy()))
True
Returns:
A scalar `tf.float64` `Tensor` representing the cardinality of the
dataset. If the cardinality is infinite or unknown, the operation returns
IEEE 754 representation of infinity and NaN respectively.
"""
return ged_ops.dataset_cardinality_v2(self._variant_tensor)
@tf_export(v1=["data.Dataset"])
class DatasetV1(DatasetV2):

View File

@ -41,6 +41,10 @@ tf_class {
name: "cache"
argspec: "args=[\'self\', \'filename\'], varargs=None, keywords=None, defaults=[\'\'], "
}
member_method {
name: "cardinality"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "concatenate"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"

View File

@ -43,6 +43,10 @@ tf_class {
name: "cache"
argspec: "args=[\'self\', \'filename\'], varargs=None, keywords=None, defaults=[\'\'], "
}
member_method {
name: "cardinality"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "concatenate"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"

View File

@ -43,6 +43,10 @@ tf_class {
name: "cache"
argspec: "args=[\'self\', \'filename\'], varargs=None, keywords=None, defaults=[\'\'], "
}
member_method {
name: "cardinality"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "concatenate"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"

View File

@ -43,6 +43,10 @@ tf_class {
name: "cache"
argspec: "args=[\'self\', \'filename\'], varargs=None, keywords=None, defaults=[\'\'], "
}
member_method {
name: "cardinality"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "concatenate"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"

View File

@ -43,6 +43,10 @@ tf_class {
name: "cache"
argspec: "args=[\'self\', \'filename\'], varargs=None, keywords=None, defaults=[\'\'], "
}
member_method {
name: "cardinality"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "concatenate"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"

View File

@ -43,6 +43,10 @@ tf_class {
name: "cache"
argspec: "args=[\'self\', \'filename\'], varargs=None, keywords=None, defaults=[\'\'], "
}
member_method {
name: "cardinality"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "concatenate"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"

View File

@ -43,6 +43,10 @@ tf_class {
name: "cache"
argspec: "args=[\'self\', \'filename\'], varargs=None, keywords=None, defaults=[\'\'], "
}
member_method {
name: "cardinality"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "concatenate"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"

View File

@ -944,6 +944,10 @@ tf_module {
name: "DatasetCardinality"
argspec: "args=[\'input_dataset\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "DatasetCardinalityV2"
argspec: "args=[\'input_dataset\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "DatasetFromGraph"
argspec: "args=[\'graph_def\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -28,6 +28,10 @@ tf_class {
name: "cache"
argspec: "args=[\'self\', \'filename\'], varargs=None, keywords=None, defaults=[\'\'], "
}
member_method {
name: "cardinality"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "concatenate"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"

View File

@ -30,6 +30,10 @@ tf_class {
name: "cache"
argspec: "args=[\'self\', \'filename\'], varargs=None, keywords=None, defaults=[\'\'], "
}
member_method {
name: "cardinality"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "concatenate"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"

View File

@ -29,6 +29,10 @@ tf_class {
name: "cache"
argspec: "args=[\'self\', \'filename\'], varargs=None, keywords=None, defaults=[\'\'], "
}
member_method {
name: "cardinality"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "concatenate"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"

View File

@ -30,6 +30,10 @@ tf_class {
name: "cache"
argspec: "args=[\'self\', \'filename\'], varargs=None, keywords=None, defaults=[\'\'], "
}
member_method {
name: "cardinality"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "concatenate"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"

View File

@ -30,6 +30,10 @@ tf_class {
name: "cache"
argspec: "args=[\'self\', \'filename\'], varargs=None, keywords=None, defaults=[\'\'], "
}
member_method {
name: "cardinality"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "concatenate"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"

View File

@ -30,6 +30,10 @@ tf_class {
name: "cache"
argspec: "args=[\'self\', \'filename\'], varargs=None, keywords=None, defaults=[\'\'], "
}
member_method {
name: "cardinality"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "concatenate"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"

View File

@ -30,6 +30,10 @@ tf_class {
name: "cache"
argspec: "args=[\'self\', \'filename\'], varargs=None, keywords=None, defaults=[\'\'], "
}
member_method {
name: "cardinality"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "concatenate"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"

View File

@ -944,6 +944,10 @@ tf_module {
name: "DatasetCardinality"
argspec: "args=[\'input_dataset\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "DatasetCardinalityV2"
argspec: "args=[\'input_dataset\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "DatasetFromGraph"
argspec: "args=[\'graph_def\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "