[tf.data] Introducing API for statically asserting the cardinality of a dataset. This API improves user experience when using tf.data with Keras in situations where the cardinality of a file-based dataset is known and can thus be asserted. Having a dataset with a statically known cardinality allows Keras to display accurate progress bar on the first epoch.
PiperOrigin-RevId: 295268926 Change-Id: If15704bee651fbefa385b1bc2d2024633d008bcf
This commit is contained in:
parent
0148382b20
commit
d25235b196
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "AssertCardinalityDataset"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -14,6 +14,17 @@ package(
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
tf_kernel_library(
|
||||
name = "assert_cardinality_dataset_op",
|
||||
srcs = ["assert_cardinality_dataset_op.cc"],
|
||||
hdrs = ["assert_cardinality_dataset_op.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:experimental_dataset_ops_op_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core/kernels/data:name_utils",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "assert_next_dataset_op",
|
||||
srcs = ["assert_next_dataset_op.cc"],
|
||||
@ -22,7 +33,6 @@ tf_kernel_library(
|
||||
"//tensorflow/core:experimental_dataset_ops_op_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core/kernels/data:name_utils",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
@ -557,6 +567,7 @@ tf_cc_test(
|
||||
tf_kernel_library(
|
||||
name = "dataset_kernels",
|
||||
deps = [
|
||||
":assert_cardinality_dataset_op",
|
||||
":assert_next_dataset_op",
|
||||
":auto_shard_dataset_op",
|
||||
":choose_fastest_branch_dataset_op",
|
||||
|
@ -0,0 +1,180 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/kernels/data/experimental/assert_cardinality_dataset_op.h"
|
||||
|
||||
#include <map>
|
||||
|
||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/kernels/data/name_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
namespace experimental {
|
||||
|
||||
/* static */ constexpr const char* const
|
||||
AssertCardinalityDatasetOp::kInputDataset;
|
||||
/* static */ constexpr const char* const
|
||||
AssertCardinalityDatasetOp::kDatasetType;
|
||||
/* static */ constexpr const char* const
|
||||
AssertCardinalityDatasetOp::kCardinality;
|
||||
/* static */ constexpr const char* const
|
||||
AssertCardinalityDatasetOp::kOutputTypes;
|
||||
/* static */ constexpr const char* const
|
||||
AssertCardinalityDatasetOp::kOutputShapes;
|
||||
|
||||
class AssertCardinalityDatasetOp::Dataset : public DatasetBase {
|
||||
public:
|
||||
Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 cardinality,
|
||||
const DataTypeVector& output_types,
|
||||
const std::vector<PartialTensorShape>& output_shapes)
|
||||
: DatasetBase(DatasetContext(ctx)),
|
||||
input_(input),
|
||||
cardinality_(cardinality),
|
||||
output_types_(output_types),
|
||||
output_shapes_(output_shapes) {
|
||||
input_->Ref();
|
||||
}
|
||||
|
||||
~Dataset() override { input_->Unref(); }
|
||||
|
||||
std::unique_ptr<IteratorBase> MakeIteratorInternal(
|
||||
const string& prefix) const override {
|
||||
return absl::make_unique<Iterator>(Iterator::Params{
|
||||
this, name_utils::IteratorPrefix(kDatasetType, prefix)});
|
||||
}
|
||||
|
||||
const DataTypeVector& output_dtypes() const override { return output_types_; }
|
||||
const std::vector<PartialTensorShape>& output_shapes() const override {
|
||||
return output_shapes_;
|
||||
}
|
||||
|
||||
string DebugString() const override {
|
||||
return name_utils::DatasetDebugString(kDatasetType);
|
||||
}
|
||||
|
||||
int64 Cardinality() const override { return cardinality_; }
|
||||
|
||||
Status CheckExternalState() const override {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
DatasetGraphDefBuilder* b,
|
||||
Node** output) const override {
|
||||
Node* input_graph_node = nullptr;
|
||||
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
|
||||
Node* cardinality_node = nullptr;
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(cardinality_, &cardinality_node));
|
||||
TF_RETURN_IF_ERROR(
|
||||
b->AddDataset(this, {input_graph_node, cardinality_node}, output));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
class Iterator : public DatasetIterator<Dataset> {
|
||||
public:
|
||||
explicit Iterator(const Params& params)
|
||||
: DatasetIterator<Dataset>(params), num_elements_(0) {}
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_);
|
||||
}
|
||||
|
||||
Status GetNextInternal(IteratorContext* ctx,
|
||||
std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) override {
|
||||
TF_RETURN_IF_ERROR(
|
||||
input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
|
||||
if (!*end_of_sequence) {
|
||||
num_elements_++;
|
||||
}
|
||||
if (*end_of_sequence && num_elements_ != dataset()->cardinality_) {
|
||||
return errors::FailedPrecondition(
|
||||
"Input dataset was expected to contain ",
|
||||
ElementString(dataset()->cardinality_), " but contained only ",
|
||||
ElementString(num_elements_), ".");
|
||||
}
|
||||
if (num_elements_ > dataset()->cardinality_) {
|
||||
return errors::FailedPrecondition(
|
||||
"Input dataset was expected to contain ",
|
||||
ElementString(dataset()->cardinality_), " but contained at least ",
|
||||
ElementString(num_elements_), ".");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
protected:
|
||||
std::shared_ptr<model::Node> CreateNode(
|
||||
IteratorContext* ctx, model::Node::Args args) const override {
|
||||
return model::MakeKnownRatioNode(std::move(args),
|
||||
/*ratio=*/1);
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name("num_elements"), num_elements_));
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RestoreInternal(IteratorContext* ctx,
|
||||
IteratorStateReader* reader) override {
|
||||
TF_RETURN_IF_ERROR(
|
||||
reader->ReadScalar(full_name("num_elements"), &num_elements_));
|
||||
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
static string ElementString(int64 n) {
|
||||
return strings::StrCat(n, " element", n != 1 ? "s" : "");
|
||||
}
|
||||
|
||||
std::unique_ptr<IteratorBase> input_impl_;
|
||||
int64 num_elements_;
|
||||
};
|
||||
|
||||
const DatasetBase* input_;
|
||||
const int64 cardinality_;
|
||||
const DataTypeVector output_types_;
|
||||
const std::vector<PartialTensorShape> output_shapes_;
|
||||
};
|
||||
|
||||
AssertCardinalityDatasetOp::AssertCardinalityDatasetOp(
|
||||
OpKernelConstruction* ctx)
|
||||
: UnaryDatasetOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
|
||||
}
|
||||
|
||||
void AssertCardinalityDatasetOp::MakeDataset(OpKernelContext* ctx,
|
||||
DatasetBase* input,
|
||||
DatasetBase** output) {
|
||||
int64 cardinality;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ParseScalarArgument<int64>(ctx, kCardinality, &cardinality));
|
||||
*output = new Dataset(ctx, input, cardinality, output_types_, output_shapes_);
|
||||
}
|
||||
|
||||
namespace {
|
||||
REGISTER_KERNEL_BUILDER(Name("AssertCardinalityDataset").Device(DEVICE_CPU),
|
||||
AssertCardinalityDatasetOp);
|
||||
} // namespace
|
||||
|
||||
} // namespace experimental
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
@ -0,0 +1,48 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_ASSERT_CARDINALITY_DATASET_OP_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_ASSERT_CARDINALITY_DATASET_OP_H_
|
||||
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
namespace experimental {
|
||||
|
||||
class AssertCardinalityDatasetOp : public UnaryDatasetOpKernel {
|
||||
public:
|
||||
static constexpr const char* const kDatasetType = "AssertCardinality";
|
||||
static constexpr const char* const kInputDataset = "input_dataset";
|
||||
static constexpr const char* const kCardinality = "cardinality";
|
||||
static constexpr const char* const kOutputTypes = "output_types";
|
||||
static constexpr const char* const kOutputShapes = "output_shapes";
|
||||
|
||||
explicit AssertCardinalityDatasetOp(OpKernelConstruction* ctx);
|
||||
|
||||
protected:
|
||||
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||
DatasetBase** output) override;
|
||||
|
||||
private:
|
||||
class Dataset;
|
||||
DataTypeVector output_types_;
|
||||
std::vector<PartialTensorShape> output_shapes_;
|
||||
};
|
||||
|
||||
} // namespace experimental
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_ASSERT_CARDINALITY_DATASET_OP_H_
|
@ -17,6 +17,19 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
REGISTER_OP("AssertCardinalityDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("cardinality: int64")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
shape_inference::ShapeHandle unused;
|
||||
// cardinality should be a scalar.
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
|
||||
return shape_inference::ScalarShape(c);
|
||||
});
|
||||
|
||||
REGISTER_OP("AssertNextDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("transformations: string")
|
||||
|
@ -45,6 +45,7 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
|
||||
@@TensorStructure
|
||||
@@ThreadingOptions
|
||||
|
||||
@@assert_cardinality
|
||||
@@bucket_by_sequence_length
|
||||
@@bytes_produced_stats
|
||||
@@cardinality
|
||||
@ -93,6 +94,7 @@ from tensorflow.python.data.experimental.ops.batching import dense_to_sparse_bat
|
||||
from tensorflow.python.data.experimental.ops.batching import map_and_batch
|
||||
from tensorflow.python.data.experimental.ops.batching import map_and_batch_with_legacy_function
|
||||
from tensorflow.python.data.experimental.ops.batching import unbatch
|
||||
from tensorflow.python.data.experimental.ops.cardinality import assert_cardinality
|
||||
from tensorflow.python.data.experimental.ops.cardinality import cardinality
|
||||
from tensorflow.python.data.experimental.ops.cardinality import INFINITE as INFINITE_CARDINALITY
|
||||
from tensorflow.python.data.experimental.ops.cardinality import UNKNOWN as UNKNOWN_CARDINALITY
|
||||
|
@ -8,6 +8,17 @@ package(
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
tf_py_test(
|
||||
name = "assert_cardinality_test",
|
||||
srcs = ["assert_cardinality_test.py"],
|
||||
deps = [
|
||||
"//tensorflow/python/data/experimental/ops:cardinality",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "assert_next_test",
|
||||
size = "small",
|
||||
|
@ -0,0 +1,79 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for `tf.data.experimental.assert_cardinality()`."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.data.experimental.ops import cardinality
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class AssertCardinalityTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
"""Tests for `tf.data.experimental.assert_cardinality()`."""
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCorrectCardinality(self):
|
||||
dataset = dataset_ops.Dataset.range(10).filter(lambda x: True)
|
||||
self.assertEqual(
|
||||
self.evaluate(cardinality.cardinality(dataset)), cardinality.UNKNOWN)
|
||||
self.assertDatasetProduces(dataset, expected_output=range(10))
|
||||
dataset = dataset.apply(cardinality.assert_cardinality(10))
|
||||
self.assertEqual(self.evaluate(cardinality.cardinality(dataset)), 10)
|
||||
self.assertDatasetProduces(dataset, expected_output=range(10))
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
combinations.combine(
|
||||
num_elements=10,
|
||||
asserted_cardinality=20,
|
||||
expected_error="Input dataset was expected to contain 20 "
|
||||
"elements but contained only 10 elements.") +
|
||||
combinations.combine(
|
||||
num_elements=1,
|
||||
asserted_cardinality=20,
|
||||
expected_error="Input dataset was expected to contain 20 "
|
||||
"elements but contained only 1 element.") +
|
||||
combinations.combine(
|
||||
num_elements=10,
|
||||
asserted_cardinality=5,
|
||||
expected_error="Input dataset was expected to contain 5 "
|
||||
"elements but contained at least 6 elements.") +
|
||||
combinations.combine(
|
||||
num_elements=10,
|
||||
asserted_cardinality=1,
|
||||
expected_error="Input dataset was expected to contain 1 "
|
||||
"element but contained at least 2 elements.")))
|
||||
def testIncorrectCardinality(self, num_elements, asserted_cardinality,
|
||||
expected_error):
|
||||
dataset = dataset_ops.Dataset.range(num_elements)
|
||||
dataset = dataset.apply(
|
||||
cardinality.assert_cardinality(asserted_cardinality))
|
||||
get_next = self.getNext(dataset)
|
||||
with self.assertRaisesRegexp(errors.FailedPreconditionError,
|
||||
expected_error):
|
||||
while True:
|
||||
self.evaluate(get_next())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -168,9 +168,8 @@ class CardinalityTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
_test_combinations()))
|
||||
def testCardinality(self, dataset_fn, expected_result):
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual(
|
||||
sess.run(cardinality.cardinality(dataset_fn())), expected_result)
|
||||
self.assertEqual(
|
||||
self.evaluate(cardinality.cardinality(dataset_fn())), expected_result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -31,6 +31,24 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "assert_cardinality_dataset_serialization_test",
|
||||
size = "small",
|
||||
srcs = ["assert_cardinality_dataset_serialization_test.py"],
|
||||
tags = [
|
||||
"no_oss",
|
||||
"no_pip",
|
||||
"no_windows",
|
||||
],
|
||||
deps = [
|
||||
":dataset_serialization_test_base",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "auto_shard_dataset_serialization_test",
|
||||
size = "medium",
|
||||
|
@ -0,0 +1,45 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for the AssertCardinalityDataset serialization."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
|
||||
from tensorflow.python.data.experimental.ops import cardinality
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class AssertCardinalityDatasetSerializationTest(
|
||||
dataset_serialization_test_base.DatasetSerializationTestBase,
|
||||
parameterized.TestCase):
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCardinality(self):
|
||||
|
||||
def build_dataset(num_elements):
|
||||
return dataset_ops.Dataset.range(num_elements).apply(
|
||||
cardinality.assert_cardinality(num_elements))
|
||||
|
||||
self.run_core_tests(lambda: build_dataset(200), 200)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -320,7 +320,6 @@ class _MapAndBatchDataset(dataset_ops.UnaryDataset):
|
||||
|
||||
def __init__(self, input_dataset, map_func, batch_size, num_parallel_calls,
|
||||
drop_remainder, use_legacy_function=False):
|
||||
"""See `Dataset.map()` for details."""
|
||||
self._input_dataset = input_dataset
|
||||
|
||||
self._map_func = dataset_ops.StructuredFunctionWrapper(
|
||||
|
@ -17,6 +17,9 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
@ -39,6 +42,18 @@ def cardinality(dataset):
|
||||
analysis fails to determine the number of elements in `dataset` (e.g. when the
|
||||
dataset source is a file).
|
||||
|
||||
>>> dataset = tf.data.Dataset.range(42)
|
||||
>>> print(tf.data.experimental.cardinality(dataset).numpy())
|
||||
42
|
||||
>>> dataset = dataset.repeat()
|
||||
>>> cardinality = tf.data.experimental.cardinality(dataset)
|
||||
>>> print((cardinality == tf.data.experimental.INFINITE_CARDINALITY).numpy())
|
||||
True
|
||||
>>> dataset = dataset.filter(lambda x: True)
|
||||
>>> cardinality = tf.data.experimental.cardinality(dataset)
|
||||
>>> print((cardinality == tf.data.experimental.UNKNOWN_CARDINALITY).numpy())
|
||||
True
|
||||
|
||||
Args:
|
||||
dataset: A `tf.data.Dataset` for which to determine cardinality.
|
||||
|
||||
@ -49,3 +64,52 @@ def cardinality(dataset):
|
||||
"""
|
||||
|
||||
return ged_ops.dataset_cardinality(dataset._variant_tensor) # pylint: disable=protected-access
|
||||
|
||||
|
||||
@tf_export("data.experimental.assert_cardinality")
|
||||
def assert_cardinality(expected_cardinality):
|
||||
"""Asserts the cardinality of the input dataset.
|
||||
|
||||
NOTE: The following assumes that "examples.tfrecord" contains 42 records.
|
||||
|
||||
>>> dataset = tf.data.TFRecordDataset("examples.tfrecord")
|
||||
>>> cardinality = tf.data.experimental.cardinality(dataset)
|
||||
>>> print((cardinality == tf.data.experimental.UNKNOWN_CARDINALITY).numpy())
|
||||
True
|
||||
>>> dataset = dataset.apply(tf.data.experimental.assert_cardinality(42))
|
||||
>>> print(tf.data.experimental.cardinality(dataset).numpy())
|
||||
42
|
||||
|
||||
Args:
|
||||
expected_cardinality: The expected cardinality of the input dataset.
|
||||
|
||||
Returns:
|
||||
A `Dataset` transformation function, which can be passed to
|
||||
`tf.data.Dataset.apply`.
|
||||
|
||||
Raises:
|
||||
FailedPreconditionError: The assertion is checked at runtime (when iterating
|
||||
the dataset) and an error is raised if the actual and expected cardinality
|
||||
differ.
|
||||
"""
|
||||
def _apply_fn(dataset):
|
||||
return _AssertCardinalityDataset(dataset, expected_cardinality)
|
||||
|
||||
return _apply_fn
|
||||
|
||||
|
||||
class _AssertCardinalityDataset(dataset_ops.UnaryUnchangedStructureDataset):
|
||||
"""A `Dataset` that assert the cardinality of its input."""
|
||||
|
||||
def __init__(self, input_dataset, expected_cardinality):
|
||||
self._input_dataset = input_dataset
|
||||
self._expected_cardinality = ops.convert_to_tensor(
|
||||
expected_cardinality, dtype=dtypes.int64, name="expected_cardinality")
|
||||
|
||||
# pylint: enable=protected-access
|
||||
variant_tensor = ged_ops.assert_cardinality_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._expected_cardinality,
|
||||
**self._flat_structure)
|
||||
super(_AssertCardinalityDataset, self).__init__(input_dataset,
|
||||
variant_tensor)
|
||||
|
@ -100,6 +100,10 @@ tf_module {
|
||||
name: "TensorStructure"
|
||||
argspec: "args=[\'dtype\', \'shape\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "assert_cardinality"
|
||||
argspec: "args=[\'expected_cardinality\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "bucket_by_sequence_length"
|
||||
argspec: "args=[\'element_length_func\', \'bucket_boundaries\', \'bucket_batch_sizes\', \'padded_shapes\', \'padding_values\', \'pad_to_bucket_boundary\', \'no_padding\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'False\', \'False\'], "
|
||||
|
@ -204,6 +204,10 @@ tf_module {
|
||||
name: "Assert"
|
||||
argspec: "args=[\'condition\', \'data\', \'summarize\', \'name\'], varargs=None, keywords=None, defaults=[\'3\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "AssertCardinalityDataset"
|
||||
argspec: "args=[\'input_dataset\', \'cardinality\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "AssertNextDataset"
|
||||
argspec: "args=[\'input_dataset\', \'transformations\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -72,6 +72,10 @@ tf_module {
|
||||
name: "Counter"
|
||||
argspec: "args=[\'start\', \'step\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \"<dtype: \'int64\'>\"], "
|
||||
}
|
||||
member_method {
|
||||
name: "assert_cardinality"
|
||||
argspec: "args=[\'expected_cardinality\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "bucket_by_sequence_length"
|
||||
argspec: "args=[\'element_length_func\', \'bucket_boundaries\', \'bucket_batch_sizes\', \'padded_shapes\', \'padding_values\', \'pad_to_bucket_boundary\', \'no_padding\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'False\', \'False\'], "
|
||||
|
@ -204,6 +204,10 @@ tf_module {
|
||||
name: "Assert"
|
||||
argspec: "args=[\'condition\', \'data\', \'summarize\', \'name\'], varargs=None, keywords=None, defaults=[\'3\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "AssertCardinalityDataset"
|
||||
argspec: "args=[\'input_dataset\', \'cardinality\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "AssertNextDataset"
|
||||
argspec: "args=[\'input_dataset\', \'transformations\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user