[tf.data] Update autotuning implementation to increase optimization frequency any time the input pipeline graph changes.
PiperOrigin-RevId: 355493912 Change-Id: I52718240c46f98921f0e77d89a4409abd9f29934
This commit is contained in:
parent
6d83e8d1f1
commit
4aa50ebfa8
@ -206,4 +206,23 @@ bool CancellationManager::IsCancelling() {
|
||||
return is_cancelling_;
|
||||
}
|
||||
|
||||
Status RegisterCancellationCallback(CancellationManager* cancellation_manager,
|
||||
std::function<void()> callback,
|
||||
std::function<void()>* deregister_fn) {
|
||||
if (cancellation_manager) {
|
||||
CancellationToken token = cancellation_manager->get_cancellation_token();
|
||||
if (!cancellation_manager->RegisterCallback(token, std::move(callback))) {
|
||||
return errors::Cancelled("Operation was cancelled");
|
||||
}
|
||||
*deregister_fn = [cancellation_manager, token]() {
|
||||
cancellation_manager->DeregisterCallback(token);
|
||||
};
|
||||
} else {
|
||||
VLOG(1) << "Cancellation manager is not set. Cancellation callback will "
|
||||
"not be registered.";
|
||||
*deregister_fn = []() {};
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
@ -182,6 +182,13 @@ class CancellationManager {
|
||||
std::unique_ptr<State> state_ TF_GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
// Registers the given cancellation callback, returning a function that can be
|
||||
// used to deregister the callback. If `cancellation_manager` is NULL, no
|
||||
// registration occurs and `deregister_fn` will be a no-op.
|
||||
Status RegisterCancellationCallback(CancellationManager* cancellation_manager,
|
||||
std::function<void()> callback,
|
||||
std::function<void()>* deregister_fn);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_FRAMEWORK_CANCELLATION_H_
|
||||
|
@ -18,6 +18,8 @@ limitations under the License.
|
||||
#include <memory>
|
||||
|
||||
#include "absl/time/clock.h"
|
||||
#include "tensorflow/core/framework/cancellation.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -1461,6 +1463,10 @@ void Model::AddNode(Node::Factory factory, const string& name,
|
||||
collect_resource_usage_ =
|
||||
collect_resource_usage_ || node->has_tunable_parameters();
|
||||
*out_node = std::move(node);
|
||||
// Reset the optimization period when a node is added so that autotuning
|
||||
// adapts to changes to the input pipeline faster.
|
||||
optimization_period_ms_ = kOptimizationPeriodMinMs;
|
||||
cond_var_.notify_all();
|
||||
}
|
||||
|
||||
void Model::FlushMetrics() {
|
||||
@ -1541,6 +1547,55 @@ bool Model::ShouldStop(
|
||||
return all_max || TotalMaximumBufferedBytes(snapshot) > ram_budget;
|
||||
}
|
||||
|
||||
// TODO(jsimsa): Add support for tracking and using the model input time.
|
||||
Status Model::OptimizeLoop(AutotuneAlgorithm algorithm, int64 cpu_budget,
|
||||
int64 ram_budget,
|
||||
CancellationManager* cancellation_manager) {
|
||||
std::function<void()> unused;
|
||||
TF_RETURN_IF_ERROR(RegisterCancellationCallback(
|
||||
cancellation_manager,
|
||||
[this]() {
|
||||
mutex_lock l(mu_);
|
||||
cond_var_.notify_all();
|
||||
},
|
||||
/*deregister_fn=*/&unused));
|
||||
|
||||
int64 last_optimization_ms = 0;
|
||||
int64 current_time_ms = EnvTime::NowMicros() / EnvTime::kMillisToMicros;
|
||||
while (true) {
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
while (!cancellation_manager->IsCancelled() &&
|
||||
last_optimization_ms + optimization_period_ms_ > current_time_ms) {
|
||||
auto wait_ms =
|
||||
last_optimization_ms + optimization_period_ms_ - current_time_ms;
|
||||
VLOG(2) << "Waiting for " << wait_ms << " ms.";
|
||||
cond_var_.wait_for(l, std::chrono::milliseconds(wait_ms));
|
||||
current_time_ms = EnvTime::NowMicros() / EnvTime::kMillisToMicros;
|
||||
}
|
||||
if (cancellation_manager->IsCancelled()) {
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
int64 optimization_start_us = EnvTime::NowMicros();
|
||||
Optimize(algorithm, cpu_budget, ram_budget, /*model_input_time=*/0);
|
||||
VLOG(2) << "Optimized for "
|
||||
<< (EnvTime::NowMicros() - optimization_start_us) << " us.";
|
||||
|
||||
// Exponentially increase the period of running the optimization
|
||||
// until a threshold is reached.
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
optimization_period_ms_ =
|
||||
std::min(optimization_period_ms_ << 1, kOptimizationPeriodMaxMs);
|
||||
}
|
||||
current_time_ms = EnvTime::NowMicros() / EnvTime::kMillisToMicros;
|
||||
last_optimization_ms = current_time_ms;
|
||||
FlushMetrics();
|
||||
}
|
||||
}
|
||||
|
||||
void Model::OptimizeGradientDescent(int64 cpu_budget, int64 ram_budget,
|
||||
double model_input_time) {
|
||||
std::shared_ptr<Node> snapshot;
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "tensorflow/core/framework/cancellation.h"
|
||||
#include "tensorflow/core/framework/metrics.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
@ -32,6 +33,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/random/random.h"
|
||||
#include "tensorflow/core/platform/cpu_info.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
@ -626,7 +628,9 @@ std::shared_ptr<Node> MakeUnknownNode(Node::Args args);
|
||||
class Model {
|
||||
public:
|
||||
// Creates a new model.
|
||||
Model() : collect_resource_usage_(false) {}
|
||||
Model()
|
||||
: collect_resource_usage_(false),
|
||||
optimization_period_ms_(kOptimizationPeriodMinMs) {}
|
||||
|
||||
// Indicates whether to collect resource usage.
|
||||
bool collect_resource_usage() const { return collect_resource_usage_; }
|
||||
@ -636,33 +640,34 @@ class Model {
|
||||
std::shared_ptr<Node> parent, std::shared_ptr<Node>* out_node)
|
||||
TF_LOCKS_EXCLUDED(mu_);
|
||||
|
||||
// Flushes metrics record by the model.
|
||||
void FlushMetrics() TF_LOCKS_EXCLUDED(mu_);
|
||||
// Uses the given algorithm and resource budgets to periodically perform the
|
||||
// autotuning optimization.
|
||||
//
|
||||
// To terminate the execution of the optimization loop, the caller needs to
|
||||
// to invoke `cancellation_mgr->StartCancel()`.
|
||||
Status OptimizeLoop(AutotuneAlgorithm algorithm, int64 cpu_budget,
|
||||
int64 ram_budget, CancellationManager* cancellation_mgr);
|
||||
|
||||
// Uses the given algorithm to perform the autotuning optimization.
|
||||
// Uses the given algorithm and resource budgets to perform the autotuning
|
||||
// optimization.
|
||||
void Optimize(AutotuneAlgorithm algorithm, int64 cpu_budget, int64 ram_budget,
|
||||
double model_input_time) TF_LOCKS_EXCLUDED(mu_);
|
||||
double model_input_time);
|
||||
|
||||
// Removes the given node.
|
||||
void RemoveNode(std::shared_ptr<Node> node) TF_LOCKS_EXCLUDED(mu_);
|
||||
|
||||
private:
|
||||
static constexpr int64 kOptimizationPeriodMinMs = 10;
|
||||
static constexpr int64 kOptimizationPeriodMaxMs =
|
||||
60 * EnvTime::kSecondsToMillis;
|
||||
|
||||
// Collects tunable parameters in the tree rooted in the given node, returning
|
||||
// a mapping from a (unique) node name to a tunable parameter.
|
||||
absl::flat_hash_map<string, std::shared_ptr<Parameter>>
|
||||
CollectTunableParameters(std::shared_ptr<Node> node);
|
||||
|
||||
// Determines if we should stop the gradient descent optimization iterations
|
||||
// based on number of increasable parameters, CPU budget, RAM budget and
|
||||
// current resource usage.
|
||||
bool ShouldStop(
|
||||
int64 cpu_budget, int64 ram_budget,
|
||||
const absl::flat_hash_map<string, std::shared_ptr<Parameter>>& parameters,
|
||||
const absl::flat_hash_map<string, std::shared_ptr<Parameter>>&
|
||||
parallelism_parameters,
|
||||
const absl::flat_hash_map<string, std::shared_ptr<Parameter>>&
|
||||
buffer_size_parameters,
|
||||
std::shared_ptr<Node> snapshot, bool* cpu_budget_reached);
|
||||
// Flushes metrics recorded by the model.
|
||||
void FlushMetrics() TF_LOCKS_EXCLUDED(mu_);
|
||||
|
||||
// This optimization algorithm starts by setting all tunable parallelism
|
||||
// parameters to the minimum value. It then repeatedly identifies the
|
||||
@ -689,6 +694,18 @@ class Model {
|
||||
double OutputTime(std::shared_ptr<Node> node, double model_input_time,
|
||||
absl::flat_hash_map<string, double>* gradients);
|
||||
|
||||
// Determines if we should stop the gradient descent optimization iterations
|
||||
// based on number of increasable parameters, CPU budget, RAM budget and
|
||||
// current resource usage.
|
||||
bool ShouldStop(
|
||||
int64 cpu_budget, int64 ram_budget,
|
||||
const absl::flat_hash_map<string, std::shared_ptr<Parameter>>& parameters,
|
||||
const absl::flat_hash_map<string, std::shared_ptr<Parameter>>&
|
||||
parallelism_parameters,
|
||||
const absl::flat_hash_map<string, std::shared_ptr<Parameter>>&
|
||||
buffer_size_parameters,
|
||||
std::shared_ptr<Node> snapshot, bool* cpu_budget_reached);
|
||||
|
||||
// Collects the processing time for the given node.
|
||||
double TotalProcessingTime(std::shared_ptr<Node> node);
|
||||
|
||||
@ -706,6 +723,8 @@ class Model {
|
||||
// access is required only when adding or removing nodes. Concurrent access to
|
||||
// existing nodes is protected by a node mutex.
|
||||
mutex mu_;
|
||||
// Used for coordinating the optimization loop and model modifications.
|
||||
condition_variable cond_var_;
|
||||
int64 id_counter_ TF_GUARDED_BY(mu_) = 1;
|
||||
std::shared_ptr<Node> output_ TF_GUARDED_BY(mu_);
|
||||
|
||||
@ -716,6 +735,10 @@ class Model {
|
||||
// tunable parameter (because the information is used for tuning the value of
|
||||
// the parameter) and never stops.
|
||||
std::atomic<bool> collect_resource_usage_;
|
||||
|
||||
// Determines the time the optimization loop should wait between
|
||||
// running optimizations.
|
||||
int64 optimization_period_ms_ TF_GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
} // namespace model
|
||||
|
@ -103,26 +103,6 @@ std::pair<int64, int64> MaybeOverrideSeeds(std::pair<int64, int64> seeds) {
|
||||
return seeds;
|
||||
}
|
||||
|
||||
Status RegisterCancellationCallback(CancellationManager* cancellation_manager,
|
||||
std::function<void()> register_fn,
|
||||
std::function<void()>* deregister_fn) {
|
||||
if (cancellation_manager) {
|
||||
CancellationToken token = cancellation_manager->get_cancellation_token();
|
||||
if (!cancellation_manager->RegisterCallback(token,
|
||||
std::move(register_fn))) {
|
||||
return errors::Cancelled("Operation was cancelled");
|
||||
}
|
||||
*deregister_fn = [cancellation_manager, token]() {
|
||||
cancellation_manager->DeregisterCallback(token);
|
||||
};
|
||||
} else {
|
||||
VLOG(1) << "Cancellation manager is not set. Cancellation callback will "
|
||||
"not be registered.";
|
||||
*deregister_fn = []() {};
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status VerifyTypeMatch(const DataType& expected, const DataType& received,
|
||||
int index) {
|
||||
if (expected != received) {
|
||||
|
@ -83,12 +83,6 @@ class AnonymousResourceOp : public OpKernel {
|
||||
bool create_deleter_ = true;
|
||||
};
|
||||
|
||||
// Registers the given cancellation callback, returning a function that can be
|
||||
// used to deregister the callback.
|
||||
Status RegisterCancellationCallback(CancellationManager* cancellation_manager,
|
||||
std::function<void()> register_fn,
|
||||
std::function<void()>* deregister_fn);
|
||||
|
||||
// Returns Status::OK() if `expected` and `received` types match,
|
||||
// errors::InvalidArgument otherwise.
|
||||
Status VerifyTypesMatch(const DataTypeVector& expected,
|
||||
|
@ -14,6 +14,8 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/kernels/data/model_dataset_op.h"
|
||||
|
||||
#include "tensorflow/core/framework/cancellation.h"
|
||||
|
||||
// On mobile we do not provide model dataset op because not all of its
|
||||
// dependencies are available there. The op is replaced with a no-op.
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
@ -32,8 +34,6 @@ namespace tensorflow {
|
||||
namespace data {
|
||||
namespace {
|
||||
|
||||
constexpr int64 kOptimizationPeriodThresholdMs = 60 * EnvTime::kSecondsToMillis;
|
||||
|
||||
// Default share of available RAM that can be used by model's internal buffers.
|
||||
constexpr double kRamBudgetShare = 0.5;
|
||||
|
||||
@ -125,16 +125,11 @@ class ModelDatasetOp::Dataset : public DatasetBase {
|
||||
ram_budget_(dataset()->ram_budget_ == 0
|
||||
? kRamBudgetShare * port::AvailableRam()
|
||||
: dataset()->ram_budget_) {
|
||||
cancellation_manager_ = absl::make_unique<CancellationManager>();
|
||||
model_ = std::make_shared<model::Model>();
|
||||
}
|
||||
|
||||
~Iterator() override {
|
||||
// Signal the optimize thread to terminate it. We will then join that
|
||||
// thread when we delete `this->optimize_thread_`.
|
||||
mutex_lock l(mu_);
|
||||
cancelled_ = true;
|
||||
cond_var_.notify_all();
|
||||
}
|
||||
~Iterator() override { cancellation_manager_->StartCancel(); }
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
IteratorContext::Params params(ctx);
|
||||
@ -149,7 +144,7 @@ class ModelDatasetOp::Dataset : public DatasetBase {
|
||||
IteratorContext::Params params(ctx);
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(EnsureModelThreadStarted(ctx));
|
||||
TF_RETURN_IF_ERROR(EnsureOptimizationLoopThreadStarted(ctx));
|
||||
params.model = model_;
|
||||
int64 now_nanos = EnvTime::NowNanos();
|
||||
RecordInput(now_nanos);
|
||||
@ -188,56 +183,21 @@ class ModelDatasetOp::Dataset : public DatasetBase {
|
||||
}
|
||||
|
||||
private:
|
||||
Status EnsureModelThreadStarted(IteratorContext* ctx)
|
||||
Status EnsureOptimizationLoopThreadStarted(IteratorContext* ctx)
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
if (!model_thread_) {
|
||||
model_thread_ =
|
||||
ctx->StartThread("tf_data_model", [this]() { ModelThread(); });
|
||||
model_thread_ = ctx->StartThread("tf_data_model", [this]() {
|
||||
Status status =
|
||||
model_->OptimizeLoop(dataset()->algorithm_, cpu_budget_,
|
||||
ram_budget_, cancellation_manager_.get());
|
||||
if (!status.ok()) {
|
||||
LOG(WARNING) << "Optimization loop failed: " << status.ToString();
|
||||
}
|
||||
});
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void ModelThread() {
|
||||
int64 last_optimization_ms = 0;
|
||||
int64 optimization_period_ms = 10;
|
||||
int64 current_time_ms = EnvTime::NowMicros() / EnvTime::kMillisToMicros;
|
||||
while (true) {
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
while (!cancelled_ && last_optimization_ms + optimization_period_ms >
|
||||
current_time_ms) {
|
||||
auto wait_ms =
|
||||
last_optimization_ms + optimization_period_ms - current_time_ms;
|
||||
VLOG(2) << "Waiting for " << wait_ms << " ms.";
|
||||
cond_var_.wait_for(l, std::chrono::milliseconds(wait_ms));
|
||||
current_time_ms = EnvTime::NowMicros() / EnvTime::kMillisToMicros;
|
||||
}
|
||||
if (cancelled_) return;
|
||||
}
|
||||
double model_input_time;
|
||||
{
|
||||
tf_shared_lock l(mu_);
|
||||
model_input_time = SelfInputTime();
|
||||
}
|
||||
|
||||
int64 optimization_start_us = EnvTime::NowMicros();
|
||||
model_->Optimize(dataset()->algorithm_, cpu_budget_, ram_budget_,
|
||||
/*model_input_time=*/0);
|
||||
VLOG(2) << "Optimized for "
|
||||
<< (EnvTime::NowMicros() - optimization_start_us) << " us.";
|
||||
|
||||
// Exponentially increase the period of running the optimization
|
||||
// until a threshold is reached.
|
||||
if (optimization_period_ms != kOptimizationPeriodThresholdMs) {
|
||||
optimization_period_ms = std::min(optimization_period_ms << 1,
|
||||
kOptimizationPeriodThresholdMs);
|
||||
}
|
||||
current_time_ms = EnvTime::NowMicros() / EnvTime::kMillisToMicros;
|
||||
last_optimization_ms = current_time_ms;
|
||||
model_->FlushMetrics();
|
||||
}
|
||||
}
|
||||
|
||||
void RecordInput(int64 time_nanos) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
if (last_output_time_ != 0) {
|
||||
DCHECK_LE(last_output_time_, time_nanos);
|
||||
@ -259,10 +219,11 @@ class ModelDatasetOp::Dataset : public DatasetBase {
|
||||
}
|
||||
|
||||
mutex mu_;
|
||||
condition_variable cond_var_;
|
||||
std::shared_ptr<model::Model> model_;
|
||||
// Controls cancellation of `model_thread_`. Must be ordered before
|
||||
// `model_thread_` so that `model_thread_` is destroyed first.
|
||||
std::unique_ptr<CancellationManager> cancellation_manager_;
|
||||
std::unique_ptr<Thread> model_thread_ TF_GUARDED_BY(mu_);
|
||||
bool cancelled_ TF_GUARDED_BY(mu_) = false;
|
||||
std::unique_ptr<IteratorBase> input_impl_;
|
||||
int64 num_input_events_ TF_GUARDED_BY(mu_) = 0;
|
||||
int64 input_time_ TF_GUARDED_BY(mu_) = 0;
|
||||
|
Loading…
Reference in New Issue
Block a user