From 18f1467496b4529a0a60ff3f67f8e57e0d103d1f Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 25 Jul 2019 09:00:53 -0700 Subject: [PATCH] [XLA] Make HLO snapshot dumping work on the LocalClient::RunAsync path. PiperOrigin-RevId: 259956061 --- .../compiler/xla/client/local_client.cc | 45 ++++++++++++++++++- 1 file changed, 43 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index 427bdf878f0..e8a316882db 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -189,8 +189,49 @@ StatusOr LocalExecutable::RunAsync( ExecutableRunOptions run_options) { TF_ASSIGN_OR_RETURN(auto options_and_stream, RunHelper(arguments, run_options)); - return executable_->ExecuteAsyncOnStream(&options_and_stream.first, - arguments); + se::Stream* stream = run_options.stream(); + + std::shared_ptr snapshot; + if (executable_->dumping_snapshot()) { + snapshot = std::make_shared(); + snapshot->set_execution_platform(backend_->platform()->Name()); + *snapshot->mutable_hlo() = *executable_->hlo_proto(); + for (const ShapedBuffer* arg : arguments) { + auto literal = std::make_shared(arg->on_host_shape()); + backend_->transfer_manager()->TransferLiteralFromDevice( + stream, *arg, literal.get(), [snapshot, literal](Status status) { + if (!status.ok()) { + LOG(ERROR) << "TransferLiteralFromDevice for HLO snapshot inputs " + "failed: " + << status; + return; + } + *snapshot->add_arguments() = literal->ToProto(); + }); + } + } + + TF_ASSIGN_OR_RETURN( + ScopedShapedBuffer outputs, + executable_->ExecuteAsyncOnStream(&options_and_stream.first, arguments)); + + // Transfer the outputs and save the snapshot to disk. + if (snapshot) { + auto literal = std::make_shared(outputs.on_host_shape()); + backend_->transfer_manager()->TransferLiteralFromDevice( + stream, outputs, literal.get(), [snapshot, literal](Status status) { + if (status.ok()) { + *snapshot->mutable_result() = literal->ToProto(); + } else { + LOG(ERROR) + << "TransferLiteralFromDevice for HLO snapshot outputs failed: " + << status; + } + DumpHloSnapshotIfEnabled(*snapshot, GetDebugOptionsFromFlags()); + }); + } + + return std::move(outputs); } StatusOr LocalExecutable::ExecuteAndDump(