Merge pull request #39383 from tensorflow/mm-cherrypick-80851c0ad16ee3f4e4daa9c7d8f1fa97af9aeca8-on-r1.15
Manual cherrypick of #38292
This commit is contained in:
commit
ef48214f85
@ -20,6 +20,7 @@ cc_library(
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/xla:executable_run_options",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/compiler/xla/executable_run_options.h"
|
||||
#include "tensorflow/compiler/xla/service/compiler.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
@ -383,6 +384,18 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
||||
run_options.set_allocator(allocator);
|
||||
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
|
||||
run_options.set_rng_seed(GetXLARandomSeed());
|
||||
xla::ThenExecuteFunction then_execute;
|
||||
if (ctx->op_device_context()) {
|
||||
then_execute = [&](se::Stream* stream, std::function<void()> fn) {
|
||||
Status status = ctx->op_device_context()->ThenExecute(
|
||||
static_cast<Device*>(ctx->device()), stream, std::move(fn));
|
||||
if (!status.ok()) {
|
||||
// This should never happen.
|
||||
LOG(ERROR) << "ThenExecute failed " << status;
|
||||
}
|
||||
};
|
||||
run_options.set_then_execute_function(&then_execute);
|
||||
}
|
||||
Env* env = Env::Default();
|
||||
auto start_time = env->NowMicros();
|
||||
|
||||
@ -579,6 +592,18 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
|
||||
run_options.set_allocator(allocator);
|
||||
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
|
||||
run_options.set_rng_seed(GetXLARandomSeed());
|
||||
xla::ThenExecuteFunction then_execute;
|
||||
if (ctx->op_device_context()) {
|
||||
then_execute = [&](se::Stream* stream, std::function<void()> fn) {
|
||||
Status status = ctx->op_device_context()->ThenExecute(
|
||||
static_cast<Device*>(ctx->device()), stream, std::move(fn));
|
||||
if (!status.ok()) {
|
||||
// This should never happen.
|
||||
LOG(ERROR) << "ThenExecute failed " << status;
|
||||
}
|
||||
};
|
||||
run_options.set_then_execute_function(&then_execute);
|
||||
}
|
||||
Env* env = Env::Default();
|
||||
auto start_time = env->NowMicros();
|
||||
|
||||
|
@ -64,6 +64,15 @@ class RunId {
|
||||
int64 data_;
|
||||
};
|
||||
|
||||
// Callback used by the GPU backend only. This is an "one-sided" version of
|
||||
// ThenDoHostCallback that enqueues a callback onto a stream. The difference
|
||||
// with ThenDoHostCallback is that the device does not block waiting for the
|
||||
// callback to complete; instead the callback is scheduled by the runtime.
|
||||
// This functionality must be provided by the caller, and hence is provided in
|
||||
// callback form.
|
||||
using ThenExecuteFunction =
|
||||
std::function<void(stream_executor::Stream*, std::function<void()>)>;
|
||||
|
||||
// Class containing options for running a LocalExecutable.
|
||||
class ExecutableRunOptions {
|
||||
public:
|
||||
@ -119,6 +128,15 @@ class ExecutableRunOptions {
|
||||
ExecutableRunOptions& set_run_id(RunId id);
|
||||
RunId run_id() const;
|
||||
|
||||
// See documentation on ThenExecuteFunction.
|
||||
ExecutableRunOptions& set_then_execute_function(ThenExecuteFunction* f) {
|
||||
then_execute_function_ = f;
|
||||
return *this;
|
||||
}
|
||||
ThenExecuteFunction* then_execute_function() const {
|
||||
return then_execute_function_;
|
||||
}
|
||||
|
||||
private:
|
||||
stream_executor::DeviceMemoryAllocator* allocator_ = nullptr;
|
||||
int device_ordinal_ = -1;
|
||||
@ -128,6 +146,7 @@ class ExecutableRunOptions {
|
||||
ExecutionProfile* execution_profile_ = nullptr;
|
||||
int rng_seed_ = 0;
|
||||
stream_executor::Stream* host_to_device_stream_ = nullptr;
|
||||
ThenExecuteFunction* then_execute_function_ = nullptr;
|
||||
RunId run_id_;
|
||||
};
|
||||
|
||||
|
@ -53,7 +53,7 @@ Status GenericTransferManager::WriteSingleTupleIndexTable(
|
||||
TF_RETURN_IF_ERROR(TransferBufferToDevice(
|
||||
stream, GetByteSizeRequirement(shape), element_pointers->data(), region));
|
||||
// Ensure the buffer is transferred before we destroy element_pointers.
|
||||
stream->ThenRunAfterNextBlockHostUntilDone([element_pointers]() {
|
||||
stream->ThenDoHostCallback([element_pointers{std::move(element_pointers)}]() {
|
||||
/* holds reference to element_pointers in closure */
|
||||
});
|
||||
return Status::OK();
|
||||
|
@ -68,7 +68,8 @@ Status ConvolutionThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
ptrs[1] = scratch.opaque();
|
||||
se::DeviceMemory<void*> tuple_addr(
|
||||
buffer_allocations.GetDeviceAddress(tuple_result_buffer_));
|
||||
SafeH2DMemcpy(tuple_addr, std::move(ptrs), kNumOutputs, params.stream);
|
||||
SafeH2DMemcpy(tuple_addr, std::move(ptrs), kNumOutputs, params.stream,
|
||||
params.deferred_host_callbacks);
|
||||
|
||||
if (!params.stream->ok()) {
|
||||
return InternalError("ConvolutionThunk::ExecuteOnStream failed.");
|
||||
|
@ -221,7 +221,8 @@ Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream(
|
||||
ptrs[2] = output_inv_stddev.opaque();
|
||||
se::DeviceMemory<void*> tuple_addr(
|
||||
buffer_allocations.GetDeviceAddress(output_tuple_));
|
||||
SafeH2DMemcpy(tuple_addr, std::move(ptrs), kNumOutputs, &stream);
|
||||
SafeH2DMemcpy(tuple_addr, std::move(ptrs), kNumOutputs, &stream,
|
||||
params.deferred_host_callbacks);
|
||||
if (!stream.ok()) {
|
||||
return InternalError("BatchNormalizationTraining call failed.");
|
||||
}
|
||||
@ -302,7 +303,8 @@ Status CudnnBatchNormBackwardThunk::ExecuteOnStream(
|
||||
ptrs[2] = output_grad_offset.opaque();
|
||||
se::DeviceMemory<void*> tuple_addr(
|
||||
buffer_allocations.GetDeviceAddress(output_tuple_));
|
||||
SafeH2DMemcpy(tuple_addr, std::move(ptrs), kNumOutputs, &stream);
|
||||
SafeH2DMemcpy(tuple_addr, std::move(ptrs), kNumOutputs, &stream,
|
||||
params.deferred_host_callbacks);
|
||||
|
||||
if (!stream.ok()) {
|
||||
return InternalError("BatchNormalizationBackward call failed.");
|
||||
|
@ -163,7 +163,8 @@ Status CustomCallThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
}
|
||||
SafeH2DMemcpy(se::DeviceMemory<void*>(
|
||||
params.buffer_allocations->GetDeviceAddress(slice)),
|
||||
std::move(tuple_ptrs), n, stream);
|
||||
std::move(tuple_ptrs), n, stream,
|
||||
params.deferred_host_callbacks);
|
||||
return Status::OK();
|
||||
});
|
||||
}
|
||||
|
@ -160,6 +160,7 @@ Status GpuExecutable::ExecuteThunks(
|
||||
|
||||
std::map<const Thunk*, std::unique_ptr<se::Event>> thunk_to_finish_event;
|
||||
bool scoped_annotation_enabled = ScopedAnnotation::IsEnabled();
|
||||
std::vector<std::function<void()>> deferred_host_callbacks;
|
||||
for (Thunk* thunk : thunk_schedule_->TotalOrder()) {
|
||||
// Annotate execution of this op if tracing was enabled when we started
|
||||
// running this module. If tracing is enabled *while* we're running the
|
||||
@ -184,8 +185,12 @@ Status GpuExecutable::ExecuteThunks(
|
||||
<< thunk->hlo_instruction()->ToString() << " on stream "
|
||||
<< stream_no;
|
||||
Thunk::ExecuteParams thunk_params{
|
||||
&buffer_allocations, stream, run_options->run_options().run_id(),
|
||||
&profiler, run_options->run_options().device_assignment()};
|
||||
&buffer_allocations,
|
||||
stream,
|
||||
run_options->run_options().run_id(),
|
||||
&profiler,
|
||||
run_options->run_options().device_assignment(),
|
||||
&deferred_host_callbacks};
|
||||
TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(thunk_params));
|
||||
if (thunk_schedule_->Depended(thunk)) {
|
||||
auto finish_event = absl::make_unique<se::Event>(main_stream->parent());
|
||||
@ -196,6 +201,19 @@ Status GpuExecutable::ExecuteThunks(
|
||||
}
|
||||
|
||||
main_stream->ThenWaitFor(&sub_streams);
|
||||
if (!deferred_host_callbacks.empty()) {
|
||||
auto fn = [deferred_host_callbacks{std::move(deferred_host_callbacks)}]() {
|
||||
for (auto& callback : deferred_host_callbacks) {
|
||||
callback();
|
||||
}
|
||||
};
|
||||
if (run_options->run_options().then_execute_function()) {
|
||||
(*run_options->run_options().then_execute_function())(main_stream,
|
||||
std::move(fn));
|
||||
} else {
|
||||
main_stream->ThenDoHostCallback(std::move(fn));
|
||||
}
|
||||
}
|
||||
// Make sure kernels are completed before deallocating temporary buffers or
|
||||
// the profiler state.
|
||||
// TODO(b/30100571): we could potentially postpone deallocating the temp
|
||||
|
@ -95,8 +95,9 @@ class Thunk {
|
||||
const BufferAllocations* buffer_allocations; // never null
|
||||
se::Stream* stream;
|
||||
RunId run_id;
|
||||
HloExecutionProfiler* profiler; // never null
|
||||
const DeviceAssignment* device_assn; // never null
|
||||
HloExecutionProfiler* profiler; // never null
|
||||
const DeviceAssignment* device_assn; // never null
|
||||
std::vector<std::function<void()>>* deferred_host_callbacks; // never null
|
||||
};
|
||||
|
||||
// Execute the kernel for the thunk on the given stream. This method must be
|
||||
@ -114,11 +115,13 @@ class Thunk {
|
||||
// Safely copies the given buffer to the GPU, deleting it on the host only
|
||||
// after the copy has completed.
|
||||
template <typename T>
|
||||
void SafeH2DMemcpy(se::DeviceMemory<T> dest, std::unique_ptr<T[]> buf,
|
||||
int64 count, se::Stream* stream) {
|
||||
void SafeH2DMemcpy(
|
||||
se::DeviceMemory<T> dest, std::unique_ptr<T[]> buf, int64 count,
|
||||
se::Stream* stream,
|
||||
std::vector<std::function<void()>>* deferred_host_callbacks) {
|
||||
stream->ThenMemcpy(&dest, buf.get(), count * sizeof(T));
|
||||
auto* buf_raw = buf.release();
|
||||
stream->ThenRunAfterNextBlockHostUntilDone([buf_raw] { delete[] buf_raw; });
|
||||
deferred_host_callbacks->push_back([buf_raw] { delete[] buf_raw; });
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -37,7 +37,8 @@ Status TupleThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
|
||||
SafeH2DMemcpy(se::DeviceMemory<void*>(
|
||||
buffer_allocations.GetDeviceAddress(dest_buffer_)),
|
||||
std::move(tuple_data), n, &stream);
|
||||
std::move(tuple_data), n, &stream,
|
||||
params.deferred_host_callbacks);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user