Add a utility function to build node name to node index.

PiperOrigin-RevId: 216853788
This commit is contained in:
Tong Shen 2018-10-12 06:33:28 -07:00 committed by TensorFlower Gardener
parent 9e0fa95786
commit 72bf28cd1f
7 changed files with 44 additions and 35 deletions

View File

@ -256,7 +256,7 @@ TEST(EncapsulateXlaComputations, Encapsulate) {
TF_ASSERT_OK(EncapsulateXlaComputationsPass::Encapsulate(&graph, &flib_def));
std::unordered_map<string, Node*> index = BuildNodeIndex(*graph);
std::unordered_map<string, Node*> index = graph->BuildNodeNameIndex();
string function = index.at("launch0")->type_string();
// Tests the outer graph is as expected.
@ -291,7 +291,8 @@ TEST(EncapsulateXlaComputations, Encapsulate) {
// function. Encapsulation should be deterministic to avoid recompilation.
TF_ASSERT_OK(
EncapsulateXlaComputationsPass::Encapsulate(&graph_copy, &flib_def));
std::unordered_map<string, Node*> index_copy = BuildNodeIndex(*graph_copy);
std::unordered_map<string, Node*> index_copy =
graph_copy->BuildNodeNameIndex();
string function_copy = index_copy.at("launch0")->type_string();
EXPECT_EQ(function, function_copy);
}

View File

@ -40,12 +40,4 @@ Status InstantiateFunctionForTest(const string& name,
return Status::OK();
}
std::unordered_map<string, Node*> BuildNodeIndex(const Graph& graph) {
std::unordered_map<string, Node*> index;
for (Node* node : graph.nodes()) {
index[node->name()] = node;
}
return index;
}
} // namespace tensorflow

View File

@ -44,9 +44,6 @@ Status InstantiateFunctionForTest(const string& name,
const FunctionLibraryDefinition& library,
InstantiationResultForTest* result);
// Builds a map from node name to Node* for `graph`.
std::unordered_map<string, Node*> BuildNodeIndex(const Graph& graph);
} // namespace tensorflow
// Variant of TF_EXPECT_GRAPH_EQ that also compares internal attributes for

View File

@ -70,15 +70,6 @@ class ConstantFoldingTest : public ::testing::Test {
test::ExpectTensorEqual<T>(t, test::AsTensor(values, shape));
}
// Builds a map from node name to Node* for `graph`.
std::unordered_map<string, Node*> NodeNameIndex(const Graph& graph) {
std::unordered_map<string, Node*> index;
for (Node* node : graph.nodes()) {
index[node->name()] = node;
}
return index;
}
// Constructs the following graph.
/*
s1 s2
@ -110,7 +101,7 @@ TEST_F(ConstantFoldingTest, Basic) {
nullptr, &g, &was_mutated));
EXPECT_TRUE(was_mutated);
std::unordered_map<string, Node*> index = NodeNameIndex(g);
std::unordered_map<string, Node*> index = g.BuildNodeNameIndex();
Node* s1 = index.at("s1");
Node* s2 = index.at("s2");
// Nodes s1 and s2 now should now have a constant input
@ -165,7 +156,7 @@ TEST_F(ConstantFoldingTest, DeterministicFolding) {
Graph g2(OpRegistry::Global());
TF_ASSERT_OK(build_graph_and_constant_folding(g2, true));
EXPECT_EQ(g1.num_nodes(), g2.num_nodes());
auto index = NodeNameIndex(g2);
auto index = g2.BuildNodeNameIndex();
// All the nodes in g1 are expected to be present in g2.
for (int64 i = 0; i < g1.num_nodes(); ++i) {
@ -188,7 +179,7 @@ TEST_F(ConstantFoldingTest, ConsiderFunction) {
ConstantFold(opts, nullptr, Env::Default(), nullptr, &g, &was_mutated));
EXPECT_TRUE(was_mutated);
std::unordered_map<string, Node*> index = NodeNameIndex(g);
std::unordered_map<string, Node*> index = g.BuildNodeNameIndex();
Node* s1 = index.at("s1");
Node* s2 = index.at("s2");
Node* m2 = index.at("m2");
@ -217,7 +208,7 @@ TEST_F(ConstantFoldingTest, TestNoReplaceAnotherConstant) {
nullptr, &g, &was_mutated));
EXPECT_TRUE(was_mutated);
std::unordered_map<string, Node*> index = NodeNameIndex(g);
std::unordered_map<string, Node*> index = g.BuildNodeNameIndex();
Node* d = index.at("d");
Node* s3 = index.at("s3");
@ -245,7 +236,7 @@ TEST_F(ConstantFoldingTest, TwoOutputs) {
nullptr, &g, &was_mutated));
EXPECT_TRUE(was_mutated);
std::unordered_map<string, Node*> index = NodeNameIndex(g);
std::unordered_map<string, Node*> index = g.BuildNodeNameIndex();
Node* b0 = index.at("b0");
Node* b1 = index.at("b1");
@ -277,7 +268,7 @@ TEST_F(ConstantFoldingTest, TwoOutputsFoldOneOutput) {
ConstantFold(opts, nullptr, Env::Default(), nullptr, &g, &was_mutated));
EXPECT_TRUE(was_mutated);
std::unordered_map<string, Node*> index = NodeNameIndex(g);
std::unordered_map<string, Node*> index = g.BuildNodeNameIndex();
Node* b0 = index.at("b0");
Node* b1 = index.at("b1");
Node* b1_ident = index.at("b1_ident");
@ -412,7 +403,7 @@ TEST_F(ConstantFoldingTest, ControlDependencies) {
nullptr, &g, &was_mutated));
EXPECT_TRUE(was_mutated);
std::unordered_map<string, Node*> index = NodeNameIndex(g);
std::unordered_map<string, Node*> index = g.BuildNodeNameIndex();
Node* recv1 = index.at("recv1");
Node* recv2 = index.at("recv2");
Node* send = index.at("send");
@ -454,7 +445,7 @@ TEST_F(ConstantFoldingTest, SimpleShapeKnown) {
"receiver");
TF_ASSERT_OK(s.ToGraph(&g));
}
std::unordered_map<string, Node*> orig_index = NodeNameIndex(g);
std::unordered_map<string, Node*> orig_index = g.BuildNodeNameIndex();
Node* recv0 = orig_index.at("recv0");
Node* recv1 = orig_index.at("recv1");
PartialTensorShape ps0;
@ -473,7 +464,7 @@ TEST_F(ConstantFoldingTest, SimpleShapeKnown) {
ConstantFold(opts, nullptr, Env::Default(), nullptr, &g, &was_mutated));
EXPECT_TRUE(was_mutated);
std::unordered_map<string, Node*> index = NodeNameIndex(g);
std::unordered_map<string, Node*> index = g.BuildNodeNameIndex();
Node* recv2 = index.at("recv2");
Node* send0 = index.at("send0");
Node* send1 = index.at("send1");
@ -533,7 +524,7 @@ TEST_F(ConstantFoldingTest, PartialShape) {
"receiver");
TF_ASSERT_OK(s.ToGraph(&g));
}
std::unordered_map<string, Node*> orig_index = NodeNameIndex(g);
std::unordered_map<string, Node*> orig_index = g.BuildNodeNameIndex();
Node* recv0 = orig_index.at("recv0");
Node* recv1 = orig_index.at("recv1");
PartialTensorShape ps0;
@ -550,7 +541,7 @@ TEST_F(ConstantFoldingTest, PartialShape) {
ConstantFold(opts, nullptr, Env::Default(), nullptr, &g, &was_mutated));
EXPECT_TRUE(was_mutated);
std::unordered_map<string, Node*> index = NodeNameIndex(g);
std::unordered_map<string, Node*> index = g.BuildNodeNameIndex();
Node* shape = index.at("shape");
Node* size = index.at("size");
Node* rank1 = index.at("rank1");
@ -590,7 +581,7 @@ TEST_F(ConstantFoldingTest, ConstShapeKnown) {
"receiver");
TF_ASSERT_OK(s.ToGraph(&g));
}
std::unordered_map<string, Node*> orig_index = NodeNameIndex(g);
std::unordered_map<string, Node*> orig_index = g.BuildNodeNameIndex();
Node* c0 = orig_index.at("c0");
PartialTensorShape ps0;
int c0_dims[] = {};
@ -604,7 +595,7 @@ TEST_F(ConstantFoldingTest, ConstShapeKnown) {
ConstantFold(opts, nullptr, Env::Default(), nullptr, &g, &was_mutated));
EXPECT_TRUE(was_mutated);
std::unordered_map<string, Node*> index = NodeNameIndex(g);
std::unordered_map<string, Node*> index = g.BuildNodeNameIndex();
Node* recv0 = index.at("recv0");
Node* send0 = index.at("send0");

View File

@ -750,6 +750,14 @@ Status Graph::AddWhileContext(StringPiece frame_name,
return Status::OK();
}
std::unordered_map<string, Node*> Graph::BuildNodeNameIndex() const {
std::unordered_map<string, Node*> result;
for (Node* n : nodes()) {
result[n->name()] = n;
}
return result;
}
string Edge::DebugString() const {
return strings::Printf("[id=%d %s:%d -> %s:%d]", id_, src_->name().c_str(),
src_output_, dst_->name().c_str(), dst_input_);

View File

@ -614,6 +614,9 @@ class Graph {
std::vector<OutputTensor> body_outputs,
WhileContext** result);
// Builds a node name to node pointer index for all nodes in the graph.
std::unordered_map<string, Node*> BuildNodeNameIndex() const;
// TODO(josh11b): uint64 hash() const;
private:

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/graph/graph.h"
#include <set>
#include <unordered_map>
#include <vector>
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/function_testlib.h"
@ -643,6 +644,22 @@ TEST_F(GraphTest, AddFunctionLibrary) {
"because it already has gradient function 'Undefined'");
}
TEST_F(GraphTest, BuildNodeNameIndex) {
FromGraphDef(
"node { name: 'A' op: 'OneOutput' }"
"node { name: 'B' op: 'OneInputTwoOutputs' input: [ 'A:0' ] }"
"node { name: 'C' op: 'NoOp' } ");
auto node_name_index = graph_.BuildNodeNameIndex();
EXPECT_EQ(node_name_index.size(), 5);
std::vector<string> node_names{"_SOURCE", "_SINK", "A", "B", "C"};
for (const string& node_name : node_names) {
EXPECT_NE(node_name_index.find(node_name), node_name_index.end());
EXPECT_EQ(node_name_index[node_name], FindNode(node_name));
}
}
REGISTER_OP("Input").Output("o: float");
REGISTER_OP("In2Out1").Input("a: float").Input("b: float").Output("o: float");