diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index a6195945054..63af48679a7 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -230,6 +230,7 @@ tf_kernel_library( gpu_srcs = [ "concat_lib_gpu_impl.cu.cc", "concat_lib.h", + "concat_lib_gpu.h", "cuda_device_array.h", "cuda_device_array_gpu.h", ], @@ -607,6 +608,7 @@ tf_kernel_library( gpu_srcs = [ "split_lib_gpu.cu.cc", "split_lib.h", + "split_lib_gpu.h", ], deps = [ ":cuda_device_array", @@ -618,9 +620,7 @@ tf_kernel_library( cc_library( name = "split_lib_hdrs", - hdrs = [ - "split_lib.h", - ], + hdrs = ["split_lib.h"], deps = [ "//tensorflow/core:framework_lite", "//third_party/eigen3", diff --git a/tensorflow/core/kernels/concat_lib.h b/tensorflow/core/kernels/concat_lib.h index 8b53ecf1216..dab4932b0e2 100644 --- a/tensorflow/core/kernels/concat_lib.h +++ b/tensorflow/core/kernels/concat_lib.h @@ -54,6 +54,24 @@ void ConcatGPU( inputs_flat, Tensor* output, typename TTypes::Tensor* output_flat); +// Explicit instantiations in concat_lib_gpu.cc. +#define REGISTER(T) \ + extern template void ConcatGPU( \ + OpKernelContext * c, \ + const std::vector::ConstMatrix>>& \ + inputs_flat, \ + Tensor* output, typename TTypes::Tensor* output_flat); + +TF_CALL_GPU_NUMBER_TYPES(REGISTER); +TF_CALL_complex64(REGISTER); +TF_CALL_complex128(REGISTER); +TF_CALL_int32(REGISTER); // Needed for TensorLists. +TF_CALL_int64(REGISTER); +TF_CALL_int16(REGISTER); +TF_CALL_bfloat16(REGISTER); +TF_CALL_bool(REGISTER); +TF_CALL_uint8(REGISTER); +#undef REGISTER #endif // GOOGLE_CUDA #ifdef TENSORFLOW_USE_SYCL diff --git a/tensorflow/core/kernels/concat_lib_gpu.cc b/tensorflow/core/kernels/concat_lib_gpu.cc index 853d7c3133d..f4f16a6bb17 100644 --- a/tensorflow/core/kernels/concat_lib_gpu.cc +++ b/tensorflow/core/kernels/concat_lib_gpu.cc @@ -26,24 +26,10 @@ limitations under the License. #if GOOGLE_CUDA +#include "tensorflow/core/kernels/concat_lib_gpu.h" #include "tensorflow/core/kernels/cuda_device_array.h" namespace tensorflow { - -template -void ConcatGPUSlice( - const Eigen::GpuDevice& gpu_device, - const std::vector::ConstMatrix>>& - inputs_flat, - typename TTypes::Matrix* output); - -template -void ConcatGPUImpl(const Eigen::GpuDevice& d, - const CudaDeviceArrayStruct& input_ptrs, - const CudaDeviceArrayStruct& ptr_offsets, - bool same_size, int slice_size, - typename TTypes::Matrix* output); - namespace { template diff --git a/tensorflow/core/kernels/concat_lib_gpu.h b/tensorflow/core/kernels/concat_lib_gpu.h new file mode 100644 index 00000000000..f8898e6537b --- /dev/null +++ b/tensorflow/core/kernels/concat_lib_gpu.h @@ -0,0 +1,82 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_CONCAT_LIB_GPU_H_ +#define TENSORFLOW_CORE_KERNELS_CONCAT_LIB_GPU_H_ + +#define EIGEN_USE_THREADS +#define EIGEN_USE_GPU + +#include +#include + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/kernels/concat_lib.h" +#include "tensorflow/core/kernels/cuda_device_array_gpu.h" + +namespace tensorflow { + +template +void ConcatGPUSlice( + const Eigen::GpuDevice& gpu_device, + const std::vector::ConstMatrix>>& + inputs_flat, + typename TTypes::Matrix* output); + +template +void ConcatGPUImpl(const Eigen::GpuDevice& d, + const CudaDeviceArrayStruct& input_ptrs, + const CudaDeviceArrayStruct& ptr_offsets, + bool same_size, int slice_size, + typename TTypes::Matrix* output); + +// Explicit instantiations in concat_lib_gpu_impl.cu.cc. +#define REGISTER(T) \ + extern template void ConcatGPUSlice( \ + const Eigen::GpuDevice& gpu_device, \ + const std::vector::ConstMatrix>>& \ + inputs_flat, \ + typename TTypes::Matrix* output); \ + extern template void ConcatGPUSlice( \ + const Eigen::GpuDevice& gpu_device, \ + const std::vector::ConstMatrix>>& \ + inputs_flat, \ + typename TTypes::Matrix* output); \ + extern template void ConcatGPUImpl( \ + const Eigen::GpuDevice& d, \ + const CudaDeviceArrayStruct& input_ptrs, \ + const CudaDeviceArrayStruct& ptr_offsets, bool fixed_size, \ + int split_size, typename TTypes::Matrix* output); \ + extern template void ConcatGPUImpl( \ + const Eigen::GpuDevice& d, \ + const CudaDeviceArrayStruct& input_ptrs, \ + const CudaDeviceArrayStruct& ptr_offsets, bool fixed_size, \ + int split_size, typename TTypes::Matrix* output); + +TF_CALL_GPU_NUMBER_TYPES(REGISTER); +TF_CALL_complex64(REGISTER); +TF_CALL_complex128(REGISTER); +TF_CALL_int32(REGISTER); // Needed for TensorLists. +TF_CALL_int64(REGISTER); +TF_CALL_int16(REGISTER); +TF_CALL_bfloat16(REGISTER); +TF_CALL_bool(REGISTER); +TF_CALL_uint8(REGISTER); +#undef REGISTER + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_CONCAT_LIB_GPU_H_ diff --git a/tensorflow/core/kernels/concat_lib_gpu_impl.cu.cc b/tensorflow/core/kernels/concat_lib_gpu_impl.cu.cc index ae828b5bf48..8727fb736ab 100644 --- a/tensorflow/core/kernels/concat_lib_gpu_impl.cu.cc +++ b/tensorflow/core/kernels/concat_lib_gpu_impl.cu.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/framework/bfloat16.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/concat_lib_gpu.h" #include "tensorflow/core/kernels/cuda_device_array_gpu.h" #include "tensorflow/core/util/cuda_kernel_helper.h" diff --git a/tensorflow/core/kernels/split_lib_gpu.cu.cc b/tensorflow/core/kernels/split_lib_gpu.cu.cc index 3d42f2dc70b..73eeb417f33 100644 --- a/tensorflow/core/kernels/split_lib_gpu.cu.cc +++ b/tensorflow/core/kernels/split_lib_gpu.cu.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/kernels/cuda_device_array_gpu.h" +#include "tensorflow/core/kernels/split_lib_gpu.h" #include "tensorflow/core/util/cuda_kernel_helper.h" namespace tensorflow { @@ -192,54 +193,52 @@ __global__ void SplitVOpKernel_fixed( } template -struct SplitOpGPULaunch { - void Run(const Eigen::GpuDevice& d, const T* input, int32 prefix_dim_size, - int32 split_dim_size, int32 suffix_dim_size, - const CudaDeviceArrayStruct& output_ptr_data) { - CudaLaunchConfig config = GetCudaLaunchConfig( - prefix_dim_size * split_dim_size * suffix_dim_size, d); +void SplitOpGPULaunch::Run( + const Eigen::GpuDevice& d, const T* input, int32 prefix_dim_size, + int32 split_dim_size, int32 suffix_dim_size, + const CudaDeviceArrayStruct& output_ptr_data) { + CudaLaunchConfig config = GetCudaLaunchConfig( + prefix_dim_size * split_dim_size * suffix_dim_size, d); - TF_CHECK_OK(CudaLaunchKernel(SplitOpKernel, config.block_count, - config.thread_per_block, 0, d.stream(), input, - prefix_dim_size, split_dim_size, - suffix_dim_size, output_ptr_data)); - } -}; + TF_CHECK_OK(CudaLaunchKernel(SplitOpKernel, config.block_count, + config.thread_per_block, 0, d.stream(), input, + prefix_dim_size, split_dim_size, suffix_dim_size, + output_ptr_data)); +} template -struct SplitVOpGPULaunch { - void Run(const Eigen::GpuDevice& gpu_device, bool fixed_size, - const T* input_ptr, int total_rows, int total_cols, - const CudaDeviceArrayStruct& output_scan, - const CudaDeviceArrayStruct& output_ptr_data) { - if (fixed_size) { - CudaLaunchConfig config = - GetCudaLaunchConfig(total_rows * total_cols, gpu_device); +void SplitVOpGPULaunch::Run( + const Eigen::GpuDevice& gpu_device, bool fixed_size, const T* input_ptr, + int total_rows, int total_cols, + const CudaDeviceArrayStruct& output_scan, + const CudaDeviceArrayStruct& output_ptr_data) { + if (fixed_size) { + CudaLaunchConfig config = + GetCudaLaunchConfig(total_rows * total_cols, gpu_device); - SplitVOpKernel_fixed<<>>( - input_ptr, total_rows, total_cols, output_ptr_data); - } else { - auto config = GetCuda2DLaunchConfig(total_cols, total_rows, gpu_device); - IntType smem_max = gpu_device.sharedMemPerBlock(); - IntType smem_usage = output_scan.size * sizeof(IntType); - // performance crossover is less than using maximum available shared - // memory on most processors possibly due to decreasing occupancy - // 4096 inputs is a lot, most code will take the smem path - const int32 kMaxSmemBytesPerformance = 16384; - if (smem_usage < smem_max && smem_usage < kMaxSmemBytesPerformance) - split_v_kernel - <<>>(input_ptr, output_scan, total_rows, - total_cols, output_ptr_data); - else - split_v_kernel - <<>>(input_ptr, output_scan, total_rows, - total_cols, output_ptr_data); - } + SplitVOpKernel_fixed<<>>( + input_ptr, total_rows, total_cols, output_ptr_data); + } else { + auto config = GetCuda2DLaunchConfig(total_cols, total_rows, gpu_device); + IntType smem_max = gpu_device.sharedMemPerBlock(); + IntType smem_usage = output_scan.size * sizeof(IntType); + // performance crossover is less than using maximum available shared + // memory on most processors possibly due to decreasing occupancy + // 4096 inputs is a lot, most code will take the smem path + const int32 kMaxSmemBytesPerformance = 16384; + if (smem_usage < smem_max && smem_usage < kMaxSmemBytesPerformance) + split_v_kernel + <<>>(input_ptr, output_scan, total_rows, + total_cols, output_ptr_data); + else + split_v_kernel + <<>>(input_ptr, output_scan, total_rows, + total_cols, output_ptr_data); } -}; +} #define REGISTER_GPU_KERNEL(T) template struct SplitOpGPULaunch; diff --git a/tensorflow/core/kernels/split_lib_gpu.h b/tensorflow/core/kernels/split_lib_gpu.h new file mode 100644 index 00000000000..85d2da912ee --- /dev/null +++ b/tensorflow/core/kernels/split_lib_gpu.h @@ -0,0 +1,61 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_SPLIT_LIB_GPU_H_ +#define TENSORFLOW_CORE_KERNELS_SPLIT_LIB_GPU_H_ + +#define EIGEN_USE_THREADS +#define EIGEN_USE_GPU + +#include +#include + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/kernels/cuda_device_array_gpu.h" +#include "tensorflow/core/kernels/split_lib.h" + +namespace tensorflow { + +template +struct SplitOpGPULaunch { + void Run(const Eigen::GpuDevice& d, const T* input, int32 prefix_dim_size, + int32 split_dim_size, int32 suffix_dim_size, + const CudaDeviceArrayStruct& output_ptr_data); +}; + +template +struct SplitVOpGPULaunch { + void Run(const Eigen::GpuDevice& d, bool fixed, const T* input, + int total_cols, int total_rows, + const CudaDeviceArrayStruct& output_scan, + const CudaDeviceArrayStruct& output_ptr_data); +}; + +// Explicit instantiations in split_lib_gpu.cu.cc. +#define REGISTER_GPU_KERNEL(T) \ + extern template struct SplitOpGPULaunch; \ + extern template struct SplitVOpGPULaunch; \ + extern template struct SplitVOpGPULaunch; + +TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL); +TF_CALL_complex64(REGISTER_GPU_KERNEL); +TF_CALL_complex128(REGISTER_GPU_KERNEL); +TF_CALL_bfloat16(REGISTER_GPU_KERNEL); +#undef REGISTER_GPU_KERNEL + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_SPLIT_LIB_GPU_H_ diff --git a/tensorflow/core/kernels/split_op.cc b/tensorflow/core/kernels/split_op.cc index ed3429ff5cb..d69ce3f5853 100644 --- a/tensorflow/core/kernels/split_op.cc +++ b/tensorflow/core/kernels/split_op.cc @@ -30,6 +30,7 @@ limitations under the License. #if GOOGLE_CUDA #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" #include "tensorflow/core/kernels/cuda_device_array.h" +#include "tensorflow/core/kernels/split_lib_gpu.h" #include "tensorflow/core/platform/stream_executor.h" #endif // GOOGLE_CUDA @@ -267,13 +268,6 @@ class SplitOpCPU : public SplitOpBase { #if GOOGLE_CUDA -template -struct SplitOpGPULaunch { - void Run(const Eigen::GpuDevice& d, const T* input, int32 prefix_dim_size, - int32 split_dim_size, int32 suffix_dim_size, - const CudaDeviceArrayStruct& output_ptr_data); -}; - // Partial specialization for GPU template class SplitOpGPU : public SplitOpBase { diff --git a/tensorflow/core/kernels/split_v_op.cc b/tensorflow/core/kernels/split_v_op.cc index 0324ce9babc..96b240b1ac7 100644 --- a/tensorflow/core/kernels/split_v_op.cc +++ b/tensorflow/core/kernels/split_v_op.cc @@ -36,6 +36,7 @@ limitations under the License. #if GOOGLE_CUDA #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" #include "tensorflow/core/kernels/cuda_device_array.h" +#include "tensorflow/core/kernels/split_lib_gpu.h" #include "tensorflow/core/platform/stream_executor.h" #endif // GOOGLE_CUDA @@ -329,14 +330,6 @@ class SplitVOpCPU : public SplitVOpBase { #if GOOGLE_CUDA -template -struct SplitVOpGPULaunch { - void Run(const Eigen::GpuDevice& d, bool fixed, const T* input, - int total_cols, int total_rows, - const CudaDeviceArrayStruct& output_scan, - const CudaDeviceArrayStruct& output_ptr_data); -}; - // Partial specialization for GPU template class SplitVOpGPU : public SplitVOpBase {