diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index 0bf5d0e937b..0317c41b6e7 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -86,7 +86,7 @@ namespace { // Command-line opts to this tool. See main() for descriptions of these // fields. struct Options { - Options() : intra_op_thread_pool_size(tensorflow::port::MaxParallelism()) {} + Options() {} bool NeedsRealData() const { return !use_fake_data && !compile_only; } @@ -106,7 +106,7 @@ struct Options { bool print_result = true; int num_runs = 1; - int intra_op_thread_pool_size; + int intra_op_thread_pool_size = -1; bool compile_only = false; }; @@ -173,7 +173,7 @@ absl::optional GetXfeedShape(bool is_infeed, if (!fake_xfeed_shape.empty()) { xfeed_shape = std::move(ParseShape(fake_xfeed_shape)).ValueOrDie(); } else if (generate_fake_xfeed) { - CHECK_LT(xfeed_instrs.size(), 2) + QCHECK_LT(xfeed_instrs.size(), 2) << "--generate_fake_" << xfeed_name << " only works if the model has 0 or 1 " << xfeed_name << " ops."; if (xfeed_instrs.empty()) { @@ -196,7 +196,7 @@ absl::optional GetXfeedShape(bool is_infeed, << " ops, but this model has " << xfeed_instrs.size() << " of them:"; log_xfeed_instrs(); - LOG(FATAL) << "Can't run model with --generate_fake_infeed."; + LOG(QFATAL) << "Can't run model with --generate_fake_infeed."; } } else if (!xfeed_instrs.empty()) { LOG(ERROR) << "Model contains " << xfeed_instrs.size() << " " << xfeed_name @@ -314,8 +314,11 @@ StatusOr ReplayComputation(const HloSnapshot& module, if (xla_hlo_profile && is_final_result) { LOG(INFO) << "\n\n***** Final run below ******"; } + int thread_pool_size = opts.intra_op_thread_pool_size < 0 + ? tensorflow::port::MaxParallelism() + : opts.intra_op_thread_pool_size; tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "XLAEigen", - opts.intra_op_thread_pool_size); + thread_pool_size); Eigen::ThreadPoolDevice thread_pool(pool.AsEigenThreadPool(), pool.NumThreads()); @@ -366,10 +369,10 @@ StatusOr> ParseRecordIoFile(absl::string_view filename, LOG(ERROR) << "Encountered bad proto"; } } - CHECK(!snapshots.empty()) + QCHECK(!snapshots.empty()) << "No proto is successfully parsed from the file - the file possibly " "has a mismatched compression option, format, etc."; - CHECK(!opts.NeedsRealData()) + QCHECK(!opts.NeedsRealData()) << "Without --use_fake_data or --compile_only, you must pass an " "HloSnapshot -- HloProto and textual HLO don't carry real data."; return snapshots; @@ -387,7 +390,7 @@ StatusOr ParseSingleHloFile(const string& filename, if (s.code() == tensorflow::error::NOT_FOUND) { return s; } - CHECK(!opts.NeedsRealData()) + QCHECK(!opts.NeedsRealData()) << "Without --use_fake_data or --compile_only, you must pass an " "HloSnapshot -- HloProto and textual HLO don't carry real data."; fprintf(stderr, "%s: is not HloSnapshot. Trying HloProto.\n",