Reset the cluster sequence number for each auto-clustering test

For various reasons we have to use a shared atomic counter to seed cluster names
in mark_for_compilation_pass.  Since this state is global, running tests in
isolation ends up resulting in different cluster names and thus test failures.

To fix this reset the cluster sequence number state before every state.

While at it, also fix up some of the graphdefs that were checked in before that
were already auto-clustered.  Auto-clustering on these graphs is a no-op so these
tests were not actually testing anything.

Also detect this and fail the test if this happens in the future (this explains why
only nmt_server_en_de_gpu.golden_summary needed to be updated to account for
resetting the cluster sequence number).

PiperOrigin-RevId: 248604614
This commit is contained in:
Sanjoy Das 2019-05-16 14:49:20 -07:00 committed by TensorFlower Gardener
parent a8c980f219
commit 1bc0a7ca33
2 changed files with 16 additions and 3 deletions

View File

@ -707,12 +707,14 @@ Status MarkForCompilationPassImpl::RunEdgeContractionLoop() {
return Status::OK();
}
std::atomic<int64> cluster_sequence_num;
int64 GetNextClusterSequenceNumber() { return cluster_sequence_num++; }
Status MarkForCompilationPassImpl::CreateClusters() {
TF_RET_CHECK(initialized_ && edges_contracted_ && !clusters_created_);
clusters_created_ = true;
static std::atomic<int64> cluster_sequence_num;
// Names for each cluster.
std::unordered_map<int, string> cluster_names;
@ -745,7 +747,7 @@ Status MarkForCompilationPassImpl::CreateClusters() {
string& name = cluster_names[cluster->cycles_graph_node_id()];
if (name.empty()) {
name = absl::StrCat("cluster_", cluster_sequence_num++);
name = absl::StrCat("cluster_", GetNextClusterSequenceNumber());
}
n->AddAttr(kXlaClusterAttr, name);
@ -1522,4 +1524,8 @@ Status MarkForCompilationPass::RunForTest(
return MarkForCompilation(options, debug_options);
}
namespace testing {
void ResetClusterSequenceNumber() { cluster_sequence_num = 0; }
} // namespace testing
} // namespace tensorflow

View File

@ -51,6 +51,13 @@ class MarkForCompilationPass : public GraphOptimizationPass {
// function is compilable iff every operator in the function body is
// compilable.
bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef);
namespace testing {
// DO NOT USE IN PRODUCTION.
//
// Resets some internal state to let us write reliable unit tests.
void ResetClusterSequenceNumber();
} // namespace testing
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_H_