Set a max number of collected NodeExecStats in StepStatsCollector.

This fixes an issue where nodes in a while loop were accumulating a very
large amount of stats, exceeding the protocol buffer limit.
Change: 139277518
This commit is contained in:
Suharsh Sivakumar 2016-11-15 18:37:02 -08:00 committed by TensorFlower Gardener
parent 3d6d512fdb
commit bad3fe0aca
2 changed files with 16 additions and 1 deletions

View File

@ -183,7 +183,8 @@ void StepStatsCollector::Save(const string& device, NodeExecStats* nt) {
VLOG(1) << "Save dev " << device << " nt " << nt;
{
mutex_lock l(mu_);
if (!step_stats_) {
if (!step_stats_ || collectedNodes >= kMaxCollectedNodes) {
VLOG(1) << "step_stats_ nullptr or already collected too many nodes.";
delete nt;
return;
}
@ -202,6 +203,7 @@ void StepStatsCollector::Save(const string& device, NodeExecStats* nt) {
dss->set_device(device);
}
nt->Swap(dss->add_node_stats());
collectedNodes++;
}
delete nt;
}
@ -210,6 +212,7 @@ void StepStatsCollector::Swap(StepStats* ss) {
mutex_lock l(mu_);
CHECK(step_stats_);
ss->Swap(step_stats_);
collectedNodes = 0;
}
} // namespace tensorflow

View File

@ -27,21 +27,33 @@ class Graph;
class NodeExecStats;
class StepStats;
// StepStatsCollector manages the collection of a StepStats object.
// The StepStats object holds multiple DeviceStats.
// Each DeviceStats object holds multiple NodeExecStats.
class StepStatsCollector {
public:
explicit StepStatsCollector(StepStats* ss);
// BuildCostModel builds or updates a CostModel managed by cost_model_manager,
// using the currently collected DeviceStats associated with the devices in
// device_map.
void BuildCostModel(
CostModelManager* cost_model_manager,
const std::unordered_map<string, const Graph*>& device_map);
// Save saves nt to the DeviceStats object associated with device.
void Save(const string& device, NodeExecStats* nt);
// Swap replaces the current step stats with ss.
void Swap(StepStats* ss);
private:
// TODO(suharshs): Make this configurable if its not possible to find a value
// that works for all cases.
const uint64 kMaxCollectedNodes = 1 << 20;
mutex mu_;
StepStats* step_stats_ GUARDED_BY(mu_);
uint64 collectedNodes GUARDED_BY(mu_) = 0;
};
} // namespace tensorflow