Merge pull request #39383 from tensorflow/mm-cherrypick-80851c0ad16ee3f4e4daa9c7d8f1fa97af9aeca8-on-r1.15

Manual cherrypick of #38292
This commit is contained in:
Mihai Maruseac 2020-05-10 19:54:36 +00:00 committed by GitHub
commit ef48214f85
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 84 additions and 13 deletions

View File

@ -20,6 +20,7 @@ cc_library(
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:tf2xla_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:client_library",

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/executable_run_options.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
@ -383,6 +384,18 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
run_options.set_allocator(allocator);
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
run_options.set_rng_seed(GetXLARandomSeed());
xla::ThenExecuteFunction then_execute;
if (ctx->op_device_context()) {
then_execute = [&](se::Stream* stream, std::function<void()> fn) {
Status status = ctx->op_device_context()->ThenExecute(
static_cast<Device*>(ctx->device()), stream, std::move(fn));
if (!status.ok()) {
// This should never happen.
LOG(ERROR) << "ThenExecute failed " << status;
}
};
run_options.set_then_execute_function(&then_execute);
}
Env* env = Env::Default();
auto start_time = env->NowMicros();
@ -579,6 +592,18 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
run_options.set_allocator(allocator);
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
run_options.set_rng_seed(GetXLARandomSeed());
xla::ThenExecuteFunction then_execute;
if (ctx->op_device_context()) {
then_execute = [&](se::Stream* stream, std::function<void()> fn) {
Status status = ctx->op_device_context()->ThenExecute(
static_cast<Device*>(ctx->device()), stream, std::move(fn));
if (!status.ok()) {
// This should never happen.
LOG(ERROR) << "ThenExecute failed " << status;
}
};
run_options.set_then_execute_function(&then_execute);
}
Env* env = Env::Default();
auto start_time = env->NowMicros();

View File

@ -64,6 +64,15 @@ class RunId {
int64 data_;
};
// Callback used by the GPU backend only. This is an "one-sided" version of
// ThenDoHostCallback that enqueues a callback onto a stream. The difference
// with ThenDoHostCallback is that the device does not block waiting for the
// callback to complete; instead the callback is scheduled by the runtime.
// This functionality must be provided by the caller, and hence is provided in
// callback form.
using ThenExecuteFunction =
std::function<void(stream_executor::Stream*, std::function<void()>)>;
// Class containing options for running a LocalExecutable.
class ExecutableRunOptions {
public:
@ -119,6 +128,15 @@ class ExecutableRunOptions {
ExecutableRunOptions& set_run_id(RunId id);
RunId run_id() const;
// See documentation on ThenExecuteFunction.
ExecutableRunOptions& set_then_execute_function(ThenExecuteFunction* f) {
then_execute_function_ = f;
return *this;
}
ThenExecuteFunction* then_execute_function() const {
return then_execute_function_;
}
private:
stream_executor::DeviceMemoryAllocator* allocator_ = nullptr;
int device_ordinal_ = -1;
@ -128,6 +146,7 @@ class ExecutableRunOptions {
ExecutionProfile* execution_profile_ = nullptr;
int rng_seed_ = 0;
stream_executor::Stream* host_to_device_stream_ = nullptr;
ThenExecuteFunction* then_execute_function_ = nullptr;
RunId run_id_;
};

View File

@ -53,7 +53,7 @@ Status GenericTransferManager::WriteSingleTupleIndexTable(
TF_RETURN_IF_ERROR(TransferBufferToDevice(
stream, GetByteSizeRequirement(shape), element_pointers->data(), region));
// Ensure the buffer is transferred before we destroy element_pointers.
stream->ThenRunAfterNextBlockHostUntilDone([element_pointers]() {
stream->ThenDoHostCallback([element_pointers{std::move(element_pointers)}]() {
/* holds reference to element_pointers in closure */
});
return Status::OK();

View File

@ -68,7 +68,8 @@ Status ConvolutionThunk::ExecuteOnStream(const ExecuteParams& params) {
ptrs[1] = scratch.opaque();
se::DeviceMemory<void*> tuple_addr(
buffer_allocations.GetDeviceAddress(tuple_result_buffer_));
SafeH2DMemcpy(tuple_addr, std::move(ptrs), kNumOutputs, params.stream);
SafeH2DMemcpy(tuple_addr, std::move(ptrs), kNumOutputs, params.stream,
params.deferred_host_callbacks);
if (!params.stream->ok()) {
return InternalError("ConvolutionThunk::ExecuteOnStream failed.");

View File

@ -221,7 +221,8 @@ Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream(
ptrs[2] = output_inv_stddev.opaque();
se::DeviceMemory<void*> tuple_addr(
buffer_allocations.GetDeviceAddress(output_tuple_));
SafeH2DMemcpy(tuple_addr, std::move(ptrs), kNumOutputs, &stream);
SafeH2DMemcpy(tuple_addr, std::move(ptrs), kNumOutputs, &stream,
params.deferred_host_callbacks);
if (!stream.ok()) {
return InternalError("BatchNormalizationTraining call failed.");
}
@ -302,7 +303,8 @@ Status CudnnBatchNormBackwardThunk::ExecuteOnStream(
ptrs[2] = output_grad_offset.opaque();
se::DeviceMemory<void*> tuple_addr(
buffer_allocations.GetDeviceAddress(output_tuple_));
SafeH2DMemcpy(tuple_addr, std::move(ptrs), kNumOutputs, &stream);
SafeH2DMemcpy(tuple_addr, std::move(ptrs), kNumOutputs, &stream,
params.deferred_host_callbacks);
if (!stream.ok()) {
return InternalError("BatchNormalizationBackward call failed.");

View File

@ -163,7 +163,8 @@ Status CustomCallThunk::ExecuteOnStream(const ExecuteParams& params) {
}
SafeH2DMemcpy(se::DeviceMemory<void*>(
params.buffer_allocations->GetDeviceAddress(slice)),
std::move(tuple_ptrs), n, stream);
std::move(tuple_ptrs), n, stream,
params.deferred_host_callbacks);
return Status::OK();
});
}

View File

@ -160,6 +160,7 @@ Status GpuExecutable::ExecuteThunks(
std::map<const Thunk*, std::unique_ptr<se::Event>> thunk_to_finish_event;
bool scoped_annotation_enabled = ScopedAnnotation::IsEnabled();
std::vector<std::function<void()>> deferred_host_callbacks;
for (Thunk* thunk : thunk_schedule_->TotalOrder()) {
// Annotate execution of this op if tracing was enabled when we started
// running this module. If tracing is enabled *while* we're running the
@ -184,8 +185,12 @@ Status GpuExecutable::ExecuteThunks(
<< thunk->hlo_instruction()->ToString() << " on stream "
<< stream_no;
Thunk::ExecuteParams thunk_params{
&buffer_allocations, stream, run_options->run_options().run_id(),
&profiler, run_options->run_options().device_assignment()};
&buffer_allocations,
stream,
run_options->run_options().run_id(),
&profiler,
run_options->run_options().device_assignment(),
&deferred_host_callbacks};
TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(thunk_params));
if (thunk_schedule_->Depended(thunk)) {
auto finish_event = absl::make_unique<se::Event>(main_stream->parent());
@ -196,6 +201,19 @@ Status GpuExecutable::ExecuteThunks(
}
main_stream->ThenWaitFor(&sub_streams);
if (!deferred_host_callbacks.empty()) {
auto fn = [deferred_host_callbacks{std::move(deferred_host_callbacks)}]() {
for (auto& callback : deferred_host_callbacks) {
callback();
}
};
if (run_options->run_options().then_execute_function()) {
(*run_options->run_options().then_execute_function())(main_stream,
std::move(fn));
} else {
main_stream->ThenDoHostCallback(std::move(fn));
}
}
// Make sure kernels are completed before deallocating temporary buffers or
// the profiler state.
// TODO(b/30100571): we could potentially postpone deallocating the temp

View File

@ -95,8 +95,9 @@ class Thunk {
const BufferAllocations* buffer_allocations; // never null
se::Stream* stream;
RunId run_id;
HloExecutionProfiler* profiler; // never null
const DeviceAssignment* device_assn; // never null
HloExecutionProfiler* profiler; // never null
const DeviceAssignment* device_assn; // never null
std::vector<std::function<void()>>* deferred_host_callbacks; // never null
};
// Execute the kernel for the thunk on the given stream. This method must be
@ -114,11 +115,13 @@ class Thunk {
// Safely copies the given buffer to the GPU, deleting it on the host only
// after the copy has completed.
template <typename T>
void SafeH2DMemcpy(se::DeviceMemory<T> dest, std::unique_ptr<T[]> buf,
int64 count, se::Stream* stream) {
void SafeH2DMemcpy(
se::DeviceMemory<T> dest, std::unique_ptr<T[]> buf, int64 count,
se::Stream* stream,
std::vector<std::function<void()>>* deferred_host_callbacks) {
stream->ThenMemcpy(&dest, buf.get(), count * sizeof(T));
auto* buf_raw = buf.release();
stream->ThenRunAfterNextBlockHostUntilDone([buf_raw] { delete[] buf_raw; });
deferred_host_callbacks->push_back([buf_raw] { delete[] buf_raw; });
}
private:

View File

@ -37,7 +37,8 @@ Status TupleThunk::ExecuteOnStream(const ExecuteParams& params) {
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
SafeH2DMemcpy(se::DeviceMemory<void*>(
buffer_allocations.GetDeviceAddress(dest_buffer_)),
std::move(tuple_data), n, &stream);
std::move(tuple_data), n, &stream,
params.deferred_host_callbacks);
return Status::OK();
}