From bad3fe0aca8542084ab270d1ac5147b8e4a68d6f Mon Sep 17 00:00:00 2001
From: Suharsh Sivakumar <suharshs@google.com>
Date: Tue, 15 Nov 2016 18:37:02 -0800
Subject: [PATCH] Set a max number of collected NodeExecStats in
 StepStatsCollector.

This fixes an issue where nodes in a while loop were accumulating a very
large amount of stats, exceeding the protocol buffer limit.
Change: 139277518
---
 .../core/common_runtime/step_stats_collector.cc      |  5 ++++-
 .../core/common_runtime/step_stats_collector.h       | 12 ++++++++++++
 2 files changed, 16 insertions(+), 1 deletion(-)

diff --git a/tensorflow/core/common_runtime/step_stats_collector.cc b/tensorflow/core/common_runtime/step_stats_collector.cc
index 51abe1bdf36..318740c3e7e 100644
--- a/tensorflow/core/common_runtime/step_stats_collector.cc
+++ b/tensorflow/core/common_runtime/step_stats_collector.cc
@@ -183,7 +183,8 @@ void StepStatsCollector::Save(const string& device, NodeExecStats* nt) {
   VLOG(1) << "Save dev " << device << " nt " << nt;
   {
     mutex_lock l(mu_);
-    if (!step_stats_) {
+    if (!step_stats_ || collectedNodes >= kMaxCollectedNodes) {
+      VLOG(1) << "step_stats_ nullptr or already collected too many nodes.";
       delete nt;
       return;
     }
@@ -202,6 +203,7 @@ void StepStatsCollector::Save(const string& device, NodeExecStats* nt) {
       dss->set_device(device);
     }
     nt->Swap(dss->add_node_stats());
+    collectedNodes++;
   }
   delete nt;
 }
@@ -210,6 +212,7 @@ void StepStatsCollector::Swap(StepStats* ss) {
   mutex_lock l(mu_);
   CHECK(step_stats_);
   ss->Swap(step_stats_);
+  collectedNodes = 0;
 }
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/step_stats_collector.h b/tensorflow/core/common_runtime/step_stats_collector.h
index 8b71b6a0e33..37b1c4b3086 100644
--- a/tensorflow/core/common_runtime/step_stats_collector.h
+++ b/tensorflow/core/common_runtime/step_stats_collector.h
@@ -27,21 +27,33 @@ class Graph;
 class NodeExecStats;
 class StepStats;
 
+// StepStatsCollector manages the collection of a StepStats object.
+// The StepStats object holds multiple DeviceStats.
+// Each DeviceStats object holds multiple NodeExecStats.
 class StepStatsCollector {
  public:
   explicit StepStatsCollector(StepStats* ss);
 
+  // BuildCostModel builds or updates a CostModel managed by cost_model_manager,
+  // using the currently collected DeviceStats associated with the devices in
+  // device_map.
   void BuildCostModel(
       CostModelManager* cost_model_manager,
       const std::unordered_map<string, const Graph*>& device_map);
 
+  // Save saves nt to the DeviceStats object associated with device.
   void Save(const string& device, NodeExecStats* nt);
 
+  // Swap replaces the current step stats with ss.
   void Swap(StepStats* ss);
 
  private:
+  // TODO(suharshs): Make this configurable if its not possible to find a value
+  //                 that works for all cases.
+  const uint64 kMaxCollectedNodes = 1 << 20;
   mutex mu_;
   StepStats* step_stats_ GUARDED_BY(mu_);
+  uint64 collectedNodes GUARDED_BY(mu_) = 0;
 };
 
 }  // namespace tensorflow