From bea72558323bece442b57bf27ec773734c09e324 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Wed, 14 Jun 2017 15:25:58 -0700 Subject: [PATCH] Made sure that the nodes listed as feed, fetch and init_op exist in the graph. PiperOrigin-RevId: 159034290 --- tensorflow/core/grappler/grappler_item.cc | 2 +- .../core/grappler/grappler_item_builder.cc | 23 +++++++++++++++++++ .../grappler/grappler_item_builder_test.cc | 11 ++++++--- 3 files changed, 32 insertions(+), 4 deletions(-) diff --git a/tensorflow/core/grappler/grappler_item.cc b/tensorflow/core/grappler/grappler_item.cc index 312a457abf4..88ddd6c1b3c 100644 --- a/tensorflow/core/grappler/grappler_item.cc +++ b/tensorflow/core/grappler/grappler_item.cc @@ -55,7 +55,7 @@ std::vector ComputeTransitiveFanin( std::vector queue; for (const string& root : terminal_nodes) { const NodeDef* node = name_to_node[NodeName(root)]; - CHECK(node); + CHECK(node) << "Unknown root " << root; queue.push_back(node); } diff --git a/tensorflow/core/grappler/grappler_item_builder.cc b/tensorflow/core/grappler/grappler_item_builder.cc index bb36152bd87..969376917be 100644 --- a/tensorflow/core/grappler/grappler_item_builder.cc +++ b/tensorflow/core/grappler/grappler_item_builder.cc @@ -332,6 +332,29 @@ std::unique_ptr GrapplerItemFromMetaGraphDef( return nullptr; } + // Validate feed, fetch and init nodes + std::unordered_set nodes; + for (const auto& node : new_item->graph.node()) { + nodes.insert(node.name()); + } + for (const auto& feed : new_item->feed) { + if (nodes.find(feed.first) == nodes.end()) { + LOG(ERROR) << "Feed node " << feed.first << " doesn't exist in graph"; + return nullptr; + } + } + for (const auto& fetch : new_item->fetch) { + if (nodes.find(fetch) == nodes.end()) { + LOG(ERROR) << "Fetch node " << fetch << " doesn't exist in graph"; + return nullptr; + } + } + for (const auto& init : new_item->init_ops) { + if (nodes.find(init) == nodes.end()) { + LOG(ERROR) << "Init node " << init << " doesn't exist in graph"; + return nullptr; + } + } return new_item; } diff --git a/tensorflow/core/grappler/grappler_item_builder_test.cc b/tensorflow/core/grappler/grappler_item_builder_test.cc index 92225ffb1b4..048870f9e51 100644 --- a/tensorflow/core/grappler/grappler_item_builder_test.cc +++ b/tensorflow/core/grappler/grappler_item_builder_test.cc @@ -51,7 +51,11 @@ void SampleSumSymbolicGradientGraphdef( auto g0 = SymbolicGradient(scope, std::initializer_list{x, y, z}, {DT_FLOAT, DT_INT32}, fn); - fetches->mutable_node_list()->add_value(g0[0].name()); + // TODO(bsteiner): we should rewrite the feed/fetch nodes to reflect the + // inlining that's done in the item builder + // fetches->mutable_node_list()->add_value(g0[0].name()); + fetches->mutable_node_list()->add_value("SymbolicGradient/dx"); + fetches->mutable_node_list()->add_value("SymbolicGradient/dy_reshaped"); TF_CHECK_OK(scope.ToGraphDef(def)); @@ -109,11 +113,12 @@ TEST_F(GrapplerItemBuilderTest, SymbolicGradientInlining) { std::unique_ptr with_inline = CreateGrapplerItem(def, fetches); // For the inlined graph, there should be 0 symbolic gradient ops. - CHECK_EQ(0, CountSymbolicGradientOps(with_inline)); + EXPECT_EQ(0, CountSymbolicGradientOps(with_inline)); // For the inlined graph, make sure all the required expanded op’s are in the // graph. - CHECK_EQ(ops_of_inline.size(), CountOpsWithNames(with_inline, ops_of_inline)); + EXPECT_EQ(ops_of_inline.size(), + CountOpsWithNames(with_inline, ops_of_inline)); } } // namespace