diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 4142de56813..86b98505ab0 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -707,12 +707,14 @@ Status MarkForCompilationPassImpl::RunEdgeContractionLoop() { return Status::OK(); } +std::atomic cluster_sequence_num; + +int64 GetNextClusterSequenceNumber() { return cluster_sequence_num++; } + Status MarkForCompilationPassImpl::CreateClusters() { TF_RET_CHECK(initialized_ && edges_contracted_ && !clusters_created_); clusters_created_ = true; - static std::atomic cluster_sequence_num; - // Names for each cluster. std::unordered_map cluster_names; @@ -745,7 +747,7 @@ Status MarkForCompilationPassImpl::CreateClusters() { string& name = cluster_names[cluster->cycles_graph_node_id()]; if (name.empty()) { - name = absl::StrCat("cluster_", cluster_sequence_num++); + name = absl::StrCat("cluster_", GetNextClusterSequenceNumber()); } n->AddAttr(kXlaClusterAttr, name); @@ -1522,4 +1524,8 @@ Status MarkForCompilationPass::RunForTest( return MarkForCompilation(options, debug_options); } + +namespace testing { +void ResetClusterSequenceNumber() { cluster_sequence_num = 0; } +} // namespace testing } // namespace tensorflow diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.h b/tensorflow/compiler/jit/mark_for_compilation_pass.h index 16b8427b60e..2eee144e645 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.h +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.h @@ -51,6 +51,13 @@ class MarkForCompilationPass : public GraphOptimizationPass { // function is compilable iff every operator in the function body is // compilable. bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef); + +namespace testing { +// DO NOT USE IN PRODUCTION. +// +// Resets some internal state to let us write reliable unit tests. +void ResetClusterSequenceNumber(); +} // namespace testing } // namespace tensorflow #endif // TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_H_