[data-stats] Adds TF 2.0 support for tf.data
StatsAggregator.
PiperOrigin-RevId: 239852211
This commit is contained in:
parent
1a02534e7d
commit
6f566fe41c
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "StatsAggregatorHandleV2"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,5 @@
|
||||
op {
|
||||
graph_op_name: "StatsAggregatorSetSummaryWriter"
|
||||
visibility: HIDDEN
|
||||
summary: "Set a summary_writer_interface to record statistics using given stats_aggregator."
|
||||
}
|
@ -528,6 +528,9 @@ class IteratorBase {
|
||||
return errors::Unimplemented("RestoreInternal");
|
||||
}
|
||||
|
||||
// Returns the number of elements produced by this itertaor.
|
||||
int64 num_elements() const { return node_->num_elements(); }
|
||||
|
||||
private:
|
||||
friend class DatasetBase; // for access to `AddCleanupFunction`
|
||||
friend class DatasetBaseIterator; // for access to `node_`
|
||||
|
@ -24,22 +24,22 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
|
||||
class Summary;
|
||||
|
||||
class SummaryWriterInterface;
|
||||
namespace data {
|
||||
|
||||
// A `StatsAggregator` accumulates statistics incrementally. A
|
||||
// `StatsAggregator` can accumulate multiple different statistics, distinguished
|
||||
// by a string name.
|
||||
//
|
||||
// The class currently supports accumulating `Histogram` objects, and we expect
|
||||
// to add other methods in future.
|
||||
// The class currently supports accumulating `Histogram`, `scalar` objects and
|
||||
// tfstreamz metrics, and we expect to add other methods in future.
|
||||
//
|
||||
// NOTE(mrry): `StatsAggregator` is a virtual interface because we anticipate
|
||||
// that many different implementations will the same interface. For example, the
|
||||
// current implementation in "stats_aggregator_ops.cc" is a simple in-memory
|
||||
// implementation that integrates with the pull-based summary API, and we may
|
||||
// add implementations that work with the push-based `SummaryWriterInterface`,
|
||||
// as well as custom monitoring services.
|
||||
// that many different implementations will have the same interface. For
|
||||
// example, we have diffferent implementations in "stats_aggregator_ops.cc" for
|
||||
// simple in-memory implementation that integrates with the pull-based summary
|
||||
// API, and for the push-based `SummaryWriterInterface`, and we may add
|
||||
// implementations that work well with other custom monitoring services.
|
||||
class StatsAggregator {
|
||||
public:
|
||||
virtual ~StatsAggregator() {}
|
||||
@ -47,19 +47,21 @@ class StatsAggregator {
|
||||
// Add the given `values` to the histogram with the given `name`. Each
|
||||
// element of `values` will be treated as a separate sample in the histogram.
|
||||
virtual void AddToHistogram(const string& name,
|
||||
gtl::ArraySlice<double> values) = 0;
|
||||
gtl::ArraySlice<double> values,
|
||||
int64 global_step = -1) = 0;
|
||||
|
||||
// TODO(shivaniagarawal): consistency in double and float usage.
|
||||
// Add the given `value` as Scalar with the given `name`.
|
||||
virtual void AddScalar(const string& name, float value) = 0;
|
||||
virtual void AddScalar(const string& name, float value,
|
||||
int64 global_step = -1) = 0;
|
||||
|
||||
// Stores a protocol buffer representation of the aggregator state in the
|
||||
// given `out_summary`.
|
||||
// TODO(mrry): Consider separating this method from the `StatsAggregator`
|
||||
// interface. It is possible that not all implementations will support
|
||||
// encoding their state as a protocol buffer.
|
||||
virtual void EncodeToProto(Summary* out_summary) = 0;
|
||||
|
||||
// Sets a `summary_writer` with this stats_aggregator.
|
||||
virtual Status SetSummaryWriter(SummaryWriterInterface* summary_writer) = 0;
|
||||
|
||||
// Increment the `label` cell of metrics mapped with `name` by given `value`.
|
||||
virtual void IncrementCounter(const string& name, const string& label,
|
||||
int64 val) = 0;
|
||||
|
@ -342,6 +342,7 @@ tf_kernel_library(
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/kernels:summary_interface",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -33,13 +33,13 @@ class StatsAggregatorWithTagAndPrefix : public StatsAggregator {
|
||||
const string& prefix)
|
||||
: wrapped_(stats_aggregator), tag_(tag), prefix_(prefix) {}
|
||||
|
||||
void AddToHistogram(const string& name,
|
||||
gtl::ArraySlice<double> values) override {
|
||||
wrapped_->AddToHistogram(TaggedName(name), values);
|
||||
void AddToHistogram(const string& name, gtl::ArraySlice<double> values,
|
||||
int64 steps) override {
|
||||
wrapped_->AddToHistogram(TaggedName(name), values, steps);
|
||||
}
|
||||
|
||||
void AddScalar(const string& name, float value) override {
|
||||
wrapped_->AddScalar(TaggedName(name), value);
|
||||
void AddScalar(const string& name, float value, int64 steps) override {
|
||||
wrapped_->AddScalar(TaggedName(name), value, steps);
|
||||
}
|
||||
|
||||
void EncodeToProto(Summary* out_summary) override {
|
||||
@ -57,6 +57,10 @@ class StatsAggregatorWithTagAndPrefix : public StatsAggregator {
|
||||
}
|
||||
}
|
||||
|
||||
Status SetSummaryWriter(SummaryWriterInterface* summary_writer) override {
|
||||
return wrapped_->SetSummaryWriter(summary_writer);
|
||||
}
|
||||
|
||||
private:
|
||||
string TaggedName(const string& name) const {
|
||||
if (!tag_.empty()) {
|
||||
|
@ -19,11 +19,14 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/resource_op_kernel.h"
|
||||
#include "tensorflow/core/framework/summary.pb.h"
|
||||
#include "tensorflow/core/kernels/summary_interface.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/histogram/histogram.h"
|
||||
#include "tensorflow/core/lib/monitoring/counter.h"
|
||||
#include "tensorflow/core/lib/monitoring/gauge.h"
|
||||
#include "tensorflow/core/lib/monitoring/sampler.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/util/events_writer.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
@ -44,8 +47,8 @@ class StatsAggregatorImpl : public StatsAggregator {
|
||||
public:
|
||||
StatsAggregatorImpl() {}
|
||||
|
||||
void AddToHistogram(const string& name,
|
||||
gtl::ArraySlice<double> values) override {
|
||||
void AddToHistogram(const string& name, gtl::ArraySlice<double> values,
|
||||
const int64 steps) override {
|
||||
mutex_lock l(mu_);
|
||||
histogram::Histogram& histogram = histograms_[name];
|
||||
for (double value : values) {
|
||||
@ -53,7 +56,7 @@ class StatsAggregatorImpl : public StatsAggregator {
|
||||
}
|
||||
}
|
||||
|
||||
void AddScalar(const string& name, float value) override {
|
||||
void AddScalar(const string& name, float value, const int64 steps) override {
|
||||
mutex_lock l(mu_);
|
||||
scalars_[name] = value;
|
||||
}
|
||||
@ -76,6 +79,13 @@ class StatsAggregatorImpl : public StatsAggregator {
|
||||
}
|
||||
}
|
||||
|
||||
// StatsAggregator implementation for V2 is based on push-based summary, no-op
|
||||
// in V1.
|
||||
Status SetSummaryWriter(
|
||||
SummaryWriterInterface* summary_writer_interface) override {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void IncrementCounter(const string& name, const string& label,
|
||||
int64 val) override {
|
||||
mutex_lock l(*get_counters_map_lock());
|
||||
@ -112,8 +122,125 @@ class StatsAggregatorHandleOp
|
||||
new StatsAggregatorResource(absl::make_unique<StatsAggregatorImpl>());
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
Status VerifyResource(StatsAggregatorResource* resource) override {
|
||||
class StatsAggregatorImplV2 : public StatsAggregator {
|
||||
public:
|
||||
StatsAggregatorImplV2() {}
|
||||
|
||||
~StatsAggregatorImplV2() override {
|
||||
if (summary_writer_interface_) {
|
||||
summary_writer_interface_->Unref();
|
||||
}
|
||||
}
|
||||
|
||||
void AddToHistogram(const string& name, gtl::ArraySlice<double> values,
|
||||
const int64 steps) override {
|
||||
mutex_lock l(mu_);
|
||||
histogram::Histogram& histogram = histograms_[name];
|
||||
for (double value : values) {
|
||||
histogram.Add(value);
|
||||
}
|
||||
AddToEvents(name, steps, histogram);
|
||||
}
|
||||
|
||||
void AddScalar(const string& name, float value, const int64 steps) override {
|
||||
mutex_lock l(mu_);
|
||||
AddToEvents(name, steps, value);
|
||||
}
|
||||
|
||||
// TODO(b/116314787): expose this is public API to manually flush summary.
|
||||
Status Flush() {
|
||||
mutex_lock l(mu_);
|
||||
if (summary_writer_interface_)
|
||||
TF_RETURN_IF_ERROR(summary_writer_interface_->Flush());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void IncrementCounter(const string& name, const string& label,
|
||||
int64 val) override {
|
||||
mutex_lock l(*get_counters_map_lock());
|
||||
auto counters_map = get_counters_map();
|
||||
if (counters_map->find(name) == counters_map->end()) {
|
||||
counters_map->emplace(
|
||||
name, monitoring::Counter<1>::New(
|
||||
/*streamz name*/ "/tensorflow/" + name,
|
||||
/*streamz description*/
|
||||
name + " generated or consumed by the component.",
|
||||
/*streamz label name*/ "component_descriptor"));
|
||||
}
|
||||
counters_map->at(name)->GetCell(label)->IncrementBy(val);
|
||||
}
|
||||
|
||||
// StatsAggregator implementation for V1 is based on pull-based summary, no-op
|
||||
// in V2.
|
||||
void EncodeToProto(Summary* out_summary) override {}
|
||||
|
||||
Status SetSummaryWriter(
|
||||
SummaryWriterInterface* summary_writer_interface) override {
|
||||
mutex_lock l(mu_);
|
||||
if (summary_writer_interface_) {
|
||||
return errors::FailedPrecondition(
|
||||
"The SummaryWriter for a StatsAggregator may only be set once.");
|
||||
} else {
|
||||
summary_writer_interface_ = summary_writer_interface;
|
||||
summary_writer_interface_->Ref();
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void AddToEvents(const string& name, const int64 steps,
|
||||
const float scalar_value) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
if (summary_writer_interface_ == nullptr) {
|
||||
return;
|
||||
}
|
||||
std::unique_ptr<Event> e{new Event};
|
||||
e->set_step(steps);
|
||||
tensorflow::Env* env = tensorflow::Env::Default();
|
||||
e->set_wall_time(env->NowMicros() / 1.0e6);
|
||||
// maybe expose GetWallTime in SummaryWriterInterface
|
||||
Summary::Value* v = e->mutable_summary()->add_value();
|
||||
v->set_tag(name);
|
||||
v->set_simple_value(scalar_value);
|
||||
TF_CHECK_OK(summary_writer_interface_->WriteEvent(std::move(e)));
|
||||
}
|
||||
|
||||
void AddToEvents(const string& name, const int64 steps,
|
||||
const histogram::Histogram& histogram)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
if (summary_writer_interface_ == nullptr) {
|
||||
return;
|
||||
}
|
||||
std::unique_ptr<Event> e{new Event};
|
||||
e->set_step(steps);
|
||||
tensorflow::Env* env = tensorflow::Env::Default();
|
||||
e->set_wall_time(env->NowMicros() / 1.0e6);
|
||||
Summary::Value* v = e->mutable_summary()->add_value();
|
||||
v->set_tag(name);
|
||||
histogram.EncodeToProto(v->mutable_histo(), false /* Drop zero buckets */);
|
||||
TF_CHECK_OK(summary_writer_interface_->WriteEvent(std::move(e)));
|
||||
}
|
||||
|
||||
mutex mu_;
|
||||
SummaryWriterInterface* summary_writer_interface_ GUARDED_BY(mu_) = nullptr;
|
||||
// not owned, we might be associating the default summary_writer from the
|
||||
// context
|
||||
std::unordered_map<string, histogram::Histogram> histograms_ GUARDED_BY(mu_);
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(StatsAggregatorImplV2);
|
||||
};
|
||||
|
||||
class StatsAggregatorHandleOpV2
|
||||
: public ResourceOpKernel<StatsAggregatorResource> {
|
||||
public:
|
||||
explicit StatsAggregatorHandleOpV2(OpKernelConstruction* ctx)
|
||||
: ResourceOpKernel<StatsAggregatorResource>(ctx) {}
|
||||
|
||||
private:
|
||||
Status CreateResource(StatsAggregatorResource** ret) override
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
*ret =
|
||||
new StatsAggregatorResource(absl::make_unique<StatsAggregatorImplV2>());
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
@ -141,12 +268,45 @@ class StatsAggregatorSummaryOp : public OpKernel {
|
||||
}
|
||||
};
|
||||
|
||||
class StatsAggregatorSetSummaryWriterOp : public OpKernel {
|
||||
public:
|
||||
explicit StatsAggregatorSetSummaryWriterOp(OpKernelConstruction* ctx)
|
||||
: OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
const Tensor& resource_handle_t = ctx->input(0);
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
|
||||
errors::InvalidArgument("resource_handle must be a scalar"));
|
||||
|
||||
StatsAggregatorResource* resource;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
|
||||
core::ScopedUnref unref_iterator(resource);
|
||||
|
||||
const Tensor& summary_resource_handle_t = ctx->input(1);
|
||||
OP_REQUIRES(ctx,
|
||||
TensorShapeUtils::IsScalar(summary_resource_handle_t.shape()),
|
||||
errors::InvalidArgument("resource_handle must be a scalar"));
|
||||
SummaryWriterInterface* sumamry_resource;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &sumamry_resource));
|
||||
core::ScopedUnref unref_sumamry_resource(sumamry_resource);
|
||||
TF_CHECK_OK(
|
||||
resource->stats_aggregator()->SetSummaryWriter(sumamry_resource));
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("ExperimentalStatsAggregatorHandle").Device(DEVICE_CPU),
|
||||
StatsAggregatorHandleOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("StatsAggregatorHandleV2").Device(DEVICE_CPU),
|
||||
StatsAggregatorHandleOpV2);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("ExperimentalStatsAggregatorSummary").Device(DEVICE_CPU),
|
||||
StatsAggregatorSummaryOp);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("StatsAggregatorSetSummaryWriter").Device(DEVICE_CPU),
|
||||
StatsAggregatorSetSummaryWriterOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace data
|
||||
|
@ -108,8 +108,9 @@ class LatencyStatsDatasetOp : public UnaryDatasetOpKernel {
|
||||
uint64 end = ctx->env()->NowMicros();
|
||||
auto stats_aggregator = ctx->stats_aggregator();
|
||||
if (stats_aggregator && !*end_of_sequence) {
|
||||
ctx->stats_aggregator()->AddToHistogram(
|
||||
dataset()->tag_, {static_cast<double>(end - start)});
|
||||
int64 steps = num_elements();
|
||||
stats_aggregator->AddToHistogram(
|
||||
dataset()->tag_, {static_cast<double>(end - start)}, steps);
|
||||
}
|
||||
return s;
|
||||
}
|
||||
@ -220,8 +221,9 @@ class BytesProducedStatsDatasetOp : public UnaryDatasetOpKernel {
|
||||
for (const Tensor& t : *out_tensors) {
|
||||
total_bytes += t.TotalBytes();
|
||||
}
|
||||
ctx->stats_aggregator()->AddToHistogram(
|
||||
dataset()->tag_, {static_cast<double>(total_bytes)});
|
||||
int64 steps = num_elements();
|
||||
stats_aggregator->AddToHistogram(
|
||||
dataset()->tag_, {static_cast<double>(total_bytes)}, steps);
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
@ -211,10 +211,11 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
|
||||
if (stats_aggregator) {
|
||||
mutex_lock l(mu_);
|
||||
dropped_elements_++;
|
||||
int64 steps = num_elements();
|
||||
stats_aggregator->AddScalar(
|
||||
stats_utils::DroppedElementsScalarName(
|
||||
dataset()->node_name()),
|
||||
static_cast<float>((dropped_elements_)));
|
||||
static_cast<float>(dropped_elements_), steps);
|
||||
|
||||
stats_aggregator->IncrementCounter(dataset()->node_name(),
|
||||
stats_utils::kDroppedElements,
|
||||
@ -227,9 +228,10 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
|
||||
if (stats_aggregator) {
|
||||
mutex_lock l(mu_);
|
||||
filtered_elements_++;
|
||||
int64 steps = num_elements();
|
||||
stats_aggregator->AddScalar(
|
||||
stats_utils::FilterdElementsScalarName(dataset()->node_name()),
|
||||
static_cast<float>((filtered_elements_)));
|
||||
static_cast<float>(filtered_elements_), steps);
|
||||
|
||||
stats_aggregator->IncrementCounter(dataset()->node_name(),
|
||||
stats_utils::kFilteredElements,
|
||||
|
@ -139,12 +139,13 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
|
||||
mutex_lock parent_l(parent_mu_);
|
||||
mutex_lock l(mu_);
|
||||
if (stats_aggregator) {
|
||||
int64 steps = num_elements();
|
||||
stats_aggregator->AddScalar(
|
||||
stats_utils::BufferSizeScalarName(dataset()->node_name()),
|
||||
static_cast<float>(buffer_.size()));
|
||||
static_cast<float>(buffer_.size()), steps);
|
||||
stats_aggregator->AddScalar(
|
||||
stats_utils::BufferCapacityScalarName(dataset()->node_name()),
|
||||
static_cast<float>(auto_tuner_.buffer_limit()));
|
||||
static_cast<float>(auto_tuner_.buffer_limit()), steps);
|
||||
}
|
||||
return input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
|
||||
}
|
||||
@ -232,16 +233,18 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
|
||||
bool* end_of_sequence) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
const auto& stats_aggregator = ctx->stats_aggregator();
|
||||
if (stats_aggregator) {
|
||||
int64 steps = num_elements();
|
||||
stats_aggregator->AddToHistogram(
|
||||
stats_utils::BufferUtilizationHistogramName(dataset()->node_name()),
|
||||
{static_cast<float>(buffer_.size()) /
|
||||
static_cast<float>(auto_tuner_.buffer_limit())});
|
||||
static_cast<float>(auto_tuner_.buffer_limit())},
|
||||
steps);
|
||||
stats_aggregator->AddScalar(
|
||||
stats_utils::BufferSizeScalarName(dataset()->node_name()),
|
||||
static_cast<float>(buffer_.size()));
|
||||
static_cast<float>(buffer_.size()), steps);
|
||||
stats_aggregator->AddScalar(
|
||||
stats_utils::BufferCapacityScalarName(dataset()->node_name()),
|
||||
static_cast<float>(auto_tuner_.buffer_limit()));
|
||||
static_cast<float>(auto_tuner_.buffer_limit()), steps);
|
||||
}
|
||||
// A new element is available. Forward the status from computing it, and
|
||||
// (if we successfully got an element) the output values.
|
||||
|
@ -17,6 +17,11 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
REGISTER_OP("StatsAggregatorSetSummaryWriter")
|
||||
.Input("stats_aggregator: resource")
|
||||
.Input("summary: resource")
|
||||
.SetShapeFn(shape_inference::NoOutputs);
|
||||
|
||||
REGISTER_OP("ExperimentalAutoShardDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("num_workers: int64")
|
||||
@ -383,6 +388,12 @@ REGISTER_OP("ExperimentalStatsAggregatorHandle")
|
||||
.Attr("container: string = ''")
|
||||
.Attr("shared_name: string = ''");
|
||||
|
||||
REGISTER_OP("StatsAggregatorHandleV2")
|
||||
.Output("handle: resource")
|
||||
.SetShapeFn(shape_inference::ScalarShape)
|
||||
.Attr("container: string = ''")
|
||||
.Attr("shared_name: string = ''");
|
||||
|
||||
REGISTER_OP("ExperimentalStatsAggregatorSummary")
|
||||
.Input("iterator: resource")
|
||||
.Output("summary: string")
|
||||
|
@ -27,6 +27,18 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "cardinality_test",
|
||||
srcs = ["cardinality_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python/data/experimental/ops:cardinality",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "copy_to_device_test",
|
||||
size = "small",
|
||||
@ -367,18 +379,6 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "cardinality_test",
|
||||
srcs = ["cardinality_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python/data/experimental/ops:cardinality",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "override_threadpool_test",
|
||||
size = "small",
|
||||
|
@ -25,7 +25,8 @@ from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
# TODO(b/116314787): add test coverage for StatsAggregatorV2.
|
||||
@test_util.run_v1_only("b/116314787, add test coverage")
|
||||
class LatencyAllEdgesTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
|
||||
def testLatencyStatsOptimization(self):
|
||||
@ -46,11 +47,13 @@ class LatencyAllEdgesTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
num_test_iterations=1)
|
||||
summary_t = aggregator.get_summary()
|
||||
summary_str = self.evaluate(summary_t)
|
||||
self._assertSummaryHasCount(summary_str, "record_latency::TensorDataset/_1",
|
||||
1)
|
||||
self._assertSummaryHasCount(summary_str, "record_latency::MapDataset/_4", 1)
|
||||
self._assertSummaryHasCount(summary_str,
|
||||
"record_latency::PrefetchDataset/_6", 1)
|
||||
self.assertSummaryHasCount(
|
||||
summary_str, self.regexForNodeName("record_latency::TensorDataset"), 1)
|
||||
self.assertSummaryHasCount(
|
||||
summary_str, self.regexForNodeName("record_latency::MapDataset"), 1)
|
||||
self.assertSummaryHasCount(
|
||||
summary_str, self.regexForNodeName("record_latency::PrefetchDataset"),
|
||||
1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -17,7 +17,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
|
||||
@ -35,174 +34,312 @@ from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def function_set_stats_aggregator(dataset,
|
||||
aggregator,
|
||||
prefix="",
|
||||
counter_prefix=""):
|
||||
return dataset.apply(
|
||||
stats_ops.set_stats_aggregator(aggregator, prefix, counter_prefix))
|
||||
|
||||
|
||||
def function_apply_options(dataset, aggregator, prefix="", counter_prefix=""):
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_stats.aggregator = aggregator
|
||||
options.experimental_stats.prefix = prefix
|
||||
options.experimental_stats.counter_prefix = counter_prefix
|
||||
options.experimental_stats.latency_all_edges = False
|
||||
return dataset.with_options(options)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
@parameterized.named_parameters(
|
||||
("SetStatsAggregator", function_set_stats_aggregator),
|
||||
("StatsOptions", function_apply_options),
|
||||
)
|
||||
class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
|
||||
def testBytesProduced(self, dataset_transformation):
|
||||
def testBytesProduced(self):
|
||||
aggregator = stats_aggregator.StatsAggregator()
|
||||
dataset = dataset_ops.Dataset.range(100).map(
|
||||
lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).apply(
|
||||
stats_ops.bytes_produced_stats("bytes_produced"))
|
||||
dataset = dataset_transformation(dataset, aggregator)
|
||||
dataset = self.datasetExperimentalStats(dataset, aggregator)
|
||||
next_element = self.getNext(dataset, requires_initialization=True)
|
||||
|
||||
expected_sum = 0.0
|
||||
for i in range(100):
|
||||
self.assertAllEqual(
|
||||
np.array([i] * i, dtype=np.int64), self.evaluate(next_element()))
|
||||
summary_str = self.evaluate(aggregator.get_summary())
|
||||
self._assertSummaryHasCount(summary_str, "bytes_produced", float(i + 1))
|
||||
handle = self.getHandle(aggregator)
|
||||
self.assertStatisticsHasCount(handle, "bytes_produced", float(i + 1),
|
||||
i + 2)
|
||||
expected_sum += i * 8.0
|
||||
self._assertSummaryHasSum(summary_str, "bytes_produced", expected_sum)
|
||||
self.assertStatisticsHasSum(handle, "bytes_produced", expected_sum, i + 2)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element())
|
||||
summary_str = self.evaluate(aggregator.get_summary())
|
||||
self._assertSummaryHasCount(summary_str, "bytes_produced", 100.0)
|
||||
self._assertSummaryHasSum(summary_str, "bytes_produced", expected_sum)
|
||||
handle = self.getHandle(aggregator)
|
||||
self.assertStatisticsHasCount(handle, "bytes_produced", 100.0, 101)
|
||||
self.assertStatisticsHasSum(handle, "bytes_produced", expected_sum, 101)
|
||||
|
||||
def testLatencyStats(self, dataset_transformation):
|
||||
def testLatencyStats(self):
|
||||
aggregator = stats_aggregator.StatsAggregator()
|
||||
dataset = dataset_ops.Dataset.range(100).apply(
|
||||
stats_ops.latency_stats("record_latency"))
|
||||
dataset = dataset_transformation(dataset, aggregator)
|
||||
|
||||
dataset = self.datasetExperimentalStats(dataset, aggregator)
|
||||
|
||||
next_element = self.getNext(dataset, requires_initialization=True)
|
||||
|
||||
for i in range(100):
|
||||
self.assertEqual(i, self.evaluate(next_element()))
|
||||
self._assertSummaryHasCount(
|
||||
self.evaluate(aggregator.get_summary()), "record_latency",
|
||||
float(i + 1))
|
||||
handle = self.getHandle(aggregator)
|
||||
self.assertStatisticsHasCount(handle, "record_latency", float(i + 1),
|
||||
i + 2)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element())
|
||||
self._assertSummaryHasCount(
|
||||
self.evaluate(aggregator.get_summary()), "record_latency", 100.0)
|
||||
handle = self.getHandle(aggregator)
|
||||
self.assertStatisticsHasCount(handle, "record_latency", 100.0, 101)
|
||||
|
||||
def testPrefetchBufferUtilization(self, dataset_transformation):
|
||||
def testPrefetchBufferUtilization(self):
|
||||
aggregator = stats_aggregator.StatsAggregator()
|
||||
dataset = dataset_ops.Dataset.range(100).map(
|
||||
lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).prefetch(-1)
|
||||
dataset = dataset_transformation(dataset, aggregator)
|
||||
dataset = self.datasetExperimentalStats(dataset, aggregator)
|
||||
next_element = self.getNext(dataset, requires_initialization=True)
|
||||
for i in range(100):
|
||||
self.assertAllEqual(
|
||||
np.array([i] * i, dtype=np.int64), self.evaluate(next_element()))
|
||||
summary_str = self.evaluate(aggregator.get_summary())
|
||||
self._assertSummaryHasCount(
|
||||
summary_str,
|
||||
handle = self.getHandle(aggregator)
|
||||
self.assertStatisticsHasCount(
|
||||
handle,
|
||||
self.regexForNodeName("PrefetchDataset", "buffer_utilization"),
|
||||
float(i + 1))
|
||||
self._assertSummaryContains(
|
||||
summary_str,
|
||||
self.regexForNodeName("PrefetchDataset", "buffer_capacity"))
|
||||
self._assertSummaryContains(
|
||||
summary_str, self.regexForNodeName("PrefetchDataset", "buffer_size"))
|
||||
self._assertSummaryHasRange(
|
||||
summary_str,
|
||||
self.regexForNodeName("PrefetchDataset", "buffer_utilization"), 0, 1)
|
||||
float(i + 1),
|
||||
3 * i + 4,
|
||||
offset=2)
|
||||
self.assertStatisticsContains(
|
||||
handle, self.regexForNodeName("PrefetchDataset", "buffer_capacity"),
|
||||
3 * i + 4)
|
||||
self.assertStatisticsContains(
|
||||
handle,
|
||||
self.regexForNodeName("PrefetchDataset", "buffer_size"),
|
||||
3 * i + 4,
|
||||
offset=1)
|
||||
self.assertStatisticsHasRange(
|
||||
handle,
|
||||
self.regexForNodeName("PrefetchDataset", "buffer_utilization"),
|
||||
0,
|
||||
1,
|
||||
3 * i + 4,
|
||||
offset=2)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element())
|
||||
summary_str = self.evaluate(aggregator.get_summary())
|
||||
self._assertSummaryHasCount(
|
||||
summary_str,
|
||||
self.regexForNodeName("PrefetchDataset", "buffer_utilization"), 100)
|
||||
handle = self.getHandle(aggregator)
|
||||
self.assertStatisticsHasCount(
|
||||
handle,
|
||||
self.regexForNodeName("PrefetchDataset", "buffer_utilization"),
|
||||
100,
|
||||
301,
|
||||
offset=2)
|
||||
|
||||
def testPrefetchBufferScalars(self, dataset_transformation):
|
||||
def testPrefetchBufferScalars(self):
|
||||
aggregator = stats_aggregator.StatsAggregator()
|
||||
dataset = dataset_ops.Dataset.range(10).map(
|
||||
lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).prefetch(1)
|
||||
dataset = dataset_transformation(dataset, aggregator)
|
||||
dataset = self.datasetExperimentalStats(dataset, aggregator)
|
||||
next_element = self.getNext(dataset, requires_initialization=True)
|
||||
|
||||
for i in range(10):
|
||||
self.assertAllEqual(
|
||||
np.array([i] * i, dtype=np.int64), self.evaluate(next_element()))
|
||||
summary_str = self.evaluate(aggregator.get_summary())
|
||||
self._assertSummaryHasScalarValue(
|
||||
summary_str,
|
||||
self.regexForNodeName("PrefetchDataset", "buffer_capacity"), 1)
|
||||
self._assertSummaryHasScalarValue(
|
||||
summary_str, self.regexForNodeName("PrefetchDataset", "buffer_size"),
|
||||
1)
|
||||
handle = self.getHandle(aggregator)
|
||||
self.assertStatisticsHasScalarValue(
|
||||
handle, self.regexForNodeName("PrefetchDataset", "buffer_capacity"),
|
||||
1, 3 * i + 4)
|
||||
self.assertStatisticsHasScalarValue(
|
||||
handle,
|
||||
self.regexForNodeName("PrefetchDataset", "buffer_size"),
|
||||
1,
|
||||
3 * i + 4,
|
||||
offset=1)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element())
|
||||
|
||||
def testFilteredElementsStats(self, dataset_transformation):
|
||||
def testFilteredElementsStats(self):
|
||||
aggregator = stats_aggregator.StatsAggregator()
|
||||
dataset = dataset_ops.Dataset.range(101).filter(
|
||||
lambda x: math_ops.equal(math_ops.mod(x, 3), 0))
|
||||
dataset = dataset_transformation(dataset, aggregator)
|
||||
dataset = self.datasetExperimentalStats(dataset, aggregator)
|
||||
next_element = self.getNext(dataset, requires_initialization=True)
|
||||
|
||||
for i in range(34):
|
||||
self.assertEqual(i * 3, self.evaluate(next_element()))
|
||||
summary_str = self.evaluate(aggregator.get_summary())
|
||||
handle = self.getHandle(aggregator)
|
||||
if i != 0:
|
||||
self._assertSummaryHasScalarValue(
|
||||
summary_str,
|
||||
self.regexForNodeName("FilterDataset", "dropped_elements"),
|
||||
self.assertStatisticsHasScalarValue(
|
||||
handle, self.regexForNodeName("FilterDataset", "dropped_elements"),
|
||||
float(i * 2))
|
||||
self._assertSummaryHasScalarValue(
|
||||
summary_str,
|
||||
self.regexForNodeName("FilterDataset", "filtered_elements"),
|
||||
self.assertStatisticsHasScalarValue(
|
||||
handle, self.regexForNodeName("FilterDataset", "filtered_elements"),
|
||||
float(i + 1))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element())
|
||||
summary_str = self.evaluate(aggregator.get_summary())
|
||||
self._assertSummaryHasScalarValue(
|
||||
summary_str, self.regexForNodeName("FilterDataset", "dropped_elements"),
|
||||
handle = self.getHandle(aggregator)
|
||||
self.assertStatisticsHasScalarValue(
|
||||
handle, self.regexForNodeName("FilterDataset", "dropped_elements"),
|
||||
67.0)
|
||||
self._assertSummaryHasScalarValue(
|
||||
summary_str, self.regexForNodeName("FilterDataset",
|
||||
"filtered_elements"), 34.0)
|
||||
self.assertStatisticsHasScalarValue(
|
||||
handle, self.regexForNodeName("FilterDataset", "filtered_elements"),
|
||||
34.0)
|
||||
|
||||
def testMapBufferUtilization(self, dataset_transformation):
|
||||
def testReinitialize(self):
|
||||
aggregator = stats_aggregator.StatsAggregator()
|
||||
dataset = dataset_ops.Dataset.range(100).apply(
|
||||
stats_ops.latency_stats("record_latency"))
|
||||
dataset = self.datasetExperimentalStats(dataset, aggregator)
|
||||
|
||||
for j in range(5):
|
||||
next_element = self.getNext(dataset, requires_initialization=True)
|
||||
for i in range(100):
|
||||
self.assertEqual(i, self.evaluate(next_element()))
|
||||
handle = self.getHandle(aggregator)
|
||||
self.assertStatisticsHasCount(handle, "record_latency",
|
||||
float((j * 100) + i + 1),
|
||||
(j * 100) + i + 2)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element())
|
||||
handle = self.getHandle(aggregator)
|
||||
self.assertStatisticsHasCount(handle, "record_latency", (j + 1) * 100.0,
|
||||
(j * 100) + 101)
|
||||
|
||||
def testNoAggregatorRegistered(self):
|
||||
dataset = dataset_ops.Dataset.range(100).apply(
|
||||
stats_ops.latency_stats("record_latency"))
|
||||
|
||||
next_element = self.getNext(dataset, requires_initialization=True)
|
||||
|
||||
for i in range(100):
|
||||
self.assertEqual(i, self.evaluate(next_element()))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element())
|
||||
|
||||
def testMultipleTags(self):
|
||||
aggregator = stats_aggregator.StatsAggregator()
|
||||
dataset = dataset_ops.Dataset.range(100).apply(
|
||||
stats_ops.latency_stats("record_latency")).apply(
|
||||
stats_ops.latency_stats("record_latency_2"))
|
||||
dataset = self.datasetExperimentalStats(dataset, aggregator)
|
||||
|
||||
next_element = self.getNext(dataset, requires_initialization=True)
|
||||
|
||||
for i in range(100):
|
||||
self.assertEqual(i, self.evaluate(next_element()))
|
||||
handle = self.getHandle(aggregator)
|
||||
self.assertStatisticsHasCount(
|
||||
handle, "record_latency", float(i + 1), 2 * i + 3, offset=1)
|
||||
self.assertStatisticsHasCount(handle, "record_latency_2", float(i + 1),
|
||||
2 * i + 3)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element())
|
||||
handle = self.getHandle(aggregator)
|
||||
self.assertStatisticsHasCount(
|
||||
handle, "record_latency", 100.0, 201, offset=1)
|
||||
self.assertStatisticsHasCount(handle, "record_latency_2", 100.0, 201)
|
||||
|
||||
def testRepeatedTags(self):
|
||||
aggregator = stats_aggregator.StatsAggregator()
|
||||
dataset = dataset_ops.Dataset.range(100).apply(
|
||||
stats_ops.latency_stats("record_latency")).apply(
|
||||
stats_ops.latency_stats("record_latency"))
|
||||
dataset = self.datasetExperimentalStats(dataset, aggregator)
|
||||
next_element = self.getNext(dataset, requires_initialization=True)
|
||||
|
||||
for i in range(100):
|
||||
self.assertEqual(i, self.evaluate(next_element()))
|
||||
handle = self.getHandle(aggregator)
|
||||
self.assertStatisticsHasCount(handle, "record_latency",
|
||||
float(2 * (i + 1)), 2 * i + 3)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element())
|
||||
handle = self.getHandle(aggregator)
|
||||
self.assertStatisticsHasCount(handle, "record_latency", 200.0, 201)
|
||||
|
||||
def testMultipleIteratorsSameAggregator(self):
|
||||
aggregator = stats_aggregator.StatsAggregator()
|
||||
dataset = dataset_ops.Dataset.range(100).apply(
|
||||
stats_ops.latency_stats("record_latency"))
|
||||
dataset = self.datasetExperimentalStats(dataset, aggregator)
|
||||
next_element1 = self.getNext(dataset, requires_initialization=True)
|
||||
next_element2 = self.getNext(dataset, requires_initialization=True)
|
||||
|
||||
for i in range(100):
|
||||
self.assertEqual(i * 2, self.evaluate(next_element1() + next_element2()))
|
||||
handle = self.getHandle(aggregator)
|
||||
self.assertStatisticsHasCount(handle, "record_latency",
|
||||
float(2 * (i + 1)), 2 * i + 3)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element1())
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element2())
|
||||
handle = self.getHandle(aggregator)
|
||||
self.assertStatisticsHasCount(handle, "record_latency", 200.0, 201)
|
||||
|
||||
def testMultipleDatasetWithPrefixes(self):
|
||||
aggregator = stats_aggregator.StatsAggregator()
|
||||
dataset = dataset_ops.Dataset.range(100).apply(
|
||||
stats_ops.latency_stats("record_latency"))
|
||||
dataset = self.datasetExperimentalStats(
|
||||
dataset, aggregator, prefix="dataset1")
|
||||
dataset2 = dataset_ops.Dataset.range(100).apply(
|
||||
stats_ops.latency_stats("record_latency"))
|
||||
dataset2 = self.datasetExperimentalStats(
|
||||
dataset2, aggregator, prefix="dataset2")
|
||||
next_element1 = self.getNext(dataset, requires_initialization=True)
|
||||
next_element2 = self.getNext(dataset2, requires_initialization=True)
|
||||
|
||||
for i in range(100):
|
||||
self.assertEqual(i * 2, self.evaluate(next_element1() + next_element2()))
|
||||
handle = self.getHandle(aggregator)
|
||||
self.assertStatisticsHasCount(
|
||||
handle, "dataset1::record_latency", float(i + 1), 2 * i + 3, offset=1)
|
||||
self.assertStatisticsHasCount(handle, "dataset2::record_latency",
|
||||
float(i + 1), 2 * i + 3)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element1())
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element2())
|
||||
handle = self.getHandle(aggregator)
|
||||
self.assertStatisticsHasCount(
|
||||
handle, "dataset1::record_latency", 100.0, 201, offset=1)
|
||||
self.assertStatisticsHasCount(handle, "dataset2::record_latency", 100.0,
|
||||
201)
|
||||
|
||||
def testMultiplePrefetchStats(self):
|
||||
|
||||
aggregator = stats_aggregator.StatsAggregator()
|
||||
dataset = dataset_ops.Dataset.range(10).prefetch(
|
||||
2).filter(lambda x: math_ops.equal(math_ops.mod(x, 2), 0)).prefetch(1)
|
||||
|
||||
dataset = self.datasetExperimentalStats(dataset, aggregator)
|
||||
next_element = self.getNext(dataset, requires_initialization=True)
|
||||
|
||||
for i in range(5):
|
||||
self.assertEqual(i * 2, self.evaluate(next_element()))
|
||||
handle = self.getHandle(aggregator)
|
||||
# TODO(shivaniagarwal): using exact name of prefetch node than the regex,
|
||||
# to differentiate between two prefetch. This might break in future, at
|
||||
# which point, it would be best to disable this test.
|
||||
self.assertStatisticsHasScalarValue(
|
||||
handle, "PrefetchDataset/_5::buffer_capacity", 2)
|
||||
self.assertStatisticsContains(handle, "PrefetchDataset/_5::buffer_size")
|
||||
self.assertStatisticsHasScalarValue(
|
||||
handle, "PrefetchDataset/_8::buffer_capacity", 1)
|
||||
self.assertStatisticsContains(handle, "PrefetchDataset/_8::buffer_size")
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element())
|
||||
|
||||
|
||||
@test_util.run_v1_only("b/116314787, add test coverage")
|
||||
class ThreadUtilizationStatsTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
|
||||
def testMapBufferUtilization(self):
|
||||
|
||||
def dataset_fn():
|
||||
return dataset_ops.Dataset.range(10).map(
|
||||
lambda x: array_ops.tile([x], ops.convert_to_tensor([x])),
|
||||
num_parallel_calls=4)
|
||||
|
||||
self._testParallelCallsStats(
|
||||
dataset_fn, {self.regexForNodeName("ParallelMapDataset")},
|
||||
10,
|
||||
dataset_transformation,
|
||||
function_processing_time=True)
|
||||
self.parallelCallsStats(
|
||||
dataset_fn, {"ParallelMapDataset"}, 10, function_processing_time=True)
|
||||
|
||||
def testMapAutoTuneBufferUtilization(self, dataset_transformation):
|
||||
def testMapAutoTuneBufferUtilization(self):
|
||||
|
||||
def dataset_fn():
|
||||
return dataset_ops.Dataset.range(10).map(
|
||||
lambda x: array_ops.tile([x], ops.convert_to_tensor([x])),
|
||||
num_parallel_calls=optimization.AUTOTUNE)
|
||||
|
||||
self._testParallelCallsStats(
|
||||
dataset_fn, {self.regexForNodeName("ParallelMapDataset")},
|
||||
10,
|
||||
dataset_transformation,
|
||||
function_processing_time=True)
|
||||
self.parallelCallsStats(
|
||||
dataset_fn, {"ParallelMapDataset"}, 10, function_processing_time=True)
|
||||
|
||||
def testInterleaveAutoTuneBufferUtilization(self, dataset_transformation):
|
||||
def testInterleaveAutoTuneBufferUtilization(self):
|
||||
|
||||
def dataset_fn():
|
||||
|
||||
@ -215,11 +352,9 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
cycle_length=1,
|
||||
num_parallel_calls=optimization.AUTOTUNE)
|
||||
|
||||
self._testParallelCallsStats(
|
||||
dataset_fn, {self.regexForNodeName("ParallelInterleaveDatasetV2")}, 10,
|
||||
dataset_transformation)
|
||||
self.parallelCallsStats(dataset_fn, {"ParallelInterleaveDatasetV2"}, 10)
|
||||
|
||||
def testMapAndBatchAutoTuneBufferUtilization(self, dataset_transformation):
|
||||
def testMapAndBatchAutoTuneBufferUtilization(self):
|
||||
|
||||
def dataset_fn():
|
||||
return dataset_ops.Dataset.range(100).apply(
|
||||
@ -229,172 +364,19 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
batch_size=16))
|
||||
|
||||
num_output = 100 // 16 + 1
|
||||
self._testParallelCallsStats(
|
||||
dataset_fn, {self.regexForNodeName("ExperimentalMapAndBatchDataset")},
|
||||
self.parallelCallsStats(
|
||||
dataset_fn, {"ExperimentalMapAndBatchDataset"},
|
||||
num_output,
|
||||
dataset_transformation,
|
||||
check_elements=False,
|
||||
function_processing_time=True)
|
||||
|
||||
def testReinitialize(self, dataset_transformation):
|
||||
aggregator = stats_aggregator.StatsAggregator()
|
||||
dataset = dataset_ops.Dataset.range(100).apply(
|
||||
stats_ops.latency_stats("record_latency"))
|
||||
dataset = dataset_transformation(dataset, aggregator)
|
||||
|
||||
for j in range(5):
|
||||
next_element = self.getNext(dataset, requires_initialization=True)
|
||||
for i in range(100):
|
||||
self.assertEqual(i, self.evaluate(next_element()))
|
||||
self._assertSummaryHasCount(
|
||||
self.evaluate(aggregator.get_summary()), "record_latency",
|
||||
float((j * 100) + i + 1))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element())
|
||||
self._assertSummaryHasCount(
|
||||
self.evaluate(aggregator.get_summary()), "record_latency",
|
||||
(j + 1) * 100.0)
|
||||
|
||||
def testNoAggregatorRegistered(self, dataset_transformation):
|
||||
dataset = dataset_ops.Dataset.range(100).apply(
|
||||
stats_ops.latency_stats("record_latency"))
|
||||
|
||||
next_element = self.getNext(dataset, requires_initialization=True)
|
||||
|
||||
for i in range(100):
|
||||
self.assertEqual(i, self.evaluate(next_element()))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element())
|
||||
|
||||
def testMultipleTags(self, dataset_transformation):
|
||||
aggregator = stats_aggregator.StatsAggregator()
|
||||
dataset = dataset_ops.Dataset.range(100).apply(
|
||||
stats_ops.latency_stats("record_latency")).apply(
|
||||
stats_ops.latency_stats("record_latency_2"))
|
||||
dataset = dataset_transformation(dataset, aggregator)
|
||||
|
||||
next_element = self.getNext(dataset, requires_initialization=True)
|
||||
|
||||
for i in range(100):
|
||||
self.assertEqual(i, self.evaluate(next_element()))
|
||||
self._assertSummaryHasCount(
|
||||
self.evaluate(aggregator.get_summary()), "record_latency",
|
||||
float(i + 1))
|
||||
self._assertSummaryHasCount(
|
||||
self.evaluate(aggregator.get_summary()), "record_latency_2",
|
||||
float(i + 1))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element())
|
||||
self._assertSummaryHasCount(
|
||||
self.evaluate(aggregator.get_summary()), "record_latency", 100.0)
|
||||
self._assertSummaryHasCount(
|
||||
self.evaluate(aggregator.get_summary()), "record_latency_2", 100.0)
|
||||
|
||||
def testRepeatedTags(self, dataset_transformation):
|
||||
aggregator = stats_aggregator.StatsAggregator()
|
||||
dataset = dataset_ops.Dataset.range(100).apply(
|
||||
stats_ops.latency_stats("record_latency")).apply(
|
||||
stats_ops.latency_stats("record_latency"))
|
||||
dataset = dataset_transformation(dataset, aggregator)
|
||||
next_element = self.getNext(dataset, requires_initialization=True)
|
||||
|
||||
for i in range(100):
|
||||
self.assertEqual(i, self.evaluate(next_element()))
|
||||
self._assertSummaryHasCount(
|
||||
self.evaluate(aggregator.get_summary()), "record_latency",
|
||||
float(2 * (i + 1)))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element())
|
||||
self._assertSummaryHasCount(
|
||||
self.evaluate(aggregator.get_summary()), "record_latency", 200.0)
|
||||
|
||||
def testMultipleIteratorsSameAggregator(self, dataset_transformation):
|
||||
aggregator = stats_aggregator.StatsAggregator()
|
||||
dataset = dataset_ops.Dataset.range(100).apply(
|
||||
stats_ops.latency_stats("record_latency"))
|
||||
dataset = dataset_transformation(dataset, aggregator)
|
||||
next_element1 = self.getNext(dataset, requires_initialization=True)
|
||||
next_element2 = self.getNext(dataset, requires_initialization=True)
|
||||
|
||||
for i in range(100):
|
||||
self.assertEqual(i * 2, self.evaluate(next_element1() + next_element2()))
|
||||
self._assertSummaryHasCount(
|
||||
self.evaluate(aggregator.get_summary()), "record_latency",
|
||||
float(2 * (i + 1)))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element1())
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element2())
|
||||
self._assertSummaryHasCount(
|
||||
self.evaluate(aggregator.get_summary()), "record_latency", 200.0)
|
||||
|
||||
def testMultipleDatasetWithPrefixes(self, dataset_transformation):
|
||||
aggregator = stats_aggregator.StatsAggregator()
|
||||
dataset = dataset_ops.Dataset.range(100).apply(
|
||||
stats_ops.latency_stats("record_latency"))
|
||||
dataset = dataset_transformation(dataset, aggregator, prefix="dataset1")
|
||||
dataset2 = dataset_ops.Dataset.range(100).apply(
|
||||
stats_ops.latency_stats("record_latency"))
|
||||
dataset2 = dataset_transformation(dataset2, aggregator, prefix="dataset2")
|
||||
next_element1 = self.getNext(dataset, requires_initialization=True)
|
||||
next_element2 = self.getNext(dataset2, requires_initialization=True)
|
||||
|
||||
for i in range(100):
|
||||
self.assertEqual(i * 2, self.evaluate(next_element1() + next_element2()))
|
||||
self._assertSummaryHasCount(
|
||||
self.evaluate(aggregator.get_summary()), "dataset1::record_latency",
|
||||
float(i + 1))
|
||||
self._assertSummaryHasCount(
|
||||
self.evaluate(aggregator.get_summary()), "dataset2::record_latency",
|
||||
float(i + 1))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element1())
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element2())
|
||||
self._assertSummaryHasCount(
|
||||
self.evaluate(aggregator.get_summary()), "dataset1::record_latency",
|
||||
100.0)
|
||||
self._assertSummaryHasCount(
|
||||
self.evaluate(aggregator.get_summary()), "dataset2::record_latency",
|
||||
100.0)
|
||||
|
||||
def testMultiplePrefetchStats(self, dataset_transformation):
|
||||
|
||||
aggregator = stats_aggregator.StatsAggregator()
|
||||
dataset = dataset_ops.Dataset.range(10).prefetch(
|
||||
2).map(lambda x: math_ops.add(x, 2)).prefetch(1)
|
||||
|
||||
dataset = dataset_transformation(dataset, aggregator)
|
||||
next_element = self.getNext(dataset, requires_initialization=True)
|
||||
|
||||
for i in range(10):
|
||||
self.assertEqual(i + 2, self.evaluate(next_element()))
|
||||
summary_str = self.evaluate(aggregator.get_summary())
|
||||
# TODO(shivaniagarwal): using exact name of prefetch node than the regex,
|
||||
# to differentiate between two prefetch. This might break in future, at
|
||||
# which point, it would be best to disable this test.
|
||||
self._assertSummaryHasScalarValue(
|
||||
summary_str, "PrefetchDataset/_5::buffer_capacity", 2)
|
||||
self._assertSummaryContains(summary_str,
|
||||
"PrefetchDataset/_5::buffer_size")
|
||||
self._assertSummaryHasScalarValue(
|
||||
summary_str, "PrefetchDataset/_8::buffer_capacity", 1)
|
||||
self._assertSummaryContains(summary_str,
|
||||
"PrefetchDataset/_8::buffer_size")
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element())
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
@parameterized.named_parameters(
|
||||
("SetStatsAggregator", function_set_stats_aggregator),
|
||||
("StatsOptions", function_apply_options)
|
||||
)
|
||||
@test_util.run_v1_only("b/116314787, add test coverage")
|
||||
class FeatureStatsDatasetTest(
|
||||
stats_dataset_test_base.StatsDatasetTestBase,
|
||||
reader_dataset_ops_test_base.MakeBatchedFeaturesDatasetTestBase):
|
||||
|
||||
def testFeaturesStats(self, dataset_transformation):
|
||||
def testFeaturesStats(self):
|
||||
num_epochs = 5
|
||||
total_records = num_epochs * self._num_records
|
||||
batch_size = 2
|
||||
@ -413,13 +395,12 @@ class FeatureStatsDatasetTest(
|
||||
if total_records % batch_size:
|
||||
num_output = total_records // batch_size + 1
|
||||
|
||||
self._testParallelCallsStats(
|
||||
dataset_fn, {self.regexForNodeName("ExperimentalParseExampleDataset")},
|
||||
self.parallelCallsStats(
|
||||
dataset_fn, {"ExperimentalParseExampleDataset"},
|
||||
num_output,
|
||||
dataset_transformation,
|
||||
check_elements=False)
|
||||
|
||||
dataset = dataset_transformation(
|
||||
dataset = self.datasetExperimentalStats(
|
||||
dataset_fn(), aggregator, prefix="record_stats")
|
||||
|
||||
next_element = self.getNext(dataset, requires_initialization=True)
|
||||
@ -429,20 +410,21 @@ class FeatureStatsDatasetTest(
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element())
|
||||
self._assertSummaryHasCount(
|
||||
self.evaluate(aggregator.get_summary()),
|
||||
handle = self.getHandle(aggregator)
|
||||
self.assertSummaryHasCount(
|
||||
handle,
|
||||
self.regexForNodeName("record_stats::ExperimentalParseExampleDataset",
|
||||
"features_count"), total_records)
|
||||
self._assertSummaryHasCount(
|
||||
self.evaluate(aggregator.get_summary()),
|
||||
self.assertSummaryHasCount(
|
||||
handle,
|
||||
self.regexForNodeName("record_stats::ExperimentalParseExampleDataset",
|
||||
"feature_values_count"), total_records)
|
||||
self._assertSummaryHasSum(
|
||||
self.evaluate(aggregator.get_summary()),
|
||||
self.assertSummaryHasSum(
|
||||
handle,
|
||||
self.regexForNodeName("record_stats::ExperimentalParseExampleDataset",
|
||||
"features_count"), total_records * 4)
|
||||
self._assertSummaryHasSum(
|
||||
self.evaluate(aggregator.get_summary()),
|
||||
self.assertSummaryHasSum(
|
||||
handle,
|
||||
self.regexForNodeName("record_stats::ExperimentalParseExampleDataset",
|
||||
"feature_values_count"),
|
||||
self._sum_keywords(1) * num_epochs + 3 * total_records)
|
||||
|
@ -17,22 +17,106 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import re
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.core.framework import summary_pb2
|
||||
from tensorflow.core.util import event_pb2
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.data.experimental.ops import stats_aggregator
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.lib.io import tf_record
|
||||
from tensorflow.python.platform import gfile
|
||||
|
||||
|
||||
class StatsDatasetTestBase(test_base.DatasetTestBase):
|
||||
"""Base class for testing statistics gathered in `StatsAggregator`."""
|
||||
|
||||
def regexForNodeName(self, op_name, stats_type=""):
|
||||
return "".join([op_name, r"/_\d+::", stats_type])
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if tf2.enabled():
|
||||
stats_aggregator._DEFAULT_MAX_QUEUE = 0 # pylint: disable=protected-access
|
||||
stats_aggregator.StatsAggregator = stats_aggregator.StatsAggregatorV2
|
||||
# TODO(b/116314787): add graph mode support for StatsAggregatorV2.
|
||||
else:
|
||||
stats_aggregator.StatsAggregator = stats_aggregator.StatsAggregatorV1
|
||||
return test_util.run_all_in_graph_and_eager_modes(cls)
|
||||
|
||||
def _assertSummaryContains(self, summary_str, tag):
|
||||
def datasetExperimentalStats(self,
|
||||
dataset,
|
||||
aggregator,
|
||||
prefix="",
|
||||
counter_prefix=""):
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_stats.aggregator = aggregator
|
||||
options.experimental_stats.prefix = prefix
|
||||
options.experimental_stats.counter_prefix = counter_prefix
|
||||
options.experimental_stats.latency_all_edges = False
|
||||
return dataset.with_options(options)
|
||||
|
||||
def regexForNodeName(self, op_name, stats_type=""):
|
||||
if stats_type:
|
||||
return "".join([op_name, r"/_\d+::", stats_type])
|
||||
return "".join([op_name, r"/_\d+"])
|
||||
|
||||
def assertStatisticsContains(self, handle, tag, num_events=-1, offset=0):
|
||||
if tf2.enabled():
|
||||
self.assertEventContains(handle, tag, num_events, offset)
|
||||
else:
|
||||
self.assertSummaryContains(handle, tag)
|
||||
|
||||
def assertStatisticsHasCount(self,
|
||||
handle,
|
||||
tag,
|
||||
count,
|
||||
num_events=-1,
|
||||
offset=0):
|
||||
if tf2.enabled():
|
||||
self.assertEventHasCount(handle, tag, count, num_events, offset)
|
||||
else:
|
||||
self.assertSummaryHasCount(handle, tag, count)
|
||||
|
||||
def assertStatisticsHasSum(self,
|
||||
handle,
|
||||
tag,
|
||||
expected_value,
|
||||
num_events=-1,
|
||||
offset=0):
|
||||
if tf2.enabled():
|
||||
self.assertEventHasSum(handle, tag, expected_value, num_events, offset)
|
||||
else:
|
||||
self.assertSummaryHasSum(handle, tag, expected_value)
|
||||
|
||||
def assertStatisticsHasScalarValue(self,
|
||||
handle,
|
||||
tag,
|
||||
expected_value,
|
||||
num_events=-1,
|
||||
offset=0):
|
||||
if tf2.enabled():
|
||||
self.assertEventHasScalarValue(handle, tag, expected_value, num_events,
|
||||
offset)
|
||||
else:
|
||||
self.assertSummaryHasScalarValue(handle, tag, expected_value)
|
||||
|
||||
def assertStatisticsHasRange(self,
|
||||
handle,
|
||||
tag,
|
||||
min_value,
|
||||
max_value,
|
||||
num_events=-1,
|
||||
offset=0):
|
||||
if tf2.enabled():
|
||||
self.assertEventHasRange(handle, tag, min_value, max_value, num_events,
|
||||
offset)
|
||||
else:
|
||||
self.assertSummaryHasRange(handle, tag, min_value, max_value)
|
||||
|
||||
def assertSummaryContains(self, summary_str, tag):
|
||||
summary_proto = summary_pb2.Summary()
|
||||
summary_proto.ParseFromString(summary_str)
|
||||
for value in summary_proto.value:
|
||||
@ -40,11 +124,11 @@ class StatsDatasetTestBase(test_base.DatasetTestBase):
|
||||
return
|
||||
self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
|
||||
|
||||
def _assertSummaryHasCount(self,
|
||||
summary_str,
|
||||
tag,
|
||||
expected_value,
|
||||
greater_than=False):
|
||||
def assertSummaryHasCount(self,
|
||||
summary_str,
|
||||
tag,
|
||||
expected_value,
|
||||
greater_than=False):
|
||||
summary_proto = summary_pb2.Summary()
|
||||
summary_proto.ParseFromString(summary_str)
|
||||
for value in summary_proto.value:
|
||||
@ -56,7 +140,7 @@ class StatsDatasetTestBase(test_base.DatasetTestBase):
|
||||
return
|
||||
self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
|
||||
|
||||
def _assertSummaryHasRange(self, summary_str, tag, min_value, max_value):
|
||||
def assertSummaryHasRange(self, summary_str, tag, min_value, max_value):
|
||||
summary_proto = summary_pb2.Summary()
|
||||
summary_proto.ParseFromString(summary_str)
|
||||
for value in summary_proto.value:
|
||||
@ -66,7 +150,7 @@ class StatsDatasetTestBase(test_base.DatasetTestBase):
|
||||
return
|
||||
self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
|
||||
|
||||
def _assertSummaryHasSum(self, summary_str, tag, expected_value):
|
||||
def assertSummaryHasSum(self, summary_str, tag, expected_value):
|
||||
summary_proto = summary_pb2.Summary()
|
||||
summary_proto.ParseFromString(summary_str)
|
||||
for value in summary_proto.value:
|
||||
@ -75,7 +159,7 @@ class StatsDatasetTestBase(test_base.DatasetTestBase):
|
||||
return
|
||||
self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
|
||||
|
||||
def _assertSummaryHasScalarValue(self, summary_str, tag, expected_value):
|
||||
def assertSummaryHasScalarValue(self, summary_str, tag, expected_value):
|
||||
summary_proto = summary_pb2.Summary()
|
||||
summary_proto.ParseFromString(summary_str)
|
||||
for value in summary_proto.value:
|
||||
@ -84,39 +168,165 @@ class StatsDatasetTestBase(test_base.DatasetTestBase):
|
||||
return
|
||||
self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
|
||||
|
||||
def _testParallelCallsStats(self,
|
||||
dataset_fn,
|
||||
dataset_names,
|
||||
num_output,
|
||||
dataset_transformation,
|
||||
function_processing_time=False,
|
||||
check_elements=True):
|
||||
def assertEventContains(self, logdir, tag, num_events, offset):
|
||||
events = _events_from_logdir(logdir)
|
||||
if num_events == -1:
|
||||
self.assertGreater(len(events), 1)
|
||||
for event in events[::-1]:
|
||||
if re.match(tag, event.summary.value[0].tag):
|
||||
return
|
||||
self.fail("Expected tag %r not found in event file in %r" % (tag, logdir))
|
||||
else:
|
||||
self.assertEqual(len(events), num_events)
|
||||
self.assertTrue(
|
||||
re.match(tag, events[num_events - offset - 1].summary.value[0].tag))
|
||||
|
||||
def assertEventHasCount(self, logdir, tag, count, num_events, offset):
|
||||
events = _events_from_logdir(logdir)
|
||||
if num_events == -1:
|
||||
self.assertGreater(len(events), 1)
|
||||
for event in events[::-1]:
|
||||
if re.match(tag, event.summary.value[0].tag):
|
||||
self.assertEqual(count, event.summary.value[0].histo.num)
|
||||
return
|
||||
self.fail("Expected tag %r not found in event file in %r" % (tag, logdir))
|
||||
else:
|
||||
self.assertEqual(len(events), num_events)
|
||||
self.assertTrue(
|
||||
re.match(tag, events[num_events - offset - 1].summary.value[0].tag))
|
||||
self.assertEqual(
|
||||
events[num_events - offset - 1].summary.value[0].histo.num, count)
|
||||
|
||||
def assertEventHasSum(self, logdir, tag, expected_value, num_events, offset):
|
||||
events = _events_from_logdir(logdir)
|
||||
if num_events == -1:
|
||||
self.assertGreater(len(events), 1)
|
||||
for event in events[::-1]:
|
||||
if re.match(tag, event.summary.value[0].tag):
|
||||
self.assertEqual(expected_value, event.summary.value[0].histo.sum)
|
||||
return
|
||||
self.fail("Expected tag %r not found in event file in %r" % (tag, logdir))
|
||||
else:
|
||||
self.assertEqual(len(events), num_events)
|
||||
self.assertTrue(
|
||||
re.match(tag, events[num_events - offset - 1].summary.value[0].tag))
|
||||
self.assertEqual(
|
||||
events[num_events - offset - 1].summary.value[0].histo.sum,
|
||||
expected_value)
|
||||
|
||||
def assertEventHasRange(self, logdir, tag, min_value, max_value, num_events,
|
||||
offset):
|
||||
events = _events_from_logdir(logdir)
|
||||
if num_events == -1:
|
||||
self.assertGreater(len(events), 1)
|
||||
for event in events[::-1]:
|
||||
if re.match(tag, event.summary.value[0].tag):
|
||||
self.assertLessEqual(min_value, event.summary.value[0].histo.min)
|
||||
self.assertGreaterEqual(max_value, event.summary.value[0].histo.max)
|
||||
return
|
||||
self.fail("Expected tag %r not found in event file in %r" % (tag, logdir))
|
||||
else:
|
||||
self.assertEqual(len(events), num_events)
|
||||
self.assertTrue(
|
||||
re.match(tag, events[num_events - offset - 1].summary.value[0].tag))
|
||||
self.assertLessEqual(
|
||||
min_value, events[num_events - offset - 1].summary.value[0].histo.min)
|
||||
self.assertGreaterEqual(
|
||||
max_value, events[num_events - offset - 1].summary.value[0].histo.max)
|
||||
|
||||
def assertEventHasScalarValue(self, logdir, tag, expected_value, num_events,
|
||||
offset):
|
||||
events = _events_from_logdir(logdir)
|
||||
if num_events == -1:
|
||||
self.assertGreater(len(events), 1)
|
||||
for event in events[::-1]:
|
||||
if re.match(tag, event.summary.value[0].tag):
|
||||
self.assertEqual(expected_value, event.summary.value[0].simple_value)
|
||||
return
|
||||
self.fail("Expected tag %r not found in event file in %r" % (tag, logdir))
|
||||
else:
|
||||
self.assertEqual(len(events), num_events)
|
||||
self.assertTrue(
|
||||
re.match(tag, events[num_events - offset - 1].summary.value[0].tag))
|
||||
self.assertLessEqual(
|
||||
expected_value,
|
||||
events[num_events - offset - 1].summary.value[0].simple_value)
|
||||
|
||||
def getHandle(self, aggregator):
|
||||
# pylint: disable=protected-access
|
||||
if isinstance(aggregator, stats_aggregator.StatsAggregatorV1):
|
||||
return self.evaluate(aggregator.get_summary())
|
||||
assert isinstance(aggregator, (stats_aggregator.StatsAggregatorV2))
|
||||
return aggregator._logdir
|
||||
|
||||
def parallelCallsStats(self,
|
||||
dataset_fn,
|
||||
dataset_names,
|
||||
num_output,
|
||||
function_processing_time=False,
|
||||
check_elements=True):
|
||||
aggregator = stats_aggregator.StatsAggregator()
|
||||
dataset = dataset_fn()
|
||||
dataset = dataset_transformation(dataset, aggregator)
|
||||
dataset = self.datasetExperimentalStats(dataset, aggregator)
|
||||
next_element = self.getNext(dataset, requires_initialization=True)
|
||||
|
||||
for i in range(num_output):
|
||||
next_ = self.evaluate(next_element())
|
||||
if check_elements:
|
||||
self.assertAllEqual(np.array([i] * i, dtype=np.int64), next_)
|
||||
summary_str = self.evaluate(aggregator.get_summary())
|
||||
handle = self.getHandle(aggregator)
|
||||
for dataset_name in dataset_names:
|
||||
if function_processing_time:
|
||||
self._assertSummaryHasCount(
|
||||
summary_str,
|
||||
r"(.*)::execution_time$",
|
||||
float(i + 1),
|
||||
greater_than=True)
|
||||
self._assertSummaryContains(summary_str,
|
||||
dataset_name + "thread_utilization")
|
||||
self.assertSummaryHasCount(
|
||||
handle, r"(.*)::execution_time$", float(i + 1), greater_than=True)
|
||||
self.assertSummaryContains(
|
||||
handle, self.regexForNodeName(dataset_name, "thread_utilization"))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element())
|
||||
if function_processing_time:
|
||||
summary_str = self.evaluate(aggregator.get_summary())
|
||||
if isinstance(aggregator, stats_aggregator.StatsAggregatorV1):
|
||||
handle = self.getHandle(aggregator)
|
||||
for dataset_name in dataset_names:
|
||||
self._assertSummaryHasCount(
|
||||
summary_str,
|
||||
self.assertSummaryHasCount(
|
||||
handle,
|
||||
r"(.*)::execution_time$",
|
||||
float(num_output),
|
||||
greater_than=True)
|
||||
|
||||
|
||||
# Adding these two methods from summary_test_util, as summary_test_util is in
|
||||
# contrib.
|
||||
def _events_from_file(filepath):
|
||||
"""Returns all events in a single event file.
|
||||
|
||||
Args:
|
||||
filepath: Path to the event file.
|
||||
|
||||
Returns:
|
||||
A list of all tf.Event protos in the event file.
|
||||
"""
|
||||
records = list(tf_record.tf_record_iterator(filepath))
|
||||
result = []
|
||||
for r in records:
|
||||
event = event_pb2.Event()
|
||||
event.ParseFromString(r)
|
||||
result.append(event)
|
||||
return result
|
||||
|
||||
|
||||
def _events_from_logdir(logdir):
|
||||
"""Returns all events in the single eventfile in logdir.
|
||||
|
||||
Args:
|
||||
logdir: The directory in which the single event file is sought.
|
||||
|
||||
Returns:
|
||||
A list of all tf.Event protos from the single event file.
|
||||
|
||||
Raises:
|
||||
AssertionError: If logdir does not contain exactly one file.
|
||||
"""
|
||||
assert gfile.Exists(logdir)
|
||||
files = gfile.ListDirectory(logdir)
|
||||
assert len(files) == 1, "Found not exactly one file in logdir: %s" % files
|
||||
return _events_from_file(os.path.join(logdir, files[0]))
|
||||
|
@ -17,12 +17,71 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tempfile
|
||||
|
||||
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
|
||||
from tensorflow.python.ops import summary_ops_v2
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
@tf_export("data.experimental.StatsAggregator")
|
||||
class StatsAggregator(object):
|
||||
_DEFAULT_MAX_QUEUE = 10
|
||||
|
||||
|
||||
@tf_export("data.experimental.StatsAggregator", v1=[])
|
||||
class StatsAggregatorV2(object):
|
||||
"""A stateful resource that aggregates statistics from one or more iterators.
|
||||
|
||||
To record statistics, use one of the custom transformation functions defined
|
||||
in this module when defining your `tf.data.Dataset`. All statistics will be
|
||||
aggregated by the `StatsAggregator` that is associated with a particular
|
||||
iterator (see below). For example, to record the latency of producing each
|
||||
element by iterating over a dataset:
|
||||
|
||||
```python
|
||||
dataset = ...
|
||||
dataset = dataset.apply(tf.data.experimental.latency_stats("total_bytes"))
|
||||
```
|
||||
|
||||
To associate a `StatsAggregator` with a `tf.data.Dataset` object, use
|
||||
the following pattern:
|
||||
|
||||
```python
|
||||
aggregator = tf.data.experimental.StatsAggregator()
|
||||
dataset = ...
|
||||
|
||||
# Apply `StatsOptions` to associate `dataset` with `aggregator`.
|
||||
options = tf.data.Options()
|
||||
options.experimental_stats.aggregator = aggregator
|
||||
dataset = dataset.with_options(options)
|
||||
```
|
||||
|
||||
Note: This interface is experimental and expected to change. In particular,
|
||||
we expect to add other implementations of `StatsAggregator` that provide
|
||||
different ways of exporting statistics, and add more types of statistics.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._resource = ged_ops.stats_aggregator_handle_v2()
|
||||
|
||||
# There could be a conflict with multiple file writer in the same logdir,
|
||||
# (b/37351340). Possible workarounds till this bug is resolved are a) having
|
||||
# multiple dataset stats specific file inside log_dir and b) get default
|
||||
# summary writer, getting default summary writer quite doesn't solve the
|
||||
# problem as there might be summary writers in log dir not set as default
|
||||
# e.g. in Keras calback.
|
||||
# Creating a summary_writer here could potentially be replaced with getting
|
||||
# the default summary_writer if any, creating it otherwise or a public
|
||||
# method to associate summary writer.
|
||||
|
||||
self._logdir = tempfile.mkdtemp()
|
||||
self._summary_writer = summary_ops_v2.create_file_writer(
|
||||
self._logdir, max_queue=_DEFAULT_MAX_QUEUE)
|
||||
ged_ops.stats_aggregator_set_summary_writer(self._resource,
|
||||
self._summary_writer._resource) # pylint: disable=protected-access
|
||||
|
||||
|
||||
@tf_export(v1=["data.experimental.StatsAggregator"])
|
||||
class StatsAggregatorV1(object):
|
||||
"""A stateful resource that aggregates statistics from one or more iterators.
|
||||
|
||||
To record statistics, use one of the custom transformation functions defined
|
||||
@ -70,7 +129,6 @@ class StatsAggregator(object):
|
||||
"""Creates a `StatsAggregator`."""
|
||||
self._resource = ged_ops.experimental_stats_aggregator_handle()
|
||||
|
||||
# TODO(b/116314787): Update this/add support for V2 summary API.
|
||||
def get_summary(self):
|
||||
"""Returns a string `tf.Tensor` that summarizes the aggregated statistics.
|
||||
|
||||
@ -81,3 +139,8 @@ class StatsAggregator(object):
|
||||
A scalar string `tf.Tensor` that summarizes the aggregated statistics.
|
||||
"""
|
||||
return ged_ops.experimental_stats_aggregator_summary(self._resource)
|
||||
|
||||
|
||||
# TODO(b/116314787): Change this to StatsAggregatorV2 when we have stable
|
||||
# SummaryWriterInterface, and do not break any users.
|
||||
StatsAggregator = StatsAggregatorV1
|
||||
|
@ -45,7 +45,8 @@ class StatsOptions(options.OptionsBase):
|
||||
|
||||
aggregator = options.create_option(
|
||||
name="aggregator",
|
||||
ty=stats_aggregator.StatsAggregator,
|
||||
ty=(stats_aggregator.StatsAggregatorV2,
|
||||
stats_aggregator.StatsAggregatorV1),
|
||||
docstring=
|
||||
"Associates the given statistics aggregator with the dataset pipeline.")
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
path: "tensorflow.data.experimental.StatsAggregator"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.data.experimental.ops.stats_aggregator.StatsAggregator\'>"
|
||||
is_instance: "<class \'tensorflow.python.data.experimental.ops.stats_aggregator.StatsAggregatorV1\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
|
@ -3740,6 +3740,14 @@ tf_module {
|
||||
name: "StaticRegexReplace"
|
||||
argspec: "args=[\'input\', \'pattern\', \'rewrite\', \'replace_global\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatsAggregatorHandleV2"
|
||||
argspec: "args=[\'container\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatsAggregatorSetSummaryWriter"
|
||||
argspec: "args=[\'stats_aggregator\', \'summary\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StopGradient"
|
||||
argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -1,13 +1,9 @@
|
||||
path: "tensorflow.data.experimental.StatsAggregator"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.data.experimental.ops.stats_aggregator.StatsAggregator\'>"
|
||||
is_instance: "<class \'tensorflow.python.data.experimental.ops.stats_aggregator.StatsAggregatorV2\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_summary"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
@ -3740,6 +3740,14 @@ tf_module {
|
||||
name: "StaticRegexReplace"
|
||||
argspec: "args=[\'input\', \'pattern\', \'rewrite\', \'replace_global\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatsAggregatorHandleV2"
|
||||
argspec: "args=[\'container\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatsAggregatorSetSummaryWriter"
|
||||
argspec: "args=[\'stats_aggregator\', \'summary\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StopGradient"
|
||||
argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user