From 20d3933a57c44d73bb85b51fc6e6a3c58201fd7f Mon Sep 17 00:00:00 2001
From: Xinan Jiang <xinan.jxn@gmail.com>
Date: Thu, 12 Nov 2020 15:29:13 +0800
Subject: [PATCH 1/4] [CostModel] Generating CostGraphDef with Global Index

---
 tensorflow/core/graph/costmodel.cc | 7 ++++---
 tensorflow/core/graph/costmodel.h  | 8 ++++++++
 2 files changed, 12 insertions(+), 3 deletions(-)

diff --git a/tensorflow/core/graph/costmodel.cc b/tensorflow/core/graph/costmodel.cc
index 1df45d9b893..57a9c1a6dd5 100644
--- a/tensorflow/core/graph/costmodel.cc
+++ b/tensorflow/core/graph/costmodel.cc
@@ -479,11 +479,12 @@ void CostModel::AddToCostGraphDef(const Graph* graph,
                                   CostGraphDef* cost_graph) const {
   std::vector<const Edge*> inputs;
   std::vector<const Edge*> control_inputs;
+  int offset = cost_graph->node_size();
   for (const Node* n : graph->nodes()) {
     CostGraphDef::Node* cnode = cost_graph->add_node();
     cnode->set_name(n->name());
     cnode->set_device(n->assigned_device_name());
-    cnode->set_id(Id(n));
+    cnode->set_id(GlobalId(n, offset));
 
     inputs.clear();
     inputs.resize(n->num_inputs(), nullptr);
@@ -502,7 +503,7 @@ void CostModel::AddToCostGraphDef(const Graph* graph,
 
     for (const Edge* e : inputs) {
       CostGraphDef::Node::InputInfo* input_info = cnode->add_input_info();
-      input_info->set_preceding_node(Id(e->src()));
+      input_info->set_preceding_node(GlobalId(e->src(), offset));
       input_info->set_preceding_port(e->src_output());
     }
 
@@ -528,7 +529,7 @@ void CostModel::AddToCostGraphDef(const Graph* graph,
     }
 
     for (const Edge* e : control_inputs) {
-      cnode->add_control_input(Id(e->src()));
+      cnode->add_control_input(GlobalId(e->src(), offset));
     }
 
     cnode->set_temporary_memory_size(TempMemorySize(n).value());
diff --git a/tensorflow/core/graph/costmodel.h b/tensorflow/core/graph/costmodel.h
index 2d94dd5cdc8..31568d7c889 100644
--- a/tensorflow/core/graph/costmodel.h
+++ b/tensorflow/core/graph/costmodel.h
@@ -66,6 +66,14 @@ class CostModel {
     }
   }
 
+  inline int GlobalId(const Node* n, int offset) const {
+    if (is_global_) {
+      return n->cost_id();
+    } else {
+      return n->id() + offset;
+    }
+  }
+
   // Initializes cost model for 'g'.
   void InitFromGraph(const Graph& g);
 

From 24ad11d8354475ef90ac725a9aafd73ef406663f Mon Sep 17 00:00:00 2001
From: Xinan Jiang <xinan.jxn@gmail.com>
Date: Fri, 8 Jan 2021 17:55:39 +0800
Subject: [PATCH 2/4] [CostModel] Add UT for CostModel Global Index

---
 tensorflow/core/BUILD                   |  1 +
 tensorflow/core/graph/BUILD             |  1 +
 tensorflow/core/graph/costmodel.cc      |  1 +
 tensorflow/core/graph/costmodel_test.cc | 98 +++++++++++++++++++++++++
 tensorflow/core/util/dump_graph.cc      | 21 ++++++
 tensorflow/core/util/dump_graph.h       |  5 ++
 6 files changed, 127 insertions(+)
 create mode 100644 tensorflow/core/graph/costmodel_test.cc

diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 9e06a07dd3a..8f34f0bce60 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -1877,6 +1877,7 @@ tf_cc_tests(
         "//tensorflow/core/example:feature_util_test.cc",
         "//tensorflow/core/graph:algorithm_test.cc",
         "//tensorflow/core/graph:control_flow_test.cc",
+        "//tensorflow/core/graph:costmodel_test.cc",
         "//tensorflow/core/graph:edgeset_test.cc",
         "//tensorflow/core/graph:graph_def_builder_test.cc",
         "//tensorflow/core/graph:graph_partition_test.cc",
diff --git a/tensorflow/core/graph/BUILD b/tensorflow/core/graph/BUILD
index e3b5076b9f6..ea95e6cce71 100644
--- a/tensorflow/core/graph/BUILD
+++ b/tensorflow/core/graph/BUILD
@@ -201,6 +201,7 @@ exports_files(
     srcs = [
         "algorithm_test.cc",
         "control_flow_test.cc",
+        "costmodel_test.cc",
         "edgeset_test.cc",
         "graph_def_builder_test.cc",
         "graph_partition_test.cc",
diff --git a/tensorflow/core/graph/costmodel.cc b/tensorflow/core/graph/costmodel.cc
index 57a9c1a6dd5..d2c68ba0353 100644
--- a/tensorflow/core/graph/costmodel.cc
+++ b/tensorflow/core/graph/costmodel.cc
@@ -15,6 +15,7 @@ limitations under the License.
 
 #include "tensorflow/core/graph/costmodel.h"
 
+#include <algorithm>
 #include <vector>
 #include "tensorflow/core/framework/allocation_description.pb.h"
 #include "tensorflow/core/framework/cost_graph.pb.h"
diff --git a/tensorflow/core/graph/costmodel_test.cc b/tensorflow/core/graph/costmodel_test.cc
new file mode 100644
index 00000000000..b397b5953c7
--- /dev/null
+++ b/tensorflow/core/graph/costmodel_test.cc
@@ -0,0 +1,98 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/graph/costmodel.h"
+
+#include <memory>
+#include <string>
+#include "tensorflow/cc/framework/scope.h"
+#include "tensorflow/core/common_runtime/costmodel_manager.h"
+#include "tensorflow/core/common_runtime/graph_constructor.h"
+#include "tensorflow/core/common_runtime/step_stats_collector.h"
+#include "tensorflow/core/framework/cost_graph.pb.h"
+#include "tensorflow/core/framework/step_stats.pb.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/util/dump_graph.h"
+
+namespace tensorflow {
+namespace {
+
+static void InitGraph(const string& s, Graph* graph) {
+  GraphDef graph_def;
+
+  auto parser = protobuf::TextFormat::Parser();
+  CHECK(parser.MergeFromString(s, &graph_def)) << s;
+  GraphConstructorOptions opts;
+  TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, graph));
+}
+
+static void GenerateStepStats(
+    Graph* graph, StepStats* step_stats, const string& device_name) {
+  // Fill RunMetadata's step_stats and partition_graphs fields.
+  DeviceStepStats* device_stepstats = step_stats->add_dev_stats();
+  device_stepstats->set_device(device_name);
+  for (const auto& node_def : graph->nodes()) {
+    NodeExecStats* node_stats = device_stepstats->add_node_stats();
+    node_stats->set_node_name(node_def->name());
+  }
+}
+
+REGISTER_OP("Input").Output("o: float").SetIsStateful();
+
+TEST(CostModelTest, GlobalId) {
+  Scope scope = Scope::NewRootScope().ExitOnError();
+  std::unique_ptr<Graph> graph1 =
+      std::unique_ptr<Graph>(new Graph(OpRegistry::Global()));
+  std::unique_ptr<Graph> graph2 =
+      std::unique_ptr<Graph>(new Graph(OpRegistry::Global()));
+  InitGraph(
+      "node { name: 'A1' op: 'Input'}"
+      "node { name: 'B1' op: 'Input'}"
+      "node { name: 'C1' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+      " input: ['A1', 'B1'] }"
+      "node { name: 'D1' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+      " input: ['A1', 'B1'] }", graph1.get());
+  InitGraph(
+      "node { name: 'A2' op: 'Input'}"
+      "node { name: 'B2' op: 'Input'}"
+      "node { name: 'C2' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+      " input: ['A2', 'B2'] }"
+      "node { name: 'D2' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+      " input: ['A2', 'B2'] }", graph2.get());
+  StepStats step_stats;
+  GenerateStepStats(graph1.get(), &step_stats, "DummyDevice1");
+  GenerateStepStats(graph2.get(), &step_stats, "DummyDevice2");
+  StepStatsCollector collector(&step_stats);
+  std::unordered_map<string, const Graph*> device_map;
+  device_map["DummyDevice1"] = graph1.get();
+  device_map["DummyDevice2"] = graph2.get();
+  CostModelManager cost_model_manager;
+  collector.BuildCostModel(&cost_model_manager, device_map);
+  CostGraphDef cost_graph_def;
+  cost_model_manager.AddToCostGraphDef(graph1.get(), &cost_graph_def);
+  cost_model_manager.AddToCostGraphDef(graph2.get(), &cost_graph_def);
+  ASSERT_EQ(cost_graph_def.node_size(), 12);
+  absl::flat_hash_map<int32, const CostGraphDef::Node> ids;
+  for (auto node : cost_graph_def.node()) {
+    int32 index = node.id();
+    auto result = ids.insert({index, node});
+    EXPECT_TRUE(result.second);
+  }
+}
+
+}  // namespace
+}  // namespace tensorflow
diff --git a/tensorflow/core/util/dump_graph.cc b/tensorflow/core/util/dump_graph.cc
index 0e16f9d3fb3..ebc1b6db137 100644
--- a/tensorflow/core/util/dump_graph.cc
+++ b/tensorflow/core/util/dump_graph.cc
@@ -18,6 +18,8 @@ limitations under the License.
 
 #include "tensorflow/core/util/dump_graph.h"
 
+#include <memory>
+#include <unordered_map>
 #include "absl/strings/match.h"
 #include "absl/strings/str_cat.h"
 #include "tensorflow/core/lib/strings/proto_serialization.h"
@@ -209,6 +211,25 @@ string DumpGraphDefToFile(const string& name, GraphDef const& graph_def,
   return filepath;
 }
 
+string DumpCostGraphDefToFile(const string& name, CostGraphDef const& graph_def,
+                          const string& dirname) {
+  string filepath;
+  std::unique_ptr<WritableFile> file;
+  Status status = CreateWritableFile(Env::Default(), dirname, name, ".pbtxt",
+                                     &filepath, &file);
+  if (!status.ok()) {
+    return StrCat("(failed to create writable file: ", status.ToString(), ")");
+  }
+
+  status = WriteTextProtoToUniqueFile(graph_def, file.get());
+  if (!status.ok()) {
+    return StrCat("(failed to dump Graph to '", filepath,
+                  "': ", status.ToString(), ")");
+  }
+  LOG(INFO) << "Dumped Graph to " << filepath;
+  return filepath;
+}
+
 string DumpGraphToFile(const string& name, Graph const& graph,
                        const FunctionLibraryDefinition* flib_def,
                        const string& dirname) {
diff --git a/tensorflow/core/util/dump_graph.h b/tensorflow/core/util/dump_graph.h
index e4428bd0206..f59e2ef5753 100644
--- a/tensorflow/core/util/dump_graph.h
+++ b/tensorflow/core/util/dump_graph.h
@@ -19,6 +19,7 @@ limitations under the License.
 #ifndef TENSORFLOW_CORE_UTIL_DUMP_GRAPH_H_
 #define TENSORFLOW_CORE_UTIL_DUMP_GRAPH_H_
 
+#include <string>
 #include "tensorflow/core/framework/function.h"
 #include "tensorflow/core/framework/graph.pb.h"
 #include "tensorflow/core/graph/graph.h"
@@ -39,6 +40,10 @@ namespace tensorflow {
 string DumpGraphDefToFile(const string& name, GraphDef const& graph_def,
                           const string& dirname = "");
 
+// Similar to DumpGraphDefToFile, use CostGraphDef instead of GraphDef.
+string DumpCostGraphDefToFile(const string& name, CostGraphDef const& graph_def,
+                          const string& dirname = "");
+
 // Similar to DumpGraphDefToFile, but builds the GraphDef to dump from a 'graph'
 // and an optional function library 'flib_def'. Returns the file name chosen.
 string DumpGraphToFile(const string& name, Graph const& graph,

From f71df7a4349226ee358b22bb4438554697ebb3e8 Mon Sep 17 00:00:00 2001
From: Xinan Jiang <xinan.jxn@gmail.com>
Date: Tue, 26 Jan 2021 14:47:44 +0800
Subject: [PATCH 3/4] Add missing include file

---
 tensorflow/core/util/dump_graph.h | 1 +
 1 file changed, 1 insertion(+)

diff --git a/tensorflow/core/util/dump_graph.h b/tensorflow/core/util/dump_graph.h
index f59e2ef5753..c9d6a4f9aab 100644
--- a/tensorflow/core/util/dump_graph.h
+++ b/tensorflow/core/util/dump_graph.h
@@ -20,6 +20,7 @@ limitations under the License.
 #define TENSORFLOW_CORE_UTIL_DUMP_GRAPH_H_
 
 #include <string>
+#include "tensorflow/core/framework/cost_graph.pb.h"
 #include "tensorflow/core/framework/function.h"
 #include "tensorflow/core/framework/graph.pb.h"
 #include "tensorflow/core/graph/graph.h"

From 175e49a0df73d6256146152591bf599bd3a9734b Mon Sep 17 00:00:00 2001
From: Xinan Jiang <xinan.jxn@gmail.com>
Date: Tue, 2 Feb 2021 11:56:46 +0800
Subject: [PATCH 4/4] [CostModel] Add TF_ASSERT_OK

---
 tensorflow/core/graph/costmodel_test.cc | 7 +++++--
 1 file changed, 5 insertions(+), 2 deletions(-)

diff --git a/tensorflow/core/graph/costmodel_test.cc b/tensorflow/core/graph/costmodel_test.cc
index b397b5953c7..671234de491 100644
--- a/tensorflow/core/graph/costmodel_test.cc
+++ b/tensorflow/core/graph/costmodel_test.cc
@@ -24,6 +24,7 @@ limitations under the License.
 #include "tensorflow/core/framework/cost_graph.pb.h"
 #include "tensorflow/core/framework/step_stats.pb.h"
 #include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/protobuf.h"
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/core/util/dump_graph.h"
@@ -83,8 +84,10 @@ TEST(CostModelTest, GlobalId) {
   CostModelManager cost_model_manager;
   collector.BuildCostModel(&cost_model_manager, device_map);
   CostGraphDef cost_graph_def;
-  cost_model_manager.AddToCostGraphDef(graph1.get(), &cost_graph_def);
-  cost_model_manager.AddToCostGraphDef(graph2.get(), &cost_graph_def);
+  TF_ASSERT_OK(
+      cost_model_manager.AddToCostGraphDef(graph1.get(), &cost_graph_def));
+  TF_ASSERT_OK(
+      cost_model_manager.AddToCostGraphDef(graph2.get(), &cost_graph_def));
   ASSERT_EQ(cost_graph_def.node_size(), 12);
   absl::flat_hash_map<int32, const CostGraphDef::Node> ids;
   for (auto node : cost_graph_def.node()) {