diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index eed0e704013..744c0b75ae5 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1906,6 +1906,7 @@ tf_cc_tests( "//tensorflow/core/example:feature_util_test.cc", "//tensorflow/core/graph:algorithm_test.cc", "//tensorflow/core/graph:control_flow_test.cc", + "//tensorflow/core/graph:costmodel_test.cc", "//tensorflow/core/graph:edgeset_test.cc", "//tensorflow/core/graph:graph_def_builder_test.cc", "//tensorflow/core/graph:graph_partition_test.cc", diff --git a/tensorflow/core/graph/BUILD b/tensorflow/core/graph/BUILD index e3b5076b9f6..ea95e6cce71 100644 --- a/tensorflow/core/graph/BUILD +++ b/tensorflow/core/graph/BUILD @@ -201,6 +201,7 @@ exports_files( srcs = [ "algorithm_test.cc", "control_flow_test.cc", + "costmodel_test.cc", "edgeset_test.cc", "graph_def_builder_test.cc", "graph_partition_test.cc", diff --git a/tensorflow/core/graph/costmodel.cc b/tensorflow/core/graph/costmodel.cc index 1df45d9b893..c1b9d7358b9 100644 --- a/tensorflow/core/graph/costmodel.cc +++ b/tensorflow/core/graph/costmodel.cc @@ -15,7 +15,9 @@ limitations under the License. #include "tensorflow/core/graph/costmodel.h" +#include #include + #include "tensorflow/core/framework/allocation_description.pb.h" #include "tensorflow/core/framework/cost_graph.pb.h" #include "tensorflow/core/framework/step_stats.pb.h" @@ -479,11 +481,12 @@ void CostModel::AddToCostGraphDef(const Graph* graph, CostGraphDef* cost_graph) const { std::vector inputs; std::vector control_inputs; + int offset = cost_graph->node_size(); for (const Node* n : graph->nodes()) { CostGraphDef::Node* cnode = cost_graph->add_node(); cnode->set_name(n->name()); cnode->set_device(n->assigned_device_name()); - cnode->set_id(Id(n)); + cnode->set_id(GlobalId(n, offset)); inputs.clear(); inputs.resize(n->num_inputs(), nullptr); @@ -502,7 +505,7 @@ void CostModel::AddToCostGraphDef(const Graph* graph, for (const Edge* e : inputs) { CostGraphDef::Node::InputInfo* input_info = cnode->add_input_info(); - input_info->set_preceding_node(Id(e->src())); + input_info->set_preceding_node(GlobalId(e->src(), offset)); input_info->set_preceding_port(e->src_output()); } @@ -528,7 +531,7 @@ void CostModel::AddToCostGraphDef(const Graph* graph, } for (const Edge* e : control_inputs) { - cnode->add_control_input(Id(e->src())); + cnode->add_control_input(GlobalId(e->src(), offset)); } cnode->set_temporary_memory_size(TempMemorySize(n).value()); diff --git a/tensorflow/core/graph/costmodel.h b/tensorflow/core/graph/costmodel.h index 2d94dd5cdc8..31568d7c889 100644 --- a/tensorflow/core/graph/costmodel.h +++ b/tensorflow/core/graph/costmodel.h @@ -66,6 +66,14 @@ class CostModel { } } + inline int GlobalId(const Node* n, int offset) const { + if (is_global_) { + return n->cost_id(); + } else { + return n->id() + offset; + } + } + // Initializes cost model for 'g'. void InitFromGraph(const Graph& g); diff --git a/tensorflow/core/graph/costmodel_test.cc b/tensorflow/core/graph/costmodel_test.cc new file mode 100644 index 00000000000..5bdfb04a859 --- /dev/null +++ b/tensorflow/core/graph/costmodel_test.cc @@ -0,0 +1,104 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/graph/costmodel.h" + +#include +#include + +#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/core/common_runtime/costmodel_manager.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" +#include "tensorflow/core/common_runtime/step_stats_collector.h" +#include "tensorflow/core/framework/cost_graph.pb.h" +#include "tensorflow/core/framework/step_stats.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/util/dump_graph.h" + +namespace tensorflow { +namespace { + +static void InitGraph(const string& s, Graph* graph) { + GraphDef graph_def; + + auto parser = protobuf::TextFormat::Parser(); + CHECK(parser.MergeFromString(s, &graph_def)) << s; + GraphConstructorOptions opts; + TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, graph)); +} + +static void GenerateStepStats(Graph* graph, StepStats* step_stats, + const string& device_name) { + // Fill RunMetadata's step_stats and partition_graphs fields. + DeviceStepStats* device_stepstats = step_stats->add_dev_stats(); + device_stepstats->set_device(device_name); + for (const auto& node_def : graph->nodes()) { + NodeExecStats* node_stats = device_stepstats->add_node_stats(); + node_stats->set_node_name(node_def->name()); + } +} + +REGISTER_OP("Input").Output("o: float").SetIsStateful(); + +TEST(CostModelTest, GlobalId) { + Scope scope = Scope::NewRootScope().ExitOnError(); + std::unique_ptr graph1 = + std::unique_ptr(new Graph(OpRegistry::Global())); + std::unique_ptr graph2 = + std::unique_ptr(new Graph(OpRegistry::Global())); + InitGraph( + "node { name: 'A1' op: 'Input'}" + "node { name: 'B1' op: 'Input'}" + "node { name: 'C1' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A1', 'B1'] }" + "node { name: 'D1' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A1', 'B1'] }", + graph1.get()); + InitGraph( + "node { name: 'A2' op: 'Input'}" + "node { name: 'B2' op: 'Input'}" + "node { name: 'C2' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A2', 'B2'] }" + "node { name: 'D2' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A2', 'B2'] }", + graph2.get()); + StepStats step_stats; + GenerateStepStats(graph1.get(), &step_stats, "DummyDevice1"); + GenerateStepStats(graph2.get(), &step_stats, "DummyDevice2"); + StepStatsCollector collector(&step_stats); + std::unordered_map device_map; + device_map["DummyDevice1"] = graph1.get(); + device_map["DummyDevice2"] = graph2.get(); + CostModelManager cost_model_manager; + collector.BuildCostModel(&cost_model_manager, device_map); + CostGraphDef cost_graph_def; + TF_ASSERT_OK( + cost_model_manager.AddToCostGraphDef(graph1.get(), &cost_graph_def)); + TF_ASSERT_OK( + cost_model_manager.AddToCostGraphDef(graph2.get(), &cost_graph_def)); + ASSERT_EQ(cost_graph_def.node_size(), 12); + absl::flat_hash_map ids; + for (auto node : cost_graph_def.node()) { + int32 index = node.id(); + auto result = ids.insert({index, node}); + EXPECT_TRUE(result.second); + } +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/util/dump_graph.cc b/tensorflow/core/util/dump_graph.cc index 0e16f9d3fb3..14eb15c620d 100644 --- a/tensorflow/core/util/dump_graph.cc +++ b/tensorflow/core/util/dump_graph.cc @@ -18,6 +18,9 @@ limitations under the License. #include "tensorflow/core/util/dump_graph.h" +#include +#include + #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "tensorflow/core/lib/strings/proto_serialization.h" @@ -209,6 +212,25 @@ string DumpGraphDefToFile(const string& name, GraphDef const& graph_def, return filepath; } +string DumpCostGraphDefToFile(const string& name, CostGraphDef const& graph_def, + const string& dirname) { + string filepath; + std::unique_ptr file; + Status status = CreateWritableFile(Env::Default(), dirname, name, ".pbtxt", + &filepath, &file); + if (!status.ok()) { + return StrCat("(failed to create writable file: ", status.ToString(), ")"); + } + + status = WriteTextProtoToUniqueFile(graph_def, file.get()); + if (!status.ok()) { + return StrCat("(failed to dump Graph to '", filepath, + "': ", status.ToString(), ")"); + } + LOG(INFO) << "Dumped Graph to " << filepath; + return filepath; +} + string DumpGraphToFile(const string& name, Graph const& graph, const FunctionLibraryDefinition* flib_def, const string& dirname) { diff --git a/tensorflow/core/util/dump_graph.h b/tensorflow/core/util/dump_graph.h index e4428bd0206..3d3861c2aed 100644 --- a/tensorflow/core/util/dump_graph.h +++ b/tensorflow/core/util/dump_graph.h @@ -19,6 +19,9 @@ limitations under the License. #ifndef TENSORFLOW_CORE_UTIL_DUMP_GRAPH_H_ #define TENSORFLOW_CORE_UTIL_DUMP_GRAPH_H_ +#include + +#include "tensorflow/core/framework/cost_graph.pb.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/graph/graph.h" @@ -39,6 +42,10 @@ namespace tensorflow { string DumpGraphDefToFile(const string& name, GraphDef const& graph_def, const string& dirname = ""); +// Similar to DumpGraphDefToFile, use CostGraphDef instead of GraphDef. +string DumpCostGraphDefToFile(const string& name, CostGraphDef const& graph_def, + const string& dirname = ""); + // Similar to DumpGraphDefToFile, but builds the GraphDef to dump from a 'graph' // and an optional function library 'flib_def'. Returns the file name chosen. string DumpGraphToFile(const string& name, Graph const& graph,