Set a max number of collected NodeExecStats in StepStatsCollector.
This fixes an issue where nodes in a while loop were accumulating a very large amount of stats, exceeding the protocol buffer limit. Change: 139277518
This commit is contained in:
parent
3d6d512fdb
commit
bad3fe0aca
tensorflow/core/common_runtime
@ -183,7 +183,8 @@ void StepStatsCollector::Save(const string& device, NodeExecStats* nt) {
|
|||||||
VLOG(1) << "Save dev " << device << " nt " << nt;
|
VLOG(1) << "Save dev " << device << " nt " << nt;
|
||||||
{
|
{
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
if (!step_stats_) {
|
if (!step_stats_ || collectedNodes >= kMaxCollectedNodes) {
|
||||||
|
VLOG(1) << "step_stats_ nullptr or already collected too many nodes.";
|
||||||
delete nt;
|
delete nt;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -202,6 +203,7 @@ void StepStatsCollector::Save(const string& device, NodeExecStats* nt) {
|
|||||||
dss->set_device(device);
|
dss->set_device(device);
|
||||||
}
|
}
|
||||||
nt->Swap(dss->add_node_stats());
|
nt->Swap(dss->add_node_stats());
|
||||||
|
collectedNodes++;
|
||||||
}
|
}
|
||||||
delete nt;
|
delete nt;
|
||||||
}
|
}
|
||||||
@ -210,6 +212,7 @@ void StepStatsCollector::Swap(StepStats* ss) {
|
|||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
CHECK(step_stats_);
|
CHECK(step_stats_);
|
||||||
ss->Swap(step_stats_);
|
ss->Swap(step_stats_);
|
||||||
|
collectedNodes = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -27,21 +27,33 @@ class Graph;
|
|||||||
class NodeExecStats;
|
class NodeExecStats;
|
||||||
class StepStats;
|
class StepStats;
|
||||||
|
|
||||||
|
// StepStatsCollector manages the collection of a StepStats object.
|
||||||
|
// The StepStats object holds multiple DeviceStats.
|
||||||
|
// Each DeviceStats object holds multiple NodeExecStats.
|
||||||
class StepStatsCollector {
|
class StepStatsCollector {
|
||||||
public:
|
public:
|
||||||
explicit StepStatsCollector(StepStats* ss);
|
explicit StepStatsCollector(StepStats* ss);
|
||||||
|
|
||||||
|
// BuildCostModel builds or updates a CostModel managed by cost_model_manager,
|
||||||
|
// using the currently collected DeviceStats associated with the devices in
|
||||||
|
// device_map.
|
||||||
void BuildCostModel(
|
void BuildCostModel(
|
||||||
CostModelManager* cost_model_manager,
|
CostModelManager* cost_model_manager,
|
||||||
const std::unordered_map<string, const Graph*>& device_map);
|
const std::unordered_map<string, const Graph*>& device_map);
|
||||||
|
|
||||||
|
// Save saves nt to the DeviceStats object associated with device.
|
||||||
void Save(const string& device, NodeExecStats* nt);
|
void Save(const string& device, NodeExecStats* nt);
|
||||||
|
|
||||||
|
// Swap replaces the current step stats with ss.
|
||||||
void Swap(StepStats* ss);
|
void Swap(StepStats* ss);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
// TODO(suharshs): Make this configurable if its not possible to find a value
|
||||||
|
// that works for all cases.
|
||||||
|
const uint64 kMaxCollectedNodes = 1 << 20;
|
||||||
mutex mu_;
|
mutex mu_;
|
||||||
StepStats* step_stats_ GUARDED_BY(mu_);
|
StepStats* step_stats_ GUARDED_BY(mu_);
|
||||||
|
uint64 collectedNodes GUARDED_BY(mu_) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
Loading…
Reference in New Issue
Block a user