Reset the cluster sequence number for each auto-clustering test
For various reasons we have to use a shared atomic counter to seed cluster names in mark_for_compilation_pass. Since this state is global, running tests in isolation ends up resulting in different cluster names and thus test failures. To fix this reset the cluster sequence number state before every state. While at it, also fix up some of the graphdefs that were checked in before that were already auto-clustered. Auto-clustering on these graphs is a no-op so these tests were not actually testing anything. Also detect this and fail the test if this happens in the future (this explains why only nmt_server_en_de_gpu.golden_summary needed to be updated to account for resetting the cluster sequence number). PiperOrigin-RevId: 248604614
This commit is contained in:
parent
a8c980f219
commit
1bc0a7ca33
@ -707,12 +707,14 @@ Status MarkForCompilationPassImpl::RunEdgeContractionLoop() {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::atomic<int64> cluster_sequence_num;
|
||||
|
||||
int64 GetNextClusterSequenceNumber() { return cluster_sequence_num++; }
|
||||
|
||||
Status MarkForCompilationPassImpl::CreateClusters() {
|
||||
TF_RET_CHECK(initialized_ && edges_contracted_ && !clusters_created_);
|
||||
clusters_created_ = true;
|
||||
|
||||
static std::atomic<int64> cluster_sequence_num;
|
||||
|
||||
// Names for each cluster.
|
||||
std::unordered_map<int, string> cluster_names;
|
||||
|
||||
@ -745,7 +747,7 @@ Status MarkForCompilationPassImpl::CreateClusters() {
|
||||
string& name = cluster_names[cluster->cycles_graph_node_id()];
|
||||
|
||||
if (name.empty()) {
|
||||
name = absl::StrCat("cluster_", cluster_sequence_num++);
|
||||
name = absl::StrCat("cluster_", GetNextClusterSequenceNumber());
|
||||
}
|
||||
|
||||
n->AddAttr(kXlaClusterAttr, name);
|
||||
@ -1522,4 +1524,8 @@ Status MarkForCompilationPass::RunForTest(
|
||||
|
||||
return MarkForCompilation(options, debug_options);
|
||||
}
|
||||
|
||||
namespace testing {
|
||||
void ResetClusterSequenceNumber() { cluster_sequence_num = 0; }
|
||||
} // namespace testing
|
||||
} // namespace tensorflow
|
||||
|
@ -51,6 +51,13 @@ class MarkForCompilationPass : public GraphOptimizationPass {
|
||||
// function is compilable iff every operator in the function body is
|
||||
// compilable.
|
||||
bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef);
|
||||
|
||||
namespace testing {
|
||||
// DO NOT USE IN PRODUCTION.
|
||||
//
|
||||
// Resets some internal state to let us write reliable unit tests.
|
||||
void ResetClusterSequenceNumber();
|
||||
} // namespace testing
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_H_
|
||||
|
Loading…
Reference in New Issue
Block a user