[tf.data] Adjusting auto-tuning period to 1 minute (from previous incorrect value of 1000 minutes) and improving auto-tuning logging.
PiperOrigin-RevId: 241826050
This commit is contained in:
parent
72222a96ac
commit
bf3bd1c026
@ -382,12 +382,11 @@ std::shared_ptr<Node> Model::AddNode(Node::Factory factory, const string& name,
|
|||||||
output_ = node;
|
output_ = node;
|
||||||
}
|
}
|
||||||
if (output) {
|
if (output) {
|
||||||
VLOG(3) << "Adding " << node->name() << "(id:" << node->id()
|
VLOG(3) << "Adding " << node->long_name() << " as input for "
|
||||||
<< ") as input for " << output->name() << "(id:" << output->id()
|
<< output->long_name();
|
||||||
<< ")";
|
|
||||||
output->add_input(node);
|
output->add_input(node);
|
||||||
} else {
|
} else {
|
||||||
VLOG(3) << "Adding " << node->name() << "(id:" << node->id() << ")";
|
VLOG(3) << "Adding " << node->long_name();
|
||||||
}
|
}
|
||||||
collect_resource_usage_ =
|
collect_resource_usage_ =
|
||||||
collect_resource_usage_ || node->has_tunable_parameters();
|
collect_resource_usage_ || node->has_tunable_parameters();
|
||||||
@ -415,16 +414,17 @@ void Model::Optimize(int64 cpu_budget) {
|
|||||||
tf_shared_lock lock(mu_);
|
tf_shared_lock lock(mu_);
|
||||||
snapshot = output_->Snapshot(nullptr);
|
snapshot = output_->Snapshot(nullptr);
|
||||||
}
|
}
|
||||||
|
VLOG(2) << "Starting optimization of tunable parameters";
|
||||||
const int64 processing_time = ProcessingTime(snapshot);
|
const int64 processing_time = ProcessingTime(snapshot);
|
||||||
auto parameters = CollectTunableParameters(snapshot);
|
auto parameters = CollectTunableParameters(snapshot);
|
||||||
for (auto& parameter : parameters) {
|
for (auto& pair : parameters) {
|
||||||
parameter->value = 1;
|
pair.second->value = 1;
|
||||||
}
|
}
|
||||||
while (true) {
|
while (true) {
|
||||||
const int64 output_time = OutputTime(snapshot);
|
const int64 output_time = OutputTime(snapshot);
|
||||||
bool all_max = true;
|
bool all_max = true;
|
||||||
for (auto& parameter : parameters) {
|
for (auto& pair : parameters) {
|
||||||
if (parameter->value < parameter->max) {
|
if (pair.second->value < pair.second->max) {
|
||||||
all_max = false;
|
all_max = false;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -434,17 +434,17 @@ void Model::Optimize(int64 cpu_budget) {
|
|||||||
}
|
}
|
||||||
int64 best_delta = -1;
|
int64 best_delta = -1;
|
||||||
Parameter* best_parameter = nullptr;
|
Parameter* best_parameter = nullptr;
|
||||||
for (auto& parameter : parameters) {
|
for (auto& pair : parameters) {
|
||||||
if (parameter->value == parameter->max) {
|
if (pair.second->value == pair.second->max) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
parameter->value++;
|
pair.second->value++;
|
||||||
int64 delta = output_time - OutputTime(snapshot);
|
int64 delta = output_time - OutputTime(snapshot);
|
||||||
if (delta > best_delta) {
|
if (delta > best_delta) {
|
||||||
best_delta = delta;
|
best_delta = delta;
|
||||||
best_parameter = parameter.get();
|
best_parameter = pair.second.get();
|
||||||
}
|
}
|
||||||
parameter->value--;
|
pair.second->value--;
|
||||||
}
|
}
|
||||||
if (!best_parameter) {
|
if (!best_parameter) {
|
||||||
// This should never happen because we are using a model snapshot and
|
// This should never happen because we are using a model snapshot and
|
||||||
@ -457,8 +457,10 @@ void Model::Optimize(int64 cpu_budget) {
|
|||||||
best_parameter->value++;
|
best_parameter->value++;
|
||||||
}
|
}
|
||||||
VLOG(2) << "Number of tunable parameters: " << parameters.size();
|
VLOG(2) << "Number of tunable parameters: " << parameters.size();
|
||||||
for (auto& parameter : parameters) {
|
for (auto& pair : parameters) {
|
||||||
VLOG(2) << "Setting tunable parameter: " << parameter->value;
|
auto& parameter = pair.second;
|
||||||
|
VLOG(2) << "Setting tunable parameter " << pair.first << " to "
|
||||||
|
<< parameter->value;
|
||||||
mutex_lock l(*parameter->state->mu);
|
mutex_lock l(*parameter->state->mu);
|
||||||
parameter->state->value = parameter->value;
|
parameter->state->value = parameter->value;
|
||||||
parameter->state->cond_var->notify_all();
|
parameter->state->cond_var->notify_all();
|
||||||
@ -513,15 +515,15 @@ void Model::RemoveNode(const string& name) {
|
|||||||
if ((*node)->output()) {
|
if ((*node)->output()) {
|
||||||
(*node)->output()->remove_input(*node);
|
(*node)->output()->remove_input(*node);
|
||||||
}
|
}
|
||||||
VLOG(3) << "Removing " << (*node)->name() << "(id:" << (*node)->id() << ")";
|
VLOG(3) << "Removing " << (*node)->long_name();
|
||||||
remove_node_hook_(*node);
|
remove_node_hook_(*node);
|
||||||
}
|
}
|
||||||
lookup_table_.erase(name);
|
lookup_table_.erase(name);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::shared_ptr<Parameter>> Model::CollectTunableParameters(
|
std::map<string, std::shared_ptr<Parameter>> Model::CollectTunableParameters(
|
||||||
std::shared_ptr<Node> node) {
|
std::shared_ptr<Node> node) {
|
||||||
std::vector<std::shared_ptr<Parameter>> parameters;
|
std::map<string, std::shared_ptr<Parameter>> parameters;
|
||||||
node->CollectTunableParameters(¶meters);
|
node->CollectTunableParameters(¶meters);
|
||||||
return parameters;
|
return parameters;
|
||||||
}
|
}
|
||||||
|
@ -160,6 +160,9 @@ class Node {
|
|||||||
return inputs_;
|
return inputs_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns a longer node name that is guaranteed to be unique.
|
||||||
|
string long_name() const { return strings::StrCat(name_, "(id:", id_, ")"); }
|
||||||
|
|
||||||
// Returns the node name.
|
// Returns the node name.
|
||||||
const string& name() const { return name_; }
|
const string& name() const { return name_; }
|
||||||
|
|
||||||
@ -212,12 +215,12 @@ class Node {
|
|||||||
|
|
||||||
// Collects tunable parameters in the subtree rooted in this node.
|
// Collects tunable parameters in the subtree rooted in this node.
|
||||||
void CollectTunableParameters(
|
void CollectTunableParameters(
|
||||||
std::vector<std::shared_ptr<Parameter>>* parameters) const
|
std::map<string, std::shared_ptr<Parameter>>* parameters) const
|
||||||
LOCKS_EXCLUDED(mu_) {
|
LOCKS_EXCLUDED(mu_) {
|
||||||
tf_shared_lock l(mu_);
|
tf_shared_lock l(mu_);
|
||||||
for (auto& pair : parameters_) {
|
for (auto& pair : parameters_) {
|
||||||
if (pair.second->state->tunable) {
|
if (pair.second->state->tunable) {
|
||||||
parameters->push_back(pair.second);
|
parameters->insert(std::make_pair(long_name(), pair.second));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (auto& input : inputs_) {
|
for (auto& input : inputs_) {
|
||||||
@ -407,8 +410,9 @@ class Model {
|
|||||||
void RemoveNode(const string& name) LOCKS_EXCLUDED(mu_);
|
void RemoveNode(const string& name) LOCKS_EXCLUDED(mu_);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Collects tunable parameters in the tree rooted in the given node.
|
// Collects tunable parameters in the tree rooted in the given node, returning
|
||||||
std::vector<std::shared_ptr<Parameter>> CollectTunableParameters(
|
// a mapping from a (unique) node name to a tunable parameter.
|
||||||
|
std::map<string, std::shared_ptr<Parameter>> CollectTunableParameters(
|
||||||
std::shared_ptr<Node> node);
|
std::shared_ptr<Node> node);
|
||||||
|
|
||||||
// Collects the output time for the given node.
|
// Collects the output time for the given node.
|
||||||
|
@ -26,7 +26,7 @@ namespace tensorflow {
|
|||||||
namespace data {
|
namespace data {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
constexpr int kOptimizationPeriodThresholdMs = 60 * EnvTime::kSecondsToMicros;
|
constexpr int64 kOptimizationPeriodThresholdMs = 60 * EnvTime::kSecondsToMillis;
|
||||||
|
|
||||||
class ModelDatasetOp : public UnaryDatasetOpKernel {
|
class ModelDatasetOp : public UnaryDatasetOpKernel {
|
||||||
public:
|
public:
|
||||||
@ -159,31 +159,32 @@ class ModelDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
void OptimizeThread(const std::shared_ptr<IteratorContext>& ctx) {
|
void OptimizeThread(const std::shared_ptr<IteratorContext>& ctx) {
|
||||||
int64 last_optimization_ms = 0;
|
int64 last_optimization_ms = 0;
|
||||||
int64 optimization_period_ms = 10;
|
int64 optimization_period_ms = 10;
|
||||||
|
int64 current_time_ms =
|
||||||
|
ctx->env()->NowMicros() / EnvTime::kMillisToMicros;
|
||||||
while (true) {
|
while (true) {
|
||||||
{
|
{
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
while (!cancelled_ &&
|
while (!cancelled_ &&
|
||||||
last_optimization_ms + optimization_period_ms >=
|
last_optimization_ms + optimization_period_ms >
|
||||||
ctx->env()->NowMicros() / EnvTime::kMillisToMicros) {
|
current_time_ms) {
|
||||||
cond_var_.wait_for(
|
auto wait_ms = last_optimization_ms + optimization_period_ms -
|
||||||
l, std::chrono::milliseconds(
|
current_time_ms;
|
||||||
last_optimization_ms + optimization_period_ms -
|
VLOG(2) << "Waiting for " << wait_ms << " ms.";
|
||||||
ctx->env()->NowMicros() / EnvTime::kMillisToMicros));
|
cond_var_.wait_for(l, std::chrono::milliseconds(wait_ms));
|
||||||
|
current_time_ms =
|
||||||
|
ctx->env()->NowMicros() / EnvTime::kMillisToMicros;
|
||||||
}
|
}
|
||||||
if (cancelled_) return;
|
if (cancelled_) return;
|
||||||
}
|
}
|
||||||
model_->Optimize(dataset()->cpu_budget_);
|
model_->Optimize(dataset()->cpu_budget_);
|
||||||
// Exponentially increase the period of running the optimization
|
// Exponentially increase the period of running the optimization
|
||||||
// until a threshold is reached.
|
// until a threshold is reached.
|
||||||
if (optimization_period_ms < kOptimizationPeriodThresholdMs) {
|
if (optimization_period_ms != kOptimizationPeriodThresholdMs) {
|
||||||
if (optimization_period_ms << 1 < kOptimizationPeriodThresholdMs) {
|
optimization_period_ms = std::min(optimization_period_ms << 1,
|
||||||
optimization_period_ms <<= 1;
|
kOptimizationPeriodThresholdMs);
|
||||||
} else {
|
|
||||||
optimization_period_ms = kOptimizationPeriodThresholdMs;
|
|
||||||
}
|
}
|
||||||
}
|
current_time_ms = ctx->env()->NowMicros() / EnvTime::kMillisToMicros;
|
||||||
last_optimization_ms =
|
last_optimization_ms = current_time_ms;
|
||||||
ctx->env()->NowMicros() / EnvTime::kMillisToMicros;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user