Fix for outside compilation.
When generating Placeholder node for host computation/outside compilation nodes, we have to generate different Placeholder node for different src_output edges. PiperOrigin-RevId: 223544322
This commit is contained in:
parent
5b14577d42
commit
92e236dd24
@ -241,7 +241,7 @@ Status ProcessDataEdgeBetweenOutsideCompilationAndHostComputation(
|
||||
|
||||
// Remove the edge from host to outside compilation. Add a placeholder as
|
||||
// outside compilation node input.
|
||||
std::map<string, Node*> placeholders;
|
||||
std::map<std::pair<string, int>, Node*> placeholders;
|
||||
for (int i = 0; i < edges.size(); i++) {
|
||||
Node* dst = g->FindNodeId(edges[i].dst_node_id);
|
||||
const Edge* e;
|
||||
@ -253,9 +253,10 @@ Status ProcessDataEdgeBetweenOutsideCompilationAndHostComputation(
|
||||
// Find or create placeholder node.
|
||||
string new_name =
|
||||
edges[i].is_host_to_outside_compilation
|
||||
? absl::StrCat(src->name(), "_host_to_oc_placeholder")
|
||||
: absl::StrCat(src->name(), "_oc_to_host_placeholder");
|
||||
auto iter = placeholders.find(new_name);
|
||||
? absl::StrCat(src->name(), "_host_to_oc_placeholder_", src_output)
|
||||
: absl::StrCat(src->name(), "_oc_to_host_placeholder_", src_output);
|
||||
auto placeholder_index = std::make_pair(src->name(), src_output);
|
||||
auto iter = placeholders.find(placeholder_index);
|
||||
Node* placeholder_node;
|
||||
if (iter == placeholders.end()) {
|
||||
NodeDefBuilder placeholder_builder(new_name, "Placeholder");
|
||||
@ -288,7 +289,7 @@ Status ProcessDataEdgeBetweenOutsideCompilationAndHostComputation(
|
||||
Status s;
|
||||
placeholder_node = g->AddNode(placeholder_def, &s);
|
||||
TF_RETURN_IF_ERROR(s);
|
||||
placeholders[new_name] = placeholder_node;
|
||||
placeholders[placeholder_index] = placeholder_node;
|
||||
} else {
|
||||
placeholder_node = iter->second;
|
||||
}
|
||||
@ -642,7 +643,7 @@ Status PreprocessDataEdgesBetweenOutsideCompilations(
|
||||
|
||||
// Remove the edge from host to outside compilation. Add a placeholder as
|
||||
// outside compilation node input.
|
||||
std::map<string, Node*> placeholders;
|
||||
std::map<std::pair<string, int>, Node*> placeholders;
|
||||
for (int i = 0; i < edges.size(); i++) {
|
||||
Node* dst = g->FindNodeId(edges[i].dst_node_id);
|
||||
const Edge* e;
|
||||
@ -652,8 +653,10 @@ Status PreprocessDataEdgesBetweenOutsideCompilations(
|
||||
g->RemoveEdge(e);
|
||||
|
||||
// Find or create placeholder node.
|
||||
string new_name = absl::StrCat(src->name(), "_oc_to_oc_placeholder");
|
||||
auto iter = placeholders.find(new_name);
|
||||
string new_name =
|
||||
absl::StrCat(src->name(), "_oc_to_oc_placeholder_", src_output);
|
||||
auto placeholder_index = std::make_pair(src->name(), src_output);
|
||||
auto iter = placeholders.find(placeholder_index);
|
||||
Node* placeholder_node;
|
||||
if (iter == placeholders.end()) {
|
||||
NodeDefBuilder placeholder_builder(new_name, "Placeholder");
|
||||
@ -673,7 +676,7 @@ Status PreprocessDataEdgesBetweenOutsideCompilations(
|
||||
Status s;
|
||||
placeholder_node = g->AddNode(placeholder_def, &s);
|
||||
TF_RETURN_IF_ERROR(s);
|
||||
placeholders[new_name] = placeholder_node;
|
||||
placeholders[placeholder_index] = placeholder_node;
|
||||
} else {
|
||||
placeholder_node = iter->second;
|
||||
}
|
||||
|
@ -153,23 +153,33 @@ TEST(PreprocessForEncapsulationTest, ControlEdges) {
|
||||
TEST(PreprocessForEncapsulationTest, DataEdges) {
|
||||
// Build the graph:
|
||||
// "const_0" and "const_1" in host computation
|
||||
// "identityn0" = ("const_0", "const_1") in host computation 0
|
||||
// "add0" = "const_0" + "const_1" in XLA computation 0
|
||||
// "add1" = "add0" + "const_0" in XLA computation 0 & outside compilation 0
|
||||
// "identity0" = "add1" in XLA computation 0
|
||||
// "add2" = "add1" + "identity0" in host computation
|
||||
// "add3" = "add1" + "add2" in XLA computation 1
|
||||
// "add4" = "identity0" + "add2" in XLA computation 1 & outside compilation 1
|
||||
// "add4" = "identity0" + "add2" in XLA computation 1 & outside compilation 0
|
||||
// "add5" = "identityn0"[0] + "identityn0"[1] in XLA computation 1 &
|
||||
// outside compilation 0
|
||||
// "identityn1" = ("identityn0"[0], "identityn0"[1]) in XLA computation 1 &
|
||||
// outside compilation 0
|
||||
// "identity1" = "add4" in XLA computation 1
|
||||
// "identity2" = "identity1" in host computation
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output const_0 = ops::Const(s.WithOpName("const_0"), 1, {});
|
||||
Output const_1 = ops::Const(s.WithOpName("const_1"), 2, {});
|
||||
auto identityn0 =
|
||||
ops::IdentityN(s.WithOpName("identityn_0"), {const_0, const_1});
|
||||
Output add0 = ops::Add(s.WithOpName("add0"), const_0, const_1);
|
||||
Output add1 = ops::Add(s.WithOpName("add1"), add0, const_0);
|
||||
Output identity0 = ops::Identity(s.WithOpName("identity0"), add1);
|
||||
Output add2 = ops::Add(s.WithOpName("add2"), add1, identity0);
|
||||
Output add3 = ops::Add(s.WithOpName("add3"), add1, add2);
|
||||
Output add4 = ops::Add(s.WithOpName("add4"), identity0, add2);
|
||||
Output add5 = ops::Add(s.WithOpName("add5"), identityn0[0], identityn0[1]);
|
||||
auto identityn1 = ops::IdentityN(s.WithOpName("identityn_1"),
|
||||
{identityn0[0], identityn0[1]});
|
||||
Output identity1 = ops::Identity(s.WithOpName("identity1"), add4);
|
||||
Output identity2 = ops::Identity(s.WithOpName("identity2"), add4);
|
||||
Graph g(OpRegistry::Global());
|
||||
@ -180,6 +190,8 @@ TEST(PreprocessForEncapsulationTest, DataEdges) {
|
||||
Node *add0_node = node_index["add0"], *add1_node = node_index["add1"],
|
||||
*identity0_node = node_index["identity0"],
|
||||
*add3_node = node_index["add3"], *add4_node = node_index["add4"],
|
||||
*add5_node = node_index["add5"],
|
||||
*identityn1_node = node_index["identityn_1"],
|
||||
*identity1_node = node_index["identity1"];
|
||||
add0_node->AddAttr("_xla", "0");
|
||||
add1_node->AddAttr("_xla", "0");
|
||||
@ -188,6 +200,10 @@ TEST(PreprocessForEncapsulationTest, DataEdges) {
|
||||
add3_node->AddAttr("_xla", "1");
|
||||
add4_node->AddAttr("_xla", "1");
|
||||
add4_node->AddAttr("_oc", "0");
|
||||
add5_node->AddAttr("_xla", "1");
|
||||
add5_node->AddAttr("_oc", "0");
|
||||
identityn1_node->AddAttr("_xla", "1");
|
||||
identityn1_node->AddAttr("_oc", "0");
|
||||
identity1_node->AddAttr("_xla", "1");
|
||||
|
||||
TF_CHECK_OK(PreprocessForEncapsulation(&g, "_xla", "_oc"));
|
||||
@ -205,8 +221,9 @@ TEST(PreprocessForEncapsulationTest, DataEdges) {
|
||||
EXPECT_NE(bridge_identity0_add4, nullptr);
|
||||
// Step 3: add placeholder for edges between host computation and outside
|
||||
// compilation.
|
||||
EXPECT_EQ(bridge_add1_add3->def().input(0), "add1_oc_to_host_placeholder");
|
||||
Node *add1_oc_to_host_placeholder = node_index["add1_oc_to_host_placeholder"];
|
||||
EXPECT_EQ(bridge_add1_add3->def().input(0), "add1_oc_to_host_placeholder_0");
|
||||
Node *add1_oc_to_host_placeholder =
|
||||
node_index["add1_oc_to_host_placeholder_0"];
|
||||
TF_CHECK_OK(GetNodeAttr(add1_oc_to_host_placeholder->attrs(),
|
||||
kOutsideCompilationToHostOriginalNodeAttrName, &str));
|
||||
EXPECT_EQ(str, "add1");
|
||||
@ -217,15 +234,34 @@ TEST(PreprocessForEncapsulationTest, DataEdges) {
|
||||
add4_node = node_index["add4"];
|
||||
ASSERT_NE(add4_node, nullptr);
|
||||
EXPECT_EQ(add4_node->def().input(0),
|
||||
"bridge_identity0_add4_host_to_oc_placeholder");
|
||||
"bridge_identity0_add4_host_to_oc_placeholder_0");
|
||||
Node *identity0_host_to_oc_placeholder =
|
||||
node_index["bridge_identity0_add4_host_to_oc_placeholder"];
|
||||
node_index["bridge_identity0_add4_host_to_oc_placeholder_0"];
|
||||
TF_CHECK_OK(GetNodeAttr(identity0_host_to_oc_placeholder->attrs(),
|
||||
kHostToOutsideCompilationOriginalNodeAttrName, &str));
|
||||
EXPECT_EQ(str, "bridge_identity0_add4");
|
||||
TF_CHECK_OK(GetNodeAttr(identity0_host_to_oc_placeholder->attrs(),
|
||||
kHostToOutsideCompilationSrcOutputAttrName, &i));
|
||||
EXPECT_EQ(i, 0);
|
||||
|
||||
// Check different placeholder nodes are created for different src_output.
|
||||
Node *placeholder0 = node_index["identityn_0_host_to_oc_placeholder_0"],
|
||||
*placeholder1 = node_index["identityn_0_host_to_oc_placeholder_1"];
|
||||
EXPECT_NE(placeholder0, nullptr);
|
||||
EXPECT_NE(placeholder1, nullptr);
|
||||
// Check we only have 2 placeholder nodes created for "identityn_0".
|
||||
int placeholder_count = 0;
|
||||
for (Node *n : g.nodes()) {
|
||||
if (HasNodeAttr(n->def(), kHostToOutsideCompilationOriginalNodeAttrName)) {
|
||||
string attr;
|
||||
TF_CHECK_OK(GetNodeAttr(
|
||||
n->attrs(), kHostToOutsideCompilationOriginalNodeAttrName, &attr));
|
||||
if (attr == "identityn_0") {
|
||||
++placeholder_count;
|
||||
}
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(placeholder_count, 2);
|
||||
}
|
||||
|
||||
TEST(PostprocessForEncapsulationTest, ControlEdges) {
|
||||
|
Loading…
Reference in New Issue
Block a user