Add a new attribute _XlaAutoJitScope for cluster scoping.

In other words, do not use the _XlaScope attribute, which has been used
when auto_jit is off. In this way, we do not change any existing clustering
behaviors and the new attribute is used only when auto_jit is on.
This commit is contained in:
Trent Lo 2019-08-15 16:47:28 -07:00
parent 4a4ecd789c
commit dda35dfbd8
5 changed files with 66 additions and 34 deletions

View File

@ -54,7 +54,7 @@ class ClusterScopingPassImpl {
absl::optional<string> GetXlaScope(Node* node) {
string scope;
if (GetNodeAttr(node->attrs(), kXlaScopeAttr, &scope).ok()) {
if (GetNodeAttr(node->attrs(), kXlaAutoJitScopeAttr, &scope).ok()) {
return scope;
}
@ -62,7 +62,7 @@ absl::optional<string> GetXlaScope(Node* node) {
}
void SetXlaScope(Node* node, StringPiece scope) {
node->AddAttr(kXlaScopeAttr, scope);
node->AddAttr(kXlaAutoJitScopeAttr, scope);
}
// NB! We append new scope as suffix to the XlaScope attribute instead of

View File

@ -53,7 +53,7 @@ absl::flat_hash_map<string, string> GetXlaScopes(const Graph& graph) {
absl::flat_hash_map<string, string> scopes;
for (Node* node : graph.nodes()) {
string scope;
if (GetNodeAttr(node->attrs(), kXlaScopeAttr, &scope).ok()) {
if (GetNodeAttr(node->attrs(), kXlaAutoJitScopeAttr, &scope).ok()) {
scopes[node->name()] = scope;
}
}
@ -85,11 +85,15 @@ TEST(XlaCompilationTest, StagePipelinePreserved) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
{
// Graph:
// a ->
// b -> add0 (ClusterX) -> relu0 (ClusterX) -> stage
// b
// |
// v
// a -> add0 (ClusterX) -> relu0 (ClusterX) -> stage
//
// unstage ->
// b -> add1 (ClusterY) -> relu1 (ClusterY)
// b
// |
// v
// unstage -> add1 (ClusterY) -> relu1 (ClusterY)
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
Node* a = ops::SourceOp("Const", builder.opts()
.WithName("a")
@ -125,11 +129,15 @@ TEST(XlaCompilationTest, StagePipelinePreservedAndInitialScopesRespected) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
{
// Graph:
// a ->
// b -> add0 (ClusterA) -> relu0 (ClusterB) -> stage
// b
// |
// v
// a -> add0 (ClusterA) -> relu0 (ClusterB) -> stage
//
// unstage ->
// b -> add1 (ClusterC) -> relu1 (ClusterD)
// b
// |
// v
// unstage -> add1 (ClusterC) -> relu1 (ClusterD)
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
Node* a = ops::SourceOp("Const", builder.opts()
.WithName("a")
@ -145,17 +153,17 @@ TEST(XlaCompilationTest, StagePipelinePreservedAndInitialScopesRespected) {
// Intentionally give add0 and add1 the same initial scope but they should
// be separated by the ClusterScopingPass.
Node* add0 = ops::BinaryOp(
"Add", a, b,
builder.opts().WithName("add0").WithAttr(kXlaScopeAttr, "ClusterA"));
Node* add1 = ops::BinaryOp(
"Add", unstage, b,
builder.opts().WithName("add1").WithAttr(kXlaScopeAttr, "ClusterA"));
Node* relu0 = ops::UnaryOp(
"Relu", add0,
builder.opts().WithName("relu0").WithAttr(kXlaScopeAttr, "ClusterB"));
Node* add0 =
ops::BinaryOp("Add", a, b, builder.opts().WithName("add0").WithAttr(
kXlaAutoJitScopeAttr, "ClusterA"));
Node* add1 = ops::BinaryOp("Add", unstage, b,
builder.opts().WithName("add1").WithAttr(
kXlaAutoJitScopeAttr, "ClusterA"));
Node* relu0 =
ops::UnaryOp("Relu", add0, builder.opts().WithName("relu0").WithAttr(
kXlaAutoJitScopeAttr, "ClusterB"));
ops::UnaryOp("Relu", add1, builder.opts().WithName("relu1").WithAttr(
kXlaScopeAttr, "ClusterD"));
kXlaAutoJitScopeAttr, "ClusterD"));
BuildStageNode(builder, "stage", {DT_FLOAT}, {relu0});
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));

View File

@ -18,6 +18,12 @@ limitations under the License.
namespace tensorflow {
const char* const kXlaCompileAttr = "_XlaCompile";
// User-provided through jit_scope. Effective when auto_jit is OFF.
const char* const kXlaScopeAttr = "_XlaScope";
// Automatically inserted by auto_jit to guide clustering results. Effective
// when auto_jit is ON.
const char* const kXlaAutoJitScopeAttr = "_XlaAutoJitScope";
} // namespace tensorflow

View File

@ -24,6 +24,7 @@ namespace tensorflow {
// Name of attribute used to tag operators for compilation with XLA
extern const char* const kXlaCompileAttr; // "_XlaCompile"
extern const char* const kXlaScopeAttr; // "_XlaScope"
extern const char* const kXlaAutoJitScopeAttr; // "_XlaAutoJitScope"
} // namespace tensorflow

View File

@ -923,20 +923,37 @@ MarkForCompilationPassImpl::ClusteringWillIntroduceInterDeviceDependency(
}
absl::optional<string> MarkForCompilationPassImpl::GetXlaScope(Node* node) {
// Look for an _XlaScope on both nodes. If both nodes have a scope and the
// scopes do not match, do not cluster along this edge. This restriction is
// overridden if the global_jit_level_ is ON. If even one of the nodes lacks
// an _XlaScope attribute, then it is treated as a "bridge" and a cluster may
// be created along it. We may want to restrict this behavior to require all
// nodes marked with _XlaCompile=true to also have a _XlaScope property set
// (and raise an error otherwise); but for now we don't do this.
if (global_jit_level_ != OptimizerOptions::OFF) {
return absl::nullopt;
}
// Look for either _XlaScope or _XlaAutoJitScope on both nodes to guide
// clustering. If both nodes have a scope and the scopes do not match, do
// not cluster along this edge. If even one of the nodes lacks a scope
// attribute, then it is treated as a "bridge" and a cluster may be created
// along it.
//
// The difference between _XlaScope and _XlaAutoJitScope is that _XlaScope is
// provided by users through jit_scope APIs, while _XlaAutoJitScope is
// automatically generated by the ClusterScopingPass when auto_jit is on. As
// such, we want to respect _kXlaScope when auto_jit is off, while respecting
// _kXlaAutoJitScope when auto_jit is on.
//
// We may want to restrict the _XlaScope behavior to require all nodes marked
// with _XlaCompile=true to also have a _XlaScope property set (and raise an
// error otherwise); but for now we don't do this.
const string& scope = GetNodeAttrString(node->attrs(), kXlaScopeAttr);
if (!scope.empty()) {
return scope;
if (global_jit_level_ != OptimizerOptions::OFF) {
// If global_jit_level_ is ON, respect kXlaAutoJitScope (and ignore
// kXlaScope).
const string& scope =
GetNodeAttrString(node->attrs(), kXlaAutoJitScopeAttr);
if (!scope.empty()) {
return scope;
}
} else {
// If global_jit_level_ is OFF, respect kXlaScope (and ignore
// kXlaAutoJitScope).
const string& scope = GetNodeAttrString(node->attrs(), kXlaScopeAttr);
if (!scope.empty()) {
return scope;
}
}
return absl::nullopt;