[data-stats] Adds TF 2.0 support for tf.data StatsAggregator.

PiperOrigin-RevId: 239852211
This commit is contained in:
Shivani Agrawal 2019-03-22 13:21:52 -07:00 committed by TensorFlower Gardener
parent 1a02534e7d
commit 6f566fe41c
21 changed files with 825 additions and 357 deletions

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "StatsAggregatorHandleV2"
visibility: HIDDEN
}

View File

@ -0,0 +1,5 @@
op {
graph_op_name: "StatsAggregatorSetSummaryWriter"
visibility: HIDDEN
summary: "Set a summary_writer_interface to record statistics using given stats_aggregator."
}

View File

@ -528,6 +528,9 @@ class IteratorBase {
return errors::Unimplemented("RestoreInternal");
}
// Returns the number of elements produced by this itertaor.
int64 num_elements() const { return node_->num_elements(); }
private:
friend class DatasetBase; // for access to `AddCleanupFunction`
friend class DatasetBaseIterator; // for access to `node_`

View File

@ -24,22 +24,22 @@ limitations under the License.
namespace tensorflow {
class Summary;
class SummaryWriterInterface;
namespace data {
// A `StatsAggregator` accumulates statistics incrementally. A
// `StatsAggregator` can accumulate multiple different statistics, distinguished
// by a string name.
//
// The class currently supports accumulating `Histogram` objects, and we expect
// to add other methods in future.
// The class currently supports accumulating `Histogram`, `scalar` objects and
// tfstreamz metrics, and we expect to add other methods in future.
//
// NOTE(mrry): `StatsAggregator` is a virtual interface because we anticipate
// that many different implementations will the same interface. For example, the
// current implementation in "stats_aggregator_ops.cc" is a simple in-memory
// implementation that integrates with the pull-based summary API, and we may
// add implementations that work with the push-based `SummaryWriterInterface`,
// as well as custom monitoring services.
// that many different implementations will have the same interface. For
// example, we have diffferent implementations in "stats_aggregator_ops.cc" for
// simple in-memory implementation that integrates with the pull-based summary
// API, and for the push-based `SummaryWriterInterface`, and we may add
// implementations that work well with other custom monitoring services.
class StatsAggregator {
public:
virtual ~StatsAggregator() {}
@ -47,19 +47,21 @@ class StatsAggregator {
// Add the given `values` to the histogram with the given `name`. Each
// element of `values` will be treated as a separate sample in the histogram.
virtual void AddToHistogram(const string& name,
gtl::ArraySlice<double> values) = 0;
gtl::ArraySlice<double> values,
int64 global_step = -1) = 0;
// TODO(shivaniagarawal): consistency in double and float usage.
// Add the given `value` as Scalar with the given `name`.
virtual void AddScalar(const string& name, float value) = 0;
virtual void AddScalar(const string& name, float value,
int64 global_step = -1) = 0;
// Stores a protocol buffer representation of the aggregator state in the
// given `out_summary`.
// TODO(mrry): Consider separating this method from the `StatsAggregator`
// interface. It is possible that not all implementations will support
// encoding their state as a protocol buffer.
virtual void EncodeToProto(Summary* out_summary) = 0;
// Sets a `summary_writer` with this stats_aggregator.
virtual Status SetSummaryWriter(SummaryWriterInterface* summary_writer) = 0;
// Increment the `label` cell of metrics mapped with `name` by given `value`.
virtual void IncrementCounter(const string& name, const string& label,
int64 val) = 0;

View File

@ -342,6 +342,7 @@ tf_kernel_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/kernels:summary_interface",
],
)

View File

@ -33,13 +33,13 @@ class StatsAggregatorWithTagAndPrefix : public StatsAggregator {
const string& prefix)
: wrapped_(stats_aggregator), tag_(tag), prefix_(prefix) {}
void AddToHistogram(const string& name,
gtl::ArraySlice<double> values) override {
wrapped_->AddToHistogram(TaggedName(name), values);
void AddToHistogram(const string& name, gtl::ArraySlice<double> values,
int64 steps) override {
wrapped_->AddToHistogram(TaggedName(name), values, steps);
}
void AddScalar(const string& name, float value) override {
wrapped_->AddScalar(TaggedName(name), value);
void AddScalar(const string& name, float value, int64 steps) override {
wrapped_->AddScalar(TaggedName(name), value, steps);
}
void EncodeToProto(Summary* out_summary) override {
@ -57,6 +57,10 @@ class StatsAggregatorWithTagAndPrefix : public StatsAggregator {
}
}
Status SetSummaryWriter(SummaryWriterInterface* summary_writer) override {
return wrapped_->SetSummaryWriter(summary_writer);
}
private:
string TaggedName(const string& name) const {
if (!tag_.empty()) {

View File

@ -19,11 +19,14 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_op_kernel.h"
#include "tensorflow/core/framework/summary.pb.h"
#include "tensorflow/core/kernels/summary_interface.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/histogram/histogram.h"
#include "tensorflow/core/lib/monitoring/counter.h"
#include "tensorflow/core/lib/monitoring/gauge.h"
#include "tensorflow/core/lib/monitoring/sampler.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/util/events_writer.h"
namespace tensorflow {
namespace data {
@ -44,8 +47,8 @@ class StatsAggregatorImpl : public StatsAggregator {
public:
StatsAggregatorImpl() {}
void AddToHistogram(const string& name,
gtl::ArraySlice<double> values) override {
void AddToHistogram(const string& name, gtl::ArraySlice<double> values,
const int64 steps) override {
mutex_lock l(mu_);
histogram::Histogram& histogram = histograms_[name];
for (double value : values) {
@ -53,7 +56,7 @@ class StatsAggregatorImpl : public StatsAggregator {
}
}
void AddScalar(const string& name, float value) override {
void AddScalar(const string& name, float value, const int64 steps) override {
mutex_lock l(mu_);
scalars_[name] = value;
}
@ -76,6 +79,13 @@ class StatsAggregatorImpl : public StatsAggregator {
}
}
// StatsAggregator implementation for V2 is based on push-based summary, no-op
// in V1.
Status SetSummaryWriter(
SummaryWriterInterface* summary_writer_interface) override {
return Status::OK();
}
void IncrementCounter(const string& name, const string& label,
int64 val) override {
mutex_lock l(*get_counters_map_lock());
@ -112,8 +122,125 @@ class StatsAggregatorHandleOp
new StatsAggregatorResource(absl::make_unique<StatsAggregatorImpl>());
return Status::OK();
}
};
Status VerifyResource(StatsAggregatorResource* resource) override {
class StatsAggregatorImplV2 : public StatsAggregator {
public:
StatsAggregatorImplV2() {}
~StatsAggregatorImplV2() override {
if (summary_writer_interface_) {
summary_writer_interface_->Unref();
}
}
void AddToHistogram(const string& name, gtl::ArraySlice<double> values,
const int64 steps) override {
mutex_lock l(mu_);
histogram::Histogram& histogram = histograms_[name];
for (double value : values) {
histogram.Add(value);
}
AddToEvents(name, steps, histogram);
}
void AddScalar(const string& name, float value, const int64 steps) override {
mutex_lock l(mu_);
AddToEvents(name, steps, value);
}
// TODO(b/116314787): expose this is public API to manually flush summary.
Status Flush() {
mutex_lock l(mu_);
if (summary_writer_interface_)
TF_RETURN_IF_ERROR(summary_writer_interface_->Flush());
return Status::OK();
}
void IncrementCounter(const string& name, const string& label,
int64 val) override {
mutex_lock l(*get_counters_map_lock());
auto counters_map = get_counters_map();
if (counters_map->find(name) == counters_map->end()) {
counters_map->emplace(
name, monitoring::Counter<1>::New(
/*streamz name*/ "/tensorflow/" + name,
/*streamz description*/
name + " generated or consumed by the component.",
/*streamz label name*/ "component_descriptor"));
}
counters_map->at(name)->GetCell(label)->IncrementBy(val);
}
// StatsAggregator implementation for V1 is based on pull-based summary, no-op
// in V2.
void EncodeToProto(Summary* out_summary) override {}
Status SetSummaryWriter(
SummaryWriterInterface* summary_writer_interface) override {
mutex_lock l(mu_);
if (summary_writer_interface_) {
return errors::FailedPrecondition(
"The SummaryWriter for a StatsAggregator may only be set once.");
} else {
summary_writer_interface_ = summary_writer_interface;
summary_writer_interface_->Ref();
return Status::OK();
}
}
private:
void AddToEvents(const string& name, const int64 steps,
const float scalar_value) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (summary_writer_interface_ == nullptr) {
return;
}
std::unique_ptr<Event> e{new Event};
e->set_step(steps);
tensorflow::Env* env = tensorflow::Env::Default();
e->set_wall_time(env->NowMicros() / 1.0e6);
// maybe expose GetWallTime in SummaryWriterInterface
Summary::Value* v = e->mutable_summary()->add_value();
v->set_tag(name);
v->set_simple_value(scalar_value);
TF_CHECK_OK(summary_writer_interface_->WriteEvent(std::move(e)));
}
void AddToEvents(const string& name, const int64 steps,
const histogram::Histogram& histogram)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (summary_writer_interface_ == nullptr) {
return;
}
std::unique_ptr<Event> e{new Event};
e->set_step(steps);
tensorflow::Env* env = tensorflow::Env::Default();
e->set_wall_time(env->NowMicros() / 1.0e6);
Summary::Value* v = e->mutable_summary()->add_value();
v->set_tag(name);
histogram.EncodeToProto(v->mutable_histo(), false /* Drop zero buckets */);
TF_CHECK_OK(summary_writer_interface_->WriteEvent(std::move(e)));
}
mutex mu_;
SummaryWriterInterface* summary_writer_interface_ GUARDED_BY(mu_) = nullptr;
// not owned, we might be associating the default summary_writer from the
// context
std::unordered_map<string, histogram::Histogram> histograms_ GUARDED_BY(mu_);
TF_DISALLOW_COPY_AND_ASSIGN(StatsAggregatorImplV2);
};
class StatsAggregatorHandleOpV2
: public ResourceOpKernel<StatsAggregatorResource> {
public:
explicit StatsAggregatorHandleOpV2(OpKernelConstruction* ctx)
: ResourceOpKernel<StatsAggregatorResource>(ctx) {}
private:
Status CreateResource(StatsAggregatorResource** ret) override
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
*ret =
new StatsAggregatorResource(absl::make_unique<StatsAggregatorImplV2>());
return Status::OK();
}
};
@ -141,12 +268,45 @@ class StatsAggregatorSummaryOp : public OpKernel {
}
};
class StatsAggregatorSetSummaryWriterOp : public OpKernel {
public:
explicit StatsAggregatorSetSummaryWriterOp(OpKernelConstruction* ctx)
: OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
const Tensor& resource_handle_t = ctx->input(0);
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
errors::InvalidArgument("resource_handle must be a scalar"));
StatsAggregatorResource* resource;
OP_REQUIRES_OK(ctx,
LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
core::ScopedUnref unref_iterator(resource);
const Tensor& summary_resource_handle_t = ctx->input(1);
OP_REQUIRES(ctx,
TensorShapeUtils::IsScalar(summary_resource_handle_t.shape()),
errors::InvalidArgument("resource_handle must be a scalar"));
SummaryWriterInterface* sumamry_resource;
OP_REQUIRES_OK(
ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &sumamry_resource));
core::ScopedUnref unref_sumamry_resource(sumamry_resource);
TF_CHECK_OK(
resource->stats_aggregator()->SetSummaryWriter(sumamry_resource));
}
};
REGISTER_KERNEL_BUILDER(
Name("ExperimentalStatsAggregatorHandle").Device(DEVICE_CPU),
StatsAggregatorHandleOp);
REGISTER_KERNEL_BUILDER(Name("StatsAggregatorHandleV2").Device(DEVICE_CPU),
StatsAggregatorHandleOpV2);
REGISTER_KERNEL_BUILDER(
Name("ExperimentalStatsAggregatorSummary").Device(DEVICE_CPU),
StatsAggregatorSummaryOp);
REGISTER_KERNEL_BUILDER(
Name("StatsAggregatorSetSummaryWriter").Device(DEVICE_CPU),
StatsAggregatorSetSummaryWriterOp);
} // namespace
} // namespace data

View File

@ -108,8 +108,9 @@ class LatencyStatsDatasetOp : public UnaryDatasetOpKernel {
uint64 end = ctx->env()->NowMicros();
auto stats_aggregator = ctx->stats_aggregator();
if (stats_aggregator && !*end_of_sequence) {
ctx->stats_aggregator()->AddToHistogram(
dataset()->tag_, {static_cast<double>(end - start)});
int64 steps = num_elements();
stats_aggregator->AddToHistogram(
dataset()->tag_, {static_cast<double>(end - start)}, steps);
}
return s;
}
@ -220,8 +221,9 @@ class BytesProducedStatsDatasetOp : public UnaryDatasetOpKernel {
for (const Tensor& t : *out_tensors) {
total_bytes += t.TotalBytes();
}
ctx->stats_aggregator()->AddToHistogram(
dataset()->tag_, {static_cast<double>(total_bytes)});
int64 steps = num_elements();
stats_aggregator->AddToHistogram(
dataset()->tag_, {static_cast<double>(total_bytes)}, steps);
}
return s;
}

View File

@ -211,10 +211,11 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
if (stats_aggregator) {
mutex_lock l(mu_);
dropped_elements_++;
int64 steps = num_elements();
stats_aggregator->AddScalar(
stats_utils::DroppedElementsScalarName(
dataset()->node_name()),
static_cast<float>((dropped_elements_)));
static_cast<float>(dropped_elements_), steps);
stats_aggregator->IncrementCounter(dataset()->node_name(),
stats_utils::kDroppedElements,
@ -227,9 +228,10 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
if (stats_aggregator) {
mutex_lock l(mu_);
filtered_elements_++;
int64 steps = num_elements();
stats_aggregator->AddScalar(
stats_utils::FilterdElementsScalarName(dataset()->node_name()),
static_cast<float>((filtered_elements_)));
static_cast<float>(filtered_elements_), steps);
stats_aggregator->IncrementCounter(dataset()->node_name(),
stats_utils::kFilteredElements,

View File

@ -139,12 +139,13 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
mutex_lock parent_l(parent_mu_);
mutex_lock l(mu_);
if (stats_aggregator) {
int64 steps = num_elements();
stats_aggregator->AddScalar(
stats_utils::BufferSizeScalarName(dataset()->node_name()),
static_cast<float>(buffer_.size()));
static_cast<float>(buffer_.size()), steps);
stats_aggregator->AddScalar(
stats_utils::BufferCapacityScalarName(dataset()->node_name()),
static_cast<float>(auto_tuner_.buffer_limit()));
static_cast<float>(auto_tuner_.buffer_limit()), steps);
}
return input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
}
@ -232,16 +233,18 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
bool* end_of_sequence) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
const auto& stats_aggregator = ctx->stats_aggregator();
if (stats_aggregator) {
int64 steps = num_elements();
stats_aggregator->AddToHistogram(
stats_utils::BufferUtilizationHistogramName(dataset()->node_name()),
{static_cast<float>(buffer_.size()) /
static_cast<float>(auto_tuner_.buffer_limit())});
static_cast<float>(auto_tuner_.buffer_limit())},
steps);
stats_aggregator->AddScalar(
stats_utils::BufferSizeScalarName(dataset()->node_name()),
static_cast<float>(buffer_.size()));
static_cast<float>(buffer_.size()), steps);
stats_aggregator->AddScalar(
stats_utils::BufferCapacityScalarName(dataset()->node_name()),
static_cast<float>(auto_tuner_.buffer_limit()));
static_cast<float>(auto_tuner_.buffer_limit()), steps);
}
// A new element is available. Forward the status from computing it, and
// (if we successfully got an element) the output values.

View File

@ -17,6 +17,11 @@ limitations under the License.
namespace tensorflow {
REGISTER_OP("StatsAggregatorSetSummaryWriter")
.Input("stats_aggregator: resource")
.Input("summary: resource")
.SetShapeFn(shape_inference::NoOutputs);
REGISTER_OP("ExperimentalAutoShardDataset")
.Input("input_dataset: variant")
.Input("num_workers: int64")
@ -383,6 +388,12 @@ REGISTER_OP("ExperimentalStatsAggregatorHandle")
.Attr("container: string = ''")
.Attr("shared_name: string = ''");
REGISTER_OP("StatsAggregatorHandleV2")
.Output("handle: resource")
.SetShapeFn(shape_inference::ScalarShape)
.Attr("container: string = ''")
.Attr("shared_name: string = ''");
REGISTER_OP("ExperimentalStatsAggregatorSummary")
.Input("iterator: resource")
.Output("summary: string")

View File

@ -27,6 +27,18 @@ py_test(
],
)
py_test(
name = "cardinality_test",
srcs = ["cardinality_test.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python/data/experimental/ops:cardinality",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"@absl_py//absl/testing:parameterized",
],
)
cuda_py_test(
name = "copy_to_device_test",
size = "small",
@ -367,18 +379,6 @@ py_test(
],
)
py_test(
name = "cardinality_test",
srcs = ["cardinality_test.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python/data/experimental/ops:cardinality",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"@absl_py//absl/testing:parameterized",
],
)
py_test(
name = "override_threadpool_test",
size = "small",

View File

@ -25,7 +25,8 @@ from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
# TODO(b/116314787): add test coverage for StatsAggregatorV2.
@test_util.run_v1_only("b/116314787, add test coverage")
class LatencyAllEdgesTest(stats_dataset_test_base.StatsDatasetTestBase):
def testLatencyStatsOptimization(self):
@ -46,11 +47,13 @@ class LatencyAllEdgesTest(stats_dataset_test_base.StatsDatasetTestBase):
num_test_iterations=1)
summary_t = aggregator.get_summary()
summary_str = self.evaluate(summary_t)
self._assertSummaryHasCount(summary_str, "record_latency::TensorDataset/_1",
1)
self._assertSummaryHasCount(summary_str, "record_latency::MapDataset/_4", 1)
self._assertSummaryHasCount(summary_str,
"record_latency::PrefetchDataset/_6", 1)
self.assertSummaryHasCount(
summary_str, self.regexForNodeName("record_latency::TensorDataset"), 1)
self.assertSummaryHasCount(
summary_str, self.regexForNodeName("record_latency::MapDataset"), 1)
self.assertSummaryHasCount(
summary_str, self.regexForNodeName("record_latency::PrefetchDataset"),
1)
if __name__ == "__main__":

View File

@ -17,7 +17,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
@ -35,174 +34,312 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
def function_set_stats_aggregator(dataset,
aggregator,
prefix="",
counter_prefix=""):
return dataset.apply(
stats_ops.set_stats_aggregator(aggregator, prefix, counter_prefix))
def function_apply_options(dataset, aggregator, prefix="", counter_prefix=""):
options = dataset_ops.Options()
options.experimental_stats.aggregator = aggregator
options.experimental_stats.prefix = prefix
options.experimental_stats.counter_prefix = counter_prefix
options.experimental_stats.latency_all_edges = False
return dataset.with_options(options)
@test_util.run_all_in_graph_and_eager_modes
@parameterized.named_parameters(
("SetStatsAggregator", function_set_stats_aggregator),
("StatsOptions", function_apply_options),
)
class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
def testBytesProduced(self, dataset_transformation):
def testBytesProduced(self):
aggregator = stats_aggregator.StatsAggregator()
dataset = dataset_ops.Dataset.range(100).map(
lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).apply(
stats_ops.bytes_produced_stats("bytes_produced"))
dataset = dataset_transformation(dataset, aggregator)
dataset = self.datasetExperimentalStats(dataset, aggregator)
next_element = self.getNext(dataset, requires_initialization=True)
expected_sum = 0.0
for i in range(100):
self.assertAllEqual(
np.array([i] * i, dtype=np.int64), self.evaluate(next_element()))
summary_str = self.evaluate(aggregator.get_summary())
self._assertSummaryHasCount(summary_str, "bytes_produced", float(i + 1))
handle = self.getHandle(aggregator)
self.assertStatisticsHasCount(handle, "bytes_produced", float(i + 1),
i + 2)
expected_sum += i * 8.0
self._assertSummaryHasSum(summary_str, "bytes_produced", expected_sum)
self.assertStatisticsHasSum(handle, "bytes_produced", expected_sum, i + 2)
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element())
summary_str = self.evaluate(aggregator.get_summary())
self._assertSummaryHasCount(summary_str, "bytes_produced", 100.0)
self._assertSummaryHasSum(summary_str, "bytes_produced", expected_sum)
handle = self.getHandle(aggregator)
self.assertStatisticsHasCount(handle, "bytes_produced", 100.0, 101)
self.assertStatisticsHasSum(handle, "bytes_produced", expected_sum, 101)
def testLatencyStats(self, dataset_transformation):
def testLatencyStats(self):
aggregator = stats_aggregator.StatsAggregator()
dataset = dataset_ops.Dataset.range(100).apply(
stats_ops.latency_stats("record_latency"))
dataset = dataset_transformation(dataset, aggregator)
dataset = self.datasetExperimentalStats(dataset, aggregator)
next_element = self.getNext(dataset, requires_initialization=True)
for i in range(100):
self.assertEqual(i, self.evaluate(next_element()))
self._assertSummaryHasCount(
self.evaluate(aggregator.get_summary()), "record_latency",
float(i + 1))
handle = self.getHandle(aggregator)
self.assertStatisticsHasCount(handle, "record_latency", float(i + 1),
i + 2)
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element())
self._assertSummaryHasCount(
self.evaluate(aggregator.get_summary()), "record_latency", 100.0)
handle = self.getHandle(aggregator)
self.assertStatisticsHasCount(handle, "record_latency", 100.0, 101)
def testPrefetchBufferUtilization(self, dataset_transformation):
def testPrefetchBufferUtilization(self):
aggregator = stats_aggregator.StatsAggregator()
dataset = dataset_ops.Dataset.range(100).map(
lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).prefetch(-1)
dataset = dataset_transformation(dataset, aggregator)
dataset = self.datasetExperimentalStats(dataset, aggregator)
next_element = self.getNext(dataset, requires_initialization=True)
for i in range(100):
self.assertAllEqual(
np.array([i] * i, dtype=np.int64), self.evaluate(next_element()))
summary_str = self.evaluate(aggregator.get_summary())
self._assertSummaryHasCount(
summary_str,
handle = self.getHandle(aggregator)
self.assertStatisticsHasCount(
handle,
self.regexForNodeName("PrefetchDataset", "buffer_utilization"),
float(i + 1))
self._assertSummaryContains(
summary_str,
self.regexForNodeName("PrefetchDataset", "buffer_capacity"))
self._assertSummaryContains(
summary_str, self.regexForNodeName("PrefetchDataset", "buffer_size"))
self._assertSummaryHasRange(
summary_str,
self.regexForNodeName("PrefetchDataset", "buffer_utilization"), 0, 1)
float(i + 1),
3 * i + 4,
offset=2)
self.assertStatisticsContains(
handle, self.regexForNodeName("PrefetchDataset", "buffer_capacity"),
3 * i + 4)
self.assertStatisticsContains(
handle,
self.regexForNodeName("PrefetchDataset", "buffer_size"),
3 * i + 4,
offset=1)
self.assertStatisticsHasRange(
handle,
self.regexForNodeName("PrefetchDataset", "buffer_utilization"),
0,
1,
3 * i + 4,
offset=2)
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element())
summary_str = self.evaluate(aggregator.get_summary())
self._assertSummaryHasCount(
summary_str,
self.regexForNodeName("PrefetchDataset", "buffer_utilization"), 100)
handle = self.getHandle(aggregator)
self.assertStatisticsHasCount(
handle,
self.regexForNodeName("PrefetchDataset", "buffer_utilization"),
100,
301,
offset=2)
def testPrefetchBufferScalars(self, dataset_transformation):
def testPrefetchBufferScalars(self):
aggregator = stats_aggregator.StatsAggregator()
dataset = dataset_ops.Dataset.range(10).map(
lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).prefetch(1)
dataset = dataset_transformation(dataset, aggregator)
dataset = self.datasetExperimentalStats(dataset, aggregator)
next_element = self.getNext(dataset, requires_initialization=True)
for i in range(10):
self.assertAllEqual(
np.array([i] * i, dtype=np.int64), self.evaluate(next_element()))
summary_str = self.evaluate(aggregator.get_summary())
self._assertSummaryHasScalarValue(
summary_str,
self.regexForNodeName("PrefetchDataset", "buffer_capacity"), 1)
self._assertSummaryHasScalarValue(
summary_str, self.regexForNodeName("PrefetchDataset", "buffer_size"),
1)
handle = self.getHandle(aggregator)
self.assertStatisticsHasScalarValue(
handle, self.regexForNodeName("PrefetchDataset", "buffer_capacity"),
1, 3 * i + 4)
self.assertStatisticsHasScalarValue(
handle,
self.regexForNodeName("PrefetchDataset", "buffer_size"),
1,
3 * i + 4,
offset=1)
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element())
def testFilteredElementsStats(self, dataset_transformation):
def testFilteredElementsStats(self):
aggregator = stats_aggregator.StatsAggregator()
dataset = dataset_ops.Dataset.range(101).filter(
lambda x: math_ops.equal(math_ops.mod(x, 3), 0))
dataset = dataset_transformation(dataset, aggregator)
dataset = self.datasetExperimentalStats(dataset, aggregator)
next_element = self.getNext(dataset, requires_initialization=True)
for i in range(34):
self.assertEqual(i * 3, self.evaluate(next_element()))
summary_str = self.evaluate(aggregator.get_summary())
handle = self.getHandle(aggregator)
if i != 0:
self._assertSummaryHasScalarValue(
summary_str,
self.regexForNodeName("FilterDataset", "dropped_elements"),
self.assertStatisticsHasScalarValue(
handle, self.regexForNodeName("FilterDataset", "dropped_elements"),
float(i * 2))
self._assertSummaryHasScalarValue(
summary_str,
self.regexForNodeName("FilterDataset", "filtered_elements"),
self.assertStatisticsHasScalarValue(
handle, self.regexForNodeName("FilterDataset", "filtered_elements"),
float(i + 1))
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element())
summary_str = self.evaluate(aggregator.get_summary())
self._assertSummaryHasScalarValue(
summary_str, self.regexForNodeName("FilterDataset", "dropped_elements"),
handle = self.getHandle(aggregator)
self.assertStatisticsHasScalarValue(
handle, self.regexForNodeName("FilterDataset", "dropped_elements"),
67.0)
self._assertSummaryHasScalarValue(
summary_str, self.regexForNodeName("FilterDataset",
"filtered_elements"), 34.0)
self.assertStatisticsHasScalarValue(
handle, self.regexForNodeName("FilterDataset", "filtered_elements"),
34.0)
def testMapBufferUtilization(self, dataset_transformation):
def testReinitialize(self):
aggregator = stats_aggregator.StatsAggregator()
dataset = dataset_ops.Dataset.range(100).apply(
stats_ops.latency_stats("record_latency"))
dataset = self.datasetExperimentalStats(dataset, aggregator)
for j in range(5):
next_element = self.getNext(dataset, requires_initialization=True)
for i in range(100):
self.assertEqual(i, self.evaluate(next_element()))
handle = self.getHandle(aggregator)
self.assertStatisticsHasCount(handle, "record_latency",
float((j * 100) + i + 1),
(j * 100) + i + 2)
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element())
handle = self.getHandle(aggregator)
self.assertStatisticsHasCount(handle, "record_latency", (j + 1) * 100.0,
(j * 100) + 101)
def testNoAggregatorRegistered(self):
dataset = dataset_ops.Dataset.range(100).apply(
stats_ops.latency_stats("record_latency"))
next_element = self.getNext(dataset, requires_initialization=True)
for i in range(100):
self.assertEqual(i, self.evaluate(next_element()))
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element())
def testMultipleTags(self):
aggregator = stats_aggregator.StatsAggregator()
dataset = dataset_ops.Dataset.range(100).apply(
stats_ops.latency_stats("record_latency")).apply(
stats_ops.latency_stats("record_latency_2"))
dataset = self.datasetExperimentalStats(dataset, aggregator)
next_element = self.getNext(dataset, requires_initialization=True)
for i in range(100):
self.assertEqual(i, self.evaluate(next_element()))
handle = self.getHandle(aggregator)
self.assertStatisticsHasCount(
handle, "record_latency", float(i + 1), 2 * i + 3, offset=1)
self.assertStatisticsHasCount(handle, "record_latency_2", float(i + 1),
2 * i + 3)
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element())
handle = self.getHandle(aggregator)
self.assertStatisticsHasCount(
handle, "record_latency", 100.0, 201, offset=1)
self.assertStatisticsHasCount(handle, "record_latency_2", 100.0, 201)
def testRepeatedTags(self):
aggregator = stats_aggregator.StatsAggregator()
dataset = dataset_ops.Dataset.range(100).apply(
stats_ops.latency_stats("record_latency")).apply(
stats_ops.latency_stats("record_latency"))
dataset = self.datasetExperimentalStats(dataset, aggregator)
next_element = self.getNext(dataset, requires_initialization=True)
for i in range(100):
self.assertEqual(i, self.evaluate(next_element()))
handle = self.getHandle(aggregator)
self.assertStatisticsHasCount(handle, "record_latency",
float(2 * (i + 1)), 2 * i + 3)
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element())
handle = self.getHandle(aggregator)
self.assertStatisticsHasCount(handle, "record_latency", 200.0, 201)
def testMultipleIteratorsSameAggregator(self):
aggregator = stats_aggregator.StatsAggregator()
dataset = dataset_ops.Dataset.range(100).apply(
stats_ops.latency_stats("record_latency"))
dataset = self.datasetExperimentalStats(dataset, aggregator)
next_element1 = self.getNext(dataset, requires_initialization=True)
next_element2 = self.getNext(dataset, requires_initialization=True)
for i in range(100):
self.assertEqual(i * 2, self.evaluate(next_element1() + next_element2()))
handle = self.getHandle(aggregator)
self.assertStatisticsHasCount(handle, "record_latency",
float(2 * (i + 1)), 2 * i + 3)
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element1())
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element2())
handle = self.getHandle(aggregator)
self.assertStatisticsHasCount(handle, "record_latency", 200.0, 201)
def testMultipleDatasetWithPrefixes(self):
aggregator = stats_aggregator.StatsAggregator()
dataset = dataset_ops.Dataset.range(100).apply(
stats_ops.latency_stats("record_latency"))
dataset = self.datasetExperimentalStats(
dataset, aggregator, prefix="dataset1")
dataset2 = dataset_ops.Dataset.range(100).apply(
stats_ops.latency_stats("record_latency"))
dataset2 = self.datasetExperimentalStats(
dataset2, aggregator, prefix="dataset2")
next_element1 = self.getNext(dataset, requires_initialization=True)
next_element2 = self.getNext(dataset2, requires_initialization=True)
for i in range(100):
self.assertEqual(i * 2, self.evaluate(next_element1() + next_element2()))
handle = self.getHandle(aggregator)
self.assertStatisticsHasCount(
handle, "dataset1::record_latency", float(i + 1), 2 * i + 3, offset=1)
self.assertStatisticsHasCount(handle, "dataset2::record_latency",
float(i + 1), 2 * i + 3)
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element1())
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element2())
handle = self.getHandle(aggregator)
self.assertStatisticsHasCount(
handle, "dataset1::record_latency", 100.0, 201, offset=1)
self.assertStatisticsHasCount(handle, "dataset2::record_latency", 100.0,
201)
def testMultiplePrefetchStats(self):
aggregator = stats_aggregator.StatsAggregator()
dataset = dataset_ops.Dataset.range(10).prefetch(
2).filter(lambda x: math_ops.equal(math_ops.mod(x, 2), 0)).prefetch(1)
dataset = self.datasetExperimentalStats(dataset, aggregator)
next_element = self.getNext(dataset, requires_initialization=True)
for i in range(5):
self.assertEqual(i * 2, self.evaluate(next_element()))
handle = self.getHandle(aggregator)
# TODO(shivaniagarwal): using exact name of prefetch node than the regex,
# to differentiate between two prefetch. This might break in future, at
# which point, it would be best to disable this test.
self.assertStatisticsHasScalarValue(
handle, "PrefetchDataset/_5::buffer_capacity", 2)
self.assertStatisticsContains(handle, "PrefetchDataset/_5::buffer_size")
self.assertStatisticsHasScalarValue(
handle, "PrefetchDataset/_8::buffer_capacity", 1)
self.assertStatisticsContains(handle, "PrefetchDataset/_8::buffer_size")
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element())
@test_util.run_v1_only("b/116314787, add test coverage")
class ThreadUtilizationStatsTest(stats_dataset_test_base.StatsDatasetTestBase):
def testMapBufferUtilization(self):
def dataset_fn():
return dataset_ops.Dataset.range(10).map(
lambda x: array_ops.tile([x], ops.convert_to_tensor([x])),
num_parallel_calls=4)
self._testParallelCallsStats(
dataset_fn, {self.regexForNodeName("ParallelMapDataset")},
10,
dataset_transformation,
function_processing_time=True)
self.parallelCallsStats(
dataset_fn, {"ParallelMapDataset"}, 10, function_processing_time=True)
def testMapAutoTuneBufferUtilization(self, dataset_transformation):
def testMapAutoTuneBufferUtilization(self):
def dataset_fn():
return dataset_ops.Dataset.range(10).map(
lambda x: array_ops.tile([x], ops.convert_to_tensor([x])),
num_parallel_calls=optimization.AUTOTUNE)
self._testParallelCallsStats(
dataset_fn, {self.regexForNodeName("ParallelMapDataset")},
10,
dataset_transformation,
function_processing_time=True)
self.parallelCallsStats(
dataset_fn, {"ParallelMapDataset"}, 10, function_processing_time=True)
def testInterleaveAutoTuneBufferUtilization(self, dataset_transformation):
def testInterleaveAutoTuneBufferUtilization(self):
def dataset_fn():
@ -215,11 +352,9 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
cycle_length=1,
num_parallel_calls=optimization.AUTOTUNE)
self._testParallelCallsStats(
dataset_fn, {self.regexForNodeName("ParallelInterleaveDatasetV2")}, 10,
dataset_transformation)
self.parallelCallsStats(dataset_fn, {"ParallelInterleaveDatasetV2"}, 10)
def testMapAndBatchAutoTuneBufferUtilization(self, dataset_transformation):
def testMapAndBatchAutoTuneBufferUtilization(self):
def dataset_fn():
return dataset_ops.Dataset.range(100).apply(
@ -229,172 +364,19 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
batch_size=16))
num_output = 100 // 16 + 1
self._testParallelCallsStats(
dataset_fn, {self.regexForNodeName("ExperimentalMapAndBatchDataset")},
self.parallelCallsStats(
dataset_fn, {"ExperimentalMapAndBatchDataset"},
num_output,
dataset_transformation,
check_elements=False,
function_processing_time=True)
def testReinitialize(self, dataset_transformation):
aggregator = stats_aggregator.StatsAggregator()
dataset = dataset_ops.Dataset.range(100).apply(
stats_ops.latency_stats("record_latency"))
dataset = dataset_transformation(dataset, aggregator)
for j in range(5):
next_element = self.getNext(dataset, requires_initialization=True)
for i in range(100):
self.assertEqual(i, self.evaluate(next_element()))
self._assertSummaryHasCount(
self.evaluate(aggregator.get_summary()), "record_latency",
float((j * 100) + i + 1))
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element())
self._assertSummaryHasCount(
self.evaluate(aggregator.get_summary()), "record_latency",
(j + 1) * 100.0)
def testNoAggregatorRegistered(self, dataset_transformation):
dataset = dataset_ops.Dataset.range(100).apply(
stats_ops.latency_stats("record_latency"))
next_element = self.getNext(dataset, requires_initialization=True)
for i in range(100):
self.assertEqual(i, self.evaluate(next_element()))
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element())
def testMultipleTags(self, dataset_transformation):
aggregator = stats_aggregator.StatsAggregator()
dataset = dataset_ops.Dataset.range(100).apply(
stats_ops.latency_stats("record_latency")).apply(
stats_ops.latency_stats("record_latency_2"))
dataset = dataset_transformation(dataset, aggregator)
next_element = self.getNext(dataset, requires_initialization=True)
for i in range(100):
self.assertEqual(i, self.evaluate(next_element()))
self._assertSummaryHasCount(
self.evaluate(aggregator.get_summary()), "record_latency",
float(i + 1))
self._assertSummaryHasCount(
self.evaluate(aggregator.get_summary()), "record_latency_2",
float(i + 1))
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element())
self._assertSummaryHasCount(
self.evaluate(aggregator.get_summary()), "record_latency", 100.0)
self._assertSummaryHasCount(
self.evaluate(aggregator.get_summary()), "record_latency_2", 100.0)
def testRepeatedTags(self, dataset_transformation):
aggregator = stats_aggregator.StatsAggregator()
dataset = dataset_ops.Dataset.range(100).apply(
stats_ops.latency_stats("record_latency")).apply(
stats_ops.latency_stats("record_latency"))
dataset = dataset_transformation(dataset, aggregator)
next_element = self.getNext(dataset, requires_initialization=True)
for i in range(100):
self.assertEqual(i, self.evaluate(next_element()))
self._assertSummaryHasCount(
self.evaluate(aggregator.get_summary()), "record_latency",
float(2 * (i + 1)))
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element())
self._assertSummaryHasCount(
self.evaluate(aggregator.get_summary()), "record_latency", 200.0)
def testMultipleIteratorsSameAggregator(self, dataset_transformation):
aggregator = stats_aggregator.StatsAggregator()
dataset = dataset_ops.Dataset.range(100).apply(
stats_ops.latency_stats("record_latency"))
dataset = dataset_transformation(dataset, aggregator)
next_element1 = self.getNext(dataset, requires_initialization=True)
next_element2 = self.getNext(dataset, requires_initialization=True)
for i in range(100):
self.assertEqual(i * 2, self.evaluate(next_element1() + next_element2()))
self._assertSummaryHasCount(
self.evaluate(aggregator.get_summary()), "record_latency",
float(2 * (i + 1)))
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element1())
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element2())
self._assertSummaryHasCount(
self.evaluate(aggregator.get_summary()), "record_latency", 200.0)
def testMultipleDatasetWithPrefixes(self, dataset_transformation):
aggregator = stats_aggregator.StatsAggregator()
dataset = dataset_ops.Dataset.range(100).apply(
stats_ops.latency_stats("record_latency"))
dataset = dataset_transformation(dataset, aggregator, prefix="dataset1")
dataset2 = dataset_ops.Dataset.range(100).apply(
stats_ops.latency_stats("record_latency"))
dataset2 = dataset_transformation(dataset2, aggregator, prefix="dataset2")
next_element1 = self.getNext(dataset, requires_initialization=True)
next_element2 = self.getNext(dataset2, requires_initialization=True)
for i in range(100):
self.assertEqual(i * 2, self.evaluate(next_element1() + next_element2()))
self._assertSummaryHasCount(
self.evaluate(aggregator.get_summary()), "dataset1::record_latency",
float(i + 1))
self._assertSummaryHasCount(
self.evaluate(aggregator.get_summary()), "dataset2::record_latency",
float(i + 1))
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element1())
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element2())
self._assertSummaryHasCount(
self.evaluate(aggregator.get_summary()), "dataset1::record_latency",
100.0)
self._assertSummaryHasCount(
self.evaluate(aggregator.get_summary()), "dataset2::record_latency",
100.0)
def testMultiplePrefetchStats(self, dataset_transformation):
aggregator = stats_aggregator.StatsAggregator()
dataset = dataset_ops.Dataset.range(10).prefetch(
2).map(lambda x: math_ops.add(x, 2)).prefetch(1)
dataset = dataset_transformation(dataset, aggregator)
next_element = self.getNext(dataset, requires_initialization=True)
for i in range(10):
self.assertEqual(i + 2, self.evaluate(next_element()))
summary_str = self.evaluate(aggregator.get_summary())
# TODO(shivaniagarwal): using exact name of prefetch node than the regex,
# to differentiate between two prefetch. This might break in future, at
# which point, it would be best to disable this test.
self._assertSummaryHasScalarValue(
summary_str, "PrefetchDataset/_5::buffer_capacity", 2)
self._assertSummaryContains(summary_str,
"PrefetchDataset/_5::buffer_size")
self._assertSummaryHasScalarValue(
summary_str, "PrefetchDataset/_8::buffer_capacity", 1)
self._assertSummaryContains(summary_str,
"PrefetchDataset/_8::buffer_size")
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element())
@test_util.run_all_in_graph_and_eager_modes
@parameterized.named_parameters(
("SetStatsAggregator", function_set_stats_aggregator),
("StatsOptions", function_apply_options)
)
@test_util.run_v1_only("b/116314787, add test coverage")
class FeatureStatsDatasetTest(
stats_dataset_test_base.StatsDatasetTestBase,
reader_dataset_ops_test_base.MakeBatchedFeaturesDatasetTestBase):
def testFeaturesStats(self, dataset_transformation):
def testFeaturesStats(self):
num_epochs = 5
total_records = num_epochs * self._num_records
batch_size = 2
@ -413,13 +395,12 @@ class FeatureStatsDatasetTest(
if total_records % batch_size:
num_output = total_records // batch_size + 1
self._testParallelCallsStats(
dataset_fn, {self.regexForNodeName("ExperimentalParseExampleDataset")},
self.parallelCallsStats(
dataset_fn, {"ExperimentalParseExampleDataset"},
num_output,
dataset_transformation,
check_elements=False)
dataset = dataset_transformation(
dataset = self.datasetExperimentalStats(
dataset_fn(), aggregator, prefix="record_stats")
next_element = self.getNext(dataset, requires_initialization=True)
@ -429,20 +410,21 @@ class FeatureStatsDatasetTest(
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element())
self._assertSummaryHasCount(
self.evaluate(aggregator.get_summary()),
handle = self.getHandle(aggregator)
self.assertSummaryHasCount(
handle,
self.regexForNodeName("record_stats::ExperimentalParseExampleDataset",
"features_count"), total_records)
self._assertSummaryHasCount(
self.evaluate(aggregator.get_summary()),
self.assertSummaryHasCount(
handle,
self.regexForNodeName("record_stats::ExperimentalParseExampleDataset",
"feature_values_count"), total_records)
self._assertSummaryHasSum(
self.evaluate(aggregator.get_summary()),
self.assertSummaryHasSum(
handle,
self.regexForNodeName("record_stats::ExperimentalParseExampleDataset",
"features_count"), total_records * 4)
self._assertSummaryHasSum(
self.evaluate(aggregator.get_summary()),
self.assertSummaryHasSum(
handle,
self.regexForNodeName("record_stats::ExperimentalParseExampleDataset",
"feature_values_count"),
self._sum_keywords(1) * num_epochs + 3 * total_records)

View File

@ -17,22 +17,106 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import re
import numpy as np
from tensorflow.core.framework import summary_pb2
from tensorflow.core.util import event_pb2
from tensorflow.python import tf2
from tensorflow.python.data.experimental.ops import stats_aggregator
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.lib.io import tf_record
from tensorflow.python.platform import gfile
class StatsDatasetTestBase(test_base.DatasetTestBase):
"""Base class for testing statistics gathered in `StatsAggregator`."""
def regexForNodeName(self, op_name, stats_type=""):
return "".join([op_name, r"/_\d+::", stats_type])
@classmethod
def setUpClass(cls):
if tf2.enabled():
stats_aggregator._DEFAULT_MAX_QUEUE = 0 # pylint: disable=protected-access
stats_aggregator.StatsAggregator = stats_aggregator.StatsAggregatorV2
# TODO(b/116314787): add graph mode support for StatsAggregatorV2.
else:
stats_aggregator.StatsAggregator = stats_aggregator.StatsAggregatorV1
return test_util.run_all_in_graph_and_eager_modes(cls)
def _assertSummaryContains(self, summary_str, tag):
def datasetExperimentalStats(self,
dataset,
aggregator,
prefix="",
counter_prefix=""):
options = dataset_ops.Options()
options.experimental_stats.aggregator = aggregator
options.experimental_stats.prefix = prefix
options.experimental_stats.counter_prefix = counter_prefix
options.experimental_stats.latency_all_edges = False
return dataset.with_options(options)
def regexForNodeName(self, op_name, stats_type=""):
if stats_type:
return "".join([op_name, r"/_\d+::", stats_type])
return "".join([op_name, r"/_\d+"])
def assertStatisticsContains(self, handle, tag, num_events=-1, offset=0):
if tf2.enabled():
self.assertEventContains(handle, tag, num_events, offset)
else:
self.assertSummaryContains(handle, tag)
def assertStatisticsHasCount(self,
handle,
tag,
count,
num_events=-1,
offset=0):
if tf2.enabled():
self.assertEventHasCount(handle, tag, count, num_events, offset)
else:
self.assertSummaryHasCount(handle, tag, count)
def assertStatisticsHasSum(self,
handle,
tag,
expected_value,
num_events=-1,
offset=0):
if tf2.enabled():
self.assertEventHasSum(handle, tag, expected_value, num_events, offset)
else:
self.assertSummaryHasSum(handle, tag, expected_value)
def assertStatisticsHasScalarValue(self,
handle,
tag,
expected_value,
num_events=-1,
offset=0):
if tf2.enabled():
self.assertEventHasScalarValue(handle, tag, expected_value, num_events,
offset)
else:
self.assertSummaryHasScalarValue(handle, tag, expected_value)
def assertStatisticsHasRange(self,
handle,
tag,
min_value,
max_value,
num_events=-1,
offset=0):
if tf2.enabled():
self.assertEventHasRange(handle, tag, min_value, max_value, num_events,
offset)
else:
self.assertSummaryHasRange(handle, tag, min_value, max_value)
def assertSummaryContains(self, summary_str, tag):
summary_proto = summary_pb2.Summary()
summary_proto.ParseFromString(summary_str)
for value in summary_proto.value:
@ -40,11 +124,11 @@ class StatsDatasetTestBase(test_base.DatasetTestBase):
return
self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
def _assertSummaryHasCount(self,
summary_str,
tag,
expected_value,
greater_than=False):
def assertSummaryHasCount(self,
summary_str,
tag,
expected_value,
greater_than=False):
summary_proto = summary_pb2.Summary()
summary_proto.ParseFromString(summary_str)
for value in summary_proto.value:
@ -56,7 +140,7 @@ class StatsDatasetTestBase(test_base.DatasetTestBase):
return
self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
def _assertSummaryHasRange(self, summary_str, tag, min_value, max_value):
def assertSummaryHasRange(self, summary_str, tag, min_value, max_value):
summary_proto = summary_pb2.Summary()
summary_proto.ParseFromString(summary_str)
for value in summary_proto.value:
@ -66,7 +150,7 @@ class StatsDatasetTestBase(test_base.DatasetTestBase):
return
self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
def _assertSummaryHasSum(self, summary_str, tag, expected_value):
def assertSummaryHasSum(self, summary_str, tag, expected_value):
summary_proto = summary_pb2.Summary()
summary_proto.ParseFromString(summary_str)
for value in summary_proto.value:
@ -75,7 +159,7 @@ class StatsDatasetTestBase(test_base.DatasetTestBase):
return
self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
def _assertSummaryHasScalarValue(self, summary_str, tag, expected_value):
def assertSummaryHasScalarValue(self, summary_str, tag, expected_value):
summary_proto = summary_pb2.Summary()
summary_proto.ParseFromString(summary_str)
for value in summary_proto.value:
@ -84,39 +168,165 @@ class StatsDatasetTestBase(test_base.DatasetTestBase):
return
self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
def _testParallelCallsStats(self,
dataset_fn,
dataset_names,
num_output,
dataset_transformation,
function_processing_time=False,
check_elements=True):
def assertEventContains(self, logdir, tag, num_events, offset):
events = _events_from_logdir(logdir)
if num_events == -1:
self.assertGreater(len(events), 1)
for event in events[::-1]:
if re.match(tag, event.summary.value[0].tag):
return
self.fail("Expected tag %r not found in event file in %r" % (tag, logdir))
else:
self.assertEqual(len(events), num_events)
self.assertTrue(
re.match(tag, events[num_events - offset - 1].summary.value[0].tag))
def assertEventHasCount(self, logdir, tag, count, num_events, offset):
events = _events_from_logdir(logdir)
if num_events == -1:
self.assertGreater(len(events), 1)
for event in events[::-1]:
if re.match(tag, event.summary.value[0].tag):
self.assertEqual(count, event.summary.value[0].histo.num)
return
self.fail("Expected tag %r not found in event file in %r" % (tag, logdir))
else:
self.assertEqual(len(events), num_events)
self.assertTrue(
re.match(tag, events[num_events - offset - 1].summary.value[0].tag))
self.assertEqual(
events[num_events - offset - 1].summary.value[0].histo.num, count)
def assertEventHasSum(self, logdir, tag, expected_value, num_events, offset):
events = _events_from_logdir(logdir)
if num_events == -1:
self.assertGreater(len(events), 1)
for event in events[::-1]:
if re.match(tag, event.summary.value[0].tag):
self.assertEqual(expected_value, event.summary.value[0].histo.sum)
return
self.fail("Expected tag %r not found in event file in %r" % (tag, logdir))
else:
self.assertEqual(len(events), num_events)
self.assertTrue(
re.match(tag, events[num_events - offset - 1].summary.value[0].tag))
self.assertEqual(
events[num_events - offset - 1].summary.value[0].histo.sum,
expected_value)
def assertEventHasRange(self, logdir, tag, min_value, max_value, num_events,
offset):
events = _events_from_logdir(logdir)
if num_events == -1:
self.assertGreater(len(events), 1)
for event in events[::-1]:
if re.match(tag, event.summary.value[0].tag):
self.assertLessEqual(min_value, event.summary.value[0].histo.min)
self.assertGreaterEqual(max_value, event.summary.value[0].histo.max)
return
self.fail("Expected tag %r not found in event file in %r" % (tag, logdir))
else:
self.assertEqual(len(events), num_events)
self.assertTrue(
re.match(tag, events[num_events - offset - 1].summary.value[0].tag))
self.assertLessEqual(
min_value, events[num_events - offset - 1].summary.value[0].histo.min)
self.assertGreaterEqual(
max_value, events[num_events - offset - 1].summary.value[0].histo.max)
def assertEventHasScalarValue(self, logdir, tag, expected_value, num_events,
offset):
events = _events_from_logdir(logdir)
if num_events == -1:
self.assertGreater(len(events), 1)
for event in events[::-1]:
if re.match(tag, event.summary.value[0].tag):
self.assertEqual(expected_value, event.summary.value[0].simple_value)
return
self.fail("Expected tag %r not found in event file in %r" % (tag, logdir))
else:
self.assertEqual(len(events), num_events)
self.assertTrue(
re.match(tag, events[num_events - offset - 1].summary.value[0].tag))
self.assertLessEqual(
expected_value,
events[num_events - offset - 1].summary.value[0].simple_value)
def getHandle(self, aggregator):
# pylint: disable=protected-access
if isinstance(aggregator, stats_aggregator.StatsAggregatorV1):
return self.evaluate(aggregator.get_summary())
assert isinstance(aggregator, (stats_aggregator.StatsAggregatorV2))
return aggregator._logdir
def parallelCallsStats(self,
dataset_fn,
dataset_names,
num_output,
function_processing_time=False,
check_elements=True):
aggregator = stats_aggregator.StatsAggregator()
dataset = dataset_fn()
dataset = dataset_transformation(dataset, aggregator)
dataset = self.datasetExperimentalStats(dataset, aggregator)
next_element = self.getNext(dataset, requires_initialization=True)
for i in range(num_output):
next_ = self.evaluate(next_element())
if check_elements:
self.assertAllEqual(np.array([i] * i, dtype=np.int64), next_)
summary_str = self.evaluate(aggregator.get_summary())
handle = self.getHandle(aggregator)
for dataset_name in dataset_names:
if function_processing_time:
self._assertSummaryHasCount(
summary_str,
r"(.*)::execution_time$",
float(i + 1),
greater_than=True)
self._assertSummaryContains(summary_str,
dataset_name + "thread_utilization")
self.assertSummaryHasCount(
handle, r"(.*)::execution_time$", float(i + 1), greater_than=True)
self.assertSummaryContains(
handle, self.regexForNodeName(dataset_name, "thread_utilization"))
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element())
if function_processing_time:
summary_str = self.evaluate(aggregator.get_summary())
if isinstance(aggregator, stats_aggregator.StatsAggregatorV1):
handle = self.getHandle(aggregator)
for dataset_name in dataset_names:
self._assertSummaryHasCount(
summary_str,
self.assertSummaryHasCount(
handle,
r"(.*)::execution_time$",
float(num_output),
greater_than=True)
# Adding these two methods from summary_test_util, as summary_test_util is in
# contrib.
def _events_from_file(filepath):
"""Returns all events in a single event file.
Args:
filepath: Path to the event file.
Returns:
A list of all tf.Event protos in the event file.
"""
records = list(tf_record.tf_record_iterator(filepath))
result = []
for r in records:
event = event_pb2.Event()
event.ParseFromString(r)
result.append(event)
return result
def _events_from_logdir(logdir):
"""Returns all events in the single eventfile in logdir.
Args:
logdir: The directory in which the single event file is sought.
Returns:
A list of all tf.Event protos from the single event file.
Raises:
AssertionError: If logdir does not contain exactly one file.
"""
assert gfile.Exists(logdir)
files = gfile.ListDirectory(logdir)
assert len(files) == 1, "Found not exactly one file in logdir: %s" % files
return _events_from_file(os.path.join(logdir, files[0]))

View File

@ -17,12 +17,71 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tempfile
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
from tensorflow.python.ops import summary_ops_v2
from tensorflow.python.util.tf_export import tf_export
@tf_export("data.experimental.StatsAggregator")
class StatsAggregator(object):
_DEFAULT_MAX_QUEUE = 10
@tf_export("data.experimental.StatsAggregator", v1=[])
class StatsAggregatorV2(object):
"""A stateful resource that aggregates statistics from one or more iterators.
To record statistics, use one of the custom transformation functions defined
in this module when defining your `tf.data.Dataset`. All statistics will be
aggregated by the `StatsAggregator` that is associated with a particular
iterator (see below). For example, to record the latency of producing each
element by iterating over a dataset:
```python
dataset = ...
dataset = dataset.apply(tf.data.experimental.latency_stats("total_bytes"))
```
To associate a `StatsAggregator` with a `tf.data.Dataset` object, use
the following pattern:
```python
aggregator = tf.data.experimental.StatsAggregator()
dataset = ...
# Apply `StatsOptions` to associate `dataset` with `aggregator`.
options = tf.data.Options()
options.experimental_stats.aggregator = aggregator
dataset = dataset.with_options(options)
```
Note: This interface is experimental and expected to change. In particular,
we expect to add other implementations of `StatsAggregator` that provide
different ways of exporting statistics, and add more types of statistics.
"""
def __init__(self):
self._resource = ged_ops.stats_aggregator_handle_v2()
# There could be a conflict with multiple file writer in the same logdir,
# (b/37351340). Possible workarounds till this bug is resolved are a) having
# multiple dataset stats specific file inside log_dir and b) get default
# summary writer, getting default summary writer quite doesn't solve the
# problem as there might be summary writers in log dir not set as default
# e.g. in Keras calback.
# Creating a summary_writer here could potentially be replaced with getting
# the default summary_writer if any, creating it otherwise or a public
# method to associate summary writer.
self._logdir = tempfile.mkdtemp()
self._summary_writer = summary_ops_v2.create_file_writer(
self._logdir, max_queue=_DEFAULT_MAX_QUEUE)
ged_ops.stats_aggregator_set_summary_writer(self._resource,
self._summary_writer._resource) # pylint: disable=protected-access
@tf_export(v1=["data.experimental.StatsAggregator"])
class StatsAggregatorV1(object):
"""A stateful resource that aggregates statistics from one or more iterators.
To record statistics, use one of the custom transformation functions defined
@ -70,7 +129,6 @@ class StatsAggregator(object):
"""Creates a `StatsAggregator`."""
self._resource = ged_ops.experimental_stats_aggregator_handle()
# TODO(b/116314787): Update this/add support for V2 summary API.
def get_summary(self):
"""Returns a string `tf.Tensor` that summarizes the aggregated statistics.
@ -81,3 +139,8 @@ class StatsAggregator(object):
A scalar string `tf.Tensor` that summarizes the aggregated statistics.
"""
return ged_ops.experimental_stats_aggregator_summary(self._resource)
# TODO(b/116314787): Change this to StatsAggregatorV2 when we have stable
# SummaryWriterInterface, and do not break any users.
StatsAggregator = StatsAggregatorV1

View File

@ -45,7 +45,8 @@ class StatsOptions(options.OptionsBase):
aggregator = options.create_option(
name="aggregator",
ty=stats_aggregator.StatsAggregator,
ty=(stats_aggregator.StatsAggregatorV2,
stats_aggregator.StatsAggregatorV1),
docstring=
"Associates the given statistics aggregator with the dataset pipeline.")

View File

@ -1,6 +1,6 @@
path: "tensorflow.data.experimental.StatsAggregator"
tf_class {
is_instance: "<class \'tensorflow.python.data.experimental.ops.stats_aggregator.StatsAggregator\'>"
is_instance: "<class \'tensorflow.python.data.experimental.ops.stats_aggregator.StatsAggregatorV1\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"

View File

@ -3740,6 +3740,14 @@ tf_module {
name: "StaticRegexReplace"
argspec: "args=[\'input\', \'pattern\', \'rewrite\', \'replace_global\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
}
member_method {
name: "StatsAggregatorHandleV2"
argspec: "args=[\'container\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'None\'], "
}
member_method {
name: "StatsAggregatorSetSummaryWriter"
argspec: "args=[\'stats_aggregator\', \'summary\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "StopGradient"
argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -1,13 +1,9 @@
path: "tensorflow.data.experimental.StatsAggregator"
tf_class {
is_instance: "<class \'tensorflow.python.data.experimental.ops.stats_aggregator.StatsAggregator\'>"
is_instance: "<class \'tensorflow.python.data.experimental.ops.stats_aggregator.StatsAggregatorV2\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_summary"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -3740,6 +3740,14 @@ tf_module {
name: "StaticRegexReplace"
argspec: "args=[\'input\', \'pattern\', \'rewrite\', \'replace_global\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
}
member_method {
name: "StatsAggregatorHandleV2"
argspec: "args=[\'container\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'None\'], "
}
member_method {
name: "StatsAggregatorSetSummaryWriter"
argspec: "args=[\'stats_aggregator\', \'summary\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "StopGradient"
argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "