Adds a sampler monitoring api to monitor values which are distributions.
- Not yet ready to be used. The collection part hasn't been plumbed in yet. Change: 136760667
This commit is contained in:
parent
af1c9a17a4
commit
9286bb127e
tensorflow/core
@ -185,6 +185,8 @@ cc_library(
|
||||
"lib/monitoring/counter.h",
|
||||
"lib/monitoring/metric_def.h",
|
||||
"lib/monitoring/mobile_counter.h",
|
||||
"lib/monitoring/mobile_sampler.h",
|
||||
"lib/monitoring/sampler.h",
|
||||
"lib/random/distribution_sampler.h",
|
||||
"lib/random/philox_random.h",
|
||||
"lib/random/simple_philox.h",
|
||||
@ -1517,6 +1519,7 @@ tf_cc_tests(
|
||||
"lib/monitoring/collection_registry_test.cc",
|
||||
"lib/monitoring/counter_test.cc",
|
||||
"lib/monitoring/metric_def_test.cc",
|
||||
"lib/monitoring/sampler_test.cc",
|
||||
"lib/random/distribution_sampler_test.cc",
|
||||
"lib/random/philox_random_test.cc",
|
||||
"lib/random/random_distributions_test.cc",
|
||||
|
70
tensorflow/core/lib/monitoring/mobile_sampler.h
Normal file
70
tensorflow/core/lib/monitoring/mobile_sampler.h
Normal file
@ -0,0 +1,70 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Null implementation of the Sampler metric for mobile platforms.
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_MOBILE_SAMPLER_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_MOBILE_SAMPLER_H_
|
||||
|
||||
#include "tensorflow/core/framework/summary.pb.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace monitoring {
|
||||
|
||||
// SamplerCell which has a null implementation.
|
||||
class SamplerCell {
|
||||
public:
|
||||
SamplerCell() {}
|
||||
~SamplerCell() {}
|
||||
|
||||
void Add(double value) {}
|
||||
HistogramProto value() const { return HistogramProto(); }
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(SamplerCell);
|
||||
};
|
||||
|
||||
// Sampler which has a null implementation.
|
||||
template <int NumLabels>
|
||||
class Sampler {
|
||||
public:
|
||||
~Sampler() {}
|
||||
|
||||
template <typename... MetricDefArgs>
|
||||
static Sampler* New(const MetricDef<MetricKind::kCumulative, HistogramProto,
|
||||
NumLabels>& metric_def,
|
||||
const std::vector<double>& explicit_bucket_limits) {
|
||||
return new Sampler<NumLabels>();
|
||||
}
|
||||
|
||||
template <typename... Labels>
|
||||
SamplerCell* GetCell(const Labels&... labels) {
|
||||
return &default_sampler_cell_;
|
||||
}
|
||||
|
||||
private:
|
||||
Sampler() {}
|
||||
|
||||
SamplerCell default_sampler_cell_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(Sampler);
|
||||
};
|
||||
|
||||
} // namespace monitoring
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_MOBILE_SAMPLER_H_
|
191
tensorflow/core/lib/monitoring/sampler.h
Normal file
191
tensorflow/core/lib/monitoring/sampler.h
Normal file
@ -0,0 +1,191 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_SAMPLER_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_SAMPLER_H_
|
||||
|
||||
// We replace this implementation with a null implementation for mobile
|
||||
// platforms.
|
||||
#include "tensorflow/core/platform/platform.h"
|
||||
#ifdef IS_MOBILE_PLATFORM
|
||||
#include "tensorflow/core/lib/monitoring/mobile_sampler.h"
|
||||
#else
|
||||
|
||||
#include <float.h>
|
||||
#include <map>
|
||||
|
||||
#include "tensorflow/core/framework/summary.pb.h"
|
||||
#include "tensorflow/core/lib/histogram/histogram.h"
|
||||
#include "tensorflow/core/lib/monitoring/metric_def.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
|
||||
// TODO(vinuraja): Not ready yet. The collection part has to be plumbed in.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace monitoring {
|
||||
|
||||
// SamplerCell stores each value of an Sampler.
|
||||
//
|
||||
// A cell can be passed off to a module which may repeatedly update it without
|
||||
// needing further map-indexing computations. This improves both encapsulation
|
||||
// (separate modules can own a cell each, without needing to know about the map
|
||||
// to which both cells belong) and performance (since map indexing and
|
||||
// associated locking are both avoided).
|
||||
//
|
||||
// This class is thread-safe.
|
||||
class SamplerCell {
|
||||
public:
|
||||
SamplerCell(const std::vector<double>& bucket_limits)
|
||||
: histogram_(bucket_limits) {}
|
||||
|
||||
~SamplerCell() {}
|
||||
|
||||
// Atomically adds a sample.
|
||||
void Add(double sample);
|
||||
|
||||
// Returns the current histogram value as a proto.
|
||||
HistogramProto value() const;
|
||||
|
||||
private:
|
||||
histogram::ThreadSafeHistogram histogram_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(SamplerCell);
|
||||
};
|
||||
|
||||
// A stateful class for updating a cumulative histogram metric.
|
||||
//
|
||||
// This class encapsulates a set of values (or a single value for a label-less
|
||||
// metric). Each value is identified by a tuple of labels. The class allows the
|
||||
// user to increment each value.
|
||||
//
|
||||
// Sampler allocates storage and maintains a cell for each value. You can
|
||||
// retrieve an individual cell using a label-tuple and update it separately.
|
||||
// This improves performance since operations related to retrieval, like
|
||||
// map-indexing and locking, are avoided.
|
||||
//
|
||||
// This class is thread-safe.
|
||||
template <int NumLabels>
|
||||
class Sampler {
|
||||
public:
|
||||
~Sampler() {}
|
||||
|
||||
// Creates the metric based on the metric-definition arguments.
|
||||
//
|
||||
// Example;
|
||||
// auto* sampler_with_label = Sampler<1>::New({"/tensorflow/sampler",
|
||||
// "Tensorflow sampler", "MyLabelName"}, {10.0, 20.0, 30.0});
|
||||
//
|
||||
// We automatically add -DBL_MAX and DBL_MAX to the list of bucket limits, so
|
||||
// that no sample goes out of bounds. So for the above example, the ranges end
|
||||
// up being: [-DBL_Max, 10.0, 20.0, 30.0, DBL_MAX]
|
||||
//
|
||||
// REQUIRES: bucket_limits[i] values are monotonically increasing.
|
||||
// REQUIRES: bucket_limits is not empty().
|
||||
static Sampler* New(const MetricDef<MetricKind::kCumulative, HistogramProto,
|
||||
NumLabels>& metric_def,
|
||||
const std::vector<double>& bucket_limits);
|
||||
|
||||
// Retrieves the cell for the specified labels, creating it on demand if
|
||||
// not already present.
|
||||
template <typename... Labels>
|
||||
SamplerCell* GetCell(const Labels&... labels) LOCKS_EXCLUDED(mu_);
|
||||
|
||||
private:
|
||||
friend class SamplerCell;
|
||||
|
||||
Sampler(const MetricDef<MetricKind::kCumulative, HistogramProto, NumLabels>&
|
||||
metric_def,
|
||||
const std::vector<double>& bucket_limits)
|
||||
: metric_def_(metric_def), bucket_limits_(bucket_limits) {}
|
||||
|
||||
mutable mutex mu_;
|
||||
|
||||
// The metric definition. This will be used to identify the metric when we
|
||||
// register it for collection.
|
||||
const MetricDef<MetricKind::kCumulative, HistogramProto, NumLabels>
|
||||
metric_def_;
|
||||
|
||||
// Bucket limits for the histograms in the cells.
|
||||
const std::vector<double> bucket_limits_;
|
||||
|
||||
// We use a std::map here because we give out pointers to the SamplerCells,
|
||||
// which need to remain valid even after more cells.
|
||||
using LabelArray = std::array<string, NumLabels>;
|
||||
std::map<LabelArray, SamplerCell> cells_ GUARDED_BY(mu_);
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(Sampler);
|
||||
};
|
||||
|
||||
////
|
||||
// Implementation details follow. API readers may skip.
|
||||
////
|
||||
|
||||
inline void SamplerCell::Add(const double sample) { histogram_.Add(sample); }
|
||||
|
||||
inline HistogramProto SamplerCell::value() const {
|
||||
HistogramProto pb;
|
||||
histogram_.EncodeToProto(&pb, true /* preserve_zero_buckets */);
|
||||
return pb;
|
||||
}
|
||||
|
||||
template <int NumLabels>
|
||||
Sampler<NumLabels>* Sampler<NumLabels>::New(
|
||||
const MetricDef<MetricKind::kCumulative, HistogramProto, NumLabels>&
|
||||
metric_def,
|
||||
const std::vector<double>& bucket_limits) {
|
||||
CHECK_GT(bucket_limits.size(), 0);
|
||||
// Verify that the bucket boundaries are strictly increasing
|
||||
for (size_t i = 1; i < bucket_limits.size(); i++) {
|
||||
CHECK_GT(bucket_limits[i], bucket_limits[i - 1]);
|
||||
}
|
||||
std::vector<double> augmented_bucket_limits(bucket_limits);
|
||||
// We add DBL_MAX to the end so that all boundaries are within [-DBL_MAX,
|
||||
// DBL_MAX].
|
||||
if (bucket_limits.back() != DBL_MAX) {
|
||||
augmented_bucket_limits.push_back(DBL_MAX);
|
||||
}
|
||||
return new Sampler<NumLabels>(metric_def, augmented_bucket_limits);
|
||||
}
|
||||
|
||||
template <int NumLabels>
|
||||
template <typename... Labels>
|
||||
SamplerCell* Sampler<NumLabels>::GetCell(const Labels&... labels)
|
||||
LOCKS_EXCLUDED(mu_) {
|
||||
// Provides a more informative error message than the one during array
|
||||
// construction below.
|
||||
static_assert(sizeof...(Labels) == NumLabels,
|
||||
"Mismatch between Sampler<NumLabels> and number of labels "
|
||||
"provided in GetCell(...).");
|
||||
|
||||
const LabelArray& label_array = {labels...};
|
||||
mutex_lock l(mu_);
|
||||
const auto found_it = cells_.find(label_array);
|
||||
if (found_it != cells_.end()) {
|
||||
return &(found_it->second);
|
||||
}
|
||||
return &(cells_
|
||||
.emplace(std::piecewise_construct,
|
||||
std::forward_as_tuple(label_array),
|
||||
std::forward_as_tuple(bucket_limits_))
|
||||
.first->second);
|
||||
}
|
||||
|
||||
} // namespace monitoring
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // IS_MOBILE_PLATFORM
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_SAMPLER_H_
|
92
tensorflow/core/lib/monitoring/sampler_test.cc
Normal file
92
tensorflow/core/lib/monitoring/sampler_test.cc
Normal file
@ -0,0 +1,92 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/lib/monitoring/sampler.h"
|
||||
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace monitoring {
|
||||
namespace {
|
||||
|
||||
using histogram::Histogram;
|
||||
|
||||
static void EqHistograms(const histogram::Histogram& expected,
|
||||
const HistogramProto& actual_proto) {
|
||||
histogram::Histogram actual;
|
||||
EXPECT_TRUE(actual.DecodeFromProto(actual_proto));
|
||||
|
||||
EXPECT_EQ(expected.ToString(), actual.ToString());
|
||||
}
|
||||
|
||||
auto* sampler_with_labels =
|
||||
Sampler<1>::New({"/tensorflow/test/sampler_with_labels",
|
||||
"Sampler with one label.", "MyLabel"},
|
||||
{10.0, 20.0});
|
||||
|
||||
TEST(LabeledSamplerTest, InitializedEmpty) {
|
||||
Histogram empty;
|
||||
EqHistograms(empty, sampler_with_labels->GetCell("Empty")->value());
|
||||
}
|
||||
|
||||
TEST(LabeledSamplerTest, BucketBoundaries) {
|
||||
// Sampler automatically adds DBL_MAX to the list of buckets.
|
||||
Histogram expected({10.0, 20.0, DBL_MAX});
|
||||
auto* cell = sampler_with_labels->GetCell("BucketBoundaries");
|
||||
sampler_with_labels->GetCell("AddedToCheckPreviousCellValidity");
|
||||
cell->Add(-1.0);
|
||||
expected.Add(-1.0);
|
||||
cell->Add(10.0);
|
||||
expected.Add(10.0);
|
||||
cell->Add(20.0);
|
||||
expected.Add(20.0);
|
||||
cell->Add(31.0);
|
||||
expected.Add(31.0);
|
||||
|
||||
EqHistograms(expected, cell->value());
|
||||
}
|
||||
|
||||
auto* init_sampler_without_labels =
|
||||
Sampler<0>::New({"/tensorflow/test/init_sampler_without_labels",
|
||||
"Sampler without labels initialized as empty."},
|
||||
{1.5, 2.8});
|
||||
|
||||
TEST(UnlabeledSamplerTest, InitializedEmpty) {
|
||||
Histogram empty;
|
||||
EqHistograms(empty, init_sampler_without_labels->GetCell()->value());
|
||||
}
|
||||
|
||||
auto* sampler_without_labels =
|
||||
Sampler<0>::New({"/tensorflow/test/sampler_without_labels",
|
||||
"Sampler without labels initialized as empty."},
|
||||
{1.5, 2.8});
|
||||
|
||||
TEST(UnlabeledSamplerTest, BucketBoundaries) {
|
||||
// Sampler automatically adds DBL_MAX to the list of buckets.
|
||||
Histogram expected({1.5, 2.8, DBL_MAX});
|
||||
auto* cell = sampler_without_labels->GetCell();
|
||||
cell->Add(-1.0);
|
||||
expected.Add(-1.0);
|
||||
cell->Add(2.0);
|
||||
expected.Add(2.0);
|
||||
cell->Add(31.0);
|
||||
expected.Add(31.0);
|
||||
|
||||
EqHistograms(expected, cell->value());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace monitoring
|
||||
} // namespace tensorflow
|
Loading…
Reference in New Issue
Block a user