Adds dataset transformation function set_stats_aggregator(..)
, which sets the given stats_aggregator
for aggregating the input dataset stats.
PiperOrigin-RevId: 193432590
This commit is contained in:
parent
695da2d928
commit
e9d47fbff0
@ -50,17 +50,17 @@ class StatsDatasetTest(test.TestCase):
|
||||
self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
|
||||
|
||||
def testBytesProduced(self):
|
||||
stats_aggregator = stats_ops.StatsAggregator()
|
||||
dataset = dataset_ops.Dataset.range(100).map(
|
||||
lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).apply(
|
||||
stats_ops.bytes_produced_stats("bytes_produced"))
|
||||
stats_ops.bytes_produced_stats("bytes_produced")).apply(
|
||||
stats_ops.set_stats_aggregator(stats_aggregator))
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
stats_aggregator = stats_ops.StatsAggregator()
|
||||
stats_aggregator_subscriber = stats_aggregator.subscribe(iterator)
|
||||
next_element = iterator.get_next()
|
||||
summary_t = stats_aggregator.get_summary()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run([iterator.initializer, stats_aggregator_subscriber])
|
||||
sess.run(iterator.initializer)
|
||||
expected_sum = 0.0
|
||||
for i in range(100):
|
||||
self.assertAllEqual(
|
||||
@ -76,16 +76,16 @@ class StatsDatasetTest(test.TestCase):
|
||||
self._assertSummaryHasSum(summary_str, "bytes_produced", expected_sum)
|
||||
|
||||
def testLatencyStats(self):
|
||||
dataset = dataset_ops.Dataset.range(100).apply(
|
||||
stats_ops.latency_stats("record_latency"))
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
stats_aggregator = stats_ops.StatsAggregator()
|
||||
stats_aggregator_subscriber = stats_aggregator.subscribe(iterator)
|
||||
dataset = dataset_ops.Dataset.range(100).apply(
|
||||
stats_ops.latency_stats("record_latency")).apply(
|
||||
stats_ops.set_stats_aggregator(stats_aggregator))
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
next_element = iterator.get_next()
|
||||
summary_t = stats_aggregator.get_summary()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run([iterator.initializer, stats_aggregator_subscriber])
|
||||
sess.run(iterator.initializer)
|
||||
for i in range(100):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
self._assertSummaryHasCount(
|
||||
@ -95,16 +95,15 @@ class StatsDatasetTest(test.TestCase):
|
||||
self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 100.0)
|
||||
|
||||
def testReinitialize(self):
|
||||
dataset = dataset_ops.Dataset.range(100).apply(
|
||||
stats_ops.latency_stats("record_latency"))
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
stats_aggregator = stats_ops.StatsAggregator()
|
||||
stats_aggregator_subscriber = stats_aggregator.subscribe(iterator)
|
||||
dataset = dataset_ops.Dataset.range(100).apply(
|
||||
stats_ops.latency_stats("record_latency")).apply(
|
||||
stats_ops.set_stats_aggregator(stats_aggregator))
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
next_element = iterator.get_next()
|
||||
summary_t = stats_aggregator.get_summary()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(stats_aggregator_subscriber)
|
||||
for j in range(5):
|
||||
sess.run(iterator.initializer)
|
||||
for i in range(100):
|
||||
@ -130,17 +129,17 @@ class StatsDatasetTest(test.TestCase):
|
||||
sess.run(next_element)
|
||||
|
||||
def testMultipleTags(self):
|
||||
stats_aggregator = stats_ops.StatsAggregator()
|
||||
dataset = dataset_ops.Dataset.range(100).apply(
|
||||
stats_ops.latency_stats("record_latency")).apply(
|
||||
stats_ops.latency_stats("record_latency_2"))
|
||||
stats_ops.latency_stats("record_latency_2")).apply(
|
||||
stats_ops.set_stats_aggregator(stats_aggregator))
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
stats_aggregator = stats_ops.StatsAggregator()
|
||||
stats_aggregator_subscriber = stats_aggregator.subscribe(iterator)
|
||||
next_element = iterator.get_next()
|
||||
summary_t = stats_aggregator.get_summary()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run([iterator.initializer, stats_aggregator_subscriber])
|
||||
sess.run(iterator.initializer)
|
||||
for i in range(100):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
self._assertSummaryHasCount(
|
||||
@ -154,17 +153,17 @@ class StatsDatasetTest(test.TestCase):
|
||||
sess.run(summary_t), "record_latency_2", 100.0)
|
||||
|
||||
def testRepeatedTags(self):
|
||||
stats_aggregator = stats_ops.StatsAggregator()
|
||||
dataset = dataset_ops.Dataset.range(100).apply(
|
||||
stats_ops.latency_stats("record_latency")).apply(
|
||||
stats_ops.latency_stats("record_latency"))
|
||||
stats_ops.latency_stats("record_latency")).apply(
|
||||
stats_ops.set_stats_aggregator(stats_aggregator))
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
stats_aggregator = stats_ops.StatsAggregator()
|
||||
stats_aggregator_subscriber = stats_aggregator.subscribe(iterator)
|
||||
next_element = iterator.get_next()
|
||||
summary_t = stats_aggregator.get_summary()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run([iterator.initializer, stats_aggregator_subscriber])
|
||||
sess.run(iterator.initializer)
|
||||
for i in range(100):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
self._assertSummaryHasCount(
|
||||
@ -174,19 +173,17 @@ class StatsDatasetTest(test.TestCase):
|
||||
self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0)
|
||||
|
||||
def testMultipleIteratorsSameAggregator(self):
|
||||
stats_aggregator = stats_ops.StatsAggregator()
|
||||
dataset = dataset_ops.Dataset.range(100).apply(
|
||||
stats_ops.latency_stats("record_latency"))
|
||||
stats_ops.latency_stats("record_latency")).apply(
|
||||
stats_ops.set_stats_aggregator(stats_aggregator))
|
||||
iterator_0 = dataset.make_initializable_iterator()
|
||||
iterator_1 = dataset.make_initializable_iterator()
|
||||
stats_aggregator = stats_ops.StatsAggregator()
|
||||
stats_aggregator_subscribers = [stats_aggregator.subscribe(iterator_0),
|
||||
stats_aggregator.subscribe(iterator_1)]
|
||||
next_element = iterator_0.get_next() + iterator_1.get_next()
|
||||
summary_t = stats_aggregator.get_summary()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run([iterator_0.initializer, iterator_1.initializer,
|
||||
stats_aggregator_subscribers])
|
||||
sess.run([iterator_0.initializer, iterator_1.initializer])
|
||||
for i in range(100):
|
||||
self.assertEqual(i * 2, sess.run(next_element))
|
||||
self._assertSummaryHasCount(
|
||||
@ -195,20 +192,6 @@ class StatsDatasetTest(test.TestCase):
|
||||
sess.run(next_element)
|
||||
self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0)
|
||||
|
||||
def testMultipleStatsAggregatorsSameIteratorFail(self):
|
||||
dataset = dataset_ops.Dataset.range(100).apply(
|
||||
stats_ops.latency_stats("record_latency"))
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
stats_aggregator_0 = stats_ops.StatsAggregator()
|
||||
stats_aggregator_1 = stats_ops.StatsAggregator()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(stats_aggregator_0.subscribe(iterator))
|
||||
# TODO(mrry): Consider making this allowable (and also allowing
|
||||
# aggregators to unsubscribe).
|
||||
with self.assertRaises(errors.FailedPreconditionError):
|
||||
sess.run(stats_aggregator_1.subscribe(iterator))
|
||||
|
||||
|
||||
class StatsDatasetSerializationTest(
|
||||
dataset_serialization_test_base.DatasetSerializationTestBase):
|
||||
@ -253,5 +236,9 @@ class StatsDatasetSerializationTest(
|
||||
None, num_outputs)
|
||||
|
||||
|
||||
# TODO(shivaniagrawal): Can not checkpoint input_pipeline with the
|
||||
# transformation `stats_ops.set_stats_aggregator`, since we don't support
|
||||
# serializing StatsAggregator yet.
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -18,7 +18,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import iterator_ops
|
||||
from tensorflow.python.data.util import nest
|
||||
from tensorflow.python.data.util import sparse
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -85,25 +84,53 @@ class StatsAggregator(object):
|
||||
"""
|
||||
return gen_dataset_ops.stats_aggregator_summary(self._resource)
|
||||
|
||||
def subscribe(self, iterator):
|
||||
"""Returns a @{tf.Operation} to associate this aggregator with `iterator`.
|
||||
|
||||
Note: Each @{tf.data.Iterator} can be associated with at most one
|
||||
`StatsAggregator`. After running the operation that this function
|
||||
returns, all statistics recorded in the iteration of `iterator`
|
||||
will be stored in `stats_aggregator`.
|
||||
class _SetStatsAggregatorDataset(dataset_ops.Dataset):
|
||||
"""A `Dataset` that acts as an identity, and sets given stats_aggregator."""
|
||||
|
||||
Args:
|
||||
iterator: A @{tf.data.Iterator} object.
|
||||
def __init__(self, input_dataset, stats_aggregator):
|
||||
super(_SetStatsAggregatorDataset, self).__init__()
|
||||
self._input_dataset = input_dataset
|
||||
self._stats_aggregator = stats_aggregator
|
||||
|
||||
Returns:
|
||||
A @{tf.Operation} that, when run, associates this aggregator with
|
||||
`iterator`.
|
||||
"""
|
||||
if not isinstance(iterator, iterator_ops.Iterator):
|
||||
raise TypeError("`iterator` must be a `tf.data.Iterator` object.")
|
||||
return gen_dataset_ops.iterator_set_stats_aggregator(
|
||||
iterator._iterator_resource, self._resource) # pylint: disable=protected-access
|
||||
def _as_variant_tensor(self):
|
||||
return gen_dataset_ops.set_stats_aggregator_dataset(
|
||||
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
|
||||
self._stats_aggregator._resource, # pylint: disable=protected-access
|
||||
output_types=nest.flatten(
|
||||
sparse.as_dense_types(self.output_types, self.output_classes)),
|
||||
output_shapes=nest.flatten(
|
||||
sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
|
||||
|
||||
@property
|
||||
def output_shapes(self):
|
||||
return self._input_dataset.output_shapes
|
||||
|
||||
@property
|
||||
def output_types(self):
|
||||
return self._input_dataset.output_types
|
||||
|
||||
@property
|
||||
def output_classes(self):
|
||||
return self._input_dataset.output_classes
|
||||
|
||||
|
||||
# TODO(shivaniagrawal): Expose these methods in `tf.contrib.data`.
|
||||
def set_stats_aggregator(stats_aggregator):
|
||||
"""Set the given stats_aggregator for aggregating the input dataset stats.
|
||||
|
||||
Args:
|
||||
stats_aggregator: A `StatsAggregator` object.
|
||||
|
||||
Returns:
|
||||
A `Dataset` transformation function, which can be passed to
|
||||
@{tf.data.Dataset.apply}.
|
||||
"""
|
||||
|
||||
def _apply_fn(dataset):
|
||||
return _SetStatsAggregatorDataset(dataset, stats_aggregator)
|
||||
|
||||
return _apply_fn
|
||||
|
||||
|
||||
def bytes_produced_stats(tag):
|
||||
|
@ -547,6 +547,7 @@ tf_cuda_library(
|
||||
"framework/selective_registration.h",
|
||||
"framework/session_state.h",
|
||||
"framework/shape_inference.h",
|
||||
"framework/stats_aggregator.h",
|
||||
"framework/tensor.h",
|
||||
"framework/tensor_shape.h",
|
||||
"framework/tensor_slice.h",
|
||||
|
@ -1,4 +0,0 @@
|
||||
op {
|
||||
graph_op_name: "IteratorSetStatsAggregator"
|
||||
summary: "Associates the given iterator with the given statistics aggregator."
|
||||
}
|
@ -0,0 +1,3 @@
|
||||
op {
|
||||
graph_op_name: "SetStatsAggregatorDataset"
|
||||
}
|
@ -1,4 +0,0 @@
|
||||
op {
|
||||
graph_op_name: "IteratorSetStatsAggregator"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "SetStatsAggregatorDataset"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_DATA_STATS_AGGREGATOR_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_DATA_STATS_AGGREGATOR_H_
|
||||
#ifndef TENSORFLOW_CORE_FRAMEWORK_STATS_AGGREGATOR_H_
|
||||
#define TENSORFLOW_CORE_FRAMEWORK_STATS_AGGREGATOR_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
@ -81,4 +81,4 @@ class StatsAggregatorResource : public ResourceBase {
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_DATA_STATS_AGGREGATOR_H_
|
||||
#endif // TENSORFLOW_CORE_FRAMEWORK_STATS_AGGREGATOR_H_
|
@ -13,20 +13,10 @@ load(
|
||||
"tf_cc_test",
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "stats_aggregator",
|
||||
hdrs = ["stats_aggregator.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "stats_aggregator_ops",
|
||||
srcs = ["stats_aggregator_ops.cc"],
|
||||
deps = [
|
||||
":stats_aggregator",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
@ -38,14 +28,7 @@ cc_library(
|
||||
name = "dataset",
|
||||
srcs = [],
|
||||
hdrs = ["dataset.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
deps = ["//tensorflow/core:framework"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
@ -360,7 +343,6 @@ tf_kernel_library(
|
||||
srcs = ["stats_dataset_ops.cc"],
|
||||
deps = [
|
||||
":dataset",
|
||||
":stats_aggregator",
|
||||
"//tensorflow/core:dataset_ops_op_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
@ -368,6 +350,16 @@ tf_kernel_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "stats_aggregator_dataset_op",
|
||||
srcs = ["stats_aggregator_dataset_op.cc"],
|
||||
deps = [
|
||||
":dataset",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib_internal",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "random_dataset_op",
|
||||
srcs = ["random_dataset_op.cc"],
|
||||
@ -510,7 +502,6 @@ tf_kernel_library(
|
||||
srcs = ["iterator_ops.cc"],
|
||||
deps = [
|
||||
":dataset",
|
||||
":stats_aggregator",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:dataset_ops_op_lib",
|
||||
"//tensorflow/core:framework",
|
||||
@ -564,6 +555,7 @@ tf_kernel_library(
|
||||
":slide_dataset_op",
|
||||
":sparse_tensor_slice_dataset_op",
|
||||
":sql_dataset_ops",
|
||||
":stats_aggregator_dataset_op",
|
||||
":stats_aggregator_ops",
|
||||
":stats_dataset_ops",
|
||||
":take_dataset_op",
|
||||
|
@ -19,11 +19,11 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/iterator.pb.h"
|
||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||
#include "tensorflow/core/framework/resource_op_kernel.h"
|
||||
#include "tensorflow/core/framework/stats_aggregator.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/variant_op_registry.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/kernels/data/dataset.h"
|
||||
#include "tensorflow/core/kernels/data/stats_aggregator.h"
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
@ -203,10 +203,6 @@ class IteratorResource : public ResourceBase {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void set_stats_aggregator(std::shared_ptr<StatsAggregator> stats_aggregator) {
|
||||
mutex_lock l(mu_);
|
||||
stats_aggregator_ = std::move(stats_aggregator);
|
||||
}
|
||||
|
||||
std::shared_ptr<StatsAggregator> stats_aggregator() {
|
||||
tf_shared_lock l(mu_);
|
||||
@ -1075,30 +1071,6 @@ class DeserializeIteratorOp : public OpKernel {
|
||||
}
|
||||
};
|
||||
|
||||
class IteratorSetStatsAggregatorOp : public OpKernel {
|
||||
public:
|
||||
explicit IteratorSetStatsAggregatorOp(OpKernelConstruction* ctx)
|
||||
: OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
IteratorResource* iterator_resource;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
|
||||
core::ScopedUnref unref_iterator(iterator_resource);
|
||||
|
||||
StatsAggregatorResource* stats_aggregator_resource;
|
||||
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1),
|
||||
&stats_aggregator_resource));
|
||||
core::ScopedUnref unref_stats_aggregator(stats_aggregator_resource);
|
||||
// TODO(mrry): Consider allowing multiple StatsAggregator ops to
|
||||
// subscribe to updates, and/or unsubscribing.
|
||||
OP_REQUIRES(ctx, !iterator_resource->stats_aggregator(),
|
||||
errors::FailedPrecondition(
|
||||
"Iterator already associated with a StatsAggregator"));
|
||||
iterator_resource->set_stats_aggregator(
|
||||
stats_aggregator_resource->stats_aggregator());
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("Iterator").Device(DEVICE_CPU), IteratorHandleOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("MakeIterator").Device(DEVICE_CPU),
|
||||
@ -1119,8 +1091,6 @@ REGISTER_KERNEL_BUILDER(Name("SerializeIterator").Device(DEVICE_CPU),
|
||||
SerializeIteratorOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("DeserializeIterator").Device(DEVICE_CPU),
|
||||
DeserializeIteratorOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("IteratorSetStatsAggregator").Device(DEVICE_CPU),
|
||||
IteratorSetStatsAggregatorOp);
|
||||
|
||||
} // namespace
|
||||
|
||||
|
135
tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc
Normal file
135
tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc
Normal file
@ -0,0 +1,135 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||
#include "tensorflow/core/framework/stats_aggregator.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/kernels/data/dataset.h"
|
||||
#include "tensorflow/core/lib/random/random.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
|
||||
public:
|
||||
explicit SetStatsAggregatorDatasetOp(OpKernelConstruction* ctx)
|
||||
: UnaryDatasetOpKernel(ctx) {}
|
||||
|
||||
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||
DatasetBase** output) override {
|
||||
StatsAggregatorResource* stats_aggregator_resource;
|
||||
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1),
|
||||
&stats_aggregator_resource));
|
||||
core::ScopedUnref unref_stats_aggregator(stats_aggregator_resource);
|
||||
|
||||
*output = new Dataset(ctx, input, stats_aggregator_resource);
|
||||
}
|
||||
|
||||
private:
|
||||
class Dataset : public GraphDatasetBase {
|
||||
public:
|
||||
explicit Dataset(OpKernelContext* ctx, const DatasetBase* input,
|
||||
StatsAggregatorResource* stats_aggregator_resource)
|
||||
: GraphDatasetBase(ctx),
|
||||
input_(input),
|
||||
stats_aggregator_resource_(stats_aggregator_resource) {
|
||||
input_->Ref();
|
||||
stats_aggregator_resource_->Ref();
|
||||
}
|
||||
|
||||
~Dataset() override {
|
||||
input_->Unref();
|
||||
stats_aggregator_resource_->Unref();
|
||||
}
|
||||
|
||||
std::unique_ptr<IteratorBase> MakeIterator(
|
||||
const string& prefix) const override {
|
||||
return std::unique_ptr<IteratorBase>(new Iterator(
|
||||
{this, strings::StrCat(prefix, "::SetStatsAggregator")}));
|
||||
}
|
||||
|
||||
const DataTypeVector& output_dtypes() const override {
|
||||
return input_->output_dtypes();
|
||||
}
|
||||
const std::vector<PartialTensorShape>& output_shapes() const override {
|
||||
return input_->output_shapes();
|
||||
}
|
||||
|
||||
string DebugString() override {
|
||||
return "SetStatsAggregatorDatasetOp::Dataset";
|
||||
}
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
|
||||
Node** output) const override {
|
||||
return errors::Unimplemented(
|
||||
"Cannot currently serialize the `stats_aggregator` for a "
|
||||
"SetStatsAggregatorDataset.");
|
||||
}
|
||||
|
||||
private:
|
||||
class Iterator : public DatasetIterator<Dataset> {
|
||||
public:
|
||||
explicit Iterator(const Params& params)
|
||||
: DatasetIterator<Dataset>(params),
|
||||
input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
|
||||
|
||||
Status GetNextInternal(IteratorContext* ctx,
|
||||
std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) override {
|
||||
mutex_lock l(mu_);
|
||||
StatsAggregatorResource* stats_aggregator_resource =
|
||||
dataset()->stats_aggregator_resource_;
|
||||
IteratorContext::Params params;
|
||||
params.env = ctx->env();
|
||||
params.runner = *(ctx->runner());
|
||||
params.stats_aggregator_getter = [stats_aggregator_resource]() {
|
||||
return stats_aggregator_resource->stats_aggregator();
|
||||
};
|
||||
params.lib = ctx->lib();
|
||||
params.function_library = ctx->function_library();
|
||||
params.allocator_getter = ctx->allocator_getter();
|
||||
IteratorContext set_stats_aggregator_ctx(params);
|
||||
return input_impl_->GetNext(&set_stats_aggregator_ctx, out_tensors,
|
||||
end_of_sequence);
|
||||
}
|
||||
|
||||
protected:
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RestoreInternal(IteratorContext* ctx,
|
||||
IteratorStateReader* reader) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
mutex mu_;
|
||||
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
const DatasetBase* const input_;
|
||||
StatsAggregatorResource* stats_aggregator_resource_;
|
||||
};
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("SetStatsAggregatorDataset").Device(DEVICE_CPU),
|
||||
SetStatsAggregatorDatasetOp);
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/kernels/data/stats_aggregator.h"
|
||||
#include "tensorflow/core/framework/stats_aggregator.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
|
@ -14,9 +14,9 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||
#include "tensorflow/core/framework/stats_aggregator.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/kernels/data/dataset.h"
|
||||
#include "tensorflow/core/kernels/data/stats_aggregator.h"
|
||||
#include "tensorflow/core/lib/random/random.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
@ -25657,18 +25657,6 @@ op {
|
||||
}
|
||||
is_stateful: true
|
||||
}
|
||||
op {
|
||||
name: "IteratorSetStatsAggregator"
|
||||
input_arg {
|
||||
name: "iterator_handle"
|
||||
type: DT_RESOURCE
|
||||
}
|
||||
input_arg {
|
||||
name: "stats_aggregator_handle"
|
||||
type: DT_RESOURCE
|
||||
}
|
||||
is_stateful: true
|
||||
}
|
||||
op {
|
||||
name: "IteratorToStringHandle"
|
||||
input_arg {
|
||||
|
@ -151,6 +151,14 @@ REGISTER_OP("LatencyStatsDataset")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("SetStatsAggregatorDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("stats_aggregator: resource")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("MapDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("other_arguments: Targuments")
|
||||
@ -506,11 +514,6 @@ REGISTER_OP("StatsAggregatorHandle")
|
||||
.Attr("container: string = ''")
|
||||
.Attr("shared_name: string = ''");
|
||||
|
||||
REGISTER_OP("IteratorSetStatsAggregator")
|
||||
.Input("iterator_handle: resource")
|
||||
.Input("stats_aggregator_handle: resource")
|
||||
.SetShapeFn(shape_inference::NoOutputs);
|
||||
|
||||
REGISTER_OP("StatsAggregatorSummary")
|
||||
.Input("iterator: resource")
|
||||
.Output("summary: string")
|
||||
|
@ -12364,18 +12364,6 @@ op {
|
||||
}
|
||||
is_stateful: true
|
||||
}
|
||||
op {
|
||||
name: "IteratorSetStatsAggregator"
|
||||
input_arg {
|
||||
name: "iterator_handle"
|
||||
type: DT_RESOURCE
|
||||
}
|
||||
input_arg {
|
||||
name: "stats_aggregator_handle"
|
||||
type: DT_RESOURCE
|
||||
}
|
||||
is_stateful: true
|
||||
}
|
||||
op {
|
||||
name: "IteratorToStringHandle"
|
||||
input_arg {
|
||||
|
Loading…
Reference in New Issue
Block a user