Don't auto-cluster functions with _scoped_allocator annotations
PiperOrigin-RevId: 239068734
This commit is contained in:
parent
2e7bf1d595
commit
b155f85e87
@ -629,6 +629,16 @@ Status FindCompilationCandidates(
|
|||||||
if (node->type_string() == "_Retval") {
|
if (node->type_string() == "_Retval") {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (node->attrs().Find("_scoped_allocator") ||
|
||||||
|
node->attrs().Find("_forward_from")) {
|
||||||
|
// TODO(b/128858118): XLA does not support _scoped_allocator and
|
||||||
|
// _forward_from.
|
||||||
|
VLOG(2) << "Not clustering " << node->name()
|
||||||
|
<< " because of _scoped_allocator or _forward_from attribute.";
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
candidates->insert(node);
|
candidates->insert(node);
|
||||||
--fuel;
|
--fuel;
|
||||||
}
|
}
|
||||||
|
@ -1354,5 +1354,59 @@ TEST(XlaCompilationTest, DontClusterResourceOpsWhenUnsafe) {
|
|||||||
EXPECT_EQ(clusters[resource_read_name], "");
|
EXPECT_EQ(clusters[resource_read_name], "");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(XlaCompilationTest, DontClusterNodesWithScopedAllocatorAttr) {
|
||||||
|
Scope root = Scope::NewRootScope().ExitOnError();
|
||||||
|
Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
|
||||||
|
Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
|
||||||
|
|
||||||
|
Output x = ops::Add(root.WithOpName("test/x"), a, b);
|
||||||
|
Output y = ops::MatMul(root.WithOpName("test/y"), a, b);
|
||||||
|
Output z = ops::Add(root.WithOpName("test/z"), x, y);
|
||||||
|
|
||||||
|
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||||
|
TF_ASSERT_OK(root.ToGraph(graph.get()));
|
||||||
|
|
||||||
|
FindNodeByName(graph.get(), "test/x")->set_assigned_device_name(kGPU0);
|
||||||
|
FindNodeByName(graph.get(), "test/y")->set_assigned_device_name(kGPU0);
|
||||||
|
FindNodeByName(graph.get(), "test/z")->set_assigned_device_name(kGPU0);
|
||||||
|
|
||||||
|
std::vector<int> scoped_allocator_value;
|
||||||
|
scoped_allocator_value.push_back(0);
|
||||||
|
scoped_allocator_value.push_back(155);
|
||||||
|
FindNodeByName(graph.get(), "test/z")
|
||||||
|
->AddAttr("_scoped_allocator", scoped_allocator_value);
|
||||||
|
|
||||||
|
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
|
||||||
|
|
||||||
|
std::unordered_map<string, string> clusters = GetClusters(*graph);
|
||||||
|
|
||||||
|
EXPECT_EQ(clusters["test/z"], "");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(XlaCompilationTest, DontClusterNodesWithForwardFromAttr) {
|
||||||
|
Scope root = Scope::NewRootScope().ExitOnError();
|
||||||
|
Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
|
||||||
|
Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
|
||||||
|
|
||||||
|
Output x = ops::Add(root.WithOpName("test/x"), a, b);
|
||||||
|
Output y = ops::MatMul(root.WithOpName("test/y"), a, b);
|
||||||
|
Output z = ops::Add(root.WithOpName("test/z"), x, y);
|
||||||
|
|
||||||
|
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||||
|
TF_ASSERT_OK(root.ToGraph(graph.get()));
|
||||||
|
|
||||||
|
FindNodeByName(graph.get(), "test/x")->set_assigned_device_name(kGPU0);
|
||||||
|
FindNodeByName(graph.get(), "test/y")->set_assigned_device_name(kGPU0);
|
||||||
|
FindNodeByName(graph.get(), "test/z")->set_assigned_device_name(kGPU0);
|
||||||
|
|
||||||
|
FindNodeByName(graph.get(), "test/z")->AddAttr("_forward_from", 0);
|
||||||
|
|
||||||
|
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
|
||||||
|
|
||||||
|
std::unordered_map<string, string> clusters = GetClusters(*graph);
|
||||||
|
|
||||||
|
EXPECT_EQ(clusters["test/z"], "");
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
Loading…
Reference in New Issue
Block a user