[tf.data] Introducing API for statically asserting the cardinality of a dataset. This API improves user experience when using tf.data with Keras in situations where the cardinality of a file-based dataset is known and can thus be asserted. Having a dataset with a statically known cardinality allows Keras to display accurate progress bar on the first epoch.

PiperOrigin-RevId: 295268926
Change-Id: If15704bee651fbefa385b1bc2d2024633d008bcf
This commit is contained in:
Jiri Simsa 2020-02-14 18:24:47 -08:00 committed by TensorFlower Gardener
parent 0148382b20
commit d25235b196
17 changed files with 494 additions and 5 deletions

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "AssertCardinalityDataset"
visibility: HIDDEN
}

View File

@ -14,6 +14,17 @@ package(
exports_files(["LICENSE"])
tf_kernel_library(
name = "assert_cardinality_dataset_op",
srcs = ["assert_cardinality_dataset_op.cc"],
hdrs = ["assert_cardinality_dataset_op.h"],
deps = [
"//tensorflow/core:experimental_dataset_ops_op_lib",
"//tensorflow/core:framework",
"//tensorflow/core/kernels/data:name_utils",
],
)
tf_kernel_library(
name = "assert_next_dataset_op",
srcs = ["assert_next_dataset_op.cc"],
@ -22,7 +33,6 @@ tf_kernel_library(
"//tensorflow/core:experimental_dataset_ops_op_lib",
"//tensorflow/core:framework",
"//tensorflow/core/kernels/data:name_utils",
"//third_party/eigen3",
],
)
@ -557,6 +567,7 @@ tf_cc_test(
tf_kernel_library(
name = "dataset_kernels",
deps = [
":assert_cardinality_dataset_op",
":assert_next_dataset_op",
":auto_shard_dataset_op",
":choose_fastest_branch_dataset_op",

View File

@ -0,0 +1,180 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/data/experimental/assert_cardinality_dataset_op.h"
#include <map>
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/name_utils.h"
namespace tensorflow {
namespace data {
namespace experimental {
/* static */ constexpr const char* const
AssertCardinalityDatasetOp::kInputDataset;
/* static */ constexpr const char* const
AssertCardinalityDatasetOp::kDatasetType;
/* static */ constexpr const char* const
AssertCardinalityDatasetOp::kCardinality;
/* static */ constexpr const char* const
AssertCardinalityDatasetOp::kOutputTypes;
/* static */ constexpr const char* const
AssertCardinalityDatasetOp::kOutputShapes;
class AssertCardinalityDatasetOp::Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 cardinality,
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes)
: DatasetBase(DatasetContext(ctx)),
input_(input),
cardinality_(cardinality),
output_types_(output_types),
output_shapes_(output_shapes) {
input_->Ref();
}
~Dataset() override { input_->Unref(); }
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return absl::make_unique<Iterator>(Iterator::Params{
this, name_utils::IteratorPrefix(kDatasetType, prefix)});
}
const DataTypeVector& output_dtypes() const override { return output_types_; }
const std::vector<PartialTensorShape>& output_shapes() const override {
return output_shapes_;
}
string DebugString() const override {
return name_utils::DatasetDebugString(kDatasetType);
}
int64 Cardinality() const override { return cardinality_; }
Status CheckExternalState() const override {
return input_->CheckExternalState();
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
Node* cardinality_node = nullptr;
TF_RETURN_IF_ERROR(b->AddScalar(cardinality_, &cardinality_node));
TF_RETURN_IF_ERROR(
b->AddDataset(this, {input_graph_node, cardinality_node}, output));
return Status::OK();
}
private:
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params), num_elements_(0) {}
Status Initialize(IteratorContext* ctx) override {
return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_);
}
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
TF_RETURN_IF_ERROR(
input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
if (!*end_of_sequence) {
num_elements_++;
}
if (*end_of_sequence && num_elements_ != dataset()->cardinality_) {
return errors::FailedPrecondition(
"Input dataset was expected to contain ",
ElementString(dataset()->cardinality_), " but contained only ",
ElementString(num_elements_), ".");
}
if (num_elements_ > dataset()->cardinality_) {
return errors::FailedPrecondition(
"Input dataset was expected to contain ",
ElementString(dataset()->cardinality_), " but contained at least ",
ElementString(num_elements_), ".");
}
return Status::OK();
}
protected:
std::shared_ptr<model::Node> CreateNode(
IteratorContext* ctx, model::Node::Args args) const override {
return model::MakeKnownRatioNode(std::move(args),
/*ratio=*/1);
}
Status SaveInternal(IteratorStateWriter* writer) override {
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("num_elements"), num_elements_));
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
return Status::OK();
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
TF_RETURN_IF_ERROR(
reader->ReadScalar(full_name("num_elements"), &num_elements_));
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
return Status::OK();
}
private:
static string ElementString(int64 n) {
return strings::StrCat(n, " element", n != 1 ? "s" : "");
}
std::unique_ptr<IteratorBase> input_impl_;
int64 num_elements_;
};
const DatasetBase* input_;
const int64 cardinality_;
const DataTypeVector output_types_;
const std::vector<PartialTensorShape> output_shapes_;
};
AssertCardinalityDatasetOp::AssertCardinalityDatasetOp(
OpKernelConstruction* ctx)
: UnaryDatasetOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
}
void AssertCardinalityDatasetOp::MakeDataset(OpKernelContext* ctx,
DatasetBase* input,
DatasetBase** output) {
int64 cardinality;
OP_REQUIRES_OK(ctx,
ParseScalarArgument<int64>(ctx, kCardinality, &cardinality));
*output = new Dataset(ctx, input, cardinality, output_types_, output_shapes_);
}
namespace {
REGISTER_KERNEL_BUILDER(Name("AssertCardinalityDataset").Device(DEVICE_CPU),
AssertCardinalityDatasetOp);
} // namespace
} // namespace experimental
} // namespace data
} // namespace tensorflow

View File

@ -0,0 +1,48 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_ASSERT_CARDINALITY_DATASET_OP_H_
#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_ASSERT_CARDINALITY_DATASET_OP_H_
#include "tensorflow/core/framework/dataset.h"
namespace tensorflow {
namespace data {
namespace experimental {
class AssertCardinalityDatasetOp : public UnaryDatasetOpKernel {
public:
static constexpr const char* const kDatasetType = "AssertCardinality";
static constexpr const char* const kInputDataset = "input_dataset";
static constexpr const char* const kCardinality = "cardinality";
static constexpr const char* const kOutputTypes = "output_types";
static constexpr const char* const kOutputShapes = "output_shapes";
explicit AssertCardinalityDatasetOp(OpKernelConstruction* ctx);
protected:
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override;
private:
class Dataset;
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
};
} // namespace experimental
} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_ASSERT_CARDINALITY_DATASET_OP_H_

View File

@ -17,6 +17,19 @@ limitations under the License.
namespace tensorflow {
REGISTER_OP("AssertCardinalityDataset")
.Input("input_dataset: variant")
.Input("cardinality: int64")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused;
// cardinality should be a scalar.
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
return shape_inference::ScalarShape(c);
});
REGISTER_OP("AssertNextDataset")
.Input("input_dataset: variant")
.Input("transformations: string")

View File

@ -45,6 +45,7 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
@@TensorStructure
@@ThreadingOptions
@@assert_cardinality
@@bucket_by_sequence_length
@@bytes_produced_stats
@@cardinality
@ -93,6 +94,7 @@ from tensorflow.python.data.experimental.ops.batching import dense_to_sparse_bat
from tensorflow.python.data.experimental.ops.batching import map_and_batch
from tensorflow.python.data.experimental.ops.batching import map_and_batch_with_legacy_function
from tensorflow.python.data.experimental.ops.batching import unbatch
from tensorflow.python.data.experimental.ops.cardinality import assert_cardinality
from tensorflow.python.data.experimental.ops.cardinality import cardinality
from tensorflow.python.data.experimental.ops.cardinality import INFINITE as INFINITE_CARDINALITY
from tensorflow.python.data.experimental.ops.cardinality import UNKNOWN as UNKNOWN_CARDINALITY

View File

@ -8,6 +8,17 @@ package(
exports_files(["LICENSE"])
tf_py_test(
name = "assert_cardinality_test",
srcs = ["assert_cardinality_test.py"],
deps = [
"//tensorflow/python/data/experimental/ops:cardinality",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"@absl_py//absl/testing:parameterized",
],
)
tf_py_test(
name = "assert_next_test",
size = "small",

View File

@ -0,0 +1,79 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for `tf.data.experimental.assert_cardinality()`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import cardinality
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.framework import errors
from tensorflow.python.platform import test
class AssertCardinalityTest(test_base.DatasetTestBase, parameterized.TestCase):
"""Tests for `tf.data.experimental.assert_cardinality()`."""
@combinations.generate(test_base.default_test_combinations())
def testCorrectCardinality(self):
dataset = dataset_ops.Dataset.range(10).filter(lambda x: True)
self.assertEqual(
self.evaluate(cardinality.cardinality(dataset)), cardinality.UNKNOWN)
self.assertDatasetProduces(dataset, expected_output=range(10))
dataset = dataset.apply(cardinality.assert_cardinality(10))
self.assertEqual(self.evaluate(cardinality.cardinality(dataset)), 10)
self.assertDatasetProduces(dataset, expected_output=range(10))
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
combinations.combine(
num_elements=10,
asserted_cardinality=20,
expected_error="Input dataset was expected to contain 20 "
"elements but contained only 10 elements.") +
combinations.combine(
num_elements=1,
asserted_cardinality=20,
expected_error="Input dataset was expected to contain 20 "
"elements but contained only 1 element.") +
combinations.combine(
num_elements=10,
asserted_cardinality=5,
expected_error="Input dataset was expected to contain 5 "
"elements but contained at least 6 elements.") +
combinations.combine(
num_elements=10,
asserted_cardinality=1,
expected_error="Input dataset was expected to contain 1 "
"element but contained at least 2 elements.")))
def testIncorrectCardinality(self, num_elements, asserted_cardinality,
expected_error):
dataset = dataset_ops.Dataset.range(num_elements)
dataset = dataset.apply(
cardinality.assert_cardinality(asserted_cardinality))
get_next = self.getNext(dataset)
with self.assertRaisesRegexp(errors.FailedPreconditionError,
expected_error):
while True:
self.evaluate(get_next())
if __name__ == "__main__":
test.main()

View File

@ -168,9 +168,8 @@ class CardinalityTest(test_base.DatasetTestBase, parameterized.TestCase):
combinations.times(test_base.default_test_combinations(),
_test_combinations()))
def testCardinality(self, dataset_fn, expected_result):
with self.cached_session() as sess:
self.assertEqual(
sess.run(cardinality.cardinality(dataset_fn())), expected_result)
self.assertEqual(
self.evaluate(cardinality.cardinality(dataset_fn())), expected_result)
if __name__ == "__main__":

View File

@ -31,6 +31,24 @@ py_library(
],
)
tf_py_test(
name = "assert_cardinality_dataset_serialization_test",
size = "small",
srcs = ["assert_cardinality_dataset_serialization_test.py"],
tags = [
"no_oss",
"no_pip",
"no_windows",
],
deps = [
":dataset_serialization_test_base",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
"//tensorflow/python/data/ops:dataset_ops",
"@absl_py//absl/testing:parameterized",
],
)
tf_py_test(
name = "auto_shard_dataset_serialization_test",
size = "medium",

View File

@ -0,0 +1,45 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the AssertCardinalityDataset serialization."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.python.data.experimental.ops import cardinality
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.platform import test
class AssertCardinalityDatasetSerializationTest(
dataset_serialization_test_base.DatasetSerializationTestBase,
parameterized.TestCase):
@combinations.generate(test_base.default_test_combinations())
def testCardinality(self):
def build_dataset(num_elements):
return dataset_ops.Dataset.range(num_elements).apply(
cardinality.assert_cardinality(num_elements))
self.run_core_tests(lambda: build_dataset(200), 200)
if __name__ == "__main__":
test.main()

View File

@ -320,7 +320,6 @@ class _MapAndBatchDataset(dataset_ops.UnaryDataset):
def __init__(self, input_dataset, map_func, batch_size, num_parallel_calls,
drop_remainder, use_legacy_function=False):
"""See `Dataset.map()` for details."""
self._input_dataset = input_dataset
self._map_func = dataset_ops.StructuredFunctionWrapper(

View File

@ -17,6 +17,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
from tensorflow.python.util.tf_export import tf_export
@ -39,6 +42,18 @@ def cardinality(dataset):
analysis fails to determine the number of elements in `dataset` (e.g. when the
dataset source is a file).
>>> dataset = tf.data.Dataset.range(42)
>>> print(tf.data.experimental.cardinality(dataset).numpy())
42
>>> dataset = dataset.repeat()
>>> cardinality = tf.data.experimental.cardinality(dataset)
>>> print((cardinality == tf.data.experimental.INFINITE_CARDINALITY).numpy())
True
>>> dataset = dataset.filter(lambda x: True)
>>> cardinality = tf.data.experimental.cardinality(dataset)
>>> print((cardinality == tf.data.experimental.UNKNOWN_CARDINALITY).numpy())
True
Args:
dataset: A `tf.data.Dataset` for which to determine cardinality.
@ -49,3 +64,52 @@ def cardinality(dataset):
"""
return ged_ops.dataset_cardinality(dataset._variant_tensor) # pylint: disable=protected-access
@tf_export("data.experimental.assert_cardinality")
def assert_cardinality(expected_cardinality):
"""Asserts the cardinality of the input dataset.
NOTE: The following assumes that "examples.tfrecord" contains 42 records.
>>> dataset = tf.data.TFRecordDataset("examples.tfrecord")
>>> cardinality = tf.data.experimental.cardinality(dataset)
>>> print((cardinality == tf.data.experimental.UNKNOWN_CARDINALITY).numpy())
True
>>> dataset = dataset.apply(tf.data.experimental.assert_cardinality(42))
>>> print(tf.data.experimental.cardinality(dataset).numpy())
42
Args:
expected_cardinality: The expected cardinality of the input dataset.
Returns:
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
Raises:
FailedPreconditionError: The assertion is checked at runtime (when iterating
the dataset) and an error is raised if the actual and expected cardinality
differ.
"""
def _apply_fn(dataset):
return _AssertCardinalityDataset(dataset, expected_cardinality)
return _apply_fn
class _AssertCardinalityDataset(dataset_ops.UnaryUnchangedStructureDataset):
"""A `Dataset` that assert the cardinality of its input."""
def __init__(self, input_dataset, expected_cardinality):
self._input_dataset = input_dataset
self._expected_cardinality = ops.convert_to_tensor(
expected_cardinality, dtype=dtypes.int64, name="expected_cardinality")
# pylint: enable=protected-access
variant_tensor = ged_ops.assert_cardinality_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._expected_cardinality,
**self._flat_structure)
super(_AssertCardinalityDataset, self).__init__(input_dataset,
variant_tensor)

View File

@ -100,6 +100,10 @@ tf_module {
name: "TensorStructure"
argspec: "args=[\'dtype\', \'shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "assert_cardinality"
argspec: "args=[\'expected_cardinality\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "bucket_by_sequence_length"
argspec: "args=[\'element_length_func\', \'bucket_boundaries\', \'bucket_batch_sizes\', \'padded_shapes\', \'padding_values\', \'pad_to_bucket_boundary\', \'no_padding\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'False\', \'False\'], "

View File

@ -204,6 +204,10 @@ tf_module {
name: "Assert"
argspec: "args=[\'condition\', \'data\', \'summarize\', \'name\'], varargs=None, keywords=None, defaults=[\'3\', \'None\'], "
}
member_method {
name: "AssertCardinalityDataset"
argspec: "args=[\'input_dataset\', \'cardinality\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "AssertNextDataset"
argspec: "args=[\'input_dataset\', \'transformations\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -72,6 +72,10 @@ tf_module {
name: "Counter"
argspec: "args=[\'start\', \'step\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \"<dtype: \'int64\'>\"], "
}
member_method {
name: "assert_cardinality"
argspec: "args=[\'expected_cardinality\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "bucket_by_sequence_length"
argspec: "args=[\'element_length_func\', \'bucket_boundaries\', \'bucket_batch_sizes\', \'padded_shapes\', \'padding_values\', \'pad_to_bucket_boundary\', \'no_padding\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'False\', \'False\'], "

View File

@ -204,6 +204,10 @@ tf_module {
name: "Assert"
argspec: "args=[\'condition\', \'data\', \'summarize\', \'name\'], varargs=None, keywords=None, defaults=[\'3\', \'None\'], "
}
member_method {
name: "AssertCardinalityDataset"
argspec: "args=[\'input_dataset\', \'cardinality\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "AssertNextDataset"
argspec: "args=[\'input_dataset\', \'transformations\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "