Don't auto-cluster functions with _scoped_allocator annotations

PiperOrigin-RevId: 239068734
This commit is contained in:
Sanjoy Das 2019-03-18 15:02:20 -07:00 committed by TensorFlower Gardener
parent 2e7bf1d595
commit b155f85e87
2 changed files with 64 additions and 0 deletions

View File

@ -629,6 +629,16 @@ Status FindCompilationCandidates(
if (node->type_string() == "_Retval") {
continue;
}
if (node->attrs().Find("_scoped_allocator") ||
node->attrs().Find("_forward_from")) {
// TODO(b/128858118): XLA does not support _scoped_allocator and
// _forward_from.
VLOG(2) << "Not clustering " << node->name()
<< " because of _scoped_allocator or _forward_from attribute.";
continue;
}
candidates->insert(node);
--fuel;
}

View File

@ -1354,5 +1354,59 @@ TEST(XlaCompilationTest, DontClusterResourceOpsWhenUnsafe) {
EXPECT_EQ(clusters[resource_read_name], "");
}
TEST(XlaCompilationTest, DontClusterNodesWithScopedAllocatorAttr) {
Scope root = Scope::NewRootScope().ExitOnError();
Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
Output x = ops::Add(root.WithOpName("test/x"), a, b);
Output y = ops::MatMul(root.WithOpName("test/y"), a, b);
Output z = ops::Add(root.WithOpName("test/z"), x, y);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(root.ToGraph(graph.get()));
FindNodeByName(graph.get(), "test/x")->set_assigned_device_name(kGPU0);
FindNodeByName(graph.get(), "test/y")->set_assigned_device_name(kGPU0);
FindNodeByName(graph.get(), "test/z")->set_assigned_device_name(kGPU0);
std::vector<int> scoped_allocator_value;
scoped_allocator_value.push_back(0);
scoped_allocator_value.push_back(155);
FindNodeByName(graph.get(), "test/z")
->AddAttr("_scoped_allocator", scoped_allocator_value);
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
std::unordered_map<string, string> clusters = GetClusters(*graph);
EXPECT_EQ(clusters["test/z"], "");
}
TEST(XlaCompilationTest, DontClusterNodesWithForwardFromAttr) {
Scope root = Scope::NewRootScope().ExitOnError();
Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
Output x = ops::Add(root.WithOpName("test/x"), a, b);
Output y = ops::MatMul(root.WithOpName("test/y"), a, b);
Output z = ops::Add(root.WithOpName("test/z"), x, y);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(root.ToGraph(graph.get()));
FindNodeByName(graph.get(), "test/x")->set_assigned_device_name(kGPU0);
FindNodeByName(graph.get(), "test/y")->set_assigned_device_name(kGPU0);
FindNodeByName(graph.get(), "test/z")->set_assigned_device_name(kGPU0);
FindNodeByName(graph.get(), "test/z")->AddAttr("_forward_from", 0);
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
std::unordered_map<string, string> clusters = GetClusters(*graph);
EXPECT_EQ(clusters["test/z"], "");
}
} // namespace
} // namespace tensorflow