diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 39f97fff9a2..184afe3aa8e 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -629,6 +629,16 @@ Status FindCompilationCandidates( if (node->type_string() == "_Retval") { continue; } + + if (node->attrs().Find("_scoped_allocator") || + node->attrs().Find("_forward_from")) { + // TODO(b/128858118): XLA does not support _scoped_allocator and + // _forward_from. + VLOG(2) << "Not clustering " << node->name() + << " because of _scoped_allocator or _forward_from attribute."; + continue; + } + candidates->insert(node); --fuel; } diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index 897ea0711bc..0129f3a486d 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -1354,5 +1354,59 @@ TEST(XlaCompilationTest, DontClusterResourceOpsWhenUnsafe) { EXPECT_EQ(clusters[resource_read_name], ""); } +TEST(XlaCompilationTest, DontClusterNodesWithScopedAllocatorAttr) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT); + Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT); + + Output x = ops::Add(root.WithOpName("test/x"), a, b); + Output y = ops::MatMul(root.WithOpName("test/y"), a, b); + Output z = ops::Add(root.WithOpName("test/z"), x, y); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + FindNodeByName(graph.get(), "test/x")->set_assigned_device_name(kGPU0); + FindNodeByName(graph.get(), "test/y")->set_assigned_device_name(kGPU0); + FindNodeByName(graph.get(), "test/z")->set_assigned_device_name(kGPU0); + + std::vector scoped_allocator_value; + scoped_allocator_value.push_back(0); + scoped_allocator_value.push_back(155); + FindNodeByName(graph.get(), "test/z") + ->AddAttr("_scoped_allocator", scoped_allocator_value); + + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + + std::unordered_map clusters = GetClusters(*graph); + + EXPECT_EQ(clusters["test/z"], ""); +} + +TEST(XlaCompilationTest, DontClusterNodesWithForwardFromAttr) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT); + Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT); + + Output x = ops::Add(root.WithOpName("test/x"), a, b); + Output y = ops::MatMul(root.WithOpName("test/y"), a, b); + Output z = ops::Add(root.WithOpName("test/z"), x, y); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + FindNodeByName(graph.get(), "test/x")->set_assigned_device_name(kGPU0); + FindNodeByName(graph.get(), "test/y")->set_assigned_device_name(kGPU0); + FindNodeByName(graph.get(), "test/z")->set_assigned_device_name(kGPU0); + + FindNodeByName(graph.get(), "test/z")->AddAttr("_forward_from", 0); + + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + + std::unordered_map clusters = GetClusters(*graph); + + EXPECT_EQ(clusters["test/z"], ""); +} + } // namespace } // namespace tensorflow