[TF:XLA] Polish naming and comments for cluster_scoping_pass.

This commit is contained in:
Trent Lo 2019-08-15 17:23:31 -07:00
parent dda35dfbd8
commit 14bb933f42
5 changed files with 28 additions and 29 deletions

View File

@ -52,7 +52,7 @@ class ClusterScopingPassImpl {
size_t unique_scope_id_; size_t unique_scope_id_;
}; };
absl::optional<string> GetXlaScope(Node* node) { absl::optional<string> GetXlaAutoJitScope(Node* node) {
string scope; string scope;
if (GetNodeAttr(node->attrs(), kXlaAutoJitScopeAttr, &scope).ok()) { if (GetNodeAttr(node->attrs(), kXlaAutoJitScopeAttr, &scope).ok()) {
return scope; return scope;
@ -61,23 +61,23 @@ absl::optional<string> GetXlaScope(Node* node) {
return absl::nullopt; return absl::nullopt;
} }
void SetXlaScope(Node* node, StringPiece scope) { void SetXlaAutoJitScope(Node* node, StringPiece scope) {
node->AddAttr(kXlaAutoJitScopeAttr, scope); node->AddAttr(kXlaAutoJitScopeAttr, scope);
} }
// NB! We append new scope as suffix to the XlaScope attribute instead of // NB! We append a new scope as suffix to the XlaAutoJitScope attribute instead
// overriding the old value. In this way, we respect the original scopes. // of overriding the old value. In this way, we respect the original scopes.
// In other words, appending X to Y creates the conjunction of the scopes X // In other words, appending X to Y creates the conjunction of the scopes X
// and Y (i.e, X & Y in effect). // and Y (i.e, X & Y in effect).
void AddOrAppendScope(Node* node, absl::string_view suffix) { void AddOrAppendXlaAutoJitScope(Node* node, absl::string_view suffix) {
string updated_scope; string updated_scope;
absl::optional<string> cur_scope = GetXlaScope(node); absl::optional<string> cur_scope = GetXlaAutoJitScope(node);
if (cur_scope == absl::nullopt) { if (cur_scope == absl::nullopt) {
updated_scope = std::string(suffix); updated_scope = std::string(suffix);
} else { } else {
updated_scope = absl::StrCat(cur_scope.value(), "&", suffix); updated_scope = absl::StrCat(cur_scope.value(), "&", suffix);
} }
SetXlaScope(node, updated_scope); SetXlaAutoJitScope(node, updated_scope);
} }
void ClusterScopingPassImpl::AddScopeToAllPredecessors(Node* start) { void ClusterScopingPassImpl::AddScopeToAllPredecessors(Node* start) {
@ -85,7 +85,7 @@ void ClusterScopingPassImpl::AddScopeToAllPredecessors(Node* start) {
std::vector<Node*> starts; std::vector<Node*> starts;
starts.push_back(start); starts.push_back(start);
auto enter = [&](Node* n) { AddOrAppendScope(n, unique_suffix); }; auto enter = [&](Node* n) { AddOrAppendXlaAutoJitScope(n, unique_suffix); };
ReverseDFSFrom(*graph_, starts, enter, /*leave=*/nullptr, ReverseDFSFrom(*graph_, starts, enter, /*leave=*/nullptr,
/*stable_comparator=*/NodeComparatorName()); /*stable_comparator=*/NodeComparatorName());
} }
@ -95,7 +95,7 @@ void ClusterScopingPassImpl::AddScopeToAllSuccessors(Node* start) {
std::vector<Node*> starts; std::vector<Node*> starts;
starts.push_back(start); starts.push_back(start);
auto enter = [&](Node* n) { AddOrAppendScope(n, unique_suffix); }; auto enter = [&](Node* n) { AddOrAppendXlaAutoJitScope(n, unique_suffix); };
auto not_back_edge = [](const Edge& edge) -> bool { auto not_back_edge = [](const Edge& edge) -> bool {
return !edge.src()->IsNextIteration(); return !edge.src()->IsNextIteration();
}; };
@ -129,7 +129,7 @@ Status ClusterScopingPassImpl::Run() {
// Without the heuristic, they may be put into the same cluster and it // Without the heuristic, they may be put into the same cluster and it
// can introduce artificial dependencies and incur great performance loss. // can introduce artificial dependencies and incur great performance loss.
// In this example, Node_Y becomes dependent on IteratorGetNext and the // In this example, Node_Y becomes dependent on IteratorGetNext and the
// latencies add up. // latencies add up if Node_X and Node_Y are in the same cluster.
// //
// IteratorGetNext -> Node_X -> Stage // IteratorGetNext -> Node_X -> Stage
// //

View File

@ -20,14 +20,14 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
// This pass adds xla scopes to graphs to guide the later clustering passes. // This pass adds scopes to nodes in the _XlaAutoJitScope attribute to guide
// A major reason to do this is to prevent the clustering from losing // the later clustering passes. A major reason to do this is to prevent the
// the important parallelism in the Tensorflow graph, which can incur // clustering from losing critical parallelism in the Tensorflow graph, which
// great performance degradation. // can incur great performance degradation.
// //
// This pass must be run before MarkForCompilationPass, as it stores the // This pass must be run before MarkForCompilationPass, as it stores the
// scoping information in the XlaScope attributes, which MarkForCompilationPass // scoping information that MarkForCompilationPass will need to respect for
// will need to respect for clustering decision. // clustering decision.
class ClusterScopingPass : public GraphOptimizationPass { class ClusterScopingPass : public GraphOptimizationPass {
public: public:
Status Run(const GraphOptimizationPassOptions& options) override; Status Run(const GraphOptimizationPassOptions& options) override;

View File

@ -49,7 +49,7 @@ Status ClusterScoping(std::unique_ptr<Graph>* graph) {
return pass.Run(opt_options); return pass.Run(opt_options);
} }
absl::flat_hash_map<string, string> GetXlaScopes(const Graph& graph) { absl::flat_hash_map<string, string> GetXlaAutoJitScopes(const Graph& graph) {
absl::flat_hash_map<string, string> scopes; absl::flat_hash_map<string, string> scopes;
for (Node* node : graph.nodes()) { for (Node* node : graph.nodes()) {
string scope; string scope;
@ -70,8 +70,9 @@ absl::flat_hash_map<string, string> GetXlaScopes(const Graph& graph) {
Node* BuildStageNode(GraphDefBuilder& builder, string name, Node* BuildStageNode(GraphDefBuilder& builder, string name,
std::initializer_list<DataType> dtypes, std::initializer_list<DataType> dtypes,
gtl::ArraySlice<ops::NodeOut> values) { gtl::ArraySlice<ops::NodeOut> values) {
auto opts = auto opts = builder.opts()
builder.opts().WithName(std::move(name)).WithAttr("dtypes", dtypes); .WithName(std::move(name))
.WithAttr("dtypes", std::move(dtypes));
if (opts.HaveError()) { if (opts.HaveError()) {
return nullptr; return nullptr;
} }
@ -119,7 +120,7 @@ TEST(XlaCompilationTest, StagePipelinePreserved) {
TF_ASSERT_OK(ClusterScoping(&graph)); TF_ASSERT_OK(ClusterScoping(&graph));
auto scopes = GetXlaScopes(*graph); auto scopes = GetXlaAutoJitScopes(*graph);
EXPECT_NE(scopes["add0"], scopes["add1"]); EXPECT_NE(scopes["add0"], scopes["add1"]);
EXPECT_EQ(scopes["add0"], scopes["relu0"]); EXPECT_EQ(scopes["add0"], scopes["relu0"]);
EXPECT_EQ(scopes["add1"], scopes["relu1"]); EXPECT_EQ(scopes["add1"], scopes["relu1"]);
@ -171,7 +172,7 @@ TEST(XlaCompilationTest, StagePipelinePreservedAndInitialScopesRespected) {
TF_ASSERT_OK(ClusterScoping(&graph)); TF_ASSERT_OK(ClusterScoping(&graph));
auto scopes = GetXlaScopes(*graph); auto scopes = GetXlaAutoJitScopes(*graph);
EXPECT_NE(scopes["add0"], scopes["add1"]); EXPECT_NE(scopes["add0"], scopes["add1"]);
EXPECT_NE(scopes["add0"], scopes["relu0"]); EXPECT_NE(scopes["add0"], scopes["relu0"]);
EXPECT_NE(scopes["add1"], scopes["relu1"]); EXPECT_NE(scopes["add1"], scopes["relu1"]);

View File

@ -19,11 +19,11 @@ namespace tensorflow {
const char* const kXlaCompileAttr = "_XlaCompile"; const char* const kXlaCompileAttr = "_XlaCompile";
// User-provided through jit_scope. Effective when auto_jit is OFF. // User-provided through jit_scope. Effective only when auto_jit is OFF.
const char* const kXlaScopeAttr = "_XlaScope"; const char* const kXlaScopeAttr = "_XlaScope";
// Automatically inserted by auto_jit to guide clustering results. Effective // Automatically inserted by auto_jit to guide clustering results. Effective
// when auto_jit is ON. // only when auto_jit is ON.
const char* const kXlaAutoJitScopeAttr = "_XlaAutoJitScope"; const char* const kXlaAutoJitScopeAttr = "_XlaAutoJitScope";
} // namespace tensorflow } // namespace tensorflow

View File

@ -932,24 +932,22 @@ absl::optional<string> MarkForCompilationPassImpl::GetXlaScope(Node* node) {
// The difference between _XlaScope and _XlaAutoJitScope is that _XlaScope is // The difference between _XlaScope and _XlaAutoJitScope is that _XlaScope is
// provided by users through jit_scope APIs, while _XlaAutoJitScope is // provided by users through jit_scope APIs, while _XlaAutoJitScope is
// automatically generated by the ClusterScopingPass when auto_jit is on. As // automatically generated by the ClusterScopingPass when auto_jit is on. As
// such, we want to respect _kXlaScope when auto_jit is off, while respecting // such, we respect _kXlaScope only when auto_jit is off, while respecting
// _kXlaAutoJitScope when auto_jit is on. // _kXlaAutoJitScope only when auto_jit is on.
// //
// We may want to restrict the _XlaScope behavior to require all nodes marked // We may want to restrict the _XlaScope behavior to require all nodes marked
// with _XlaCompile=true to also have a _XlaScope property set (and raise an // with _XlaCompile=true to also have a _XlaScope property set (and raise an
// error otherwise); but for now we don't do this. // error otherwise); but for now we don't do this.
if (global_jit_level_ != OptimizerOptions::OFF) { if (global_jit_level_ != OptimizerOptions::OFF) {
// If global_jit_level_ is ON, respect kXlaAutoJitScope (and ignore // If global_jit_level_ is ON, respect only kXlaAutoJitScope.
// kXlaScope).
const string& scope = const string& scope =
GetNodeAttrString(node->attrs(), kXlaAutoJitScopeAttr); GetNodeAttrString(node->attrs(), kXlaAutoJitScopeAttr);
if (!scope.empty()) { if (!scope.empty()) {
return scope; return scope;
} }
} else { } else {
// If global_jit_level_ is OFF, respect kXlaScope (and ignore // If global_jit_level_ is OFF, respect only kXlaScope.
// kXlaAutoJitScope).
const string& scope = GetNodeAttrString(node->attrs(), kXlaScopeAttr); const string& scope = GetNodeAttrString(node->attrs(), kXlaScopeAttr);
if (!scope.empty()) { if (!scope.empty()) {
return scope; return scope;