Make gpu_lib for non-cuda deps that we use in public kernels.

Change: 115598732
This commit is contained in:
Vijay Vasudevan 2016-02-25 13:39:24 -08:00 committed by TensorFlower Gardener
parent a5f3979004
commit a82f7e6b55
3 changed files with 36 additions and 8 deletions
tensorflow

View File

@ -750,6 +750,27 @@ cc_library(
],
)
# Libraries for GPU facilities that are useful for writing kernels.
cc_library(
name = "gpu_lib",
srcs = [
"common_runtime/gpu/gpu_event_mgr.cc",
],
hdrs = [
"common_runtime/gpu/gpu_event_mgr.h",
],
copts = tf_copts(),
visibility = ["//visibility:public"],
deps = [
":framework",
":framework_internal",
":lib",
":lib_internal",
":protos_all_cc",
":stream_executor",
],
)
# -----------------------------------------------------------------------------
# Internal targets
@ -1002,9 +1023,13 @@ tf_cuda_library(
exclude = [
"**/*main.cc",
"**/*test.cc",
"common_runtime/gpu/gpu_event_mgr.cc",
],
),
hdrs = glob(["common_runtime/gpu/*.h"]),
hdrs = glob(
["common_runtime/gpu/*.h"],
exclude = ["common_runtime/gpu/gpu_event_mgr.h"],
),
copts = tf_copts(),
cuda_deps = [
":cuda",
@ -1015,6 +1040,7 @@ tf_cuda_library(
":core_cpu_internal",
":framework",
":framework_internal",
":gpu_lib",
":lib",
":lib_internal",
":protos_all_cc",

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#if GOOGLE_CUDA
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
#include "tensorflow/core/common_runtime/gpu_device_context.h"
#include "tensorflow/core/platform/stream_executor.h"
#endif // GOOGLE_CUDA
@ -229,10 +230,11 @@ class SplitOpGPU : public SplitOpBase<GPUDevice, T> {
perftools::gputools::DeviceMemoryBase output_ptrs_base{
output_ptrs_on_gpu.flat<int8>().data(), static_cast<uint64>(num_split)};
TensorReference tensor_ref(output_ptrs_on_host);
stream
->ThenMemcpy(&output_ptrs_base, output_ptrs_on_host.flat<int8>().data(),
output_ptrs_total_bytes)
.ThenDoHostCallback([tensor_ref]() { tensor_ref.Unref(); });
stream->ThenMemcpy(&output_ptrs_base,
output_ptrs_on_host.flat<int8>().data(),
output_ptrs_total_bytes);
context->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute(
stream, [tensor_ref]() { tensor_ref.Unref(); });
SplitOpGPULaunch<T>().Run(
context->eigen_device<GPUDevice>(), input.flat<T>().data(), num_split,
prefix_dim_size, split_dim_size, suffix_dim_size,

View File

@ -274,15 +274,15 @@ def tf_kernel_library(name, prefix=None, srcs=None, gpu_srcs=None, hdrs=None,
srcs = srcs + native.glob([prefix + "*.cc"],
exclude = ["*test*", "*.cu.cc"])
hdrs = hdrs + native.glob([prefix + "*.h"], exclude = ["*test*", "*.cu.h"])
cuda_deps = ["//tensorflow/core:gpu_lib"]
if gpu_srcs:
tf_gpu_kernel_library(
name = name + "_gpu",
srcs = gpu_srcs,
deps = gpu_deps,
**kwargs)
cuda_deps = [":" + name + "_gpu"]
else:
cuda_deps = None
cuda_deps.extend([":" + name + "_gpu"])
tf_cuda_library(
name = name,
srcs = srcs,