diff --git a/tensorflow/compiler/jit/cluster_scoping_pass.cc b/tensorflow/compiler/jit/cluster_scoping_pass.cc index 12c8c3f0217..d51a872f898 100644 --- a/tensorflow/compiler/jit/cluster_scoping_pass.cc +++ b/tensorflow/compiler/jit/cluster_scoping_pass.cc @@ -54,7 +54,7 @@ class ClusterScopingPassImpl { absl::optional GetXlaScope(Node* node) { string scope; - if (GetNodeAttr(node->attrs(), kXlaScopeAttr, &scope).ok()) { + if (GetNodeAttr(node->attrs(), kXlaAutoJitScopeAttr, &scope).ok()) { return scope; } @@ -62,7 +62,7 @@ absl::optional GetXlaScope(Node* node) { } void SetXlaScope(Node* node, StringPiece scope) { - node->AddAttr(kXlaScopeAttr, scope); + node->AddAttr(kXlaAutoJitScopeAttr, scope); } // NB! We append new scope as suffix to the XlaScope attribute instead of diff --git a/tensorflow/compiler/jit/cluster_scoping_pass_test.cc b/tensorflow/compiler/jit/cluster_scoping_pass_test.cc index 484a40cb8e1..9653d1e65bb 100644 --- a/tensorflow/compiler/jit/cluster_scoping_pass_test.cc +++ b/tensorflow/compiler/jit/cluster_scoping_pass_test.cc @@ -53,7 +53,7 @@ absl::flat_hash_map GetXlaScopes(const Graph& graph) { absl::flat_hash_map scopes; for (Node* node : graph.nodes()) { string scope; - if (GetNodeAttr(node->attrs(), kXlaScopeAttr, &scope).ok()) { + if (GetNodeAttr(node->attrs(), kXlaAutoJitScopeAttr, &scope).ok()) { scopes[node->name()] = scope; } } @@ -85,11 +85,15 @@ TEST(XlaCompilationTest, StagePipelinePreserved) { std::unique_ptr graph(new Graph(OpRegistry::Global())); { // Graph: - // a -> - // b -> add0 (ClusterX) -> relu0 (ClusterX) -> stage + // b + // | + // v + // a -> add0 (ClusterX) -> relu0 (ClusterX) -> stage // - // unstage -> - // b -> add1 (ClusterY) -> relu1 (ClusterY) + // b + // | + // v + // unstage -> add1 (ClusterY) -> relu1 (ClusterY) GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); Node* a = ops::SourceOp("Const", builder.opts() .WithName("a") @@ -125,11 +129,15 @@ TEST(XlaCompilationTest, StagePipelinePreservedAndInitialScopesRespected) { std::unique_ptr graph(new Graph(OpRegistry::Global())); { // Graph: - // a -> - // b -> add0 (ClusterA) -> relu0 (ClusterB) -> stage + // b + // | + // v + // a -> add0 (ClusterA) -> relu0 (ClusterB) -> stage // - // unstage -> - // b -> add1 (ClusterC) -> relu1 (ClusterD) + // b + // | + // v + // unstage -> add1 (ClusterC) -> relu1 (ClusterD) GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); Node* a = ops::SourceOp("Const", builder.opts() .WithName("a") @@ -145,17 +153,17 @@ TEST(XlaCompilationTest, StagePipelinePreservedAndInitialScopesRespected) { // Intentionally give add0 and add1 the same initial scope but they should // be separated by the ClusterScopingPass. - Node* add0 = ops::BinaryOp( - "Add", a, b, - builder.opts().WithName("add0").WithAttr(kXlaScopeAttr, "ClusterA")); - Node* add1 = ops::BinaryOp( - "Add", unstage, b, - builder.opts().WithName("add1").WithAttr(kXlaScopeAttr, "ClusterA")); - Node* relu0 = ops::UnaryOp( - "Relu", add0, - builder.opts().WithName("relu0").WithAttr(kXlaScopeAttr, "ClusterB")); + Node* add0 = + ops::BinaryOp("Add", a, b, builder.opts().WithName("add0").WithAttr( + kXlaAutoJitScopeAttr, "ClusterA")); + Node* add1 = ops::BinaryOp("Add", unstage, b, + builder.opts().WithName("add1").WithAttr( + kXlaAutoJitScopeAttr, "ClusterA")); + Node* relu0 = + ops::UnaryOp("Relu", add0, builder.opts().WithName("relu0").WithAttr( + kXlaAutoJitScopeAttr, "ClusterB")); ops::UnaryOp("Relu", add1, builder.opts().WithName("relu1").WithAttr( - kXlaScopeAttr, "ClusterD")); + kXlaAutoJitScopeAttr, "ClusterD")); BuildStageNode(builder, "stage", {DT_FLOAT}, {relu0}); TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); diff --git a/tensorflow/compiler/jit/defs.cc b/tensorflow/compiler/jit/defs.cc index f847d66f3c6..e71011d8c5d 100644 --- a/tensorflow/compiler/jit/defs.cc +++ b/tensorflow/compiler/jit/defs.cc @@ -18,6 +18,12 @@ limitations under the License. namespace tensorflow { const char* const kXlaCompileAttr = "_XlaCompile"; + +// User-provided through jit_scope. Effective when auto_jit is OFF. const char* const kXlaScopeAttr = "_XlaScope"; +// Automatically inserted by auto_jit to guide clustering results. Effective +// when auto_jit is ON. +const char* const kXlaAutoJitScopeAttr = "_XlaAutoJitScope"; + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/defs.h b/tensorflow/compiler/jit/defs.h index a3aabc949db..dcaa4e03ec6 100644 --- a/tensorflow/compiler/jit/defs.h +++ b/tensorflow/compiler/jit/defs.h @@ -24,6 +24,7 @@ namespace tensorflow { // Name of attribute used to tag operators for compilation with XLA extern const char* const kXlaCompileAttr; // "_XlaCompile" extern const char* const kXlaScopeAttr; // "_XlaScope" +extern const char* const kXlaAutoJitScopeAttr; // "_XlaAutoJitScope" } // namespace tensorflow diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index b86ef934b45..8cfa69ab768 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -923,20 +923,37 @@ MarkForCompilationPassImpl::ClusteringWillIntroduceInterDeviceDependency( } absl::optional MarkForCompilationPassImpl::GetXlaScope(Node* node) { - // Look for an _XlaScope on both nodes. If both nodes have a scope and the - // scopes do not match, do not cluster along this edge. This restriction is - // overridden if the global_jit_level_ is ON. If even one of the nodes lacks - // an _XlaScope attribute, then it is treated as a "bridge" and a cluster may - // be created along it. We may want to restrict this behavior to require all - // nodes marked with _XlaCompile=true to also have a _XlaScope property set - // (and raise an error otherwise); but for now we don't do this. - if (global_jit_level_ != OptimizerOptions::OFF) { - return absl::nullopt; - } + // Look for either _XlaScope or _XlaAutoJitScope on both nodes to guide + // clustering. If both nodes have a scope and the scopes do not match, do + // not cluster along this edge. If even one of the nodes lacks a scope + // attribute, then it is treated as a "bridge" and a cluster may be created + // along it. + // + // The difference between _XlaScope and _XlaAutoJitScope is that _XlaScope is + // provided by users through jit_scope APIs, while _XlaAutoJitScope is + // automatically generated by the ClusterScopingPass when auto_jit is on. As + // such, we want to respect _kXlaScope when auto_jit is off, while respecting + // _kXlaAutoJitScope when auto_jit is on. + // + // We may want to restrict the _XlaScope behavior to require all nodes marked + // with _XlaCompile=true to also have a _XlaScope property set (and raise an + // error otherwise); but for now we don't do this. - const string& scope = GetNodeAttrString(node->attrs(), kXlaScopeAttr); - if (!scope.empty()) { - return scope; + if (global_jit_level_ != OptimizerOptions::OFF) { + // If global_jit_level_ is ON, respect kXlaAutoJitScope (and ignore + // kXlaScope). + const string& scope = + GetNodeAttrString(node->attrs(), kXlaAutoJitScopeAttr); + if (!scope.empty()) { + return scope; + } + } else { + // If global_jit_level_ is OFF, respect kXlaScope (and ignore + // kXlaAutoJitScope). + const string& scope = GetNodeAttrString(node->attrs(), kXlaScopeAttr); + if (!scope.empty()) { + return scope; + } } return absl::nullopt;