[tf.data] Update autotuning implementation to increase optimization frequency any time the input pipeline graph changes.

PiperOrigin-RevId: 355493912
Change-Id: I52718240c46f98921f0e77d89a4409abd9f29934
This commit is contained in:
Jiri Simsa 2021-02-03 15:19:40 -08:00 committed by TensorFlower Gardener
parent 6d83e8d1f1
commit 4aa50ebfa8
7 changed files with 137 additions and 98 deletions

View File

@ -206,4 +206,23 @@ bool CancellationManager::IsCancelling() {
return is_cancelling_;
}
Status RegisterCancellationCallback(CancellationManager* cancellation_manager,
std::function<void()> callback,
std::function<void()>* deregister_fn) {
if (cancellation_manager) {
CancellationToken token = cancellation_manager->get_cancellation_token();
if (!cancellation_manager->RegisterCallback(token, std::move(callback))) {
return errors::Cancelled("Operation was cancelled");
}
*deregister_fn = [cancellation_manager, token]() {
cancellation_manager->DeregisterCallback(token);
};
} else {
VLOG(1) << "Cancellation manager is not set. Cancellation callback will "
"not be registered.";
*deregister_fn = []() {};
}
return Status::OK();
}
} // end namespace tensorflow

View File

@ -182,6 +182,13 @@ class CancellationManager {
std::unique_ptr<State> state_ TF_GUARDED_BY(mu_);
};
// Registers the given cancellation callback, returning a function that can be
// used to deregister the callback. If `cancellation_manager` is NULL, no
// registration occurs and `deregister_fn` will be a no-op.
Status RegisterCancellationCallback(CancellationManager* cancellation_manager,
std::function<void()> callback,
std::function<void()>* deregister_fn);
} // namespace tensorflow
#endif // TENSORFLOW_CORE_FRAMEWORK_CANCELLATION_H_

View File

@ -18,6 +18,8 @@ limitations under the License.
#include <memory>
#include "absl/time/clock.h"
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/strings/str_util.h"
namespace tensorflow {
@ -1461,6 +1463,10 @@ void Model::AddNode(Node::Factory factory, const string& name,
collect_resource_usage_ =
collect_resource_usage_ || node->has_tunable_parameters();
*out_node = std::move(node);
// Reset the optimization period when a node is added so that autotuning
// adapts to changes to the input pipeline faster.
optimization_period_ms_ = kOptimizationPeriodMinMs;
cond_var_.notify_all();
}
void Model::FlushMetrics() {
@ -1541,6 +1547,55 @@ bool Model::ShouldStop(
return all_max || TotalMaximumBufferedBytes(snapshot) > ram_budget;
}
// TODO(jsimsa): Add support for tracking and using the model input time.
Status Model::OptimizeLoop(AutotuneAlgorithm algorithm, int64 cpu_budget,
int64 ram_budget,
CancellationManager* cancellation_manager) {
std::function<void()> unused;
TF_RETURN_IF_ERROR(RegisterCancellationCallback(
cancellation_manager,
[this]() {
mutex_lock l(mu_);
cond_var_.notify_all();
},
/*deregister_fn=*/&unused));
int64 last_optimization_ms = 0;
int64 current_time_ms = EnvTime::NowMicros() / EnvTime::kMillisToMicros;
while (true) {
{
mutex_lock l(mu_);
while (!cancellation_manager->IsCancelled() &&
last_optimization_ms + optimization_period_ms_ > current_time_ms) {
auto wait_ms =
last_optimization_ms + optimization_period_ms_ - current_time_ms;
VLOG(2) << "Waiting for " << wait_ms << " ms.";
cond_var_.wait_for(l, std::chrono::milliseconds(wait_ms));
current_time_ms = EnvTime::NowMicros() / EnvTime::kMillisToMicros;
}
if (cancellation_manager->IsCancelled()) {
return Status::OK();
}
}
int64 optimization_start_us = EnvTime::NowMicros();
Optimize(algorithm, cpu_budget, ram_budget, /*model_input_time=*/0);
VLOG(2) << "Optimized for "
<< (EnvTime::NowMicros() - optimization_start_us) << " us.";
// Exponentially increase the period of running the optimization
// until a threshold is reached.
{
mutex_lock l(mu_);
optimization_period_ms_ =
std::min(optimization_period_ms_ << 1, kOptimizationPeriodMaxMs);
}
current_time_ms = EnvTime::NowMicros() / EnvTime::kMillisToMicros;
last_optimization_ms = current_time_ms;
FlushMetrics();
}
}
void Model::OptimizeGradientDescent(int64 cpu_budget, int64 ram_budget,
double model_input_time) {
std::shared_ptr<Node> snapshot;

View File

@ -24,6 +24,7 @@ limitations under the License.
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/framework/metrics.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
@ -32,6 +33,7 @@ limitations under the License.
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/mutex.h"
namespace tensorflow {
namespace data {
@ -626,7 +628,9 @@ std::shared_ptr<Node> MakeUnknownNode(Node::Args args);
class Model {
public:
// Creates a new model.
Model() : collect_resource_usage_(false) {}
Model()
: collect_resource_usage_(false),
optimization_period_ms_(kOptimizationPeriodMinMs) {}
// Indicates whether to collect resource usage.
bool collect_resource_usage() const { return collect_resource_usage_; }
@ -636,33 +640,34 @@ class Model {
std::shared_ptr<Node> parent, std::shared_ptr<Node>* out_node)
TF_LOCKS_EXCLUDED(mu_);
// Flushes metrics record by the model.
void FlushMetrics() TF_LOCKS_EXCLUDED(mu_);
// Uses the given algorithm and resource budgets to periodically perform the
// autotuning optimization.
//
// To terminate the execution of the optimization loop, the caller needs to
// to invoke `cancellation_mgr->StartCancel()`.
Status OptimizeLoop(AutotuneAlgorithm algorithm, int64 cpu_budget,
int64 ram_budget, CancellationManager* cancellation_mgr);
// Uses the given algorithm to perform the autotuning optimization.
// Uses the given algorithm and resource budgets to perform the autotuning
// optimization.
void Optimize(AutotuneAlgorithm algorithm, int64 cpu_budget, int64 ram_budget,
double model_input_time) TF_LOCKS_EXCLUDED(mu_);
double model_input_time);
// Removes the given node.
void RemoveNode(std::shared_ptr<Node> node) TF_LOCKS_EXCLUDED(mu_);
private:
static constexpr int64 kOptimizationPeriodMinMs = 10;
static constexpr int64 kOptimizationPeriodMaxMs =
60 * EnvTime::kSecondsToMillis;
// Collects tunable parameters in the tree rooted in the given node, returning
// a mapping from a (unique) node name to a tunable parameter.
absl::flat_hash_map<string, std::shared_ptr<Parameter>>
CollectTunableParameters(std::shared_ptr<Node> node);
// Determines if we should stop the gradient descent optimization iterations
// based on number of increasable parameters, CPU budget, RAM budget and
// current resource usage.
bool ShouldStop(
int64 cpu_budget, int64 ram_budget,
const absl::flat_hash_map<string, std::shared_ptr<Parameter>>& parameters,
const absl::flat_hash_map<string, std::shared_ptr<Parameter>>&
parallelism_parameters,
const absl::flat_hash_map<string, std::shared_ptr<Parameter>>&
buffer_size_parameters,
std::shared_ptr<Node> snapshot, bool* cpu_budget_reached);
// Flushes metrics recorded by the model.
void FlushMetrics() TF_LOCKS_EXCLUDED(mu_);
// This optimization algorithm starts by setting all tunable parallelism
// parameters to the minimum value. It then repeatedly identifies the
@ -689,6 +694,18 @@ class Model {
double OutputTime(std::shared_ptr<Node> node, double model_input_time,
absl::flat_hash_map<string, double>* gradients);
// Determines if we should stop the gradient descent optimization iterations
// based on number of increasable parameters, CPU budget, RAM budget and
// current resource usage.
bool ShouldStop(
int64 cpu_budget, int64 ram_budget,
const absl::flat_hash_map<string, std::shared_ptr<Parameter>>& parameters,
const absl::flat_hash_map<string, std::shared_ptr<Parameter>>&
parallelism_parameters,
const absl::flat_hash_map<string, std::shared_ptr<Parameter>>&
buffer_size_parameters,
std::shared_ptr<Node> snapshot, bool* cpu_budget_reached);
// Collects the processing time for the given node.
double TotalProcessingTime(std::shared_ptr<Node> node);
@ -706,6 +723,8 @@ class Model {
// access is required only when adding or removing nodes. Concurrent access to
// existing nodes is protected by a node mutex.
mutex mu_;
// Used for coordinating the optimization loop and model modifications.
condition_variable cond_var_;
int64 id_counter_ TF_GUARDED_BY(mu_) = 1;
std::shared_ptr<Node> output_ TF_GUARDED_BY(mu_);
@ -716,6 +735,10 @@ class Model {
// tunable parameter (because the information is used for tuning the value of
// the parameter) and never stops.
std::atomic<bool> collect_resource_usage_;
// Determines the time the optimization loop should wait between
// running optimizations.
int64 optimization_period_ms_ TF_GUARDED_BY(mu_);
};
} // namespace model

View File

@ -103,26 +103,6 @@ std::pair<int64, int64> MaybeOverrideSeeds(std::pair<int64, int64> seeds) {
return seeds;
}
Status RegisterCancellationCallback(CancellationManager* cancellation_manager,
std::function<void()> register_fn,
std::function<void()>* deregister_fn) {
if (cancellation_manager) {
CancellationToken token = cancellation_manager->get_cancellation_token();
if (!cancellation_manager->RegisterCallback(token,
std::move(register_fn))) {
return errors::Cancelled("Operation was cancelled");
}
*deregister_fn = [cancellation_manager, token]() {
cancellation_manager->DeregisterCallback(token);
};
} else {
VLOG(1) << "Cancellation manager is not set. Cancellation callback will "
"not be registered.";
*deregister_fn = []() {};
}
return Status::OK();
}
Status VerifyTypeMatch(const DataType& expected, const DataType& received,
int index) {
if (expected != received) {

View File

@ -83,12 +83,6 @@ class AnonymousResourceOp : public OpKernel {
bool create_deleter_ = true;
};
// Registers the given cancellation callback, returning a function that can be
// used to deregister the callback.
Status RegisterCancellationCallback(CancellationManager* cancellation_manager,
std::function<void()> register_fn,
std::function<void()>* deregister_fn);
// Returns Status::OK() if `expected` and `received` types match,
// errors::InvalidArgument otherwise.
Status VerifyTypesMatch(const DataTypeVector& expected,

View File

@ -14,6 +14,8 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/data/model_dataset_op.h"
#include "tensorflow/core/framework/cancellation.h"
// On mobile we do not provide model dataset op because not all of its
// dependencies are available there. The op is replaced with a no-op.
#if !defined(IS_MOBILE_PLATFORM)
@ -32,8 +34,6 @@ namespace tensorflow {
namespace data {
namespace {
constexpr int64 kOptimizationPeriodThresholdMs = 60 * EnvTime::kSecondsToMillis;
// Default share of available RAM that can be used by model's internal buffers.
constexpr double kRamBudgetShare = 0.5;
@ -125,16 +125,11 @@ class ModelDatasetOp::Dataset : public DatasetBase {
ram_budget_(dataset()->ram_budget_ == 0
? kRamBudgetShare * port::AvailableRam()
: dataset()->ram_budget_) {
cancellation_manager_ = absl::make_unique<CancellationManager>();
model_ = std::make_shared<model::Model>();
}
~Iterator() override {
// Signal the optimize thread to terminate it. We will then join that
// thread when we delete `this->optimize_thread_`.
mutex_lock l(mu_);
cancelled_ = true;
cond_var_.notify_all();
}
~Iterator() override { cancellation_manager_->StartCancel(); }
Status Initialize(IteratorContext* ctx) override {
IteratorContext::Params params(ctx);
@ -149,7 +144,7 @@ class ModelDatasetOp::Dataset : public DatasetBase {
IteratorContext::Params params(ctx);
{
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(EnsureModelThreadStarted(ctx));
TF_RETURN_IF_ERROR(EnsureOptimizationLoopThreadStarted(ctx));
params.model = model_;
int64 now_nanos = EnvTime::NowNanos();
RecordInput(now_nanos);
@ -188,56 +183,21 @@ class ModelDatasetOp::Dataset : public DatasetBase {
}
private:
Status EnsureModelThreadStarted(IteratorContext* ctx)
Status EnsureOptimizationLoopThreadStarted(IteratorContext* ctx)
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (!model_thread_) {
model_thread_ =
ctx->StartThread("tf_data_model", [this]() { ModelThread(); });
model_thread_ = ctx->StartThread("tf_data_model", [this]() {
Status status =
model_->OptimizeLoop(dataset()->algorithm_, cpu_budget_,
ram_budget_, cancellation_manager_.get());
if (!status.ok()) {
LOG(WARNING) << "Optimization loop failed: " << status.ToString();
}
});
}
return Status::OK();
}
void ModelThread() {
int64 last_optimization_ms = 0;
int64 optimization_period_ms = 10;
int64 current_time_ms = EnvTime::NowMicros() / EnvTime::kMillisToMicros;
while (true) {
{
mutex_lock l(mu_);
while (!cancelled_ && last_optimization_ms + optimization_period_ms >
current_time_ms) {
auto wait_ms =
last_optimization_ms + optimization_period_ms - current_time_ms;
VLOG(2) << "Waiting for " << wait_ms << " ms.";
cond_var_.wait_for(l, std::chrono::milliseconds(wait_ms));
current_time_ms = EnvTime::NowMicros() / EnvTime::kMillisToMicros;
}
if (cancelled_) return;
}
double model_input_time;
{
tf_shared_lock l(mu_);
model_input_time = SelfInputTime();
}
int64 optimization_start_us = EnvTime::NowMicros();
model_->Optimize(dataset()->algorithm_, cpu_budget_, ram_budget_,
/*model_input_time=*/0);
VLOG(2) << "Optimized for "
<< (EnvTime::NowMicros() - optimization_start_us) << " us.";
// Exponentially increase the period of running the optimization
// until a threshold is reached.
if (optimization_period_ms != kOptimizationPeriodThresholdMs) {
optimization_period_ms = std::min(optimization_period_ms << 1,
kOptimizationPeriodThresholdMs);
}
current_time_ms = EnvTime::NowMicros() / EnvTime::kMillisToMicros;
last_optimization_ms = current_time_ms;
model_->FlushMetrics();
}
}
void RecordInput(int64 time_nanos) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (last_output_time_ != 0) {
DCHECK_LE(last_output_time_, time_nanos);
@ -259,10 +219,11 @@ class ModelDatasetOp::Dataset : public DatasetBase {
}
mutex mu_;
condition_variable cond_var_;
std::shared_ptr<model::Model> model_;
// Controls cancellation of `model_thread_`. Must be ordered before
// `model_thread_` so that `model_thread_` is destroyed first.
std::unique_ptr<CancellationManager> cancellation_manager_;
std::unique_ptr<Thread> model_thread_ TF_GUARDED_BY(mu_);
bool cancelled_ TF_GUARDED_BY(mu_) = false;
std::unique_ptr<IteratorBase> input_impl_;
int64 num_input_events_ TF_GUARDED_BY(mu_) = 0;
int64 input_time_ TF_GUARDED_BY(mu_) = 0;