[XLA] Small ergonomic improvements to replay_computation.
* Use QCHECK/LOG(QFATAL) when a back trace is not helpful to the user. * Don't look at the maximum parallelism until after InitMain() is called; otherwise it logs an error. PiperOrigin-RevId: 357741924 Change-Id: I7565e4e312aa915cbd55030b8bf192ef8aecca5f
This commit is contained in:
parent
b2e33765a1
commit
b8b03dc654
@ -86,7 +86,7 @@ namespace {
|
||||
// Command-line opts to this tool. See main() for descriptions of these
|
||||
// fields.
|
||||
struct Options {
|
||||
Options() : intra_op_thread_pool_size(tensorflow::port::MaxParallelism()) {}
|
||||
Options() {}
|
||||
|
||||
bool NeedsRealData() const { return !use_fake_data && !compile_only; }
|
||||
|
||||
@ -106,7 +106,7 @@ struct Options {
|
||||
bool print_result = true;
|
||||
int num_runs = 1;
|
||||
|
||||
int intra_op_thread_pool_size;
|
||||
int intra_op_thread_pool_size = -1;
|
||||
|
||||
bool compile_only = false;
|
||||
};
|
||||
@ -173,7 +173,7 @@ absl::optional<Shape> GetXfeedShape(bool is_infeed,
|
||||
if (!fake_xfeed_shape.empty()) {
|
||||
xfeed_shape = std::move(ParseShape(fake_xfeed_shape)).ValueOrDie();
|
||||
} else if (generate_fake_xfeed) {
|
||||
CHECK_LT(xfeed_instrs.size(), 2)
|
||||
QCHECK_LT(xfeed_instrs.size(), 2)
|
||||
<< "--generate_fake_" << xfeed_name
|
||||
<< " only works if the model has 0 or 1 " << xfeed_name << " ops.";
|
||||
if (xfeed_instrs.empty()) {
|
||||
@ -196,7 +196,7 @@ absl::optional<Shape> GetXfeedShape(bool is_infeed,
|
||||
<< " ops, but this model has " << xfeed_instrs.size()
|
||||
<< " of them:";
|
||||
log_xfeed_instrs();
|
||||
LOG(FATAL) << "Can't run model with --generate_fake_infeed.";
|
||||
LOG(QFATAL) << "Can't run model with --generate_fake_infeed.";
|
||||
}
|
||||
} else if (!xfeed_instrs.empty()) {
|
||||
LOG(ERROR) << "Model contains " << xfeed_instrs.size() << " " << xfeed_name
|
||||
@ -314,8 +314,11 @@ StatusOr<Literal> ReplayComputation(const HloSnapshot& module,
|
||||
if (xla_hlo_profile && is_final_result) {
|
||||
LOG(INFO) << "\n\n***** Final run below ******";
|
||||
}
|
||||
int thread_pool_size = opts.intra_op_thread_pool_size < 0
|
||||
? tensorflow::port::MaxParallelism()
|
||||
: opts.intra_op_thread_pool_size;
|
||||
tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "XLAEigen",
|
||||
opts.intra_op_thread_pool_size);
|
||||
thread_pool_size);
|
||||
Eigen::ThreadPoolDevice thread_pool(pool.AsEigenThreadPool(),
|
||||
pool.NumThreads());
|
||||
|
||||
@ -366,10 +369,10 @@ StatusOr<std::vector<HloSnapshot>> ParseRecordIoFile(absl::string_view filename,
|
||||
LOG(ERROR) << "Encountered bad proto";
|
||||
}
|
||||
}
|
||||
CHECK(!snapshots.empty())
|
||||
QCHECK(!snapshots.empty())
|
||||
<< "No proto is successfully parsed from the file - the file possibly "
|
||||
"has a mismatched compression option, format, etc.";
|
||||
CHECK(!opts.NeedsRealData())
|
||||
QCHECK(!opts.NeedsRealData())
|
||||
<< "Without --use_fake_data or --compile_only, you must pass an "
|
||||
"HloSnapshot -- HloProto and textual HLO don't carry real data.";
|
||||
return snapshots;
|
||||
@ -387,7 +390,7 @@ StatusOr<HloSnapshot> ParseSingleHloFile(const string& filename,
|
||||
if (s.code() == tensorflow::error::NOT_FOUND) {
|
||||
return s;
|
||||
}
|
||||
CHECK(!opts.NeedsRealData())
|
||||
QCHECK(!opts.NeedsRealData())
|
||||
<< "Without --use_fake_data or --compile_only, you must pass an "
|
||||
"HloSnapshot -- HloProto and textual HLO don't carry real data.";
|
||||
fprintf(stderr, "%s: is not HloSnapshot. Trying HloProto.\n",
|
||||
|
Loading…
Reference in New Issue
Block a user