Made sure that the nodes listed as feed, fetch and init_op exist in the graph.
PiperOrigin-RevId: 159034290
This commit is contained in:
parent
69bc160235
commit
bea7255832
@ -55,7 +55,7 @@ std::vector<const NodeDef*> ComputeTransitiveFanin(
|
||||
std::vector<const NodeDef*> queue;
|
||||
for (const string& root : terminal_nodes) {
|
||||
const NodeDef* node = name_to_node[NodeName(root)];
|
||||
CHECK(node);
|
||||
CHECK(node) << "Unknown root " << root;
|
||||
queue.push_back(node);
|
||||
}
|
||||
|
||||
|
@ -332,6 +332,29 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Validate feed, fetch and init nodes
|
||||
std::unordered_set<string> nodes;
|
||||
for (const auto& node : new_item->graph.node()) {
|
||||
nodes.insert(node.name());
|
||||
}
|
||||
for (const auto& feed : new_item->feed) {
|
||||
if (nodes.find(feed.first) == nodes.end()) {
|
||||
LOG(ERROR) << "Feed node " << feed.first << " doesn't exist in graph";
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
for (const auto& fetch : new_item->fetch) {
|
||||
if (nodes.find(fetch) == nodes.end()) {
|
||||
LOG(ERROR) << "Fetch node " << fetch << " doesn't exist in graph";
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
for (const auto& init : new_item->init_ops) {
|
||||
if (nodes.find(init) == nodes.end()) {
|
||||
LOG(ERROR) << "Init node " << init << " doesn't exist in graph";
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
return new_item;
|
||||
}
|
||||
|
||||
|
@ -51,7 +51,11 @@ void SampleSumSymbolicGradientGraphdef(
|
||||
auto g0 = SymbolicGradient(scope, std::initializer_list<Input>{x, y, z},
|
||||
{DT_FLOAT, DT_INT32}, fn);
|
||||
|
||||
fetches->mutable_node_list()->add_value(g0[0].name());
|
||||
// TODO(bsteiner): we should rewrite the feed/fetch nodes to reflect the
|
||||
// inlining that's done in the item builder
|
||||
// fetches->mutable_node_list()->add_value(g0[0].name());
|
||||
fetches->mutable_node_list()->add_value("SymbolicGradient/dx");
|
||||
fetches->mutable_node_list()->add_value("SymbolicGradient/dy_reshaped");
|
||||
|
||||
TF_CHECK_OK(scope.ToGraphDef(def));
|
||||
|
||||
@ -109,11 +113,12 @@ TEST_F(GrapplerItemBuilderTest, SymbolicGradientInlining) {
|
||||
std::unique_ptr<GrapplerItem> with_inline = CreateGrapplerItem(def, fetches);
|
||||
|
||||
// For the inlined graph, there should be 0 symbolic gradient ops.
|
||||
CHECK_EQ(0, CountSymbolicGradientOps(with_inline));
|
||||
EXPECT_EQ(0, CountSymbolicGradientOps(with_inline));
|
||||
|
||||
// For the inlined graph, make sure all the required expanded op’s are in the
|
||||
// graph.
|
||||
CHECK_EQ(ops_of_inline.size(), CountOpsWithNames(with_inline, ops_of_inline));
|
||||
EXPECT_EQ(ops_of_inline.size(),
|
||||
CountOpsWithNames(with_inline, ops_of_inline));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
Loading…
Reference in New Issue
Block a user