diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc index f4c33541b8d..3ec3cda9d82 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc @@ -207,6 +207,13 @@ BaseGPUDevice::~BaseGPUDevice() { gtl::STLDeleteElements(&streams_); } +bool BaseGPUDevice::RequiresRecordingAccessedTensors() const { + // When there is no more than one stream, we release the tensor reference + // at the end of the kernel launch, instead of at the end of the kernel + // execution. + return streams_.size() > 1; +} + Status BaseGPUDevice::FillContextMap(const Graph* graph, DeviceContextMap* device_context_map) { VLOG(2) << "FillContextMap"; diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.h b/tensorflow/core/common_runtime/gpu/gpu_device.h index dde16dfafc5..ea785275aa2 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device.h +++ b/tensorflow/core/common_runtime/gpu/gpu_device.h @@ -51,7 +51,7 @@ class BaseGPUDevice : public LocalDevice { // GPU devices require the Op Compute method to save a reference to // any temporary tensors that are allocated until the Op execution // completes. - bool RequiresRecordingAccessedTensors() const override { return true; } + bool RequiresRecordingAccessedTensors() const override; void ConsumeListOfAccessedTensors( DeviceContext* device_context, diff --git a/tensorflow/core/common_runtime/gpu/gpu_util.cc b/tensorflow/core/common_runtime/gpu/gpu_util.cc index f34ac256d16..d4974ebbd97 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_util.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_util.cc @@ -173,9 +173,9 @@ void GPUUtil::DeviceToDeviceCopy(DeviceContext* send_dev_context, const Tensor* input, Tensor* output, StatusCallback done) { const DeviceBase::GpuDeviceInfo* dev_info = nullptr; - gpu::Stream* stream = nullptr; - Status s = - PrepareCopy(src, send_dev_context, *input, output, &dev_info, &stream); + gpu::Stream* send_stream = nullptr; + Status s = PrepareCopy(src, send_dev_context, *input, output, &dev_info, + &send_stream); if (!s.ok()) { done(s); return; @@ -187,20 +187,33 @@ void GPUUtil::DeviceToDeviceCopy(DeviceContext* send_dev_context, DeviceMemoryBase gpu_src_ptr(src_ptr, total_bytes); void* dst_ptr = GetBase(output); DeviceMemoryBase gpu_dst_ptr(dst_ptr, total_bytes); + // Since we want to use the memory from recv_stream in the send_stream, + // add a dependency to make sure the memory is truely free. + // TODO(zhengxq): remove this dependency when we switch to a better way + // to make sure the memory is free. + auto recv_stream = + static_cast(recv_dev_context)->stream(); + if (recv_stream == nullptr) { + done(errors::Internal("No recv gpu stream is available.")); + return; + } + send_stream->ThenWaitFor(recv_stream); + VLOG(2) << "src_ptr " << src_ptr << " dst_ptr " << dst_ptr; - stream->ThenMemcpy(&gpu_dst_ptr, gpu_src_ptr, total_bytes); + send_stream->ThenMemcpy(&gpu_dst_ptr, gpu_src_ptr, total_bytes); } // Use of input may outlive stack scope, so keep a ref. TensorReference input_ref(*input); - dev_info->event_mgr->ThenExecute(stream, [done, stream, input_ref]() { - input_ref.Unref(); - if (!stream->ok()) { - LOG(FATAL) << "GPU->GPU Memcpy failed"; - } - done(Status::OK()); - }); - send_dev_context->MaintainLifetimeOnStream(input, stream); + dev_info->event_mgr->ThenExecute(send_stream, + [done, send_stream, input_ref]() { + input_ref.Unref(); + if (!send_stream->ok()) { + LOG(FATAL) << "GPU->GPU Memcpy failed"; + } + done(Status::OK()); + }); + send_dev_context->MaintainLifetimeOnStream(input, send_stream); } static CopyTensor::Registration register_gpu_gpu_copy( diff --git a/tensorflow/core/kernels/conv_ops_gpu.h b/tensorflow/core/kernels/conv_ops_gpu.h index 8f58d3af266..8a9d97f25c6 100644 --- a/tensorflow/core/kernels/conv_ops_gpu.h +++ b/tensorflow/core/kernels/conv_ops_gpu.h @@ -67,6 +67,9 @@ class CudnnScratchAllocator : public perftools::gputools::ScratchAllocator { perftools::gputools::DeviceMemory>( AsDeviceMemory(nullptr, 0)); } + // Hold the reference of the allocated tensors until the end of the + // allocator. + allocated_tensors_.push_back(temporary_memory); return perftools::gputools::port::StatusOr< perftools::gputools::DeviceMemory>( AsDeviceMemory(temporary_memory.flat().data(), @@ -76,6 +79,7 @@ class CudnnScratchAllocator : public perftools::gputools::ScratchAllocator { private: int64 memory_limit_; OpKernelContext* context_; + std::vector allocated_tensors_; }; } // namespace tensorflow