Add a utility function to build node name to node index.
PiperOrigin-RevId: 216853788
This commit is contained in:
parent
9e0fa95786
commit
72bf28cd1f
@ -256,7 +256,7 @@ TEST(EncapsulateXlaComputations, Encapsulate) {
|
||||
|
||||
TF_ASSERT_OK(EncapsulateXlaComputationsPass::Encapsulate(&graph, &flib_def));
|
||||
|
||||
std::unordered_map<string, Node*> index = BuildNodeIndex(*graph);
|
||||
std::unordered_map<string, Node*> index = graph->BuildNodeNameIndex();
|
||||
string function = index.at("launch0")->type_string();
|
||||
|
||||
// Tests the outer graph is as expected.
|
||||
@ -291,7 +291,8 @@ TEST(EncapsulateXlaComputations, Encapsulate) {
|
||||
// function. Encapsulation should be deterministic to avoid recompilation.
|
||||
TF_ASSERT_OK(
|
||||
EncapsulateXlaComputationsPass::Encapsulate(&graph_copy, &flib_def));
|
||||
std::unordered_map<string, Node*> index_copy = BuildNodeIndex(*graph_copy);
|
||||
std::unordered_map<string, Node*> index_copy =
|
||||
graph_copy->BuildNodeNameIndex();
|
||||
string function_copy = index_copy.at("launch0")->type_string();
|
||||
EXPECT_EQ(function, function_copy);
|
||||
}
|
||||
|
@ -40,12 +40,4 @@ Status InstantiateFunctionForTest(const string& name,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::unordered_map<string, Node*> BuildNodeIndex(const Graph& graph) {
|
||||
std::unordered_map<string, Node*> index;
|
||||
for (Node* node : graph.nodes()) {
|
||||
index[node->name()] = node;
|
||||
}
|
||||
return index;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -44,9 +44,6 @@ Status InstantiateFunctionForTest(const string& name,
|
||||
const FunctionLibraryDefinition& library,
|
||||
InstantiationResultForTest* result);
|
||||
|
||||
// Builds a map from node name to Node* for `graph`.
|
||||
std::unordered_map<string, Node*> BuildNodeIndex(const Graph& graph);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
// Variant of TF_EXPECT_GRAPH_EQ that also compares internal attributes for
|
||||
|
@ -70,15 +70,6 @@ class ConstantFoldingTest : public ::testing::Test {
|
||||
test::ExpectTensorEqual<T>(t, test::AsTensor(values, shape));
|
||||
}
|
||||
|
||||
// Builds a map from node name to Node* for `graph`.
|
||||
std::unordered_map<string, Node*> NodeNameIndex(const Graph& graph) {
|
||||
std::unordered_map<string, Node*> index;
|
||||
for (Node* node : graph.nodes()) {
|
||||
index[node->name()] = node;
|
||||
}
|
||||
return index;
|
||||
}
|
||||
|
||||
// Constructs the following graph.
|
||||
/*
|
||||
s1 s2
|
||||
@ -110,7 +101,7 @@ TEST_F(ConstantFoldingTest, Basic) {
|
||||
nullptr, &g, &was_mutated));
|
||||
EXPECT_TRUE(was_mutated);
|
||||
|
||||
std::unordered_map<string, Node*> index = NodeNameIndex(g);
|
||||
std::unordered_map<string, Node*> index = g.BuildNodeNameIndex();
|
||||
Node* s1 = index.at("s1");
|
||||
Node* s2 = index.at("s2");
|
||||
// Nodes s1 and s2 now should now have a constant input
|
||||
@ -165,7 +156,7 @@ TEST_F(ConstantFoldingTest, DeterministicFolding) {
|
||||
Graph g2(OpRegistry::Global());
|
||||
TF_ASSERT_OK(build_graph_and_constant_folding(g2, true));
|
||||
EXPECT_EQ(g1.num_nodes(), g2.num_nodes());
|
||||
auto index = NodeNameIndex(g2);
|
||||
auto index = g2.BuildNodeNameIndex();
|
||||
|
||||
// All the nodes in g1 are expected to be present in g2.
|
||||
for (int64 i = 0; i < g1.num_nodes(); ++i) {
|
||||
@ -188,7 +179,7 @@ TEST_F(ConstantFoldingTest, ConsiderFunction) {
|
||||
ConstantFold(opts, nullptr, Env::Default(), nullptr, &g, &was_mutated));
|
||||
EXPECT_TRUE(was_mutated);
|
||||
|
||||
std::unordered_map<string, Node*> index = NodeNameIndex(g);
|
||||
std::unordered_map<string, Node*> index = g.BuildNodeNameIndex();
|
||||
Node* s1 = index.at("s1");
|
||||
Node* s2 = index.at("s2");
|
||||
Node* m2 = index.at("m2");
|
||||
@ -217,7 +208,7 @@ TEST_F(ConstantFoldingTest, TestNoReplaceAnotherConstant) {
|
||||
nullptr, &g, &was_mutated));
|
||||
EXPECT_TRUE(was_mutated);
|
||||
|
||||
std::unordered_map<string, Node*> index = NodeNameIndex(g);
|
||||
std::unordered_map<string, Node*> index = g.BuildNodeNameIndex();
|
||||
Node* d = index.at("d");
|
||||
Node* s3 = index.at("s3");
|
||||
|
||||
@ -245,7 +236,7 @@ TEST_F(ConstantFoldingTest, TwoOutputs) {
|
||||
nullptr, &g, &was_mutated));
|
||||
EXPECT_TRUE(was_mutated);
|
||||
|
||||
std::unordered_map<string, Node*> index = NodeNameIndex(g);
|
||||
std::unordered_map<string, Node*> index = g.BuildNodeNameIndex();
|
||||
Node* b0 = index.at("b0");
|
||||
Node* b1 = index.at("b1");
|
||||
|
||||
@ -277,7 +268,7 @@ TEST_F(ConstantFoldingTest, TwoOutputsFoldOneOutput) {
|
||||
ConstantFold(opts, nullptr, Env::Default(), nullptr, &g, &was_mutated));
|
||||
EXPECT_TRUE(was_mutated);
|
||||
|
||||
std::unordered_map<string, Node*> index = NodeNameIndex(g);
|
||||
std::unordered_map<string, Node*> index = g.BuildNodeNameIndex();
|
||||
Node* b0 = index.at("b0");
|
||||
Node* b1 = index.at("b1");
|
||||
Node* b1_ident = index.at("b1_ident");
|
||||
@ -412,7 +403,7 @@ TEST_F(ConstantFoldingTest, ControlDependencies) {
|
||||
nullptr, &g, &was_mutated));
|
||||
EXPECT_TRUE(was_mutated);
|
||||
|
||||
std::unordered_map<string, Node*> index = NodeNameIndex(g);
|
||||
std::unordered_map<string, Node*> index = g.BuildNodeNameIndex();
|
||||
Node* recv1 = index.at("recv1");
|
||||
Node* recv2 = index.at("recv2");
|
||||
Node* send = index.at("send");
|
||||
@ -454,7 +445,7 @@ TEST_F(ConstantFoldingTest, SimpleShapeKnown) {
|
||||
"receiver");
|
||||
TF_ASSERT_OK(s.ToGraph(&g));
|
||||
}
|
||||
std::unordered_map<string, Node*> orig_index = NodeNameIndex(g);
|
||||
std::unordered_map<string, Node*> orig_index = g.BuildNodeNameIndex();
|
||||
Node* recv0 = orig_index.at("recv0");
|
||||
Node* recv1 = orig_index.at("recv1");
|
||||
PartialTensorShape ps0;
|
||||
@ -473,7 +464,7 @@ TEST_F(ConstantFoldingTest, SimpleShapeKnown) {
|
||||
ConstantFold(opts, nullptr, Env::Default(), nullptr, &g, &was_mutated));
|
||||
EXPECT_TRUE(was_mutated);
|
||||
|
||||
std::unordered_map<string, Node*> index = NodeNameIndex(g);
|
||||
std::unordered_map<string, Node*> index = g.BuildNodeNameIndex();
|
||||
Node* recv2 = index.at("recv2");
|
||||
Node* send0 = index.at("send0");
|
||||
Node* send1 = index.at("send1");
|
||||
@ -533,7 +524,7 @@ TEST_F(ConstantFoldingTest, PartialShape) {
|
||||
"receiver");
|
||||
TF_ASSERT_OK(s.ToGraph(&g));
|
||||
}
|
||||
std::unordered_map<string, Node*> orig_index = NodeNameIndex(g);
|
||||
std::unordered_map<string, Node*> orig_index = g.BuildNodeNameIndex();
|
||||
Node* recv0 = orig_index.at("recv0");
|
||||
Node* recv1 = orig_index.at("recv1");
|
||||
PartialTensorShape ps0;
|
||||
@ -550,7 +541,7 @@ TEST_F(ConstantFoldingTest, PartialShape) {
|
||||
ConstantFold(opts, nullptr, Env::Default(), nullptr, &g, &was_mutated));
|
||||
EXPECT_TRUE(was_mutated);
|
||||
|
||||
std::unordered_map<string, Node*> index = NodeNameIndex(g);
|
||||
std::unordered_map<string, Node*> index = g.BuildNodeNameIndex();
|
||||
Node* shape = index.at("shape");
|
||||
Node* size = index.at("size");
|
||||
Node* rank1 = index.at("rank1");
|
||||
@ -590,7 +581,7 @@ TEST_F(ConstantFoldingTest, ConstShapeKnown) {
|
||||
"receiver");
|
||||
TF_ASSERT_OK(s.ToGraph(&g));
|
||||
}
|
||||
std::unordered_map<string, Node*> orig_index = NodeNameIndex(g);
|
||||
std::unordered_map<string, Node*> orig_index = g.BuildNodeNameIndex();
|
||||
Node* c0 = orig_index.at("c0");
|
||||
PartialTensorShape ps0;
|
||||
int c0_dims[] = {};
|
||||
@ -604,7 +595,7 @@ TEST_F(ConstantFoldingTest, ConstShapeKnown) {
|
||||
ConstantFold(opts, nullptr, Env::Default(), nullptr, &g, &was_mutated));
|
||||
EXPECT_TRUE(was_mutated);
|
||||
|
||||
std::unordered_map<string, Node*> index = NodeNameIndex(g);
|
||||
std::unordered_map<string, Node*> index = g.BuildNodeNameIndex();
|
||||
Node* recv0 = index.at("recv0");
|
||||
Node* send0 = index.at("send0");
|
||||
|
||||
|
@ -750,6 +750,14 @@ Status Graph::AddWhileContext(StringPiece frame_name,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::unordered_map<string, Node*> Graph::BuildNodeNameIndex() const {
|
||||
std::unordered_map<string, Node*> result;
|
||||
for (Node* n : nodes()) {
|
||||
result[n->name()] = n;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
string Edge::DebugString() const {
|
||||
return strings::Printf("[id=%d %s:%d -> %s:%d]", id_, src_->name().c_str(),
|
||||
src_output_, dst_->name().c_str(), dst_input_);
|
||||
|
@ -614,6 +614,9 @@ class Graph {
|
||||
std::vector<OutputTensor> body_outputs,
|
||||
WhileContext** result);
|
||||
|
||||
// Builds a node name to node pointer index for all nodes in the graph.
|
||||
std::unordered_map<string, Node*> BuildNodeNameIndex() const;
|
||||
|
||||
// TODO(josh11b): uint64 hash() const;
|
||||
|
||||
private:
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
|
||||
#include <set>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/framework/function_testlib.h"
|
||||
@ -643,6 +644,22 @@ TEST_F(GraphTest, AddFunctionLibrary) {
|
||||
"because it already has gradient function 'Undefined'");
|
||||
}
|
||||
|
||||
TEST_F(GraphTest, BuildNodeNameIndex) {
|
||||
FromGraphDef(
|
||||
"node { name: 'A' op: 'OneOutput' }"
|
||||
"node { name: 'B' op: 'OneInputTwoOutputs' input: [ 'A:0' ] }"
|
||||
"node { name: 'C' op: 'NoOp' } ");
|
||||
|
||||
auto node_name_index = graph_.BuildNodeNameIndex();
|
||||
EXPECT_EQ(node_name_index.size(), 5);
|
||||
|
||||
std::vector<string> node_names{"_SOURCE", "_SINK", "A", "B", "C"};
|
||||
for (const string& node_name : node_names) {
|
||||
EXPECT_NE(node_name_index.find(node_name), node_name_index.end());
|
||||
EXPECT_EQ(node_name_index[node_name], FindNode(node_name));
|
||||
}
|
||||
}
|
||||
|
||||
REGISTER_OP("Input").Output("o: float");
|
||||
REGISTER_OP("In2Out1").Input("a: float").Input("b: float").Output("o: float");
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user