[tf.data] Introducing API for statically asserting the cardinality of a dataset. This API improves user experience when using tf.data with Keras in situations where the cardinality of a file-based dataset is known and can thus be asserted. Having a dataset with a statically known cardinality allows Keras to display accurate progress bar on the first epoch.
PiperOrigin-RevId: 295268926 Change-Id: If15704bee651fbefa385b1bc2d2024633d008bcf
This commit is contained in:
parent
0148382b20
commit
d25235b196
@ -0,0 +1,4 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "AssertCardinalityDataset"
|
||||||
|
visibility: HIDDEN
|
||||||
|
}
|
@ -14,6 +14,17 @@ package(
|
|||||||
|
|
||||||
exports_files(["LICENSE"])
|
exports_files(["LICENSE"])
|
||||||
|
|
||||||
|
tf_kernel_library(
|
||||||
|
name = "assert_cardinality_dataset_op",
|
||||||
|
srcs = ["assert_cardinality_dataset_op.cc"],
|
||||||
|
hdrs = ["assert_cardinality_dataset_op.h"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core:experimental_dataset_ops_op_lib",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core/kernels/data:name_utils",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_kernel_library(
|
tf_kernel_library(
|
||||||
name = "assert_next_dataset_op",
|
name = "assert_next_dataset_op",
|
||||||
srcs = ["assert_next_dataset_op.cc"],
|
srcs = ["assert_next_dataset_op.cc"],
|
||||||
@ -22,7 +33,6 @@ tf_kernel_library(
|
|||||||
"//tensorflow/core:experimental_dataset_ops_op_lib",
|
"//tensorflow/core:experimental_dataset_ops_op_lib",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core/kernels/data:name_utils",
|
"//tensorflow/core/kernels/data:name_utils",
|
||||||
"//third_party/eigen3",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -557,6 +567,7 @@ tf_cc_test(
|
|||||||
tf_kernel_library(
|
tf_kernel_library(
|
||||||
name = "dataset_kernels",
|
name = "dataset_kernels",
|
||||||
deps = [
|
deps = [
|
||||||
|
":assert_cardinality_dataset_op",
|
||||||
":assert_next_dataset_op",
|
":assert_next_dataset_op",
|
||||||
":auto_shard_dataset_op",
|
":auto_shard_dataset_op",
|
||||||
":choose_fastest_branch_dataset_op",
|
":choose_fastest_branch_dataset_op",
|
||||||
|
@ -0,0 +1,180 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "tensorflow/core/kernels/data/experimental/assert_cardinality_dataset_op.h"
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/kernels/data/name_utils.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace data {
|
||||||
|
namespace experimental {
|
||||||
|
|
||||||
|
/* static */ constexpr const char* const
|
||||||
|
AssertCardinalityDatasetOp::kInputDataset;
|
||||||
|
/* static */ constexpr const char* const
|
||||||
|
AssertCardinalityDatasetOp::kDatasetType;
|
||||||
|
/* static */ constexpr const char* const
|
||||||
|
AssertCardinalityDatasetOp::kCardinality;
|
||||||
|
/* static */ constexpr const char* const
|
||||||
|
AssertCardinalityDatasetOp::kOutputTypes;
|
||||||
|
/* static */ constexpr const char* const
|
||||||
|
AssertCardinalityDatasetOp::kOutputShapes;
|
||||||
|
|
||||||
|
class AssertCardinalityDatasetOp::Dataset : public DatasetBase {
|
||||||
|
public:
|
||||||
|
Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 cardinality,
|
||||||
|
const DataTypeVector& output_types,
|
||||||
|
const std::vector<PartialTensorShape>& output_shapes)
|
||||||
|
: DatasetBase(DatasetContext(ctx)),
|
||||||
|
input_(input),
|
||||||
|
cardinality_(cardinality),
|
||||||
|
output_types_(output_types),
|
||||||
|
output_shapes_(output_shapes) {
|
||||||
|
input_->Ref();
|
||||||
|
}
|
||||||
|
|
||||||
|
~Dataset() override { input_->Unref(); }
|
||||||
|
|
||||||
|
std::unique_ptr<IteratorBase> MakeIteratorInternal(
|
||||||
|
const string& prefix) const override {
|
||||||
|
return absl::make_unique<Iterator>(Iterator::Params{
|
||||||
|
this, name_utils::IteratorPrefix(kDatasetType, prefix)});
|
||||||
|
}
|
||||||
|
|
||||||
|
const DataTypeVector& output_dtypes() const override { return output_types_; }
|
||||||
|
const std::vector<PartialTensorShape>& output_shapes() const override {
|
||||||
|
return output_shapes_;
|
||||||
|
}
|
||||||
|
|
||||||
|
string DebugString() const override {
|
||||||
|
return name_utils::DatasetDebugString(kDatasetType);
|
||||||
|
}
|
||||||
|
|
||||||
|
int64 Cardinality() const override { return cardinality_; }
|
||||||
|
|
||||||
|
Status CheckExternalState() const override {
|
||||||
|
return input_->CheckExternalState();
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||||
|
DatasetGraphDefBuilder* b,
|
||||||
|
Node** output) const override {
|
||||||
|
Node* input_graph_node = nullptr;
|
||||||
|
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
|
||||||
|
Node* cardinality_node = nullptr;
|
||||||
|
TF_RETURN_IF_ERROR(b->AddScalar(cardinality_, &cardinality_node));
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
b->AddDataset(this, {input_graph_node, cardinality_node}, output));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
class Iterator : public DatasetIterator<Dataset> {
|
||||||
|
public:
|
||||||
|
explicit Iterator(const Params& params)
|
||||||
|
: DatasetIterator<Dataset>(params), num_elements_(0) {}
|
||||||
|
|
||||||
|
Status Initialize(IteratorContext* ctx) override {
|
||||||
|
return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GetNextInternal(IteratorContext* ctx,
|
||||||
|
std::vector<Tensor>* out_tensors,
|
||||||
|
bool* end_of_sequence) override {
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
|
||||||
|
if (!*end_of_sequence) {
|
||||||
|
num_elements_++;
|
||||||
|
}
|
||||||
|
if (*end_of_sequence && num_elements_ != dataset()->cardinality_) {
|
||||||
|
return errors::FailedPrecondition(
|
||||||
|
"Input dataset was expected to contain ",
|
||||||
|
ElementString(dataset()->cardinality_), " but contained only ",
|
||||||
|
ElementString(num_elements_), ".");
|
||||||
|
}
|
||||||
|
if (num_elements_ > dataset()->cardinality_) {
|
||||||
|
return errors::FailedPrecondition(
|
||||||
|
"Input dataset was expected to contain ",
|
||||||
|
ElementString(dataset()->cardinality_), " but contained at least ",
|
||||||
|
ElementString(num_elements_), ".");
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
std::shared_ptr<model::Node> CreateNode(
|
||||||
|
IteratorContext* ctx, model::Node::Args args) const override {
|
||||||
|
return model::MakeKnownRatioNode(std::move(args),
|
||||||
|
/*ratio=*/1);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
writer->WriteScalar(full_name("num_elements"), num_elements_));
|
||||||
|
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status RestoreInternal(IteratorContext* ctx,
|
||||||
|
IteratorStateReader* reader) override {
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
reader->ReadScalar(full_name("num_elements"), &num_elements_));
|
||||||
|
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
static string ElementString(int64 n) {
|
||||||
|
return strings::StrCat(n, " element", n != 1 ? "s" : "");
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<IteratorBase> input_impl_;
|
||||||
|
int64 num_elements_;
|
||||||
|
};
|
||||||
|
|
||||||
|
const DatasetBase* input_;
|
||||||
|
const int64 cardinality_;
|
||||||
|
const DataTypeVector output_types_;
|
||||||
|
const std::vector<PartialTensorShape> output_shapes_;
|
||||||
|
};
|
||||||
|
|
||||||
|
AssertCardinalityDatasetOp::AssertCardinalityDatasetOp(
|
||||||
|
OpKernelConstruction* ctx)
|
||||||
|
: UnaryDatasetOpKernel(ctx) {
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
|
||||||
|
}
|
||||||
|
|
||||||
|
void AssertCardinalityDatasetOp::MakeDataset(OpKernelContext* ctx,
|
||||||
|
DatasetBase* input,
|
||||||
|
DatasetBase** output) {
|
||||||
|
int64 cardinality;
|
||||||
|
OP_REQUIRES_OK(ctx,
|
||||||
|
ParseScalarArgument<int64>(ctx, kCardinality, &cardinality));
|
||||||
|
*output = new Dataset(ctx, input, cardinality, output_types_, output_shapes_);
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("AssertCardinalityDataset").Device(DEVICE_CPU),
|
||||||
|
AssertCardinalityDatasetOp);
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
} // namespace experimental
|
||||||
|
} // namespace data
|
||||||
|
} // namespace tensorflow
|
@ -0,0 +1,48 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_ASSERT_CARDINALITY_DATASET_OP_H_
|
||||||
|
#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_ASSERT_CARDINALITY_DATASET_OP_H_
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/dataset.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace data {
|
||||||
|
namespace experimental {
|
||||||
|
|
||||||
|
class AssertCardinalityDatasetOp : public UnaryDatasetOpKernel {
|
||||||
|
public:
|
||||||
|
static constexpr const char* const kDatasetType = "AssertCardinality";
|
||||||
|
static constexpr const char* const kInputDataset = "input_dataset";
|
||||||
|
static constexpr const char* const kCardinality = "cardinality";
|
||||||
|
static constexpr const char* const kOutputTypes = "output_types";
|
||||||
|
static constexpr const char* const kOutputShapes = "output_shapes";
|
||||||
|
|
||||||
|
explicit AssertCardinalityDatasetOp(OpKernelConstruction* ctx);
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||||
|
DatasetBase** output) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
class Dataset;
|
||||||
|
DataTypeVector output_types_;
|
||||||
|
std::vector<PartialTensorShape> output_shapes_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace experimental
|
||||||
|
} // namespace data
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_ASSERT_CARDINALITY_DATASET_OP_H_
|
@ -17,6 +17,19 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
REGISTER_OP("AssertCardinalityDataset")
|
||||||
|
.Input("input_dataset: variant")
|
||||||
|
.Input("cardinality: int64")
|
||||||
|
.Output("handle: variant")
|
||||||
|
.Attr("output_types: list(type) >= 1")
|
||||||
|
.Attr("output_shapes: list(shape) >= 1")
|
||||||
|
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||||
|
shape_inference::ShapeHandle unused;
|
||||||
|
// cardinality should be a scalar.
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
|
||||||
|
return shape_inference::ScalarShape(c);
|
||||||
|
});
|
||||||
|
|
||||||
REGISTER_OP("AssertNextDataset")
|
REGISTER_OP("AssertNextDataset")
|
||||||
.Input("input_dataset: variant")
|
.Input("input_dataset: variant")
|
||||||
.Input("transformations: string")
|
.Input("transformations: string")
|
||||||
|
@ -45,6 +45,7 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
|
|||||||
@@TensorStructure
|
@@TensorStructure
|
||||||
@@ThreadingOptions
|
@@ThreadingOptions
|
||||||
|
|
||||||
|
@@assert_cardinality
|
||||||
@@bucket_by_sequence_length
|
@@bucket_by_sequence_length
|
||||||
@@bytes_produced_stats
|
@@bytes_produced_stats
|
||||||
@@cardinality
|
@@cardinality
|
||||||
@ -93,6 +94,7 @@ from tensorflow.python.data.experimental.ops.batching import dense_to_sparse_bat
|
|||||||
from tensorflow.python.data.experimental.ops.batching import map_and_batch
|
from tensorflow.python.data.experimental.ops.batching import map_and_batch
|
||||||
from tensorflow.python.data.experimental.ops.batching import map_and_batch_with_legacy_function
|
from tensorflow.python.data.experimental.ops.batching import map_and_batch_with_legacy_function
|
||||||
from tensorflow.python.data.experimental.ops.batching import unbatch
|
from tensorflow.python.data.experimental.ops.batching import unbatch
|
||||||
|
from tensorflow.python.data.experimental.ops.cardinality import assert_cardinality
|
||||||
from tensorflow.python.data.experimental.ops.cardinality import cardinality
|
from tensorflow.python.data.experimental.ops.cardinality import cardinality
|
||||||
from tensorflow.python.data.experimental.ops.cardinality import INFINITE as INFINITE_CARDINALITY
|
from tensorflow.python.data.experimental.ops.cardinality import INFINITE as INFINITE_CARDINALITY
|
||||||
from tensorflow.python.data.experimental.ops.cardinality import UNKNOWN as UNKNOWN_CARDINALITY
|
from tensorflow.python.data.experimental.ops.cardinality import UNKNOWN as UNKNOWN_CARDINALITY
|
||||||
|
@ -8,6 +8,17 @@ package(
|
|||||||
|
|
||||||
exports_files(["LICENSE"])
|
exports_files(["LICENSE"])
|
||||||
|
|
||||||
|
tf_py_test(
|
||||||
|
name = "assert_cardinality_test",
|
||||||
|
srcs = ["assert_cardinality_test.py"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/python/data/experimental/ops:cardinality",
|
||||||
|
"//tensorflow/python/data/kernel_tests:test_base",
|
||||||
|
"//tensorflow/python/data/ops:dataset_ops",
|
||||||
|
"@absl_py//absl/testing:parameterized",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_py_test(
|
tf_py_test(
|
||||||
name = "assert_next_test",
|
name = "assert_next_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
|
@ -0,0 +1,79 @@
|
|||||||
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for `tf.data.experimental.assert_cardinality()`."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
|
|
||||||
|
from tensorflow.python.data.experimental.ops import cardinality
|
||||||
|
from tensorflow.python.data.kernel_tests import test_base
|
||||||
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
|
from tensorflow.python.framework import combinations
|
||||||
|
from tensorflow.python.framework import errors
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
class AssertCardinalityTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||||
|
"""Tests for `tf.data.experimental.assert_cardinality()`."""
|
||||||
|
|
||||||
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
|
def testCorrectCardinality(self):
|
||||||
|
dataset = dataset_ops.Dataset.range(10).filter(lambda x: True)
|
||||||
|
self.assertEqual(
|
||||||
|
self.evaluate(cardinality.cardinality(dataset)), cardinality.UNKNOWN)
|
||||||
|
self.assertDatasetProduces(dataset, expected_output=range(10))
|
||||||
|
dataset = dataset.apply(cardinality.assert_cardinality(10))
|
||||||
|
self.assertEqual(self.evaluate(cardinality.cardinality(dataset)), 10)
|
||||||
|
self.assertDatasetProduces(dataset, expected_output=range(10))
|
||||||
|
|
||||||
|
@combinations.generate(
|
||||||
|
combinations.times(
|
||||||
|
test_base.default_test_combinations(),
|
||||||
|
combinations.combine(
|
||||||
|
num_elements=10,
|
||||||
|
asserted_cardinality=20,
|
||||||
|
expected_error="Input dataset was expected to contain 20 "
|
||||||
|
"elements but contained only 10 elements.") +
|
||||||
|
combinations.combine(
|
||||||
|
num_elements=1,
|
||||||
|
asserted_cardinality=20,
|
||||||
|
expected_error="Input dataset was expected to contain 20 "
|
||||||
|
"elements but contained only 1 element.") +
|
||||||
|
combinations.combine(
|
||||||
|
num_elements=10,
|
||||||
|
asserted_cardinality=5,
|
||||||
|
expected_error="Input dataset was expected to contain 5 "
|
||||||
|
"elements but contained at least 6 elements.") +
|
||||||
|
combinations.combine(
|
||||||
|
num_elements=10,
|
||||||
|
asserted_cardinality=1,
|
||||||
|
expected_error="Input dataset was expected to contain 1 "
|
||||||
|
"element but contained at least 2 elements.")))
|
||||||
|
def testIncorrectCardinality(self, num_elements, asserted_cardinality,
|
||||||
|
expected_error):
|
||||||
|
dataset = dataset_ops.Dataset.range(num_elements)
|
||||||
|
dataset = dataset.apply(
|
||||||
|
cardinality.assert_cardinality(asserted_cardinality))
|
||||||
|
get_next = self.getNext(dataset)
|
||||||
|
with self.assertRaisesRegexp(errors.FailedPreconditionError,
|
||||||
|
expected_error):
|
||||||
|
while True:
|
||||||
|
self.evaluate(get_next())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test.main()
|
@ -168,9 +168,8 @@ class CardinalityTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
combinations.times(test_base.default_test_combinations(),
|
combinations.times(test_base.default_test_combinations(),
|
||||||
_test_combinations()))
|
_test_combinations()))
|
||||||
def testCardinality(self, dataset_fn, expected_result):
|
def testCardinality(self, dataset_fn, expected_result):
|
||||||
with self.cached_session() as sess:
|
self.assertEqual(
|
||||||
self.assertEqual(
|
self.evaluate(cardinality.cardinality(dataset_fn())), expected_result)
|
||||||
sess.run(cardinality.cardinality(dataset_fn())), expected_result)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -31,6 +31,24 @@ py_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_py_test(
|
||||||
|
name = "assert_cardinality_dataset_serialization_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["assert_cardinality_dataset_serialization_test.py"],
|
||||||
|
tags = [
|
||||||
|
"no_oss",
|
||||||
|
"no_pip",
|
||||||
|
"no_windows",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":dataset_serialization_test_base",
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
"//tensorflow/python:errors",
|
||||||
|
"//tensorflow/python/data/ops:dataset_ops",
|
||||||
|
"@absl_py//absl/testing:parameterized",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_py_test(
|
tf_py_test(
|
||||||
name = "auto_shard_dataset_serialization_test",
|
name = "auto_shard_dataset_serialization_test",
|
||||||
size = "medium",
|
size = "medium",
|
||||||
|
@ -0,0 +1,45 @@
|
|||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for the AssertCardinalityDataset serialization."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
|
|
||||||
|
from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
|
||||||
|
from tensorflow.python.data.experimental.ops import cardinality
|
||||||
|
from tensorflow.python.data.kernel_tests import test_base
|
||||||
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
|
from tensorflow.python.framework import combinations
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
class AssertCardinalityDatasetSerializationTest(
|
||||||
|
dataset_serialization_test_base.DatasetSerializationTestBase,
|
||||||
|
parameterized.TestCase):
|
||||||
|
|
||||||
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
|
def testCardinality(self):
|
||||||
|
|
||||||
|
def build_dataset(num_elements):
|
||||||
|
return dataset_ops.Dataset.range(num_elements).apply(
|
||||||
|
cardinality.assert_cardinality(num_elements))
|
||||||
|
|
||||||
|
self.run_core_tests(lambda: build_dataset(200), 200)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test.main()
|
@ -320,7 +320,6 @@ class _MapAndBatchDataset(dataset_ops.UnaryDataset):
|
|||||||
|
|
||||||
def __init__(self, input_dataset, map_func, batch_size, num_parallel_calls,
|
def __init__(self, input_dataset, map_func, batch_size, num_parallel_calls,
|
||||||
drop_remainder, use_legacy_function=False):
|
drop_remainder, use_legacy_function=False):
|
||||||
"""See `Dataset.map()` for details."""
|
|
||||||
self._input_dataset = input_dataset
|
self._input_dataset = input_dataset
|
||||||
|
|
||||||
self._map_func = dataset_ops.StructuredFunctionWrapper(
|
self._map_func = dataset_ops.StructuredFunctionWrapper(
|
||||||
|
@ -17,6 +17,9 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
|
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
@ -39,6 +42,18 @@ def cardinality(dataset):
|
|||||||
analysis fails to determine the number of elements in `dataset` (e.g. when the
|
analysis fails to determine the number of elements in `dataset` (e.g. when the
|
||||||
dataset source is a file).
|
dataset source is a file).
|
||||||
|
|
||||||
|
>>> dataset = tf.data.Dataset.range(42)
|
||||||
|
>>> print(tf.data.experimental.cardinality(dataset).numpy())
|
||||||
|
42
|
||||||
|
>>> dataset = dataset.repeat()
|
||||||
|
>>> cardinality = tf.data.experimental.cardinality(dataset)
|
||||||
|
>>> print((cardinality == tf.data.experimental.INFINITE_CARDINALITY).numpy())
|
||||||
|
True
|
||||||
|
>>> dataset = dataset.filter(lambda x: True)
|
||||||
|
>>> cardinality = tf.data.experimental.cardinality(dataset)
|
||||||
|
>>> print((cardinality == tf.data.experimental.UNKNOWN_CARDINALITY).numpy())
|
||||||
|
True
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset: A `tf.data.Dataset` for which to determine cardinality.
|
dataset: A `tf.data.Dataset` for which to determine cardinality.
|
||||||
|
|
||||||
@ -49,3 +64,52 @@ def cardinality(dataset):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
return ged_ops.dataset_cardinality(dataset._variant_tensor) # pylint: disable=protected-access
|
return ged_ops.dataset_cardinality(dataset._variant_tensor) # pylint: disable=protected-access
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export("data.experimental.assert_cardinality")
|
||||||
|
def assert_cardinality(expected_cardinality):
|
||||||
|
"""Asserts the cardinality of the input dataset.
|
||||||
|
|
||||||
|
NOTE: The following assumes that "examples.tfrecord" contains 42 records.
|
||||||
|
|
||||||
|
>>> dataset = tf.data.TFRecordDataset("examples.tfrecord")
|
||||||
|
>>> cardinality = tf.data.experimental.cardinality(dataset)
|
||||||
|
>>> print((cardinality == tf.data.experimental.UNKNOWN_CARDINALITY).numpy())
|
||||||
|
True
|
||||||
|
>>> dataset = dataset.apply(tf.data.experimental.assert_cardinality(42))
|
||||||
|
>>> print(tf.data.experimental.cardinality(dataset).numpy())
|
||||||
|
42
|
||||||
|
|
||||||
|
Args:
|
||||||
|
expected_cardinality: The expected cardinality of the input dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `Dataset` transformation function, which can be passed to
|
||||||
|
`tf.data.Dataset.apply`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FailedPreconditionError: The assertion is checked at runtime (when iterating
|
||||||
|
the dataset) and an error is raised if the actual and expected cardinality
|
||||||
|
differ.
|
||||||
|
"""
|
||||||
|
def _apply_fn(dataset):
|
||||||
|
return _AssertCardinalityDataset(dataset, expected_cardinality)
|
||||||
|
|
||||||
|
return _apply_fn
|
||||||
|
|
||||||
|
|
||||||
|
class _AssertCardinalityDataset(dataset_ops.UnaryUnchangedStructureDataset):
|
||||||
|
"""A `Dataset` that assert the cardinality of its input."""
|
||||||
|
|
||||||
|
def __init__(self, input_dataset, expected_cardinality):
|
||||||
|
self._input_dataset = input_dataset
|
||||||
|
self._expected_cardinality = ops.convert_to_tensor(
|
||||||
|
expected_cardinality, dtype=dtypes.int64, name="expected_cardinality")
|
||||||
|
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
variant_tensor = ged_ops.assert_cardinality_dataset(
|
||||||
|
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||||
|
self._expected_cardinality,
|
||||||
|
**self._flat_structure)
|
||||||
|
super(_AssertCardinalityDataset, self).__init__(input_dataset,
|
||||||
|
variant_tensor)
|
||||||
|
@ -100,6 +100,10 @@ tf_module {
|
|||||||
name: "TensorStructure"
|
name: "TensorStructure"
|
||||||
argspec: "args=[\'dtype\', \'shape\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'dtype\', \'shape\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "assert_cardinality"
|
||||||
|
argspec: "args=[\'expected_cardinality\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "bucket_by_sequence_length"
|
name: "bucket_by_sequence_length"
|
||||||
argspec: "args=[\'element_length_func\', \'bucket_boundaries\', \'bucket_batch_sizes\', \'padded_shapes\', \'padding_values\', \'pad_to_bucket_boundary\', \'no_padding\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'False\', \'False\'], "
|
argspec: "args=[\'element_length_func\', \'bucket_boundaries\', \'bucket_batch_sizes\', \'padded_shapes\', \'padding_values\', \'pad_to_bucket_boundary\', \'no_padding\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'False\', \'False\'], "
|
||||||
|
@ -204,6 +204,10 @@ tf_module {
|
|||||||
name: "Assert"
|
name: "Assert"
|
||||||
argspec: "args=[\'condition\', \'data\', \'summarize\', \'name\'], varargs=None, keywords=None, defaults=[\'3\', \'None\'], "
|
argspec: "args=[\'condition\', \'data\', \'summarize\', \'name\'], varargs=None, keywords=None, defaults=[\'3\', \'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "AssertCardinalityDataset"
|
||||||
|
argspec: "args=[\'input_dataset\', \'cardinality\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "AssertNextDataset"
|
name: "AssertNextDataset"
|
||||||
argspec: "args=[\'input_dataset\', \'transformations\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'input_dataset\', \'transformations\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
@ -72,6 +72,10 @@ tf_module {
|
|||||||
name: "Counter"
|
name: "Counter"
|
||||||
argspec: "args=[\'start\', \'step\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \"<dtype: \'int64\'>\"], "
|
argspec: "args=[\'start\', \'step\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \"<dtype: \'int64\'>\"], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "assert_cardinality"
|
||||||
|
argspec: "args=[\'expected_cardinality\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "bucket_by_sequence_length"
|
name: "bucket_by_sequence_length"
|
||||||
argspec: "args=[\'element_length_func\', \'bucket_boundaries\', \'bucket_batch_sizes\', \'padded_shapes\', \'padding_values\', \'pad_to_bucket_boundary\', \'no_padding\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'False\', \'False\'], "
|
argspec: "args=[\'element_length_func\', \'bucket_boundaries\', \'bucket_batch_sizes\', \'padded_shapes\', \'padding_values\', \'pad_to_bucket_boundary\', \'no_padding\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'False\', \'False\'], "
|
||||||
|
@ -204,6 +204,10 @@ tf_module {
|
|||||||
name: "Assert"
|
name: "Assert"
|
||||||
argspec: "args=[\'condition\', \'data\', \'summarize\', \'name\'], varargs=None, keywords=None, defaults=[\'3\', \'None\'], "
|
argspec: "args=[\'condition\', \'data\', \'summarize\', \'name\'], varargs=None, keywords=None, defaults=[\'3\', \'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "AssertCardinalityDataset"
|
||||||
|
argspec: "args=[\'input_dataset\', \'cardinality\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "AssertNextDataset"
|
name: "AssertNextDataset"
|
||||||
argspec: "args=[\'input_dataset\', \'transformations\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'input_dataset\', \'transformations\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
Loading…
x
Reference in New Issue
Block a user