[XLA] Small ergonomic improvements to replay_computation.
* Use QCHECK/LOG(QFATAL) when a back trace is not helpful to the user. * Don't look at the maximum parallelism until after InitMain() is called; otherwise it logs an error. PiperOrigin-RevId: 357741924 Change-Id: I7565e4e312aa915cbd55030b8bf192ef8aecca5f
This commit is contained in:
parent
b2e33765a1
commit
b8b03dc654
@ -86,7 +86,7 @@ namespace {
|
|||||||
// Command-line opts to this tool. See main() for descriptions of these
|
// Command-line opts to this tool. See main() for descriptions of these
|
||||||
// fields.
|
// fields.
|
||||||
struct Options {
|
struct Options {
|
||||||
Options() : intra_op_thread_pool_size(tensorflow::port::MaxParallelism()) {}
|
Options() {}
|
||||||
|
|
||||||
bool NeedsRealData() const { return !use_fake_data && !compile_only; }
|
bool NeedsRealData() const { return !use_fake_data && !compile_only; }
|
||||||
|
|
||||||
@ -106,7 +106,7 @@ struct Options {
|
|||||||
bool print_result = true;
|
bool print_result = true;
|
||||||
int num_runs = 1;
|
int num_runs = 1;
|
||||||
|
|
||||||
int intra_op_thread_pool_size;
|
int intra_op_thread_pool_size = -1;
|
||||||
|
|
||||||
bool compile_only = false;
|
bool compile_only = false;
|
||||||
};
|
};
|
||||||
@ -173,7 +173,7 @@ absl::optional<Shape> GetXfeedShape(bool is_infeed,
|
|||||||
if (!fake_xfeed_shape.empty()) {
|
if (!fake_xfeed_shape.empty()) {
|
||||||
xfeed_shape = std::move(ParseShape(fake_xfeed_shape)).ValueOrDie();
|
xfeed_shape = std::move(ParseShape(fake_xfeed_shape)).ValueOrDie();
|
||||||
} else if (generate_fake_xfeed) {
|
} else if (generate_fake_xfeed) {
|
||||||
CHECK_LT(xfeed_instrs.size(), 2)
|
QCHECK_LT(xfeed_instrs.size(), 2)
|
||||||
<< "--generate_fake_" << xfeed_name
|
<< "--generate_fake_" << xfeed_name
|
||||||
<< " only works if the model has 0 or 1 " << xfeed_name << " ops.";
|
<< " only works if the model has 0 or 1 " << xfeed_name << " ops.";
|
||||||
if (xfeed_instrs.empty()) {
|
if (xfeed_instrs.empty()) {
|
||||||
@ -196,7 +196,7 @@ absl::optional<Shape> GetXfeedShape(bool is_infeed,
|
|||||||
<< " ops, but this model has " << xfeed_instrs.size()
|
<< " ops, but this model has " << xfeed_instrs.size()
|
||||||
<< " of them:";
|
<< " of them:";
|
||||||
log_xfeed_instrs();
|
log_xfeed_instrs();
|
||||||
LOG(FATAL) << "Can't run model with --generate_fake_infeed.";
|
LOG(QFATAL) << "Can't run model with --generate_fake_infeed.";
|
||||||
}
|
}
|
||||||
} else if (!xfeed_instrs.empty()) {
|
} else if (!xfeed_instrs.empty()) {
|
||||||
LOG(ERROR) << "Model contains " << xfeed_instrs.size() << " " << xfeed_name
|
LOG(ERROR) << "Model contains " << xfeed_instrs.size() << " " << xfeed_name
|
||||||
@ -314,8 +314,11 @@ StatusOr<Literal> ReplayComputation(const HloSnapshot& module,
|
|||||||
if (xla_hlo_profile && is_final_result) {
|
if (xla_hlo_profile && is_final_result) {
|
||||||
LOG(INFO) << "\n\n***** Final run below ******";
|
LOG(INFO) << "\n\n***** Final run below ******";
|
||||||
}
|
}
|
||||||
|
int thread_pool_size = opts.intra_op_thread_pool_size < 0
|
||||||
|
? tensorflow::port::MaxParallelism()
|
||||||
|
: opts.intra_op_thread_pool_size;
|
||||||
tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "XLAEigen",
|
tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "XLAEigen",
|
||||||
opts.intra_op_thread_pool_size);
|
thread_pool_size);
|
||||||
Eigen::ThreadPoolDevice thread_pool(pool.AsEigenThreadPool(),
|
Eigen::ThreadPoolDevice thread_pool(pool.AsEigenThreadPool(),
|
||||||
pool.NumThreads());
|
pool.NumThreads());
|
||||||
|
|
||||||
@ -366,10 +369,10 @@ StatusOr<std::vector<HloSnapshot>> ParseRecordIoFile(absl::string_view filename,
|
|||||||
LOG(ERROR) << "Encountered bad proto";
|
LOG(ERROR) << "Encountered bad proto";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
CHECK(!snapshots.empty())
|
QCHECK(!snapshots.empty())
|
||||||
<< "No proto is successfully parsed from the file - the file possibly "
|
<< "No proto is successfully parsed from the file - the file possibly "
|
||||||
"has a mismatched compression option, format, etc.";
|
"has a mismatched compression option, format, etc.";
|
||||||
CHECK(!opts.NeedsRealData())
|
QCHECK(!opts.NeedsRealData())
|
||||||
<< "Without --use_fake_data or --compile_only, you must pass an "
|
<< "Without --use_fake_data or --compile_only, you must pass an "
|
||||||
"HloSnapshot -- HloProto and textual HLO don't carry real data.";
|
"HloSnapshot -- HloProto and textual HLO don't carry real data.";
|
||||||
return snapshots;
|
return snapshots;
|
||||||
@ -387,7 +390,7 @@ StatusOr<HloSnapshot> ParseSingleHloFile(const string& filename,
|
|||||||
if (s.code() == tensorflow::error::NOT_FOUND) {
|
if (s.code() == tensorflow::error::NOT_FOUND) {
|
||||||
return s;
|
return s;
|
||||||
}
|
}
|
||||||
CHECK(!opts.NeedsRealData())
|
QCHECK(!opts.NeedsRealData())
|
||||||
<< "Without --use_fake_data or --compile_only, you must pass an "
|
<< "Without --use_fake_data or --compile_only, you must pass an "
|
||||||
"HloSnapshot -- HloProto and textual HLO don't carry real data.";
|
"HloSnapshot -- HloProto and textual HLO don't carry real data.";
|
||||||
fprintf(stderr, "%s: is not HloSnapshot. Trying HloProto.\n",
|
fprintf(stderr, "%s: is not HloSnapshot. Trying HloProto.\n",
|
||||||
|
Loading…
Reference in New Issue
Block a user