Made sure that the nodes listed as feed, fetch and init_op exist in the graph.

PiperOrigin-RevId: 159034290
This commit is contained in:
Benoit Steiner 2017-06-14 15:25:58 -07:00 committed by TensorFlower Gardener
parent 69bc160235
commit bea7255832
3 changed files with 32 additions and 4 deletions

View File

@ -55,7 +55,7 @@ std::vector<const NodeDef*> ComputeTransitiveFanin(
std::vector<const NodeDef*> queue;
for (const string& root : terminal_nodes) {
const NodeDef* node = name_to_node[NodeName(root)];
CHECK(node);
CHECK(node) << "Unknown root " << root;
queue.push_back(node);
}

View File

@ -332,6 +332,29 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
return nullptr;
}
// Validate feed, fetch and init nodes
std::unordered_set<string> nodes;
for (const auto& node : new_item->graph.node()) {
nodes.insert(node.name());
}
for (const auto& feed : new_item->feed) {
if (nodes.find(feed.first) == nodes.end()) {
LOG(ERROR) << "Feed node " << feed.first << " doesn't exist in graph";
return nullptr;
}
}
for (const auto& fetch : new_item->fetch) {
if (nodes.find(fetch) == nodes.end()) {
LOG(ERROR) << "Fetch node " << fetch << " doesn't exist in graph";
return nullptr;
}
}
for (const auto& init : new_item->init_ops) {
if (nodes.find(init) == nodes.end()) {
LOG(ERROR) << "Init node " << init << " doesn't exist in graph";
return nullptr;
}
}
return new_item;
}

View File

@ -51,7 +51,11 @@ void SampleSumSymbolicGradientGraphdef(
auto g0 = SymbolicGradient(scope, std::initializer_list<Input>{x, y, z},
{DT_FLOAT, DT_INT32}, fn);
fetches->mutable_node_list()->add_value(g0[0].name());
// TODO(bsteiner): we should rewrite the feed/fetch nodes to reflect the
// inlining that's done in the item builder
// fetches->mutable_node_list()->add_value(g0[0].name());
fetches->mutable_node_list()->add_value("SymbolicGradient/dx");
fetches->mutable_node_list()->add_value("SymbolicGradient/dy_reshaped");
TF_CHECK_OK(scope.ToGraphDef(def));
@ -109,11 +113,12 @@ TEST_F(GrapplerItemBuilderTest, SymbolicGradientInlining) {
std::unique_ptr<GrapplerItem> with_inline = CreateGrapplerItem(def, fetches);
// For the inlined graph, there should be 0 symbolic gradient ops.
CHECK_EQ(0, CountSymbolicGradientOps(with_inline));
EXPECT_EQ(0, CountSymbolicGradientOps(with_inline));
// For the inlined graph, make sure all the required expanded ops are in the
// graph.
CHECK_EQ(ops_of_inline.size(), CountOpsWithNames(with_inline, ops_of_inline));
EXPECT_EQ(ops_of_inline.size(),
CountOpsWithNames(with_inline, ops_of_inline));
}
} // namespace