diff --git a/tensorflow/core/common_runtime/step_stats_collector.cc b/tensorflow/core/common_runtime/step_stats_collector.cc index 51abe1bdf36..318740c3e7e 100644 --- a/tensorflow/core/common_runtime/step_stats_collector.cc +++ b/tensorflow/core/common_runtime/step_stats_collector.cc @@ -183,7 +183,8 @@ void StepStatsCollector::Save(const string& device, NodeExecStats* nt) { VLOG(1) << "Save dev " << device << " nt " << nt; { mutex_lock l(mu_); - if (!step_stats_) { + if (!step_stats_ || collectedNodes >= kMaxCollectedNodes) { + VLOG(1) << "step_stats_ nullptr or already collected too many nodes."; delete nt; return; } @@ -202,6 +203,7 @@ void StepStatsCollector::Save(const string& device, NodeExecStats* nt) { dss->set_device(device); } nt->Swap(dss->add_node_stats()); + collectedNodes++; } delete nt; } @@ -210,6 +212,7 @@ void StepStatsCollector::Swap(StepStats* ss) { mutex_lock l(mu_); CHECK(step_stats_); ss->Swap(step_stats_); + collectedNodes = 0; } } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/step_stats_collector.h b/tensorflow/core/common_runtime/step_stats_collector.h index 8b71b6a0e33..37b1c4b3086 100644 --- a/tensorflow/core/common_runtime/step_stats_collector.h +++ b/tensorflow/core/common_runtime/step_stats_collector.h @@ -27,21 +27,33 @@ class Graph; class NodeExecStats; class StepStats; +// StepStatsCollector manages the collection of a StepStats object. +// The StepStats object holds multiple DeviceStats. +// Each DeviceStats object holds multiple NodeExecStats. class StepStatsCollector { public: explicit StepStatsCollector(StepStats* ss); + // BuildCostModel builds or updates a CostModel managed by cost_model_manager, + // using the currently collected DeviceStats associated with the devices in + // device_map. void BuildCostModel( CostModelManager* cost_model_manager, const std::unordered_map& device_map); + // Save saves nt to the DeviceStats object associated with device. void Save(const string& device, NodeExecStats* nt); + // Swap replaces the current step stats with ss. void Swap(StepStats* ss); private: + // TODO(suharshs): Make this configurable if its not possible to find a value + // that works for all cases. + const uint64 kMaxCollectedNodes = 1 << 20; mutex mu_; StepStats* step_stats_ GUARDED_BY(mu_); + uint64 collectedNodes GUARDED_BY(mu_) = 0; }; } // namespace tensorflow