diff --git a/tensorflow/core/grappler/BUILD b/tensorflow/core/grappler/BUILD index 1bd6635d9bb..da9a5e277d1 100644 --- a/tensorflow/core/grappler/BUILD +++ b/tensorflow/core/grappler/BUILD @@ -46,6 +46,11 @@ tf_cc_test( "//tensorflow/core:tensor_testutil", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/lib/bfloat16", + "//tensorflow/core/lib/core:status", + "//tensorflow/core/platform:notification", + "//tensorflow/core/platform:types", "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc index 7c176fedfe6..b28876a1702 100644 --- a/tensorflow/core/grappler/utils.cc +++ b/tensorflow/core/grappler/utils.cc @@ -80,10 +80,12 @@ NodeMap::NodeMap(GraphDef* graph) { auto rslt = nodes_.emplace(node_name, node); // Check that the graph doesn't contain multiple nodes with the same name. if (!rslt.second) { + // The first node found with a given name becomes the canonical. LOG(WARNING) << "Duplicated node in the graph: " << node_name; } + NodeDef* canonical = rslt.second ? node : rslt.first->second; for (const auto& input : node->input()) { - outputs_[NodeName(input)].insert(nodes_[node_name]); + outputs_[NodeName(input)].insert(canonical); } } } diff --git a/tensorflow/core/grappler/utils_test.cc b/tensorflow/core/grappler/utils_test.cc index 4d53bf9ced5..a65704147ea 100644 --- a/tensorflow/core/grappler/utils_test.cc +++ b/tensorflow/core/grappler/utils_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/grappler/utils.h" #include + #include #include @@ -23,6 +24,7 @@ limitations under the License. #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/graph/benchmark_testlib.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/lib/bfloat16/bfloat16.h" #include "tensorflow/core/lib/core/status.h" @@ -599,6 +601,17 @@ TEST(SetTensorValueTest, Quantized) { /*error_msg=*/""); } +static void BM_NodeMapConstruct(int iters, int size) { + testing::StopTiming(); + GraphDef graph = test::CreateRandomGraph(size); + testing::StartTiming(); + for (int i = 0; i < iters; i++) { + NodeMap node_map(&graph); + } + testing::StopTiming(); +} +BENCHMARK(BM_NodeMapConstruct)->Range(1, 1 << 20); + } // namespace } // namespace grappler } // namespace tensorflow