diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD index 3fbd977cadb..af12468446c 100644 --- a/tensorflow/compiler/jit/kernels/BUILD +++ b/tensorflow/compiler/jit/kernels/BUILD @@ -20,6 +20,7 @@ cc_library( "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:tf2xla_util", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:client_library", diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index 8594b1ec39d..0b5bdffb259 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -383,6 +384,18 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { run_options.set_allocator(allocator); run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); run_options.set_rng_seed(GetXLARandomSeed()); + xla::ThenExecuteFunction then_execute; + if (ctx->op_device_context()) { + then_execute = [&](se::Stream* stream, std::function fn) { + Status status = ctx->op_device_context()->ThenExecute( + static_cast(ctx->device()), stream, std::move(fn)); + if (!status.ok()) { + // This should never happen. + LOG(ERROR) << "ThenExecute failed " << status; + } + }; + run_options.set_then_execute_function(&then_execute); + } Env* env = Env::Default(); auto start_time = env->NowMicros(); @@ -579,6 +592,18 @@ void XlaRunOp::Compute(OpKernelContext* ctx) { run_options.set_allocator(allocator); run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); run_options.set_rng_seed(GetXLARandomSeed()); + xla::ThenExecuteFunction then_execute; + if (ctx->op_device_context()) { + then_execute = [&](se::Stream* stream, std::function fn) { + Status status = ctx->op_device_context()->ThenExecute( + static_cast(ctx->device()), stream, std::move(fn)); + if (!status.ok()) { + // This should never happen. + LOG(ERROR) << "ThenExecute failed " << status; + } + }; + run_options.set_then_execute_function(&then_execute); + } Env* env = Env::Default(); auto start_time = env->NowMicros(); diff --git a/tensorflow/compiler/xla/executable_run_options.h b/tensorflow/compiler/xla/executable_run_options.h index 4de8148451b..ed67bfbeb0d 100644 --- a/tensorflow/compiler/xla/executable_run_options.h +++ b/tensorflow/compiler/xla/executable_run_options.h @@ -64,6 +64,15 @@ class RunId { int64 data_; }; +// Callback used by the GPU backend only. This is an "one-sided" version of +// ThenDoHostCallback that enqueues a callback onto a stream. The difference +// with ThenDoHostCallback is that the device does not block waiting for the +// callback to complete; instead the callback is scheduled by the runtime. +// This functionality must be provided by the caller, and hence is provided in +// callback form. +using ThenExecuteFunction = + std::function)>; + // Class containing options for running a LocalExecutable. class ExecutableRunOptions { public: @@ -119,6 +128,15 @@ class ExecutableRunOptions { ExecutableRunOptions& set_run_id(RunId id); RunId run_id() const; + // See documentation on ThenExecuteFunction. + ExecutableRunOptions& set_then_execute_function(ThenExecuteFunction* f) { + then_execute_function_ = f; + return *this; + } + ThenExecuteFunction* then_execute_function() const { + return then_execute_function_; + } + private: stream_executor::DeviceMemoryAllocator* allocator_ = nullptr; int device_ordinal_ = -1; @@ -128,6 +146,7 @@ class ExecutableRunOptions { ExecutionProfile* execution_profile_ = nullptr; int rng_seed_ = 0; stream_executor::Stream* host_to_device_stream_ = nullptr; + ThenExecuteFunction* then_execute_function_ = nullptr; RunId run_id_; }; diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc index d65083d701a..5a97998f1f0 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -53,7 +53,7 @@ Status GenericTransferManager::WriteSingleTupleIndexTable( TF_RETURN_IF_ERROR(TransferBufferToDevice( stream, GetByteSizeRequirement(shape), element_pointers->data(), region)); // Ensure the buffer is transferred before we destroy element_pointers. - stream->ThenRunAfterNextBlockHostUntilDone([element_pointers]() { + stream->ThenDoHostCallback([element_pointers{std::move(element_pointers)}]() { /* holds reference to element_pointers in closure */ }); return Status::OK(); diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc index a8408b6cb3f..a62ac4168f8 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc @@ -68,7 +68,8 @@ Status ConvolutionThunk::ExecuteOnStream(const ExecuteParams& params) { ptrs[1] = scratch.opaque(); se::DeviceMemory tuple_addr( buffer_allocations.GetDeviceAddress(tuple_result_buffer_)); - SafeH2DMemcpy(tuple_addr, std::move(ptrs), kNumOutputs, params.stream); + SafeH2DMemcpy(tuple_addr, std::move(ptrs), kNumOutputs, params.stream, + params.deferred_host_callbacks); if (!params.stream->ok()) { return InternalError("ConvolutionThunk::ExecuteOnStream failed."); diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc index b582bb51485..16ee918e549 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc @@ -221,7 +221,8 @@ Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream( ptrs[2] = output_inv_stddev.opaque(); se::DeviceMemory tuple_addr( buffer_allocations.GetDeviceAddress(output_tuple_)); - SafeH2DMemcpy(tuple_addr, std::move(ptrs), kNumOutputs, &stream); + SafeH2DMemcpy(tuple_addr, std::move(ptrs), kNumOutputs, &stream, + params.deferred_host_callbacks); if (!stream.ok()) { return InternalError("BatchNormalizationTraining call failed."); } @@ -302,7 +303,8 @@ Status CudnnBatchNormBackwardThunk::ExecuteOnStream( ptrs[2] = output_grad_offset.opaque(); se::DeviceMemory tuple_addr( buffer_allocations.GetDeviceAddress(output_tuple_)); - SafeH2DMemcpy(tuple_addr, std::move(ptrs), kNumOutputs, &stream); + SafeH2DMemcpy(tuple_addr, std::move(ptrs), kNumOutputs, &stream, + params.deferred_host_callbacks); if (!stream.ok()) { return InternalError("BatchNormalizationBackward call failed."); diff --git a/tensorflow/compiler/xla/service/gpu/custom_call_thunk.cc b/tensorflow/compiler/xla/service/gpu/custom_call_thunk.cc index 65673106391..7734801d8b7 100644 --- a/tensorflow/compiler/xla/service/gpu/custom_call_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/custom_call_thunk.cc @@ -163,7 +163,8 @@ Status CustomCallThunk::ExecuteOnStream(const ExecuteParams& params) { } SafeH2DMemcpy(se::DeviceMemory( params.buffer_allocations->GetDeviceAddress(slice)), - std::move(tuple_ptrs), n, stream); + std::move(tuple_ptrs), n, stream, + params.deferred_host_callbacks); return Status::OK(); }); } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index abf2cd1f23f..c43821438eb 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -160,6 +160,7 @@ Status GpuExecutable::ExecuteThunks( std::map> thunk_to_finish_event; bool scoped_annotation_enabled = ScopedAnnotation::IsEnabled(); + std::vector> deferred_host_callbacks; for (Thunk* thunk : thunk_schedule_->TotalOrder()) { // Annotate execution of this op if tracing was enabled when we started // running this module. If tracing is enabled *while* we're running the @@ -184,8 +185,12 @@ Status GpuExecutable::ExecuteThunks( << thunk->hlo_instruction()->ToString() << " on stream " << stream_no; Thunk::ExecuteParams thunk_params{ - &buffer_allocations, stream, run_options->run_options().run_id(), - &profiler, run_options->run_options().device_assignment()}; + &buffer_allocations, + stream, + run_options->run_options().run_id(), + &profiler, + run_options->run_options().device_assignment(), + &deferred_host_callbacks}; TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(thunk_params)); if (thunk_schedule_->Depended(thunk)) { auto finish_event = absl::make_unique(main_stream->parent()); @@ -196,6 +201,19 @@ Status GpuExecutable::ExecuteThunks( } main_stream->ThenWaitFor(&sub_streams); + if (!deferred_host_callbacks.empty()) { + auto fn = [deferred_host_callbacks{std::move(deferred_host_callbacks)}]() { + for (auto& callback : deferred_host_callbacks) { + callback(); + } + }; + if (run_options->run_options().then_execute_function()) { + (*run_options->run_options().then_execute_function())(main_stream, + std::move(fn)); + } else { + main_stream->ThenDoHostCallback(std::move(fn)); + } + } // Make sure kernels are completed before deallocating temporary buffers or // the profiler state. // TODO(b/30100571): we could potentially postpone deallocating the temp diff --git a/tensorflow/compiler/xla/service/gpu/thunk.h b/tensorflow/compiler/xla/service/gpu/thunk.h index 20461248539..abf829cee00 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk.h +++ b/tensorflow/compiler/xla/service/gpu/thunk.h @@ -95,8 +95,9 @@ class Thunk { const BufferAllocations* buffer_allocations; // never null se::Stream* stream; RunId run_id; - HloExecutionProfiler* profiler; // never null - const DeviceAssignment* device_assn; // never null + HloExecutionProfiler* profiler; // never null + const DeviceAssignment* device_assn; // never null + std::vector>* deferred_host_callbacks; // never null }; // Execute the kernel for the thunk on the given stream. This method must be @@ -114,11 +115,13 @@ class Thunk { // Safely copies the given buffer to the GPU, deleting it on the host only // after the copy has completed. template - void SafeH2DMemcpy(se::DeviceMemory dest, std::unique_ptr buf, - int64 count, se::Stream* stream) { + void SafeH2DMemcpy( + se::DeviceMemory dest, std::unique_ptr buf, int64 count, + se::Stream* stream, + std::vector>* deferred_host_callbacks) { stream->ThenMemcpy(&dest, buf.get(), count * sizeof(T)); auto* buf_raw = buf.release(); - stream->ThenRunAfterNextBlockHostUntilDone([buf_raw] { delete[] buf_raw; }); + deferred_host_callbacks->push_back([buf_raw] { delete[] buf_raw; }); } private: diff --git a/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc b/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc index 495ccbe894d..cbbbb7baf68 100644 --- a/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc @@ -37,7 +37,8 @@ Status TupleThunk::ExecuteOnStream(const ExecuteParams& params) { params.profiler->MakeScopedInstructionProfiler(hlo_instruction()); SafeH2DMemcpy(se::DeviceMemory( buffer_allocations.GetDeviceAddress(dest_buffer_)), - std::move(tuple_data), n, &stream); + std::move(tuple_data), n, &stream, + params.deferred_host_callbacks); return Status::OK(); }