[XLA] Make replay_computation not crash when compile fails.
PiperOrigin-RevId: 231840092
This commit is contained in:
parent
1b1f0959cb
commit
983a547e85
@ -90,8 +90,8 @@ struct Options {
|
||||
int num_runs = 1;
|
||||
};
|
||||
|
||||
std::unique_ptr<LocalExecutable> CompileExecutable(const HloSnapshot& module,
|
||||
LocalClient* client) {
|
||||
StatusOr<std::unique_ptr<LocalExecutable>> CompileExecutable(
|
||||
const HloSnapshot& module, LocalClient* client) {
|
||||
XlaComputation computation(module.hlo().hlo_module());
|
||||
std::vector<Shape> argument_layouts;
|
||||
argument_layouts.reserve(
|
||||
@ -102,9 +102,8 @@ std::unique_ptr<LocalExecutable> CompileExecutable(const HloSnapshot& module,
|
||||
argument_layouts.push_back(Shape(param));
|
||||
argument_layout_ptrs.push_back(&argument_layouts.back());
|
||||
}
|
||||
return client
|
||||
->Compile(computation, argument_layout_ptrs, ExecutableBuildOptions())
|
||||
.ValueOrDie();
|
||||
return client->Compile(computation, argument_layout_ptrs,
|
||||
ExecutableBuildOptions());
|
||||
}
|
||||
|
||||
absl::optional<Shape> GetXfeedShape(bool is_infeed,
|
||||
@ -357,7 +356,7 @@ int RealMain(absl::Span<char* const> args, const Options& opts) {
|
||||
|
||||
// Compile all the modules in parallel.
|
||||
LOG(INFO) << "Compiling " << snapshots.size() << " modules in parallel.";
|
||||
std::vector<std::unique_ptr<LocalExecutable>> executables;
|
||||
std::vector<StatusOr<std::unique_ptr<LocalExecutable>>> executables;
|
||||
{
|
||||
// ThreadPool CHECK-fails if we give it 0 threads.
|
||||
tensorflow::thread::ThreadPool thread_pool(
|
||||
@ -374,7 +373,12 @@ int RealMain(absl::Span<char* const> args, const Options& opts) {
|
||||
LOG(INFO) << "Done compiling; now running the modules.";
|
||||
|
||||
for (int64 i = 0; i < executables.size(); ++i) {
|
||||
LocalExecutable* executable = executables[i].get();
|
||||
if (!executables[i].ok()) {
|
||||
LOG(ERROR) << "Compilation failed: " << executables[i].status();
|
||||
exit_status = EXIT_FAILURE;
|
||||
continue;
|
||||
}
|
||||
LocalExecutable* executable = executables[i].ValueOrDie().get();
|
||||
LOG(ERROR) << "Running iteration " << i;
|
||||
StatusOr<Literal> result_status =
|
||||
ReplayComputation(snapshots[i], executable, client, opts);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user