Fix for outside compilation.
When generating Placeholder node for host computation/outside compilation nodes, we have to generate different Placeholder node for different src_output edges. PiperOrigin-RevId: 223544322
This commit is contained in:
parent
5b14577d42
commit
92e236dd24
@ -241,7 +241,7 @@ Status ProcessDataEdgeBetweenOutsideCompilationAndHostComputation(
|
|||||||
|
|
||||||
// Remove the edge from host to outside compilation. Add a placeholder as
|
// Remove the edge from host to outside compilation. Add a placeholder as
|
||||||
// outside compilation node input.
|
// outside compilation node input.
|
||||||
std::map<string, Node*> placeholders;
|
std::map<std::pair<string, int>, Node*> placeholders;
|
||||||
for (int i = 0; i < edges.size(); i++) {
|
for (int i = 0; i < edges.size(); i++) {
|
||||||
Node* dst = g->FindNodeId(edges[i].dst_node_id);
|
Node* dst = g->FindNodeId(edges[i].dst_node_id);
|
||||||
const Edge* e;
|
const Edge* e;
|
||||||
@ -253,9 +253,10 @@ Status ProcessDataEdgeBetweenOutsideCompilationAndHostComputation(
|
|||||||
// Find or create placeholder node.
|
// Find or create placeholder node.
|
||||||
string new_name =
|
string new_name =
|
||||||
edges[i].is_host_to_outside_compilation
|
edges[i].is_host_to_outside_compilation
|
||||||
? absl::StrCat(src->name(), "_host_to_oc_placeholder")
|
? absl::StrCat(src->name(), "_host_to_oc_placeholder_", src_output)
|
||||||
: absl::StrCat(src->name(), "_oc_to_host_placeholder");
|
: absl::StrCat(src->name(), "_oc_to_host_placeholder_", src_output);
|
||||||
auto iter = placeholders.find(new_name);
|
auto placeholder_index = std::make_pair(src->name(), src_output);
|
||||||
|
auto iter = placeholders.find(placeholder_index);
|
||||||
Node* placeholder_node;
|
Node* placeholder_node;
|
||||||
if (iter == placeholders.end()) {
|
if (iter == placeholders.end()) {
|
||||||
NodeDefBuilder placeholder_builder(new_name, "Placeholder");
|
NodeDefBuilder placeholder_builder(new_name, "Placeholder");
|
||||||
@ -288,7 +289,7 @@ Status ProcessDataEdgeBetweenOutsideCompilationAndHostComputation(
|
|||||||
Status s;
|
Status s;
|
||||||
placeholder_node = g->AddNode(placeholder_def, &s);
|
placeholder_node = g->AddNode(placeholder_def, &s);
|
||||||
TF_RETURN_IF_ERROR(s);
|
TF_RETURN_IF_ERROR(s);
|
||||||
placeholders[new_name] = placeholder_node;
|
placeholders[placeholder_index] = placeholder_node;
|
||||||
} else {
|
} else {
|
||||||
placeholder_node = iter->second;
|
placeholder_node = iter->second;
|
||||||
}
|
}
|
||||||
@ -642,7 +643,7 @@ Status PreprocessDataEdgesBetweenOutsideCompilations(
|
|||||||
|
|
||||||
// Remove the edge from host to outside compilation. Add a placeholder as
|
// Remove the edge from host to outside compilation. Add a placeholder as
|
||||||
// outside compilation node input.
|
// outside compilation node input.
|
||||||
std::map<string, Node*> placeholders;
|
std::map<std::pair<string, int>, Node*> placeholders;
|
||||||
for (int i = 0; i < edges.size(); i++) {
|
for (int i = 0; i < edges.size(); i++) {
|
||||||
Node* dst = g->FindNodeId(edges[i].dst_node_id);
|
Node* dst = g->FindNodeId(edges[i].dst_node_id);
|
||||||
const Edge* e;
|
const Edge* e;
|
||||||
@ -652,8 +653,10 @@ Status PreprocessDataEdgesBetweenOutsideCompilations(
|
|||||||
g->RemoveEdge(e);
|
g->RemoveEdge(e);
|
||||||
|
|
||||||
// Find or create placeholder node.
|
// Find or create placeholder node.
|
||||||
string new_name = absl::StrCat(src->name(), "_oc_to_oc_placeholder");
|
string new_name =
|
||||||
auto iter = placeholders.find(new_name);
|
absl::StrCat(src->name(), "_oc_to_oc_placeholder_", src_output);
|
||||||
|
auto placeholder_index = std::make_pair(src->name(), src_output);
|
||||||
|
auto iter = placeholders.find(placeholder_index);
|
||||||
Node* placeholder_node;
|
Node* placeholder_node;
|
||||||
if (iter == placeholders.end()) {
|
if (iter == placeholders.end()) {
|
||||||
NodeDefBuilder placeholder_builder(new_name, "Placeholder");
|
NodeDefBuilder placeholder_builder(new_name, "Placeholder");
|
||||||
@ -673,7 +676,7 @@ Status PreprocessDataEdgesBetweenOutsideCompilations(
|
|||||||
Status s;
|
Status s;
|
||||||
placeholder_node = g->AddNode(placeholder_def, &s);
|
placeholder_node = g->AddNode(placeholder_def, &s);
|
||||||
TF_RETURN_IF_ERROR(s);
|
TF_RETURN_IF_ERROR(s);
|
||||||
placeholders[new_name] = placeholder_node;
|
placeholders[placeholder_index] = placeholder_node;
|
||||||
} else {
|
} else {
|
||||||
placeholder_node = iter->second;
|
placeholder_node = iter->second;
|
||||||
}
|
}
|
||||||
|
@ -153,23 +153,33 @@ TEST(PreprocessForEncapsulationTest, ControlEdges) {
|
|||||||
TEST(PreprocessForEncapsulationTest, DataEdges) {
|
TEST(PreprocessForEncapsulationTest, DataEdges) {
|
||||||
// Build the graph:
|
// Build the graph:
|
||||||
// "const_0" and "const_1" in host computation
|
// "const_0" and "const_1" in host computation
|
||||||
|
// "identityn0" = ("const_0", "const_1") in host computation 0
|
||||||
// "add0" = "const_0" + "const_1" in XLA computation 0
|
// "add0" = "const_0" + "const_1" in XLA computation 0
|
||||||
// "add1" = "add0" + "const_0" in XLA computation 0 & outside compilation 0
|
// "add1" = "add0" + "const_0" in XLA computation 0 & outside compilation 0
|
||||||
// "identity0" = "add1" in XLA computation 0
|
// "identity0" = "add1" in XLA computation 0
|
||||||
// "add2" = "add1" + "identity0" in host computation
|
// "add2" = "add1" + "identity0" in host computation
|
||||||
// "add3" = "add1" + "add2" in XLA computation 1
|
// "add3" = "add1" + "add2" in XLA computation 1
|
||||||
// "add4" = "identity0" + "add2" in XLA computation 1 & outside compilation 1
|
// "add4" = "identity0" + "add2" in XLA computation 1 & outside compilation 0
|
||||||
|
// "add5" = "identityn0"[0] + "identityn0"[1] in XLA computation 1 &
|
||||||
|
// outside compilation 0
|
||||||
|
// "identityn1" = ("identityn0"[0], "identityn0"[1]) in XLA computation 1 &
|
||||||
|
// outside compilation 0
|
||||||
// "identity1" = "add4" in XLA computation 1
|
// "identity1" = "add4" in XLA computation 1
|
||||||
// "identity2" = "identity1" in host computation
|
// "identity2" = "identity1" in host computation
|
||||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||||
Output const_0 = ops::Const(s.WithOpName("const_0"), 1, {});
|
Output const_0 = ops::Const(s.WithOpName("const_0"), 1, {});
|
||||||
Output const_1 = ops::Const(s.WithOpName("const_1"), 2, {});
|
Output const_1 = ops::Const(s.WithOpName("const_1"), 2, {});
|
||||||
|
auto identityn0 =
|
||||||
|
ops::IdentityN(s.WithOpName("identityn_0"), {const_0, const_1});
|
||||||
Output add0 = ops::Add(s.WithOpName("add0"), const_0, const_1);
|
Output add0 = ops::Add(s.WithOpName("add0"), const_0, const_1);
|
||||||
Output add1 = ops::Add(s.WithOpName("add1"), add0, const_0);
|
Output add1 = ops::Add(s.WithOpName("add1"), add0, const_0);
|
||||||
Output identity0 = ops::Identity(s.WithOpName("identity0"), add1);
|
Output identity0 = ops::Identity(s.WithOpName("identity0"), add1);
|
||||||
Output add2 = ops::Add(s.WithOpName("add2"), add1, identity0);
|
Output add2 = ops::Add(s.WithOpName("add2"), add1, identity0);
|
||||||
Output add3 = ops::Add(s.WithOpName("add3"), add1, add2);
|
Output add3 = ops::Add(s.WithOpName("add3"), add1, add2);
|
||||||
Output add4 = ops::Add(s.WithOpName("add4"), identity0, add2);
|
Output add4 = ops::Add(s.WithOpName("add4"), identity0, add2);
|
||||||
|
Output add5 = ops::Add(s.WithOpName("add5"), identityn0[0], identityn0[1]);
|
||||||
|
auto identityn1 = ops::IdentityN(s.WithOpName("identityn_1"),
|
||||||
|
{identityn0[0], identityn0[1]});
|
||||||
Output identity1 = ops::Identity(s.WithOpName("identity1"), add4);
|
Output identity1 = ops::Identity(s.WithOpName("identity1"), add4);
|
||||||
Output identity2 = ops::Identity(s.WithOpName("identity2"), add4);
|
Output identity2 = ops::Identity(s.WithOpName("identity2"), add4);
|
||||||
Graph g(OpRegistry::Global());
|
Graph g(OpRegistry::Global());
|
||||||
@ -180,6 +190,8 @@ TEST(PreprocessForEncapsulationTest, DataEdges) {
|
|||||||
Node *add0_node = node_index["add0"], *add1_node = node_index["add1"],
|
Node *add0_node = node_index["add0"], *add1_node = node_index["add1"],
|
||||||
*identity0_node = node_index["identity0"],
|
*identity0_node = node_index["identity0"],
|
||||||
*add3_node = node_index["add3"], *add4_node = node_index["add4"],
|
*add3_node = node_index["add3"], *add4_node = node_index["add4"],
|
||||||
|
*add5_node = node_index["add5"],
|
||||||
|
*identityn1_node = node_index["identityn_1"],
|
||||||
*identity1_node = node_index["identity1"];
|
*identity1_node = node_index["identity1"];
|
||||||
add0_node->AddAttr("_xla", "0");
|
add0_node->AddAttr("_xla", "0");
|
||||||
add1_node->AddAttr("_xla", "0");
|
add1_node->AddAttr("_xla", "0");
|
||||||
@ -188,6 +200,10 @@ TEST(PreprocessForEncapsulationTest, DataEdges) {
|
|||||||
add3_node->AddAttr("_xla", "1");
|
add3_node->AddAttr("_xla", "1");
|
||||||
add4_node->AddAttr("_xla", "1");
|
add4_node->AddAttr("_xla", "1");
|
||||||
add4_node->AddAttr("_oc", "0");
|
add4_node->AddAttr("_oc", "0");
|
||||||
|
add5_node->AddAttr("_xla", "1");
|
||||||
|
add5_node->AddAttr("_oc", "0");
|
||||||
|
identityn1_node->AddAttr("_xla", "1");
|
||||||
|
identityn1_node->AddAttr("_oc", "0");
|
||||||
identity1_node->AddAttr("_xla", "1");
|
identity1_node->AddAttr("_xla", "1");
|
||||||
|
|
||||||
TF_CHECK_OK(PreprocessForEncapsulation(&g, "_xla", "_oc"));
|
TF_CHECK_OK(PreprocessForEncapsulation(&g, "_xla", "_oc"));
|
||||||
@ -205,8 +221,9 @@ TEST(PreprocessForEncapsulationTest, DataEdges) {
|
|||||||
EXPECT_NE(bridge_identity0_add4, nullptr);
|
EXPECT_NE(bridge_identity0_add4, nullptr);
|
||||||
// Step 3: add placeholder for edges between host computation and outside
|
// Step 3: add placeholder for edges between host computation and outside
|
||||||
// compilation.
|
// compilation.
|
||||||
EXPECT_EQ(bridge_add1_add3->def().input(0), "add1_oc_to_host_placeholder");
|
EXPECT_EQ(bridge_add1_add3->def().input(0), "add1_oc_to_host_placeholder_0");
|
||||||
Node *add1_oc_to_host_placeholder = node_index["add1_oc_to_host_placeholder"];
|
Node *add1_oc_to_host_placeholder =
|
||||||
|
node_index["add1_oc_to_host_placeholder_0"];
|
||||||
TF_CHECK_OK(GetNodeAttr(add1_oc_to_host_placeholder->attrs(),
|
TF_CHECK_OK(GetNodeAttr(add1_oc_to_host_placeholder->attrs(),
|
||||||
kOutsideCompilationToHostOriginalNodeAttrName, &str));
|
kOutsideCompilationToHostOriginalNodeAttrName, &str));
|
||||||
EXPECT_EQ(str, "add1");
|
EXPECT_EQ(str, "add1");
|
||||||
@ -217,15 +234,34 @@ TEST(PreprocessForEncapsulationTest, DataEdges) {
|
|||||||
add4_node = node_index["add4"];
|
add4_node = node_index["add4"];
|
||||||
ASSERT_NE(add4_node, nullptr);
|
ASSERT_NE(add4_node, nullptr);
|
||||||
EXPECT_EQ(add4_node->def().input(0),
|
EXPECT_EQ(add4_node->def().input(0),
|
||||||
"bridge_identity0_add4_host_to_oc_placeholder");
|
"bridge_identity0_add4_host_to_oc_placeholder_0");
|
||||||
Node *identity0_host_to_oc_placeholder =
|
Node *identity0_host_to_oc_placeholder =
|
||||||
node_index["bridge_identity0_add4_host_to_oc_placeholder"];
|
node_index["bridge_identity0_add4_host_to_oc_placeholder_0"];
|
||||||
TF_CHECK_OK(GetNodeAttr(identity0_host_to_oc_placeholder->attrs(),
|
TF_CHECK_OK(GetNodeAttr(identity0_host_to_oc_placeholder->attrs(),
|
||||||
kHostToOutsideCompilationOriginalNodeAttrName, &str));
|
kHostToOutsideCompilationOriginalNodeAttrName, &str));
|
||||||
EXPECT_EQ(str, "bridge_identity0_add4");
|
EXPECT_EQ(str, "bridge_identity0_add4");
|
||||||
TF_CHECK_OK(GetNodeAttr(identity0_host_to_oc_placeholder->attrs(),
|
TF_CHECK_OK(GetNodeAttr(identity0_host_to_oc_placeholder->attrs(),
|
||||||
kHostToOutsideCompilationSrcOutputAttrName, &i));
|
kHostToOutsideCompilationSrcOutputAttrName, &i));
|
||||||
EXPECT_EQ(i, 0);
|
EXPECT_EQ(i, 0);
|
||||||
|
|
||||||
|
// Check different placeholder nodes are created for different src_output.
|
||||||
|
Node *placeholder0 = node_index["identityn_0_host_to_oc_placeholder_0"],
|
||||||
|
*placeholder1 = node_index["identityn_0_host_to_oc_placeholder_1"];
|
||||||
|
EXPECT_NE(placeholder0, nullptr);
|
||||||
|
EXPECT_NE(placeholder1, nullptr);
|
||||||
|
// Check we only have 2 placeholder nodes created for "identityn_0".
|
||||||
|
int placeholder_count = 0;
|
||||||
|
for (Node *n : g.nodes()) {
|
||||||
|
if (HasNodeAttr(n->def(), kHostToOutsideCompilationOriginalNodeAttrName)) {
|
||||||
|
string attr;
|
||||||
|
TF_CHECK_OK(GetNodeAttr(
|
||||||
|
n->attrs(), kHostToOutsideCompilationOriginalNodeAttrName, &attr));
|
||||||
|
if (attr == "identityn_0") {
|
||||||
|
++placeholder_count;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
EXPECT_EQ(placeholder_count, 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(PostprocessForEncapsulationTest, ControlEdges) {
|
TEST(PostprocessForEncapsulationTest, ControlEdges) {
|
||||||
|
Loading…
Reference in New Issue
Block a user