From 983a547e85b5cd6e6abf62bacc0e2370d474577b Mon Sep 17 00:00:00 2001 From: Michael Kuperstein Date: Thu, 31 Jan 2019 13:15:04 -0800 Subject: [PATCH] [XLA] Make replay_computation not crash when compile fails. PiperOrigin-RevId: 231840092 --- .../compiler/xla/tools/replay_computation.cc | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index c01a47b510c..21217c23f65 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -90,8 +90,8 @@ struct Options { int num_runs = 1; }; -std::unique_ptr CompileExecutable(const HloSnapshot& module, - LocalClient* client) { +StatusOr> CompileExecutable( + const HloSnapshot& module, LocalClient* client) { XlaComputation computation(module.hlo().hlo_module()); std::vector argument_layouts; argument_layouts.reserve( @@ -102,9 +102,8 @@ std::unique_ptr CompileExecutable(const HloSnapshot& module, argument_layouts.push_back(Shape(param)); argument_layout_ptrs.push_back(&argument_layouts.back()); } - return client - ->Compile(computation, argument_layout_ptrs, ExecutableBuildOptions()) - .ValueOrDie(); + return client->Compile(computation, argument_layout_ptrs, + ExecutableBuildOptions()); } absl::optional GetXfeedShape(bool is_infeed, @@ -357,7 +356,7 @@ int RealMain(absl::Span args, const Options& opts) { // Compile all the modules in parallel. LOG(INFO) << "Compiling " << snapshots.size() << " modules in parallel."; - std::vector> executables; + std::vector>> executables; { // ThreadPool CHECK-fails if we give it 0 threads. tensorflow::thread::ThreadPool thread_pool( @@ -374,7 +373,12 @@ int RealMain(absl::Span args, const Options& opts) { LOG(INFO) << "Done compiling; now running the modules."; for (int64 i = 0; i < executables.size(); ++i) { - LocalExecutable* executable = executables[i].get(); + if (!executables[i].ok()) { + LOG(ERROR) << "Compilation failed: " << executables[i].status(); + exit_status = EXIT_FAILURE; + continue; + } + LocalExecutable* executable = executables[i].ValueOrDie().get(); LOG(ERROR) << "Running iteration " << i; StatusOr result_status = ReplayComputation(snapshots[i], executable, client, opts);