From b8b03dc654727b4c20ca98919a275756c4582906 Mon Sep 17 00:00:00 2001
From: Peter Hawkins <phawkins@google.com>
Date: Tue, 16 Feb 2021 09:55:07 -0800
Subject: [PATCH] [XLA] Small ergonomic improvements to replay_computation.

* Use QCHECK/LOG(QFATAL) when a back trace is not helpful to the user.
* Don't look at the maximum parallelism until after InitMain() is called; otherwise it logs an error.

PiperOrigin-RevId: 357741924
Change-Id: I7565e4e312aa915cbd55030b8bf192ef8aecca5f
---
 .../compiler/xla/tools/replay_computation.cc  | 19 +++++++++++--------
 1 file changed, 11 insertions(+), 8 deletions(-)

diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc
index 0bf5d0e937b..0317c41b6e7 100644
--- a/tensorflow/compiler/xla/tools/replay_computation.cc
+++ b/tensorflow/compiler/xla/tools/replay_computation.cc
@@ -86,7 +86,7 @@ namespace {
 // Command-line opts to this tool.  See main() for descriptions of these
 // fields.
 struct Options {
-  Options() : intra_op_thread_pool_size(tensorflow::port::MaxParallelism()) {}
+  Options() {}
 
   bool NeedsRealData() const { return !use_fake_data && !compile_only; }
 
@@ -106,7 +106,7 @@ struct Options {
   bool print_result = true;
   int num_runs = 1;
 
-  int intra_op_thread_pool_size;
+  int intra_op_thread_pool_size = -1;
 
   bool compile_only = false;
 };
@@ -173,7 +173,7 @@ absl::optional<Shape> GetXfeedShape(bool is_infeed,
   if (!fake_xfeed_shape.empty()) {
     xfeed_shape = std::move(ParseShape(fake_xfeed_shape)).ValueOrDie();
   } else if (generate_fake_xfeed) {
-    CHECK_LT(xfeed_instrs.size(), 2)
+    QCHECK_LT(xfeed_instrs.size(), 2)
         << "--generate_fake_" << xfeed_name
         << " only works if the model has 0 or 1 " << xfeed_name << " ops.";
     if (xfeed_instrs.empty()) {
@@ -196,7 +196,7 @@ absl::optional<Shape> GetXfeedShape(bool is_infeed,
                  << " ops, but this model has " << xfeed_instrs.size()
                  << " of them:";
       log_xfeed_instrs();
-      LOG(FATAL) << "Can't run model with --generate_fake_infeed.";
+      LOG(QFATAL) << "Can't run model with --generate_fake_infeed.";
     }
   } else if (!xfeed_instrs.empty()) {
     LOG(ERROR) << "Model contains " << xfeed_instrs.size() << " " << xfeed_name
@@ -314,8 +314,11 @@ StatusOr<Literal> ReplayComputation(const HloSnapshot& module,
     if (xla_hlo_profile && is_final_result) {
       LOG(INFO) << "\n\n***** Final run below ******";
     }
+    int thread_pool_size = opts.intra_op_thread_pool_size < 0
+                               ? tensorflow::port::MaxParallelism()
+                               : opts.intra_op_thread_pool_size;
     tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "XLAEigen",
-                                        opts.intra_op_thread_pool_size);
+                                        thread_pool_size);
     Eigen::ThreadPoolDevice thread_pool(pool.AsEigenThreadPool(),
                                         pool.NumThreads());
 
@@ -366,10 +369,10 @@ StatusOr<std::vector<HloSnapshot>> ParseRecordIoFile(absl::string_view filename,
       LOG(ERROR) << "Encountered bad proto";
     }
   }
-  CHECK(!snapshots.empty())
+  QCHECK(!snapshots.empty())
       << "No proto is successfully parsed from the file - the file possibly "
          "has a mismatched compression option, format, etc.";
-  CHECK(!opts.NeedsRealData())
+  QCHECK(!opts.NeedsRealData())
       << "Without --use_fake_data or --compile_only, you must pass an "
          "HloSnapshot -- HloProto and textual HLO don't carry real data.";
   return snapshots;
@@ -387,7 +390,7 @@ StatusOr<HloSnapshot> ParseSingleHloFile(const string& filename,
   if (s.code() == tensorflow::error::NOT_FOUND) {
     return s;
   }
-  CHECK(!opts.NeedsRealData())
+  QCHECK(!opts.NeedsRealData())
       << "Without --use_fake_data or --compile_only, you must pass an "
          "HloSnapshot -- HloProto and textual HLO don't carry real data.";
   fprintf(stderr, "%s: is not HloSnapshot. Trying HloProto.\n",