Make gpu_lib for non-cuda deps that we use in public kernels.
Change: 115598732
This commit is contained in:
parent
a5f3979004
commit
a82f7e6b55
tensorflow
@ -750,6 +750,27 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
# Libraries for GPU facilities that are useful for writing kernels.
|
||||
cc_library(
|
||||
name = "gpu_lib",
|
||||
srcs = [
|
||||
"common_runtime/gpu/gpu_event_mgr.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"common_runtime/gpu/gpu_event_mgr.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":framework",
|
||||
":framework_internal",
|
||||
":lib",
|
||||
":lib_internal",
|
||||
":protos_all_cc",
|
||||
":stream_executor",
|
||||
],
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Internal targets
|
||||
|
||||
@ -1002,9 +1023,13 @@ tf_cuda_library(
|
||||
exclude = [
|
||||
"**/*main.cc",
|
||||
"**/*test.cc",
|
||||
"common_runtime/gpu/gpu_event_mgr.cc",
|
||||
],
|
||||
),
|
||||
hdrs = glob(["common_runtime/gpu/*.h"]),
|
||||
hdrs = glob(
|
||||
["common_runtime/gpu/*.h"],
|
||||
exclude = ["common_runtime/gpu/gpu_event_mgr.h"],
|
||||
),
|
||||
copts = tf_copts(),
|
||||
cuda_deps = [
|
||||
":cuda",
|
||||
@ -1015,6 +1040,7 @@ tf_cuda_library(
|
||||
":core_cpu_internal",
|
||||
":framework",
|
||||
":framework_internal",
|
||||
":gpu_lib",
|
||||
":lib",
|
||||
":lib_internal",
|
||||
":protos_all_cc",
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#if GOOGLE_CUDA
|
||||
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/gpu_device_context.h"
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
#endif // GOOGLE_CUDA
|
||||
@ -229,10 +230,11 @@ class SplitOpGPU : public SplitOpBase<GPUDevice, T> {
|
||||
perftools::gputools::DeviceMemoryBase output_ptrs_base{
|
||||
output_ptrs_on_gpu.flat<int8>().data(), static_cast<uint64>(num_split)};
|
||||
TensorReference tensor_ref(output_ptrs_on_host);
|
||||
stream
|
||||
->ThenMemcpy(&output_ptrs_base, output_ptrs_on_host.flat<int8>().data(),
|
||||
output_ptrs_total_bytes)
|
||||
.ThenDoHostCallback([tensor_ref]() { tensor_ref.Unref(); });
|
||||
stream->ThenMemcpy(&output_ptrs_base,
|
||||
output_ptrs_on_host.flat<int8>().data(),
|
||||
output_ptrs_total_bytes);
|
||||
context->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute(
|
||||
stream, [tensor_ref]() { tensor_ref.Unref(); });
|
||||
SplitOpGPULaunch<T>().Run(
|
||||
context->eigen_device<GPUDevice>(), input.flat<T>().data(), num_split,
|
||||
prefix_dim_size, split_dim_size, suffix_dim_size,
|
||||
|
@ -274,15 +274,15 @@ def tf_kernel_library(name, prefix=None, srcs=None, gpu_srcs=None, hdrs=None,
|
||||
srcs = srcs + native.glob([prefix + "*.cc"],
|
||||
exclude = ["*test*", "*.cu.cc"])
|
||||
hdrs = hdrs + native.glob([prefix + "*.h"], exclude = ["*test*", "*.cu.h"])
|
||||
|
||||
cuda_deps = ["//tensorflow/core:gpu_lib"]
|
||||
if gpu_srcs:
|
||||
tf_gpu_kernel_library(
|
||||
name = name + "_gpu",
|
||||
srcs = gpu_srcs,
|
||||
deps = gpu_deps,
|
||||
**kwargs)
|
||||
cuda_deps = [":" + name + "_gpu"]
|
||||
else:
|
||||
cuda_deps = None
|
||||
cuda_deps.extend([":" + name + "_gpu"])
|
||||
tf_cuda_library(
|
||||
name = name,
|
||||
srcs = srcs,
|
||||
|
Loading…
Reference in New Issue
Block a user