[XLA] Small ergonomic improvements to replay_computation.

* Use QCHECK/LOG(QFATAL) when a back trace is not helpful to the user.
* Don't look at the maximum parallelism until after InitMain() is called; otherwise it logs an error.

PiperOrigin-RevId: 357741924
Change-Id: I7565e4e312aa915cbd55030b8bf192ef8aecca5f
This commit is contained in:
Peter Hawkins 2021-02-16 09:55:07 -08:00 committed by TensorFlower Gardener
parent b2e33765a1
commit b8b03dc654

View File

@ -86,7 +86,7 @@ namespace {
// Command-line opts to this tool. See main() for descriptions of these
// fields.
struct Options {
Options() : intra_op_thread_pool_size(tensorflow::port::MaxParallelism()) {}
Options() {}
bool NeedsRealData() const { return !use_fake_data && !compile_only; }
@ -106,7 +106,7 @@ struct Options {
bool print_result = true;
int num_runs = 1;
int intra_op_thread_pool_size;
int intra_op_thread_pool_size = -1;
bool compile_only = false;
};
@ -173,7 +173,7 @@ absl::optional<Shape> GetXfeedShape(bool is_infeed,
if (!fake_xfeed_shape.empty()) {
xfeed_shape = std::move(ParseShape(fake_xfeed_shape)).ValueOrDie();
} else if (generate_fake_xfeed) {
CHECK_LT(xfeed_instrs.size(), 2)
QCHECK_LT(xfeed_instrs.size(), 2)
<< "--generate_fake_" << xfeed_name
<< " only works if the model has 0 or 1 " << xfeed_name << " ops.";
if (xfeed_instrs.empty()) {
@ -196,7 +196,7 @@ absl::optional<Shape> GetXfeedShape(bool is_infeed,
<< " ops, but this model has " << xfeed_instrs.size()
<< " of them:";
log_xfeed_instrs();
LOG(FATAL) << "Can't run model with --generate_fake_infeed.";
LOG(QFATAL) << "Can't run model with --generate_fake_infeed.";
}
} else if (!xfeed_instrs.empty()) {
LOG(ERROR) << "Model contains " << xfeed_instrs.size() << " " << xfeed_name
@ -314,8 +314,11 @@ StatusOr<Literal> ReplayComputation(const HloSnapshot& module,
if (xla_hlo_profile && is_final_result) {
LOG(INFO) << "\n\n***** Final run below ******";
}
int thread_pool_size = opts.intra_op_thread_pool_size < 0
? tensorflow::port::MaxParallelism()
: opts.intra_op_thread_pool_size;
tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "XLAEigen",
opts.intra_op_thread_pool_size);
thread_pool_size);
Eigen::ThreadPoolDevice thread_pool(pool.AsEigenThreadPool(),
pool.NumThreads());
@ -366,10 +369,10 @@ StatusOr<std::vector<HloSnapshot>> ParseRecordIoFile(absl::string_view filename,
LOG(ERROR) << "Encountered bad proto";
}
}
CHECK(!snapshots.empty())
QCHECK(!snapshots.empty())
<< "No proto is successfully parsed from the file - the file possibly "
"has a mismatched compression option, format, etc.";
CHECK(!opts.NeedsRealData())
QCHECK(!opts.NeedsRealData())
<< "Without --use_fake_data or --compile_only, you must pass an "
"HloSnapshot -- HloProto and textual HLO don't carry real data.";
return snapshots;
@ -387,7 +390,7 @@ StatusOr<HloSnapshot> ParseSingleHloFile(const string& filename,
if (s.code() == tensorflow::error::NOT_FOUND) {
return s;
}
CHECK(!opts.NeedsRealData())
QCHECK(!opts.NeedsRealData())
<< "Without --use_fake_data or --compile_only, you must pass an "
"HloSnapshot -- HloProto and textual HLO don't carry real data.";
fprintf(stderr, "%s: is not HloSnapshot. Trying HloProto.\n",