Add a new attribute _XlaAutoJitScope for cluster scoping.
In other words, do not use the _XlaScope attribute, which has been used when auto_jit is off. In this way, we do not change any existing clustering behaviors and the new attribute is used only when auto_jit is on.
This commit is contained in:
parent
4a4ecd789c
commit
dda35dfbd8
@ -54,7 +54,7 @@ class ClusterScopingPassImpl {
|
||||
|
||||
absl::optional<string> GetXlaScope(Node* node) {
|
||||
string scope;
|
||||
if (GetNodeAttr(node->attrs(), kXlaScopeAttr, &scope).ok()) {
|
||||
if (GetNodeAttr(node->attrs(), kXlaAutoJitScopeAttr, &scope).ok()) {
|
||||
return scope;
|
||||
}
|
||||
|
||||
@ -62,7 +62,7 @@ absl::optional<string> GetXlaScope(Node* node) {
|
||||
}
|
||||
|
||||
void SetXlaScope(Node* node, StringPiece scope) {
|
||||
node->AddAttr(kXlaScopeAttr, scope);
|
||||
node->AddAttr(kXlaAutoJitScopeAttr, scope);
|
||||
}
|
||||
|
||||
// NB! We append new scope as suffix to the XlaScope attribute instead of
|
||||
|
@ -53,7 +53,7 @@ absl::flat_hash_map<string, string> GetXlaScopes(const Graph& graph) {
|
||||
absl::flat_hash_map<string, string> scopes;
|
||||
for (Node* node : graph.nodes()) {
|
||||
string scope;
|
||||
if (GetNodeAttr(node->attrs(), kXlaScopeAttr, &scope).ok()) {
|
||||
if (GetNodeAttr(node->attrs(), kXlaAutoJitScopeAttr, &scope).ok()) {
|
||||
scopes[node->name()] = scope;
|
||||
}
|
||||
}
|
||||
@ -85,11 +85,15 @@ TEST(XlaCompilationTest, StagePipelinePreserved) {
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
{
|
||||
// Graph:
|
||||
// a ->
|
||||
// b -> add0 (ClusterX) -> relu0 (ClusterX) -> stage
|
||||
// b
|
||||
// |
|
||||
// v
|
||||
// a -> add0 (ClusterX) -> relu0 (ClusterX) -> stage
|
||||
//
|
||||
// unstage ->
|
||||
// b -> add1 (ClusterY) -> relu1 (ClusterY)
|
||||
// b
|
||||
// |
|
||||
// v
|
||||
// unstage -> add1 (ClusterY) -> relu1 (ClusterY)
|
||||
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
|
||||
Node* a = ops::SourceOp("Const", builder.opts()
|
||||
.WithName("a")
|
||||
@ -125,11 +129,15 @@ TEST(XlaCompilationTest, StagePipelinePreservedAndInitialScopesRespected) {
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
{
|
||||
// Graph:
|
||||
// a ->
|
||||
// b -> add0 (ClusterA) -> relu0 (ClusterB) -> stage
|
||||
// b
|
||||
// |
|
||||
// v
|
||||
// a -> add0 (ClusterA) -> relu0 (ClusterB) -> stage
|
||||
//
|
||||
// unstage ->
|
||||
// b -> add1 (ClusterC) -> relu1 (ClusterD)
|
||||
// b
|
||||
// |
|
||||
// v
|
||||
// unstage -> add1 (ClusterC) -> relu1 (ClusterD)
|
||||
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
|
||||
Node* a = ops::SourceOp("Const", builder.opts()
|
||||
.WithName("a")
|
||||
@ -145,17 +153,17 @@ TEST(XlaCompilationTest, StagePipelinePreservedAndInitialScopesRespected) {
|
||||
|
||||
// Intentionally give add0 and add1 the same initial scope but they should
|
||||
// be separated by the ClusterScopingPass.
|
||||
Node* add0 = ops::BinaryOp(
|
||||
"Add", a, b,
|
||||
builder.opts().WithName("add0").WithAttr(kXlaScopeAttr, "ClusterA"));
|
||||
Node* add1 = ops::BinaryOp(
|
||||
"Add", unstage, b,
|
||||
builder.opts().WithName("add1").WithAttr(kXlaScopeAttr, "ClusterA"));
|
||||
Node* relu0 = ops::UnaryOp(
|
||||
"Relu", add0,
|
||||
builder.opts().WithName("relu0").WithAttr(kXlaScopeAttr, "ClusterB"));
|
||||
Node* add0 =
|
||||
ops::BinaryOp("Add", a, b, builder.opts().WithName("add0").WithAttr(
|
||||
kXlaAutoJitScopeAttr, "ClusterA"));
|
||||
Node* add1 = ops::BinaryOp("Add", unstage, b,
|
||||
builder.opts().WithName("add1").WithAttr(
|
||||
kXlaAutoJitScopeAttr, "ClusterA"));
|
||||
Node* relu0 =
|
||||
ops::UnaryOp("Relu", add0, builder.opts().WithName("relu0").WithAttr(
|
||||
kXlaAutoJitScopeAttr, "ClusterB"));
|
||||
ops::UnaryOp("Relu", add1, builder.opts().WithName("relu1").WithAttr(
|
||||
kXlaScopeAttr, "ClusterD"));
|
||||
kXlaAutoJitScopeAttr, "ClusterD"));
|
||||
BuildStageNode(builder, "stage", {DT_FLOAT}, {relu0});
|
||||
|
||||
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
|
||||
|
@ -18,6 +18,12 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
|
||||
const char* const kXlaCompileAttr = "_XlaCompile";
|
||||
|
||||
// User-provided through jit_scope. Effective when auto_jit is OFF.
|
||||
const char* const kXlaScopeAttr = "_XlaScope";
|
||||
|
||||
// Automatically inserted by auto_jit to guide clustering results. Effective
|
||||
// when auto_jit is ON.
|
||||
const char* const kXlaAutoJitScopeAttr = "_XlaAutoJitScope";
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -24,6 +24,7 @@ namespace tensorflow {
|
||||
// Name of attribute used to tag operators for compilation with XLA
|
||||
extern const char* const kXlaCompileAttr; // "_XlaCompile"
|
||||
extern const char* const kXlaScopeAttr; // "_XlaScope"
|
||||
extern const char* const kXlaAutoJitScopeAttr; // "_XlaAutoJitScope"
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -923,20 +923,37 @@ MarkForCompilationPassImpl::ClusteringWillIntroduceInterDeviceDependency(
|
||||
}
|
||||
|
||||
absl::optional<string> MarkForCompilationPassImpl::GetXlaScope(Node* node) {
|
||||
// Look for an _XlaScope on both nodes. If both nodes have a scope and the
|
||||
// scopes do not match, do not cluster along this edge. This restriction is
|
||||
// overridden if the global_jit_level_ is ON. If even one of the nodes lacks
|
||||
// an _XlaScope attribute, then it is treated as a "bridge" and a cluster may
|
||||
// be created along it. We may want to restrict this behavior to require all
|
||||
// nodes marked with _XlaCompile=true to also have a _XlaScope property set
|
||||
// (and raise an error otherwise); but for now we don't do this.
|
||||
if (global_jit_level_ != OptimizerOptions::OFF) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
// Look for either _XlaScope or _XlaAutoJitScope on both nodes to guide
|
||||
// clustering. If both nodes have a scope and the scopes do not match, do
|
||||
// not cluster along this edge. If even one of the nodes lacks a scope
|
||||
// attribute, then it is treated as a "bridge" and a cluster may be created
|
||||
// along it.
|
||||
//
|
||||
// The difference between _XlaScope and _XlaAutoJitScope is that _XlaScope is
|
||||
// provided by users through jit_scope APIs, while _XlaAutoJitScope is
|
||||
// automatically generated by the ClusterScopingPass when auto_jit is on. As
|
||||
// such, we want to respect _kXlaScope when auto_jit is off, while respecting
|
||||
// _kXlaAutoJitScope when auto_jit is on.
|
||||
//
|
||||
// We may want to restrict the _XlaScope behavior to require all nodes marked
|
||||
// with _XlaCompile=true to also have a _XlaScope property set (and raise an
|
||||
// error otherwise); but for now we don't do this.
|
||||
|
||||
const string& scope = GetNodeAttrString(node->attrs(), kXlaScopeAttr);
|
||||
if (!scope.empty()) {
|
||||
return scope;
|
||||
if (global_jit_level_ != OptimizerOptions::OFF) {
|
||||
// If global_jit_level_ is ON, respect kXlaAutoJitScope (and ignore
|
||||
// kXlaScope).
|
||||
const string& scope =
|
||||
GetNodeAttrString(node->attrs(), kXlaAutoJitScopeAttr);
|
||||
if (!scope.empty()) {
|
||||
return scope;
|
||||
}
|
||||
} else {
|
||||
// If global_jit_level_ is OFF, respect kXlaScope (and ignore
|
||||
// kXlaAutoJitScope).
|
||||
const string& scope = GetNodeAttrString(node->attrs(), kXlaScopeAttr);
|
||||
if (!scope.empty()) {
|
||||
return scope;
|
||||
}
|
||||
}
|
||||
|
||||
return absl::nullopt;
|
||||
|
Loading…
x
Reference in New Issue
Block a user