Set a max number of collected NodeExecStats in StepStatsCollector.

This fixes an issue where nodes in a while loop were accumulating a very
large amount of stats, exceeding the protocol buffer limit.
Change: 139277518
This commit is contained in:
Suharsh Sivakumar 2016-11-15 18:37:02 -08:00 committed by TensorFlower Gardener
parent 3d6d512fdb
commit bad3fe0aca
2 changed files with 16 additions and 1 deletions
tensorflow/core/common_runtime

View File

@ -183,7 +183,8 @@ void StepStatsCollector::Save(const string& device, NodeExecStats* nt) {
VLOG(1) << "Save dev " << device << " nt " << nt; VLOG(1) << "Save dev " << device << " nt " << nt;
{ {
mutex_lock l(mu_); mutex_lock l(mu_);
if (!step_stats_) { if (!step_stats_ || collectedNodes >= kMaxCollectedNodes) {
VLOG(1) << "step_stats_ nullptr or already collected too many nodes.";
delete nt; delete nt;
return; return;
} }
@ -202,6 +203,7 @@ void StepStatsCollector::Save(const string& device, NodeExecStats* nt) {
dss->set_device(device); dss->set_device(device);
} }
nt->Swap(dss->add_node_stats()); nt->Swap(dss->add_node_stats());
collectedNodes++;
} }
delete nt; delete nt;
} }
@ -210,6 +212,7 @@ void StepStatsCollector::Swap(StepStats* ss) {
mutex_lock l(mu_); mutex_lock l(mu_);
CHECK(step_stats_); CHECK(step_stats_);
ss->Swap(step_stats_); ss->Swap(step_stats_);
collectedNodes = 0;
} }
} // namespace tensorflow } // namespace tensorflow

View File

@ -27,21 +27,33 @@ class Graph;
class NodeExecStats; class NodeExecStats;
class StepStats; class StepStats;
// StepStatsCollector manages the collection of a StepStats object.
// The StepStats object holds multiple DeviceStats.
// Each DeviceStats object holds multiple NodeExecStats.
class StepStatsCollector { class StepStatsCollector {
public: public:
explicit StepStatsCollector(StepStats* ss); explicit StepStatsCollector(StepStats* ss);
// BuildCostModel builds or updates a CostModel managed by cost_model_manager,
// using the currently collected DeviceStats associated with the devices in
// device_map.
void BuildCostModel( void BuildCostModel(
CostModelManager* cost_model_manager, CostModelManager* cost_model_manager,
const std::unordered_map<string, const Graph*>& device_map); const std::unordered_map<string, const Graph*>& device_map);
// Save saves nt to the DeviceStats object associated with device.
void Save(const string& device, NodeExecStats* nt); void Save(const string& device, NodeExecStats* nt);
// Swap replaces the current step stats with ss.
void Swap(StepStats* ss); void Swap(StepStats* ss);
private: private:
// TODO(suharshs): Make this configurable if its not possible to find a value
// that works for all cases.
const uint64 kMaxCollectedNodes = 1 << 20;
mutex mu_; mutex mu_;
StepStats* step_stats_ GUARDED_BY(mu_); StepStats* step_stats_ GUARDED_BY(mu_);
uint64 collectedNodes GUARDED_BY(mu_) = 0;
}; };
} // namespace tensorflow } // namespace tensorflow