[XLA] Small ergonomic improvements to replay_computation.

* Use QCHECK/LOG(QFATAL) when a back trace is not helpful to the user.
* Don't look at the maximum parallelism until after InitMain() is called; otherwise it logs an error.

PiperOrigin-RevId: 357741924
Change-Id: I7565e4e312aa915cbd55030b8bf192ef8aecca5f
This commit is contained in:
Peter Hawkins 2021-02-16 09:55:07 -08:00 committed by TensorFlower Gardener
parent b2e33765a1
commit b8b03dc654

View File

@ -86,7 +86,7 @@ namespace {
// Command-line opts to this tool. See main() for descriptions of these // Command-line opts to this tool. See main() for descriptions of these
// fields. // fields.
struct Options { struct Options {
Options() : intra_op_thread_pool_size(tensorflow::port::MaxParallelism()) {} Options() {}
bool NeedsRealData() const { return !use_fake_data && !compile_only; } bool NeedsRealData() const { return !use_fake_data && !compile_only; }
@ -106,7 +106,7 @@ struct Options {
bool print_result = true; bool print_result = true;
int num_runs = 1; int num_runs = 1;
int intra_op_thread_pool_size; int intra_op_thread_pool_size = -1;
bool compile_only = false; bool compile_only = false;
}; };
@ -173,7 +173,7 @@ absl::optional<Shape> GetXfeedShape(bool is_infeed,
if (!fake_xfeed_shape.empty()) { if (!fake_xfeed_shape.empty()) {
xfeed_shape = std::move(ParseShape(fake_xfeed_shape)).ValueOrDie(); xfeed_shape = std::move(ParseShape(fake_xfeed_shape)).ValueOrDie();
} else if (generate_fake_xfeed) { } else if (generate_fake_xfeed) {
CHECK_LT(xfeed_instrs.size(), 2) QCHECK_LT(xfeed_instrs.size(), 2)
<< "--generate_fake_" << xfeed_name << "--generate_fake_" << xfeed_name
<< " only works if the model has 0 or 1 " << xfeed_name << " ops."; << " only works if the model has 0 or 1 " << xfeed_name << " ops.";
if (xfeed_instrs.empty()) { if (xfeed_instrs.empty()) {
@ -196,7 +196,7 @@ absl::optional<Shape> GetXfeedShape(bool is_infeed,
<< " ops, but this model has " << xfeed_instrs.size() << " ops, but this model has " << xfeed_instrs.size()
<< " of them:"; << " of them:";
log_xfeed_instrs(); log_xfeed_instrs();
LOG(FATAL) << "Can't run model with --generate_fake_infeed."; LOG(QFATAL) << "Can't run model with --generate_fake_infeed.";
} }
} else if (!xfeed_instrs.empty()) { } else if (!xfeed_instrs.empty()) {
LOG(ERROR) << "Model contains " << xfeed_instrs.size() << " " << xfeed_name LOG(ERROR) << "Model contains " << xfeed_instrs.size() << " " << xfeed_name
@ -314,8 +314,11 @@ StatusOr<Literal> ReplayComputation(const HloSnapshot& module,
if (xla_hlo_profile && is_final_result) { if (xla_hlo_profile && is_final_result) {
LOG(INFO) << "\n\n***** Final run below ******"; LOG(INFO) << "\n\n***** Final run below ******";
} }
int thread_pool_size = opts.intra_op_thread_pool_size < 0
? tensorflow::port::MaxParallelism()
: opts.intra_op_thread_pool_size;
tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "XLAEigen", tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "XLAEigen",
opts.intra_op_thread_pool_size); thread_pool_size);
Eigen::ThreadPoolDevice thread_pool(pool.AsEigenThreadPool(), Eigen::ThreadPoolDevice thread_pool(pool.AsEigenThreadPool(),
pool.NumThreads()); pool.NumThreads());
@ -366,10 +369,10 @@ StatusOr<std::vector<HloSnapshot>> ParseRecordIoFile(absl::string_view filename,
LOG(ERROR) << "Encountered bad proto"; LOG(ERROR) << "Encountered bad proto";
} }
} }
CHECK(!snapshots.empty()) QCHECK(!snapshots.empty())
<< "No proto is successfully parsed from the file - the file possibly " << "No proto is successfully parsed from the file - the file possibly "
"has a mismatched compression option, format, etc."; "has a mismatched compression option, format, etc.";
CHECK(!opts.NeedsRealData()) QCHECK(!opts.NeedsRealData())
<< "Without --use_fake_data or --compile_only, you must pass an " << "Without --use_fake_data or --compile_only, you must pass an "
"HloSnapshot -- HloProto and textual HLO don't carry real data."; "HloSnapshot -- HloProto and textual HLO don't carry real data.";
return snapshots; return snapshots;
@ -387,7 +390,7 @@ StatusOr<HloSnapshot> ParseSingleHloFile(const string& filename,
if (s.code() == tensorflow::error::NOT_FOUND) { if (s.code() == tensorflow::error::NOT_FOUND) {
return s; return s;
} }
CHECK(!opts.NeedsRealData()) QCHECK(!opts.NeedsRealData())
<< "Without --use_fake_data or --compile_only, you must pass an " << "Without --use_fake_data or --compile_only, you must pass an "
"HloSnapshot -- HloProto and textual HLO don't carry real data."; "HloSnapshot -- HloProto and textual HLO don't carry real data.";
fprintf(stderr, "%s: is not HloSnapshot. Trying HloProto.\n", fprintf(stderr, "%s: is not HloSnapshot. Trying HloProto.\n",