From e9d47fbff0d644a75c6f3dcdcb852685ef515b64 Mon Sep 17 00:00:00 2001 From: Shivani Agrawal Date: Wed, 18 Apr 2018 16:01:55 -0700 Subject: [PATCH] Adds dataset transformation function `set_stats_aggregator(..)`, which sets the given `stats_aggregator` for aggregating the input dataset stats. PiperOrigin-RevId: 193432590 --- .../kernel_tests/stats_dataset_ops_test.py | 71 ++++----- .../contrib/data/python/ops/stats_ops.py | 61 +++++--- tensorflow/core/BUILD | 1 + .../api_def_IteratorSetStatsAggregator.pbtxt | 4 - .../api_def_SetStatsAggregatorDataset.pbtxt | 3 + .../api_def_IteratorSetStatsAggregator.pbtxt | 4 - .../api_def_SetStatsAggregatorDataset.pbtxt | 4 + .../data => framework}/stats_aggregator.h | 6 +- tensorflow/core/kernels/data/BUILD | 32 ++--- tensorflow/core/kernels/data/iterator_ops.cc | 32 +---- .../data/stats_aggregator_dataset_op.cc | 135 ++++++++++++++++++ .../core/kernels/data/stats_aggregator_ops.cc | 2 +- .../core/kernels/data/stats_dataset_ops.cc | 2 +- .../core/ops/compat/ops_history.v1.pbtxt | 12 -- tensorflow/core/ops/dataset_ops.cc | 13 +- tensorflow/core/ops/ops.pbtxt | 12 -- 16 files changed, 242 insertions(+), 152 deletions(-) delete mode 100644 tensorflow/core/api_def/base_api/api_def_IteratorSetStatsAggregator.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_SetStatsAggregatorDataset.pbtxt delete mode 100644 tensorflow/core/api_def/python_api/api_def_IteratorSetStatsAggregator.pbtxt create mode 100644 tensorflow/core/api_def/python_api/api_def_SetStatsAggregatorDataset.pbtxt rename tensorflow/core/{kernels/data => framework}/stats_aggregator.h (94%) create mode 100644 tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py index 07bdf920446..7acbc676ceb 100644 --- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py @@ -50,17 +50,17 @@ class StatsDatasetTest(test.TestCase): self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto)) def testBytesProduced(self): + stats_aggregator = stats_ops.StatsAggregator() dataset = dataset_ops.Dataset.range(100).map( lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).apply( - stats_ops.bytes_produced_stats("bytes_produced")) + stats_ops.bytes_produced_stats("bytes_produced")).apply( + stats_ops.set_stats_aggregator(stats_aggregator)) iterator = dataset.make_initializable_iterator() - stats_aggregator = stats_ops.StatsAggregator() - stats_aggregator_subscriber = stats_aggregator.subscribe(iterator) next_element = iterator.get_next() summary_t = stats_aggregator.get_summary() with self.test_session() as sess: - sess.run([iterator.initializer, stats_aggregator_subscriber]) + sess.run(iterator.initializer) expected_sum = 0.0 for i in range(100): self.assertAllEqual( @@ -76,16 +76,16 @@ class StatsDatasetTest(test.TestCase): self._assertSummaryHasSum(summary_str, "bytes_produced", expected_sum) def testLatencyStats(self): - dataset = dataset_ops.Dataset.range(100).apply( - stats_ops.latency_stats("record_latency")) - iterator = dataset.make_initializable_iterator() stats_aggregator = stats_ops.StatsAggregator() - stats_aggregator_subscriber = stats_aggregator.subscribe(iterator) + dataset = dataset_ops.Dataset.range(100).apply( + stats_ops.latency_stats("record_latency")).apply( + stats_ops.set_stats_aggregator(stats_aggregator)) + iterator = dataset.make_initializable_iterator() next_element = iterator.get_next() summary_t = stats_aggregator.get_summary() with self.test_session() as sess: - sess.run([iterator.initializer, stats_aggregator_subscriber]) + sess.run(iterator.initializer) for i in range(100): self.assertEqual(i, sess.run(next_element)) self._assertSummaryHasCount( @@ -95,16 +95,15 @@ class StatsDatasetTest(test.TestCase): self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 100.0) def testReinitialize(self): - dataset = dataset_ops.Dataset.range(100).apply( - stats_ops.latency_stats("record_latency")) - iterator = dataset.make_initializable_iterator() stats_aggregator = stats_ops.StatsAggregator() - stats_aggregator_subscriber = stats_aggregator.subscribe(iterator) + dataset = dataset_ops.Dataset.range(100).apply( + stats_ops.latency_stats("record_latency")).apply( + stats_ops.set_stats_aggregator(stats_aggregator)) + iterator = dataset.make_initializable_iterator() next_element = iterator.get_next() summary_t = stats_aggregator.get_summary() with self.test_session() as sess: - sess.run(stats_aggregator_subscriber) for j in range(5): sess.run(iterator.initializer) for i in range(100): @@ -130,17 +129,17 @@ class StatsDatasetTest(test.TestCase): sess.run(next_element) def testMultipleTags(self): + stats_aggregator = stats_ops.StatsAggregator() dataset = dataset_ops.Dataset.range(100).apply( stats_ops.latency_stats("record_latency")).apply( - stats_ops.latency_stats("record_latency_2")) + stats_ops.latency_stats("record_latency_2")).apply( + stats_ops.set_stats_aggregator(stats_aggregator)) iterator = dataset.make_initializable_iterator() - stats_aggregator = stats_ops.StatsAggregator() - stats_aggregator_subscriber = stats_aggregator.subscribe(iterator) next_element = iterator.get_next() summary_t = stats_aggregator.get_summary() with self.test_session() as sess: - sess.run([iterator.initializer, stats_aggregator_subscriber]) + sess.run(iterator.initializer) for i in range(100): self.assertEqual(i, sess.run(next_element)) self._assertSummaryHasCount( @@ -154,17 +153,17 @@ class StatsDatasetTest(test.TestCase): sess.run(summary_t), "record_latency_2", 100.0) def testRepeatedTags(self): + stats_aggregator = stats_ops.StatsAggregator() dataset = dataset_ops.Dataset.range(100).apply( stats_ops.latency_stats("record_latency")).apply( - stats_ops.latency_stats("record_latency")) + stats_ops.latency_stats("record_latency")).apply( + stats_ops.set_stats_aggregator(stats_aggregator)) iterator = dataset.make_initializable_iterator() - stats_aggregator = stats_ops.StatsAggregator() - stats_aggregator_subscriber = stats_aggregator.subscribe(iterator) next_element = iterator.get_next() summary_t = stats_aggregator.get_summary() with self.test_session() as sess: - sess.run([iterator.initializer, stats_aggregator_subscriber]) + sess.run(iterator.initializer) for i in range(100): self.assertEqual(i, sess.run(next_element)) self._assertSummaryHasCount( @@ -174,19 +173,17 @@ class StatsDatasetTest(test.TestCase): self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0) def testMultipleIteratorsSameAggregator(self): + stats_aggregator = stats_ops.StatsAggregator() dataset = dataset_ops.Dataset.range(100).apply( - stats_ops.latency_stats("record_latency")) + stats_ops.latency_stats("record_latency")).apply( + stats_ops.set_stats_aggregator(stats_aggregator)) iterator_0 = dataset.make_initializable_iterator() iterator_1 = dataset.make_initializable_iterator() - stats_aggregator = stats_ops.StatsAggregator() - stats_aggregator_subscribers = [stats_aggregator.subscribe(iterator_0), - stats_aggregator.subscribe(iterator_1)] next_element = iterator_0.get_next() + iterator_1.get_next() summary_t = stats_aggregator.get_summary() with self.test_session() as sess: - sess.run([iterator_0.initializer, iterator_1.initializer, - stats_aggregator_subscribers]) + sess.run([iterator_0.initializer, iterator_1.initializer]) for i in range(100): self.assertEqual(i * 2, sess.run(next_element)) self._assertSummaryHasCount( @@ -195,20 +192,6 @@ class StatsDatasetTest(test.TestCase): sess.run(next_element) self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0) - def testMultipleStatsAggregatorsSameIteratorFail(self): - dataset = dataset_ops.Dataset.range(100).apply( - stats_ops.latency_stats("record_latency")) - iterator = dataset.make_initializable_iterator() - stats_aggregator_0 = stats_ops.StatsAggregator() - stats_aggregator_1 = stats_ops.StatsAggregator() - - with self.test_session() as sess: - sess.run(stats_aggregator_0.subscribe(iterator)) - # TODO(mrry): Consider making this allowable (and also allowing - # aggregators to unsubscribe). - with self.assertRaises(errors.FailedPreconditionError): - sess.run(stats_aggregator_1.subscribe(iterator)) - class StatsDatasetSerializationTest( dataset_serialization_test_base.DatasetSerializationTestBase): @@ -253,5 +236,9 @@ class StatsDatasetSerializationTest( None, num_outputs) +# TODO(shivaniagrawal): Can not checkpoint input_pipeline with the +# transformation `stats_ops.set_stats_aggregator`, since we don't support +# serializing StatsAggregator yet. + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/ops/stats_ops.py b/tensorflow/contrib/data/python/ops/stats_ops.py index b5cf0fcfe91..d3917203968 100644 --- a/tensorflow/contrib/data/python/ops/stats_ops.py +++ b/tensorflow/contrib/data/python/ops/stats_ops.py @@ -18,7 +18,6 @@ from __future__ import division from __future__ import print_function from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.ops import iterator_ops from tensorflow.python.data.util import nest from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes @@ -85,25 +84,53 @@ class StatsAggregator(object): """ return gen_dataset_ops.stats_aggregator_summary(self._resource) - def subscribe(self, iterator): - """Returns a @{tf.Operation} to associate this aggregator with `iterator`. - Note: Each @{tf.data.Iterator} can be associated with at most one - `StatsAggregator`. After running the operation that this function - returns, all statistics recorded in the iteration of `iterator` - will be stored in `stats_aggregator`. +class _SetStatsAggregatorDataset(dataset_ops.Dataset): + """A `Dataset` that acts as an identity, and sets given stats_aggregator.""" - Args: - iterator: A @{tf.data.Iterator} object. + def __init__(self, input_dataset, stats_aggregator): + super(_SetStatsAggregatorDataset, self).__init__() + self._input_dataset = input_dataset + self._stats_aggregator = stats_aggregator - Returns: - A @{tf.Operation} that, when run, associates this aggregator with - `iterator`. - """ - if not isinstance(iterator, iterator_ops.Iterator): - raise TypeError("`iterator` must be a `tf.data.Iterator` object.") - return gen_dataset_ops.iterator_set_stats_aggregator( - iterator._iterator_resource, self._resource) # pylint: disable=protected-access + def _as_variant_tensor(self): + return gen_dataset_ops.set_stats_aggregator_dataset( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + self._stats_aggregator._resource, # pylint: disable=protected-access + output_types=nest.flatten( + sparse.as_dense_types(self.output_types, self.output_classes)), + output_shapes=nest.flatten( + sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_types(self): + return self._input_dataset.output_types + + @property + def output_classes(self): + return self._input_dataset.output_classes + + +# TODO(shivaniagrawal): Expose these methods in `tf.contrib.data`. +def set_stats_aggregator(stats_aggregator): + """Set the given stats_aggregator for aggregating the input dataset stats. + + Args: + stats_aggregator: A `StatsAggregator` object. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply}. + """ + + def _apply_fn(dataset): + return _SetStatsAggregatorDataset(dataset, stats_aggregator) + + return _apply_fn def bytes_produced_stats(tag): diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 21f929894cd..54e7ab31d75 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -547,6 +547,7 @@ tf_cuda_library( "framework/selective_registration.h", "framework/session_state.h", "framework/shape_inference.h", + "framework/stats_aggregator.h", "framework/tensor.h", "framework/tensor_shape.h", "framework/tensor_slice.h", diff --git a/tensorflow/core/api_def/base_api/api_def_IteratorSetStatsAggregator.pbtxt b/tensorflow/core/api_def/base_api/api_def_IteratorSetStatsAggregator.pbtxt deleted file mode 100644 index c6f2212cd4f..00000000000 --- a/tensorflow/core/api_def/base_api/api_def_IteratorSetStatsAggregator.pbtxt +++ /dev/null @@ -1,4 +0,0 @@ -op { - graph_op_name: "IteratorSetStatsAggregator" - summary: "Associates the given iterator with the given statistics aggregator." -} diff --git a/tensorflow/core/api_def/base_api/api_def_SetStatsAggregatorDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_SetStatsAggregatorDataset.pbtxt new file mode 100644 index 00000000000..77123e143b2 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_SetStatsAggregatorDataset.pbtxt @@ -0,0 +1,3 @@ +op { + graph_op_name: "SetStatsAggregatorDataset" +} diff --git a/tensorflow/core/api_def/python_api/api_def_IteratorSetStatsAggregator.pbtxt b/tensorflow/core/api_def/python_api/api_def_IteratorSetStatsAggregator.pbtxt deleted file mode 100644 index db51ae3873c..00000000000 --- a/tensorflow/core/api_def/python_api/api_def_IteratorSetStatsAggregator.pbtxt +++ /dev/null @@ -1,4 +0,0 @@ -op { - graph_op_name: "IteratorSetStatsAggregator" - visibility: HIDDEN -} diff --git a/tensorflow/core/api_def/python_api/api_def_SetStatsAggregatorDataset.pbtxt b/tensorflow/core/api_def/python_api/api_def_SetStatsAggregatorDataset.pbtxt new file mode 100644 index 00000000000..3a8c1036ca3 --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_SetStatsAggregatorDataset.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "SetStatsAggregatorDataset" + visibility: HIDDEN +} diff --git a/tensorflow/core/kernels/data/stats_aggregator.h b/tensorflow/core/framework/stats_aggregator.h similarity index 94% rename from tensorflow/core/kernels/data/stats_aggregator.h rename to tensorflow/core/framework/stats_aggregator.h index 076a56b0bf1..a449f324e60 100644 --- a/tensorflow/core/kernels/data/stats_aggregator.h +++ b/tensorflow/core/framework/stats_aggregator.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CORE_KERNELS_DATA_STATS_AGGREGATOR_H_ -#define TENSORFLOW_CORE_KERNELS_DATA_STATS_AGGREGATOR_H_ +#ifndef TENSORFLOW_CORE_FRAMEWORK_STATS_AGGREGATOR_H_ +#define TENSORFLOW_CORE_FRAMEWORK_STATS_AGGREGATOR_H_ #include #include @@ -81,4 +81,4 @@ class StatsAggregatorResource : public ResourceBase { } // namespace tensorflow -#endif // TENSORFLOW_CORE_KERNELS_DATA_STATS_AGGREGATOR_H_ +#endif // TENSORFLOW_CORE_FRAMEWORK_STATS_AGGREGATOR_H_ diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD index e856ede44bc..221724e25d8 100644 --- a/tensorflow/core/kernels/data/BUILD +++ b/tensorflow/core/kernels/data/BUILD @@ -13,20 +13,10 @@ load( "tf_cc_test", ) -cc_library( - name = "stats_aggregator", - hdrs = ["stats_aggregator.h"], - deps = [ - "//tensorflow/core:framework", - "//tensorflow/core:lib", - ], -) - tf_kernel_library( name = "stats_aggregator_ops", srcs = ["stats_aggregator_ops.cc"], deps = [ - ":stats_aggregator", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -38,14 +28,7 @@ cc_library( name = "dataset", srcs = [], hdrs = ["dataset.h"], - deps = [ - "//tensorflow/core:core_cpu", - "//tensorflow/core:framework", - "//tensorflow/core:graph", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:protos_all_cc", - ], + deps = ["//tensorflow/core:framework"], ) cc_library( @@ -360,7 +343,6 @@ tf_kernel_library( srcs = ["stats_dataset_ops.cc"], deps = [ ":dataset", - ":stats_aggregator", "//tensorflow/core:dataset_ops_op_lib", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -368,6 +350,16 @@ tf_kernel_library( ], ) +tf_kernel_library( + name = "stats_aggregator_dataset_op", + srcs = ["stats_aggregator_dataset_op.cc"], + deps = [ + ":dataset", + "//tensorflow/core:framework", + "//tensorflow/core:lib_internal", + ], +) + tf_kernel_library( name = "random_dataset_op", srcs = ["random_dataset_op.cc"], @@ -510,7 +502,6 @@ tf_kernel_library( srcs = ["iterator_ops.cc"], deps = [ ":dataset", - ":stats_aggregator", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:dataset_ops_op_lib", "//tensorflow/core:framework", @@ -564,6 +555,7 @@ tf_kernel_library( ":slide_dataset_op", ":sparse_tensor_slice_dataset_op", ":sql_dataset_ops", + ":stats_aggregator_dataset_op", ":stats_aggregator_ops", ":stats_dataset_ops", ":take_dataset_op", diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc index 780f927a4f1..4e4997d7b3f 100644 --- a/tensorflow/core/kernels/data/iterator_ops.cc +++ b/tensorflow/core/kernels/data/iterator_ops.cc @@ -19,11 +19,11 @@ limitations under the License. #include "tensorflow/core/framework/iterator.pb.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/resource_op_kernel.h" +#include "tensorflow/core/framework/stats_aggregator.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/variant_op_registry.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/kernels/data/dataset.h" -#include "tensorflow/core/kernels/data/stats_aggregator.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/gtl/cleanup.h" @@ -203,10 +203,6 @@ class IteratorResource : public ResourceBase { return Status::OK(); } - void set_stats_aggregator(std::shared_ptr stats_aggregator) { - mutex_lock l(mu_); - stats_aggregator_ = std::move(stats_aggregator); - } std::shared_ptr stats_aggregator() { tf_shared_lock l(mu_); @@ -1075,30 +1071,6 @@ class DeserializeIteratorOp : public OpKernel { } }; -class IteratorSetStatsAggregatorOp : public OpKernel { - public: - explicit IteratorSetStatsAggregatorOp(OpKernelConstruction* ctx) - : OpKernel(ctx) {} - - void Compute(OpKernelContext* ctx) override { - IteratorResource* iterator_resource; - OP_REQUIRES_OK( - ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource)); - core::ScopedUnref unref_iterator(iterator_resource); - - StatsAggregatorResource* stats_aggregator_resource; - OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1), - &stats_aggregator_resource)); - core::ScopedUnref unref_stats_aggregator(stats_aggregator_resource); - // TODO(mrry): Consider allowing multiple StatsAggregator ops to - // subscribe to updates, and/or unsubscribing. - OP_REQUIRES(ctx, !iterator_resource->stats_aggregator(), - errors::FailedPrecondition( - "Iterator already associated with a StatsAggregator")); - iterator_resource->set_stats_aggregator( - stats_aggregator_resource->stats_aggregator()); - } -}; REGISTER_KERNEL_BUILDER(Name("Iterator").Device(DEVICE_CPU), IteratorHandleOp); REGISTER_KERNEL_BUILDER(Name("MakeIterator").Device(DEVICE_CPU), @@ -1119,8 +1091,6 @@ REGISTER_KERNEL_BUILDER(Name("SerializeIterator").Device(DEVICE_CPU), SerializeIteratorOp); REGISTER_KERNEL_BUILDER(Name("DeserializeIterator").Device(DEVICE_CPU), DeserializeIteratorOp); -REGISTER_KERNEL_BUILDER(Name("IteratorSetStatsAggregator").Device(DEVICE_CPU), - IteratorSetStatsAggregatorOp); } // namespace diff --git a/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc b/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc new file mode 100644 index 00000000000..eb96b8a872c --- /dev/null +++ b/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc @@ -0,0 +1,135 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/stats_aggregator.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/data/dataset.h" +#include "tensorflow/core/lib/random/random.h" + +namespace tensorflow { +namespace { + +class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel { + public: + explicit SetStatsAggregatorDatasetOp(OpKernelConstruction* ctx) + : UnaryDatasetOpKernel(ctx) {} + + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override { + StatsAggregatorResource* stats_aggregator_resource; + OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1), + &stats_aggregator_resource)); + core::ScopedUnref unref_stats_aggregator(stats_aggregator_resource); + + *output = new Dataset(ctx, input, stats_aggregator_resource); + } + + private: + class Dataset : public GraphDatasetBase { + public: + explicit Dataset(OpKernelContext* ctx, const DatasetBase* input, + StatsAggregatorResource* stats_aggregator_resource) + : GraphDatasetBase(ctx), + input_(input), + stats_aggregator_resource_(stats_aggregator_resource) { + input_->Ref(); + stats_aggregator_resource_->Ref(); + } + + ~Dataset() override { + input_->Unref(); + stats_aggregator_resource_->Unref(); + } + + std::unique_ptr MakeIterator( + const string& prefix) const override { + return std::unique_ptr(new Iterator( + {this, strings::StrCat(prefix, "::SetStatsAggregator")})); + } + + const DataTypeVector& output_dtypes() const override { + return input_->output_dtypes(); + } + const std::vector& output_shapes() const override { + return input_->output_shapes(); + } + + string DebugString() override { + return "SetStatsAggregatorDatasetOp::Dataset"; + } + + protected: + Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Node** output) const override { + return errors::Unimplemented( + "Cannot currently serialize the `stats_aggregator` for a " + "SetStatsAggregatorDataset."); + } + + private: + class Iterator : public DatasetIterator { + public: + explicit Iterator(const Params& params) + : DatasetIterator(params), + input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {} + + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + StatsAggregatorResource* stats_aggregator_resource = + dataset()->stats_aggregator_resource_; + IteratorContext::Params params; + params.env = ctx->env(); + params.runner = *(ctx->runner()); + params.stats_aggregator_getter = [stats_aggregator_resource]() { + return stats_aggregator_resource->stats_aggregator(); + }; + params.lib = ctx->lib(); + params.function_library = ctx->function_library(); + params.allocator_getter = ctx->allocator_getter(); + IteratorContext set_stats_aggregator_ctx(params); + return input_impl_->GetNext(&set_stats_aggregator_ctx, out_tensors, + end_of_sequence); + } + + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_)); + return Status::OK(); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_)); + return Status::OK(); + } + + private: + mutex mu_; + std::unique_ptr input_impl_ GUARDED_BY(mu_); + }; + + const DatasetBase* const input_; + StatsAggregatorResource* stats_aggregator_resource_; + }; +}; + +REGISTER_KERNEL_BUILDER(Name("SetStatsAggregatorDataset").Device(DEVICE_CPU), + SetStatsAggregatorDatasetOp); +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/stats_aggregator_ops.cc b/tensorflow/core/kernels/data/stats_aggregator_ops.cc index 17103627e07..dd373115806 100644 --- a/tensorflow/core/kernels/data/stats_aggregator_ops.cc +++ b/tensorflow/core/kernels/data/stats_aggregator_ops.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/kernels/data/stats_aggregator.h" +#include "tensorflow/core/framework/stats_aggregator.h" #include diff --git a/tensorflow/core/kernels/data/stats_dataset_ops.cc b/tensorflow/core/kernels/data/stats_dataset_ops.cc index 4dc1343e21f..633cd854511 100644 --- a/tensorflow/core/kernels/data/stats_dataset_ops.cc +++ b/tensorflow/core/kernels/data/stats_dataset_ops.cc @@ -14,9 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/stats_aggregator.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/kernels/data/dataset.h" -#include "tensorflow/core/kernels/data/stats_aggregator.h" #include "tensorflow/core/lib/random/random.h" namespace tensorflow { diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt index 5bd37efac8e..031932d79fe 100644 --- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt +++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt @@ -25657,18 +25657,6 @@ op { } is_stateful: true } -op { - name: "IteratorSetStatsAggregator" - input_arg { - name: "iterator_handle" - type: DT_RESOURCE - } - input_arg { - name: "stats_aggregator_handle" - type: DT_RESOURCE - } - is_stateful: true -} op { name: "IteratorToStringHandle" input_arg { diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc index b25abbcc678..57f871af32b 100644 --- a/tensorflow/core/ops/dataset_ops.cc +++ b/tensorflow/core/ops/dataset_ops.cc @@ -151,6 +151,14 @@ REGISTER_OP("LatencyStatsDataset") .Attr("output_shapes: list(shape) >= 1") .SetShapeFn(shape_inference::ScalarShape); +REGISTER_OP("SetStatsAggregatorDataset") + .Input("input_dataset: variant") + .Input("stats_aggregator: resource") + .Output("handle: variant") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape); + REGISTER_OP("MapDataset") .Input("input_dataset: variant") .Input("other_arguments: Targuments") @@ -506,11 +514,6 @@ REGISTER_OP("StatsAggregatorHandle") .Attr("container: string = ''") .Attr("shared_name: string = ''"); -REGISTER_OP("IteratorSetStatsAggregator") - .Input("iterator_handle: resource") - .Input("stats_aggregator_handle: resource") - .SetShapeFn(shape_inference::NoOutputs); - REGISTER_OP("StatsAggregatorSummary") .Input("iterator: resource") .Output("summary: string") diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index a36608ded34..4ae1c3d7e0b 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -12364,18 +12364,6 @@ op { } is_stateful: true } -op { - name: "IteratorSetStatsAggregator" - input_arg { - name: "iterator_handle" - type: DT_RESOURCE - } - input_arg { - name: "stats_aggregator_handle" - type: DT_RESOURCE - } - is_stateful: true -} op { name: "IteratorToStringHandle" input_arg {