[XLA:GPU] Don't defer host callbacks until BlockHostUntilDone().

Some clients (e.g., TF2, JAX) may never call BlockHostUntilDone(), so this is essentially a memory leak.

For simplicity, use a blocking callback instead, but batch deallocations together to minimize overhead.

PiperOrigin-RevId: 294696009
Change-Id: I9527d2f800550168518eca5216afbaa5c2d45672
This commit is contained in:
Peter Hawkins 2020-02-12 10:17:46 -08:00 committed by TensorFlower Gardener
parent 438a63ff08
commit 80851c0ad1
10 changed files with 85 additions and 13 deletions

View File

@ -20,6 +20,7 @@ XLA_OPS_DEPS = [
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:tf2xla_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:client_library",

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/executable_run_options.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
@ -41,6 +42,7 @@ limitations under the License.
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/profiler/lib/traceme.h"
@ -384,6 +386,18 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
run_options.set_allocator(allocator);
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
run_options.set_rng_seed(GetXLARandomSeed());
xla::ThenExecuteFunction then_execute;
if (ctx->op_device_context()) {
then_execute = [&](se::Stream* stream, std::function<void()> fn) {
Status status = ctx->op_device_context()->ThenExecute(
down_cast<Device*>(ctx->device()), stream, std::move(fn));
if (!status.ok()) {
// This should never happen.
LOG(ERROR) << "ThenExecute failed " << status;
}
};
run_options.set_then_execute_function(&then_execute);
}
Env* env = Env::Default();
auto start_time = env->NowMicros();
@ -592,6 +606,18 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
run_options.set_allocator(allocator);
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
run_options.set_rng_seed(GetXLARandomSeed());
xla::ThenExecuteFunction then_execute;
if (ctx->op_device_context()) {
then_execute = [&](se::Stream* stream, std::function<void()> fn) {
Status status = ctx->op_device_context()->ThenExecute(
down_cast<Device*>(ctx->device()), stream, std::move(fn));
if (!status.ok()) {
// This should never happen.
LOG(ERROR) << "ThenExecute failed " << status;
}
};
run_options.set_then_execute_function(&then_execute);
}
Env* env = Env::Default();
auto start_time = env->NowMicros();

View File

@ -64,6 +64,15 @@ class RunId {
int64 data_;
};
// Callback used by the GPU backend only. This is an "one-sided" version of
// ThenDoHostCallback that enqueues a callback onto a stream. The difference
// with ThenDoHostCallback is that the device does not block waiting for the
// callback to complete; instead the callback is scheduled by the runtime.
// This functionality must be provided by the caller, and hence is provided in
// callback form.
using ThenExecuteFunction =
std::function<void(stream_executor::Stream*, std::function<void()>)>;
// Class containing options for running a LocalExecutable.
class ExecutableRunOptions {
public:
@ -119,6 +128,15 @@ class ExecutableRunOptions {
ExecutableRunOptions& set_run_id(RunId id);
RunId run_id() const;
// See documentation on ThenExecuteFunction.
ExecutableRunOptions& set_then_execute_function(ThenExecuteFunction* f) {
then_execute_function_ = f;
return *this;
}
ThenExecuteFunction* then_execute_function() const {
return then_execute_function_;
}
private:
stream_executor::DeviceMemoryAllocator* allocator_ = nullptr;
int device_ordinal_ = -1;
@ -128,6 +146,7 @@ class ExecutableRunOptions {
ExecutionProfile* execution_profile_ = nullptr;
int rng_seed_ = 0;
stream_executor::Stream* host_to_device_stream_ = nullptr;
ThenExecuteFunction* then_execute_function_ = nullptr;
RunId run_id_;
};

View File

@ -52,7 +52,7 @@ Status GenericTransferManager::WriteSingleTupleIndexTable(
TF_RETURN_IF_ERROR(TransferBufferToDevice(
stream, GetByteSizeRequirement(shape), element_pointers->data(), region));
// Ensure the buffer is transferred before we destroy element_pointers.
stream->ThenRunAfterNextBlockHostUntilDone([element_pointers]() {
stream->ThenDoHostCallback([element_pointers{std::move(element_pointers)}]() {
/* holds reference to element_pointers in closure */
});
return Status::OK();

View File

@ -67,7 +67,8 @@ Status ConvolutionThunk::ExecuteOnStream(const ExecuteParams& params) {
ptrs[1] = scratch.opaque();
se::DeviceMemory<void*> tuple_addr(
buffer_allocations.GetDeviceAddress(tuple_result_buffer_));
SafeH2DMemcpy(tuple_addr, std::move(ptrs), kNumOutputs, params.stream);
SafeH2DMemcpy(tuple_addr, std::move(ptrs), kNumOutputs, params.stream,
params.deferred_host_callbacks);
if (!params.stream->ok()) {
return InternalError("ConvolutionThunk::ExecuteOnStream failed.");

View File

@ -194,7 +194,8 @@ Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream(
ptrs[2] = output_inv_stddev.opaque();
se::DeviceMemory<void*> tuple_addr(
buffer_allocations.GetDeviceAddress(output_tuple_));
SafeH2DMemcpy(tuple_addr, std::move(ptrs), kNumOutputs, &stream);
SafeH2DMemcpy(tuple_addr, std::move(ptrs), kNumOutputs, &stream,
params.deferred_host_callbacks);
if (!stream.ok()) {
return InternalError("BatchNormalizationTraining call failed.");
}
@ -264,7 +265,8 @@ Status CudnnBatchNormBackwardThunk::ExecuteOnStream(
ptrs[2] = output_grad_offset.opaque();
se::DeviceMemory<void*> tuple_addr(
buffer_allocations.GetDeviceAddress(output_tuple_));
SafeH2DMemcpy(tuple_addr, std::move(ptrs), kNumOutputs, stream);
SafeH2DMemcpy(tuple_addr, std::move(ptrs), kNumOutputs, stream,
params.deferred_host_callbacks);
if (!stream->ok()) {
return InternalError("BatchNormalizationBackward call failed.");

View File

@ -162,7 +162,8 @@ Status CustomCallThunk::ExecuteOnStream(const ExecuteParams& params) {
}
SafeH2DMemcpy(se::DeviceMemory<void*>(
params.buffer_allocations->GetDeviceAddress(slice)),
std::move(tuple_ptrs), n, stream);
std::move(tuple_ptrs), n, stream,
params.deferred_host_callbacks);
return Status::OK();
});
}

View File

@ -172,6 +172,7 @@ Status GpuExecutable::ExecuteThunks(
std::map<const Thunk*, std::unique_ptr<se::Event>> thunk_to_finish_event;
bool scoped_annotation_enabled = ScopedAnnotation::IsEnabled();
std::vector<std::function<void()>> deferred_host_callbacks;
for (Thunk* thunk : thunk_schedule_->TotalOrder()) {
// Annotate execution of this op if tracing was enabled when we started
// running this module. If tracing is enabled *while* we're running the
@ -196,8 +197,12 @@ Status GpuExecutable::ExecuteThunks(
<< thunk->hlo_instruction()->ToString() << " on stream "
<< stream_no;
Thunk::ExecuteParams thunk_params{
&buffer_allocations, stream, run_options->run_options().run_id(),
&profiler, run_options->run_options().device_assignment()};
&buffer_allocations,
stream,
run_options->run_options().run_id(),
&profiler,
run_options->run_options().device_assignment(),
&deferred_host_callbacks};
TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(thunk_params));
if (thunk_schedule_->Depended(thunk)) {
auto finish_event = absl::make_unique<se::Event>(main_stream->parent());
@ -208,6 +213,19 @@ Status GpuExecutable::ExecuteThunks(
}
main_stream->ThenWaitFor(&sub_streams);
if (!deferred_host_callbacks.empty()) {
auto fn = [deferred_host_callbacks{std::move(deferred_host_callbacks)}]() {
for (auto& callback : deferred_host_callbacks) {
callback();
}
};
if (run_options->run_options().then_execute_function()) {
(*run_options->run_options().then_execute_function())(main_stream,
std::move(fn));
} else {
main_stream->ThenDoHostCallback(std::move(fn));
}
}
// Make sure kernels are completed before deallocating temporary buffers or
// the profiler state.
// TODO(b/30100571): we could potentially postpone deallocating the temp

View File

@ -95,8 +95,9 @@ class Thunk {
const BufferAllocations* buffer_allocations; // never null
se::Stream* stream;
RunId run_id;
HloExecutionProfiler* profiler; // never null
const DeviceAssignment* device_assn; // never null
HloExecutionProfiler* profiler; // never null
const DeviceAssignment* device_assn; // never null
std::vector<std::function<void()>>* deferred_host_callbacks; // never null
};
// Execute the kernel for the thunk on the given stream. This method must be
@ -114,11 +115,13 @@ class Thunk {
// Safely copies the given buffer to the GPU, deleting it on the host only
// after the copy has completed.
template <typename T>
void SafeH2DMemcpy(se::DeviceMemory<T> dest, std::unique_ptr<T[]> buf,
int64 count, se::Stream* stream) {
void SafeH2DMemcpy(
se::DeviceMemory<T> dest, std::unique_ptr<T[]> buf, int64 count,
se::Stream* stream,
std::vector<std::function<void()>>* deferred_host_callbacks) {
stream->ThenMemcpy(&dest, buf.get(), count * sizeof(T));
auto* buf_raw = buf.release();
stream->ThenRunAfterNextBlockHostUntilDone([buf_raw] { delete[] buf_raw; });
deferred_host_callbacks->push_back([buf_raw] { delete[] buf_raw; });
}
private:

View File

@ -37,7 +37,8 @@ Status TupleThunk::ExecuteOnStream(const ExecuteParams& params) {
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
SafeH2DMemcpy(se::DeviceMemory<void*>(
buffer_allocations.GetDeviceAddress(dest_buffer_)),
std::move(tuple_data), n, &stream);
std::move(tuple_data), n, &stream,
params.deferred_host_callbacks);
return Status::OK();
}