Adds dataset transformation function set_stats_aggregator(..), which sets the given stats_aggregator for aggregating the input dataset stats.

PiperOrigin-RevId: 193432590
This commit is contained in:
Shivani Agrawal 2018-04-18 16:01:55 -07:00 committed by TensorFlower Gardener
parent 695da2d928
commit e9d47fbff0
16 changed files with 242 additions and 152 deletions

View File

@ -50,17 +50,17 @@ class StatsDatasetTest(test.TestCase):
self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
def testBytesProduced(self):
stats_aggregator = stats_ops.StatsAggregator()
dataset = dataset_ops.Dataset.range(100).map(
lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).apply(
stats_ops.bytes_produced_stats("bytes_produced"))
stats_ops.bytes_produced_stats("bytes_produced")).apply(
stats_ops.set_stats_aggregator(stats_aggregator))
iterator = dataset.make_initializable_iterator()
stats_aggregator = stats_ops.StatsAggregator()
stats_aggregator_subscriber = stats_aggregator.subscribe(iterator)
next_element = iterator.get_next()
summary_t = stats_aggregator.get_summary()
with self.test_session() as sess:
sess.run([iterator.initializer, stats_aggregator_subscriber])
sess.run(iterator.initializer)
expected_sum = 0.0
for i in range(100):
self.assertAllEqual(
@ -76,16 +76,16 @@ class StatsDatasetTest(test.TestCase):
self._assertSummaryHasSum(summary_str, "bytes_produced", expected_sum)
def testLatencyStats(self):
dataset = dataset_ops.Dataset.range(100).apply(
stats_ops.latency_stats("record_latency"))
iterator = dataset.make_initializable_iterator()
stats_aggregator = stats_ops.StatsAggregator()
stats_aggregator_subscriber = stats_aggregator.subscribe(iterator)
dataset = dataset_ops.Dataset.range(100).apply(
stats_ops.latency_stats("record_latency")).apply(
stats_ops.set_stats_aggregator(stats_aggregator))
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
summary_t = stats_aggregator.get_summary()
with self.test_session() as sess:
sess.run([iterator.initializer, stats_aggregator_subscriber])
sess.run(iterator.initializer)
for i in range(100):
self.assertEqual(i, sess.run(next_element))
self._assertSummaryHasCount(
@ -95,16 +95,15 @@ class StatsDatasetTest(test.TestCase):
self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 100.0)
def testReinitialize(self):
dataset = dataset_ops.Dataset.range(100).apply(
stats_ops.latency_stats("record_latency"))
iterator = dataset.make_initializable_iterator()
stats_aggregator = stats_ops.StatsAggregator()
stats_aggregator_subscriber = stats_aggregator.subscribe(iterator)
dataset = dataset_ops.Dataset.range(100).apply(
stats_ops.latency_stats("record_latency")).apply(
stats_ops.set_stats_aggregator(stats_aggregator))
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
summary_t = stats_aggregator.get_summary()
with self.test_session() as sess:
sess.run(stats_aggregator_subscriber)
for j in range(5):
sess.run(iterator.initializer)
for i in range(100):
@ -130,17 +129,17 @@ class StatsDatasetTest(test.TestCase):
sess.run(next_element)
def testMultipleTags(self):
stats_aggregator = stats_ops.StatsAggregator()
dataset = dataset_ops.Dataset.range(100).apply(
stats_ops.latency_stats("record_latency")).apply(
stats_ops.latency_stats("record_latency_2"))
stats_ops.latency_stats("record_latency_2")).apply(
stats_ops.set_stats_aggregator(stats_aggregator))
iterator = dataset.make_initializable_iterator()
stats_aggregator = stats_ops.StatsAggregator()
stats_aggregator_subscriber = stats_aggregator.subscribe(iterator)
next_element = iterator.get_next()
summary_t = stats_aggregator.get_summary()
with self.test_session() as sess:
sess.run([iterator.initializer, stats_aggregator_subscriber])
sess.run(iterator.initializer)
for i in range(100):
self.assertEqual(i, sess.run(next_element))
self._assertSummaryHasCount(
@ -154,17 +153,17 @@ class StatsDatasetTest(test.TestCase):
sess.run(summary_t), "record_latency_2", 100.0)
def testRepeatedTags(self):
stats_aggregator = stats_ops.StatsAggregator()
dataset = dataset_ops.Dataset.range(100).apply(
stats_ops.latency_stats("record_latency")).apply(
stats_ops.latency_stats("record_latency"))
stats_ops.latency_stats("record_latency")).apply(
stats_ops.set_stats_aggregator(stats_aggregator))
iterator = dataset.make_initializable_iterator()
stats_aggregator = stats_ops.StatsAggregator()
stats_aggregator_subscriber = stats_aggregator.subscribe(iterator)
next_element = iterator.get_next()
summary_t = stats_aggregator.get_summary()
with self.test_session() as sess:
sess.run([iterator.initializer, stats_aggregator_subscriber])
sess.run(iterator.initializer)
for i in range(100):
self.assertEqual(i, sess.run(next_element))
self._assertSummaryHasCount(
@ -174,19 +173,17 @@ class StatsDatasetTest(test.TestCase):
self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0)
def testMultipleIteratorsSameAggregator(self):
stats_aggregator = stats_ops.StatsAggregator()
dataset = dataset_ops.Dataset.range(100).apply(
stats_ops.latency_stats("record_latency"))
stats_ops.latency_stats("record_latency")).apply(
stats_ops.set_stats_aggregator(stats_aggregator))
iterator_0 = dataset.make_initializable_iterator()
iterator_1 = dataset.make_initializable_iterator()
stats_aggregator = stats_ops.StatsAggregator()
stats_aggregator_subscribers = [stats_aggregator.subscribe(iterator_0),
stats_aggregator.subscribe(iterator_1)]
next_element = iterator_0.get_next() + iterator_1.get_next()
summary_t = stats_aggregator.get_summary()
with self.test_session() as sess:
sess.run([iterator_0.initializer, iterator_1.initializer,
stats_aggregator_subscribers])
sess.run([iterator_0.initializer, iterator_1.initializer])
for i in range(100):
self.assertEqual(i * 2, sess.run(next_element))
self._assertSummaryHasCount(
@ -195,20 +192,6 @@ class StatsDatasetTest(test.TestCase):
sess.run(next_element)
self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0)
def testMultipleStatsAggregatorsSameIteratorFail(self):
dataset = dataset_ops.Dataset.range(100).apply(
stats_ops.latency_stats("record_latency"))
iterator = dataset.make_initializable_iterator()
stats_aggregator_0 = stats_ops.StatsAggregator()
stats_aggregator_1 = stats_ops.StatsAggregator()
with self.test_session() as sess:
sess.run(stats_aggregator_0.subscribe(iterator))
# TODO(mrry): Consider making this allowable (and also allowing
# aggregators to unsubscribe).
with self.assertRaises(errors.FailedPreconditionError):
sess.run(stats_aggregator_1.subscribe(iterator))
class StatsDatasetSerializationTest(
dataset_serialization_test_base.DatasetSerializationTestBase):
@ -253,5 +236,9 @@ class StatsDatasetSerializationTest(
None, num_outputs)
# TODO(shivaniagrawal): Can not checkpoint input_pipeline with the
# transformation `stats_ops.set_stats_aggregator`, since we don't support
# serializing StatsAggregator yet.
if __name__ == "__main__":
test.main()

View File

@ -18,7 +18,6 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import sparse
from tensorflow.python.framework import dtypes
@ -85,25 +84,53 @@ class StatsAggregator(object):
"""
return gen_dataset_ops.stats_aggregator_summary(self._resource)
def subscribe(self, iterator):
"""Returns a @{tf.Operation} to associate this aggregator with `iterator`.
Note: Each @{tf.data.Iterator} can be associated with at most one
`StatsAggregator`. After running the operation that this function
returns, all statistics recorded in the iteration of `iterator`
will be stored in `stats_aggregator`.
class _SetStatsAggregatorDataset(dataset_ops.Dataset):
"""A `Dataset` that acts as an identity, and sets given stats_aggregator."""
Args:
iterator: A @{tf.data.Iterator} object.
def __init__(self, input_dataset, stats_aggregator):
super(_SetStatsAggregatorDataset, self).__init__()
self._input_dataset = input_dataset
self._stats_aggregator = stats_aggregator
Returns:
A @{tf.Operation} that, when run, associates this aggregator with
`iterator`.
"""
if not isinstance(iterator, iterator_ops.Iterator):
raise TypeError("`iterator` must be a `tf.data.Iterator` object.")
return gen_dataset_ops.iterator_set_stats_aggregator(
iterator._iterator_resource, self._resource) # pylint: disable=protected-access
def _as_variant_tensor(self):
return gen_dataset_ops.set_stats_aggregator_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
self._stats_aggregator._resource, # pylint: disable=protected-access
output_types=nest.flatten(
sparse.as_dense_types(self.output_types, self.output_classes)),
output_shapes=nest.flatten(
sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
@property
def output_shapes(self):
return self._input_dataset.output_shapes
@property
def output_types(self):
return self._input_dataset.output_types
@property
def output_classes(self):
return self._input_dataset.output_classes
# TODO(shivaniagrawal): Expose these methods in `tf.contrib.data`.
def set_stats_aggregator(stats_aggregator):
"""Set the given stats_aggregator for aggregating the input dataset stats.
Args:
stats_aggregator: A `StatsAggregator` object.
Returns:
A `Dataset` transformation function, which can be passed to
@{tf.data.Dataset.apply}.
"""
def _apply_fn(dataset):
return _SetStatsAggregatorDataset(dataset, stats_aggregator)
return _apply_fn
def bytes_produced_stats(tag):

View File

@ -547,6 +547,7 @@ tf_cuda_library(
"framework/selective_registration.h",
"framework/session_state.h",
"framework/shape_inference.h",
"framework/stats_aggregator.h",
"framework/tensor.h",
"framework/tensor_shape.h",
"framework/tensor_slice.h",

View File

@ -1,4 +0,0 @@
op {
graph_op_name: "IteratorSetStatsAggregator"
summary: "Associates the given iterator with the given statistics aggregator."
}

View File

@ -0,0 +1,3 @@
op {
graph_op_name: "SetStatsAggregatorDataset"
}

View File

@ -1,4 +0,0 @@
op {
graph_op_name: "IteratorSetStatsAggregator"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "SetStatsAggregatorDataset"
visibility: HIDDEN
}

View File

@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_DATA_STATS_AGGREGATOR_H_
#define TENSORFLOW_CORE_KERNELS_DATA_STATS_AGGREGATOR_H_
#ifndef TENSORFLOW_CORE_FRAMEWORK_STATS_AGGREGATOR_H_
#define TENSORFLOW_CORE_FRAMEWORK_STATS_AGGREGATOR_H_
#include <memory>
#include <string>
@ -81,4 +81,4 @@ class StatsAggregatorResource : public ResourceBase {
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_STATS_AGGREGATOR_H_
#endif // TENSORFLOW_CORE_FRAMEWORK_STATS_AGGREGATOR_H_

View File

@ -13,20 +13,10 @@ load(
"tf_cc_test",
)
cc_library(
name = "stats_aggregator",
hdrs = ["stats_aggregator.h"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
],
)
tf_kernel_library(
name = "stats_aggregator_ops",
srcs = ["stats_aggregator_ops.cc"],
deps = [
":stats_aggregator",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
@ -38,14 +28,7 @@ cc_library(
name = "dataset",
srcs = [],
hdrs = ["dataset.h"],
deps = [
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
],
deps = ["//tensorflow/core:framework"],
)
cc_library(
@ -360,7 +343,6 @@ tf_kernel_library(
srcs = ["stats_dataset_ops.cc"],
deps = [
":dataset",
":stats_aggregator",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@ -368,6 +350,16 @@ tf_kernel_library(
],
)
tf_kernel_library(
name = "stats_aggregator_dataset_op",
srcs = ["stats_aggregator_dataset_op.cc"],
deps = [
":dataset",
"//tensorflow/core:framework",
"//tensorflow/core:lib_internal",
],
)
tf_kernel_library(
name = "random_dataset_op",
srcs = ["random_dataset_op.cc"],
@ -510,7 +502,6 @@ tf_kernel_library(
srcs = ["iterator_ops.cc"],
deps = [
":dataset",
":stats_aggregator",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
@ -564,6 +555,7 @@ tf_kernel_library(
":slide_dataset_op",
":sparse_tensor_slice_dataset_op",
":sql_dataset_ops",
":stats_aggregator_dataset_op",
":stats_aggregator_ops",
":stats_dataset_ops",
":take_dataset_op",

View File

@ -19,11 +19,11 @@ limitations under the License.
#include "tensorflow/core/framework/iterator.pb.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/resource_op_kernel.h"
#include "tensorflow/core/framework/stats_aggregator.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/variant_op_registry.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/kernels/data/dataset.h"
#include "tensorflow/core/kernels/data/stats_aggregator.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
@ -203,10 +203,6 @@ class IteratorResource : public ResourceBase {
return Status::OK();
}
void set_stats_aggregator(std::shared_ptr<StatsAggregator> stats_aggregator) {
mutex_lock l(mu_);
stats_aggregator_ = std::move(stats_aggregator);
}
std::shared_ptr<StatsAggregator> stats_aggregator() {
tf_shared_lock l(mu_);
@ -1075,30 +1071,6 @@ class DeserializeIteratorOp : public OpKernel {
}
};
class IteratorSetStatsAggregatorOp : public OpKernel {
public:
explicit IteratorSetStatsAggregatorOp(OpKernelConstruction* ctx)
: OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
IteratorResource* iterator_resource;
OP_REQUIRES_OK(
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
core::ScopedUnref unref_iterator(iterator_resource);
StatsAggregatorResource* stats_aggregator_resource;
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1),
&stats_aggregator_resource));
core::ScopedUnref unref_stats_aggregator(stats_aggregator_resource);
// TODO(mrry): Consider allowing multiple StatsAggregator ops to
// subscribe to updates, and/or unsubscribing.
OP_REQUIRES(ctx, !iterator_resource->stats_aggregator(),
errors::FailedPrecondition(
"Iterator already associated with a StatsAggregator"));
iterator_resource->set_stats_aggregator(
stats_aggregator_resource->stats_aggregator());
}
};
REGISTER_KERNEL_BUILDER(Name("Iterator").Device(DEVICE_CPU), IteratorHandleOp);
REGISTER_KERNEL_BUILDER(Name("MakeIterator").Device(DEVICE_CPU),
@ -1119,8 +1091,6 @@ REGISTER_KERNEL_BUILDER(Name("SerializeIterator").Device(DEVICE_CPU),
SerializeIteratorOp);
REGISTER_KERNEL_BUILDER(Name("DeserializeIterator").Device(DEVICE_CPU),
DeserializeIteratorOp);
REGISTER_KERNEL_BUILDER(Name("IteratorSetStatsAggregator").Device(DEVICE_CPU),
IteratorSetStatsAggregatorOp);
} // namespace

View File

@ -0,0 +1,135 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/stats_aggregator.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/dataset.h"
#include "tensorflow/core/lib/random/random.h"
namespace tensorflow {
namespace {
class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
public:
explicit SetStatsAggregatorDatasetOp(OpKernelConstruction* ctx)
: UnaryDatasetOpKernel(ctx) {}
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
StatsAggregatorResource* stats_aggregator_resource;
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1),
&stats_aggregator_resource));
core::ScopedUnref unref_stats_aggregator(stats_aggregator_resource);
*output = new Dataset(ctx, input, stats_aggregator_resource);
}
private:
class Dataset : public GraphDatasetBase {
public:
explicit Dataset(OpKernelContext* ctx, const DatasetBase* input,
StatsAggregatorResource* stats_aggregator_resource)
: GraphDatasetBase(ctx),
input_(input),
stats_aggregator_resource_(stats_aggregator_resource) {
input_->Ref();
stats_aggregator_resource_->Ref();
}
~Dataset() override {
input_->Unref();
stats_aggregator_resource_->Unref();
}
std::unique_ptr<IteratorBase> MakeIterator(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(new Iterator(
{this, strings::StrCat(prefix, "::SetStatsAggregator")}));
}
const DataTypeVector& output_dtypes() const override {
return input_->output_dtypes();
}
const std::vector<PartialTensorShape>& output_shapes() const override {
return input_->output_shapes();
}
string DebugString() override {
return "SetStatsAggregatorDatasetOp::Dataset";
}
protected:
Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
Node** output) const override {
return errors::Unimplemented(
"Cannot currently serialize the `stats_aggregator` for a "
"SetStatsAggregatorDataset.");
}
private:
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params),
input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_);
StatsAggregatorResource* stats_aggregator_resource =
dataset()->stats_aggregator_resource_;
IteratorContext::Params params;
params.env = ctx->env();
params.runner = *(ctx->runner());
params.stats_aggregator_getter = [stats_aggregator_resource]() {
return stats_aggregator_resource->stats_aggregator();
};
params.lib = ctx->lib();
params.function_library = ctx->function_library();
params.allocator_getter = ctx->allocator_getter();
IteratorContext set_stats_aggregator_ctx(params);
return input_impl_->GetNext(&set_stats_aggregator_ctx, out_tensors,
end_of_sequence);
}
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
return Status::OK();
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
return Status::OK();
}
private:
mutex mu_;
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
};
const DatasetBase* const input_;
StatsAggregatorResource* stats_aggregator_resource_;
};
};
REGISTER_KERNEL_BUILDER(Name("SetStatsAggregatorDataset").Device(DEVICE_CPU),
SetStatsAggregatorDatasetOp);
} // namespace
} // namespace tensorflow

View File

@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/data/stats_aggregator.h"
#include "tensorflow/core/framework/stats_aggregator.h"
#include <memory>

View File

@ -14,9 +14,9 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/stats_aggregator.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/dataset.h"
#include "tensorflow/core/kernels/data/stats_aggregator.h"
#include "tensorflow/core/lib/random/random.h"
namespace tensorflow {

View File

@ -25657,18 +25657,6 @@ op {
}
is_stateful: true
}
op {
name: "IteratorSetStatsAggregator"
input_arg {
name: "iterator_handle"
type: DT_RESOURCE
}
input_arg {
name: "stats_aggregator_handle"
type: DT_RESOURCE
}
is_stateful: true
}
op {
name: "IteratorToStringHandle"
input_arg {

View File

@ -151,6 +151,14 @@ REGISTER_OP("LatencyStatsDataset")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("SetStatsAggregatorDataset")
.Input("input_dataset: variant")
.Input("stats_aggregator: resource")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("MapDataset")
.Input("input_dataset: variant")
.Input("other_arguments: Targuments")
@ -506,11 +514,6 @@ REGISTER_OP("StatsAggregatorHandle")
.Attr("container: string = ''")
.Attr("shared_name: string = ''");
REGISTER_OP("IteratorSetStatsAggregator")
.Input("iterator_handle: resource")
.Input("stats_aggregator_handle: resource")
.SetShapeFn(shape_inference::NoOutputs);
REGISTER_OP("StatsAggregatorSummary")
.Input("iterator: resource")
.Output("summary: string")

View File

@ -12364,18 +12364,6 @@ op {
}
is_stateful: true
}
op {
name: "IteratorSetStatsAggregator"
input_arg {
name: "iterator_handle"
type: DT_RESOURCE
}
input_arg {
name: "stats_aggregator_handle"
type: DT_RESOURCE
}
is_stateful: true
}
op {
name: "IteratorToStringHandle"
input_arg {