From eb2d84f6596e37817c38c84767dae95df398e067 Mon Sep 17 00:00:00 2001 From: Sanjoy Das Date: Thu, 16 May 2019 15:17:49 -0700 Subject: [PATCH] Print out the deadness predicate on a mismatch PiperOrigin-RevId: 248610290 --- tensorflow/compiler/jit/deadness_analysis.cc | 4 ++++ tensorflow/compiler/jit/deadness_analysis.h | 2 ++ .../compiler/jit/mark_for_compilation_pass.cc | 20 +++++++++++++++---- 3 files changed, 22 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc index 0a92c06ad10..d2501b9ef1e 100644 --- a/tensorflow/compiler/jit/deadness_analysis.cc +++ b/tensorflow/compiler/jit/deadness_analysis.cc @@ -1188,4 +1188,8 @@ Status ComputePredicates(const Graph& graph, } } // namespace deadness_analysis_internal +string DeadnessAnalysis::DebugString(DeadnessPredicate predicate) const { + return static_cast(predicate.pred_)->ToString(); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/deadness_analysis.h b/tensorflow/compiler/jit/deadness_analysis.h index 08d8ad011bc..c8527de503d 100644 --- a/tensorflow/compiler/jit/deadness_analysis.h +++ b/tensorflow/compiler/jit/deadness_analysis.h @@ -82,6 +82,8 @@ class DeadnessAnalysis { virtual void Print() const = 0; virtual ~DeadnessAnalysis(); + string DebugString(DeadnessPredicate predicate) const; + // Run the deadness analysis over `graph` and returns an error or a populated // instance of DeadnessAnalysis in `result`. static Status Run(const Graph& graph, diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 86b98505ab0..3d3497c5c36 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -401,6 +401,13 @@ class MarkForCompilationPassImpl { return true; } + string EdgeContractionFailureMsg(Cluster* from, Cluster* to, + absl::string_view reason) { + return absl::StrCat("Could not contract ", from->DebugString(*graph_), + " -> ", to->DebugString(*graph_), " because ", reason, + "."); + } + DebugOptions debug_options_; Graph* graph_; FunctionLibraryDefinition* flib_def_; @@ -1067,8 +1074,7 @@ bool MarkForCompilationPassImpl::CompilationDisallowedByXlaCompileAttr( bool MarkForCompilationPassImpl::LogNotContractableAndReturnFalse( Cluster* from, Cluster* to, absl::string_view reason) { - VLOG(3) << "Could not contract " << from->DebugString(*graph_) << " -> " - << to->DebugString(*graph_) << " because " << reason << "."; + VLOG(3) << EdgeContractionFailureMsg(from, to, reason); return false; } @@ -1077,8 +1083,14 @@ StatusOr MarkForCompilationPassImpl::TryToContractEdge(Cluster* from, DCHECK(from->deadness_predicate().has_value() == to->deadness_predicate().has_value()); if (from->deadness_predicate() != to->deadness_predicate()) { - return LogNotContractableAndReturnFalse( - from, to, "the two nodes have mismatching deadness"); + VLOG(3) << EdgeContractionFailureMsg( + from, to, + absl::StrCat( + "the two nodes have mismatching deadness: ", + deadness_analysis_->DebugString(*from->deadness_predicate()), + " and ", + deadness_analysis_->DebugString(*to->deadness_predicate()))); + return false; } TF_ASSIGN_OR_RETURN(bool devices_compatible,