[XLA] Make replay_computation not crash when compile fails.

PiperOrigin-RevId: 231840092
This commit is contained in:
Michael Kuperstein 2019-01-31 13:15:04 -08:00 committed by TensorFlower Gardener
parent 1b1f0959cb
commit 983a547e85

View File

@ -90,8 +90,8 @@ struct Options {
int num_runs = 1;
};
std::unique_ptr<LocalExecutable> CompileExecutable(const HloSnapshot& module,
LocalClient* client) {
StatusOr<std::unique_ptr<LocalExecutable>> CompileExecutable(
const HloSnapshot& module, LocalClient* client) {
XlaComputation computation(module.hlo().hlo_module());
std::vector<Shape> argument_layouts;
argument_layouts.reserve(
@ -102,9 +102,8 @@ std::unique_ptr<LocalExecutable> CompileExecutable(const HloSnapshot& module,
argument_layouts.push_back(Shape(param));
argument_layout_ptrs.push_back(&argument_layouts.back());
}
return client
->Compile(computation, argument_layout_ptrs, ExecutableBuildOptions())
.ValueOrDie();
return client->Compile(computation, argument_layout_ptrs,
ExecutableBuildOptions());
}
absl::optional<Shape> GetXfeedShape(bool is_infeed,
@ -357,7 +356,7 @@ int RealMain(absl::Span<char* const> args, const Options& opts) {
// Compile all the modules in parallel.
LOG(INFO) << "Compiling " << snapshots.size() << " modules in parallel.";
std::vector<std::unique_ptr<LocalExecutable>> executables;
std::vector<StatusOr<std::unique_ptr<LocalExecutable>>> executables;
{
// ThreadPool CHECK-fails if we give it 0 threads.
tensorflow::thread::ThreadPool thread_pool(
@ -374,7 +373,12 @@ int RealMain(absl::Span<char* const> args, const Options& opts) {
LOG(INFO) << "Done compiling; now running the modules.";
for (int64 i = 0; i < executables.size(); ++i) {
LocalExecutable* executable = executables[i].get();
if (!executables[i].ok()) {
LOG(ERROR) << "Compilation failed: " << executables[i].status();
exit_status = EXIT_FAILURE;
continue;
}
LocalExecutable* executable = executables[i].ValueOrDie().get();
LOG(ERROR) << "Running iteration " << i;
StatusOr<Literal> result_status =
ReplayComputation(snapshots[i], executable, client, opts);