Add missing declarations for explicit instantiations in concat_lib and split_lib, and add new headers concat_lib_gpu.h and split_lib_gpu.h to contain them (and the declaration of the primary templates).

The current behaviour (using externally defined instantiations without having seen a declaration of those external instantiations) is undesirable and effectively deprecated, and is warned about by -Wundefined-func-template.

PiperOrigin-RevId: 237158146
This commit is contained in:
A. Unique TensorFlower 2019-03-06 18:08:55 -08:00 committed by TensorFlower Gardener
parent 58e052bd77
commit 87cd62e4d1
9 changed files with 210 additions and 76 deletions

View File

@ -230,6 +230,7 @@ tf_kernel_library(
gpu_srcs = [
"concat_lib_gpu_impl.cu.cc",
"concat_lib.h",
"concat_lib_gpu.h",
"cuda_device_array.h",
"cuda_device_array_gpu.h",
],
@ -607,6 +608,7 @@ tf_kernel_library(
gpu_srcs = [
"split_lib_gpu.cu.cc",
"split_lib.h",
"split_lib_gpu.h",
],
deps = [
":cuda_device_array",
@ -618,9 +620,7 @@ tf_kernel_library(
cc_library(
name = "split_lib_hdrs",
hdrs = [
"split_lib.h",
],
hdrs = ["split_lib.h"],
deps = [
"//tensorflow/core:framework_lite",
"//third_party/eigen3",

View File

@ -54,6 +54,24 @@ void ConcatGPU(
inputs_flat,
Tensor* output, typename TTypes<T, 2>::Tensor* output_flat);
// Explicit instantiations in concat_lib_gpu.cc.
#define REGISTER(T) \
extern template void ConcatGPU<T>( \
OpKernelContext * c, \
const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>& \
inputs_flat, \
Tensor* output, typename TTypes<T, 2>::Tensor* output_flat);
TF_CALL_GPU_NUMBER_TYPES(REGISTER);
TF_CALL_complex64(REGISTER);
TF_CALL_complex128(REGISTER);
TF_CALL_int32(REGISTER); // Needed for TensorLists.
TF_CALL_int64(REGISTER);
TF_CALL_int16(REGISTER);
TF_CALL_bfloat16(REGISTER);
TF_CALL_bool(REGISTER);
TF_CALL_uint8(REGISTER);
#undef REGISTER
#endif // GOOGLE_CUDA
#ifdef TENSORFLOW_USE_SYCL

View File

@ -26,24 +26,10 @@ limitations under the License.
#if GOOGLE_CUDA
#include "tensorflow/core/kernels/concat_lib_gpu.h"
#include "tensorflow/core/kernels/cuda_device_array.h"
namespace tensorflow {
template <typename T, typename IntType>
void ConcatGPUSlice(
const Eigen::GpuDevice& gpu_device,
const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>&
inputs_flat,
typename TTypes<T, 2>::Matrix* output);
template <typename T, typename IntType>
void ConcatGPUImpl(const Eigen::GpuDevice& d,
const CudaDeviceArrayStruct<const T*>& input_ptrs,
const CudaDeviceArrayStruct<IntType>& ptr_offsets,
bool same_size, int slice_size,
typename TTypes<T, 2>::Matrix* output);
namespace {
template <typename T, typename IntType>

View File

@ -0,0 +1,82 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_CONCAT_LIB_GPU_H_
#define TENSORFLOW_CORE_KERNELS_CONCAT_LIB_GPU_H_
#define EIGEN_USE_THREADS
#define EIGEN_USE_GPU
#include <memory>
#include <vector>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/kernels/concat_lib.h"
#include "tensorflow/core/kernels/cuda_device_array_gpu.h"
namespace tensorflow {
template <typename T, typename IntType>
void ConcatGPUSlice(
const Eigen::GpuDevice& gpu_device,
const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>&
inputs_flat,
typename TTypes<T, 2>::Matrix* output);
template <typename T, typename IntType>
void ConcatGPUImpl(const Eigen::GpuDevice& d,
const CudaDeviceArrayStruct<const T*>& input_ptrs,
const CudaDeviceArrayStruct<IntType>& ptr_offsets,
bool same_size, int slice_size,
typename TTypes<T, 2>::Matrix* output);
// Explicit instantiations in concat_lib_gpu_impl.cu.cc.
#define REGISTER(T) \
extern template void ConcatGPUSlice<T, int32>( \
const Eigen::GpuDevice& gpu_device, \
const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>& \
inputs_flat, \
typename TTypes<T, 2>::Matrix* output); \
extern template void ConcatGPUSlice<T, int64>( \
const Eigen::GpuDevice& gpu_device, \
const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>& \
inputs_flat, \
typename TTypes<T, 2>::Matrix* output); \
extern template void ConcatGPUImpl<T, int32>( \
const Eigen::GpuDevice& d, \
const CudaDeviceArrayStruct<const T*>& input_ptrs, \
const CudaDeviceArrayStruct<int32>& ptr_offsets, bool fixed_size, \
int split_size, typename TTypes<T, 2>::Matrix* output); \
extern template void ConcatGPUImpl<T, int64>( \
const Eigen::GpuDevice& d, \
const CudaDeviceArrayStruct<const T*>& input_ptrs, \
const CudaDeviceArrayStruct<int64>& ptr_offsets, bool fixed_size, \
int split_size, typename TTypes<T, 2>::Matrix* output);
TF_CALL_GPU_NUMBER_TYPES(REGISTER);
TF_CALL_complex64(REGISTER);
TF_CALL_complex128(REGISTER);
TF_CALL_int32(REGISTER); // Needed for TensorLists.
TF_CALL_int64(REGISTER);
TF_CALL_int16(REGISTER);
TF_CALL_bfloat16(REGISTER);
TF_CALL_bool(REGISTER);
TF_CALL_uint8(REGISTER);
#undef REGISTER
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_CONCAT_LIB_GPU_H_

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/framework/bfloat16.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/concat_lib_gpu.h"
#include "tensorflow/core/kernels/cuda_device_array_gpu.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/cuda_device_array_gpu.h"
#include "tensorflow/core/kernels/split_lib_gpu.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"
namespace tensorflow {
@ -192,54 +193,52 @@ __global__ void SplitVOpKernel_fixed(
}
template <typename T>
struct SplitOpGPULaunch {
void Run(const Eigen::GpuDevice& d, const T* input, int32 prefix_dim_size,
int32 split_dim_size, int32 suffix_dim_size,
const CudaDeviceArrayStruct<T*>& output_ptr_data) {
CudaLaunchConfig config = GetCudaLaunchConfig(
prefix_dim_size * split_dim_size * suffix_dim_size, d);
void SplitOpGPULaunch<T>::Run(
const Eigen::GpuDevice& d, const T* input, int32 prefix_dim_size,
int32 split_dim_size, int32 suffix_dim_size,
const CudaDeviceArrayStruct<T*>& output_ptr_data) {
CudaLaunchConfig config = GetCudaLaunchConfig(
prefix_dim_size * split_dim_size * suffix_dim_size, d);
TF_CHECK_OK(CudaLaunchKernel(SplitOpKernel<T>, config.block_count,
config.thread_per_block, 0, d.stream(), input,
prefix_dim_size, split_dim_size,
suffix_dim_size, output_ptr_data));
}
};
TF_CHECK_OK(CudaLaunchKernel(SplitOpKernel<T>, config.block_count,
config.thread_per_block, 0, d.stream(), input,
prefix_dim_size, split_dim_size, suffix_dim_size,
output_ptr_data));
}
template <typename T, typename IntType>
struct SplitVOpGPULaunch {
void Run(const Eigen::GpuDevice& gpu_device, bool fixed_size,
const T* input_ptr, int total_rows, int total_cols,
const CudaDeviceArrayStruct<IntType>& output_scan,
const CudaDeviceArrayStruct<T*>& output_ptr_data) {
if (fixed_size) {
CudaLaunchConfig config =
GetCudaLaunchConfig(total_rows * total_cols, gpu_device);
void SplitVOpGPULaunch<T, IntType>::Run(
const Eigen::GpuDevice& gpu_device, bool fixed_size, const T* input_ptr,
int total_rows, int total_cols,
const CudaDeviceArrayStruct<IntType>& output_scan,
const CudaDeviceArrayStruct<T*>& output_ptr_data) {
if (fixed_size) {
CudaLaunchConfig config =
GetCudaLaunchConfig(total_rows * total_cols, gpu_device);
SplitVOpKernel_fixed<T><<<config.block_count, config.thread_per_block, 0,
gpu_device.stream()>>>(
input_ptr, total_rows, total_cols, output_ptr_data);
} else {
auto config = GetCuda2DLaunchConfig(total_cols, total_rows, gpu_device);
IntType smem_max = gpu_device.sharedMemPerBlock();
IntType smem_usage = output_scan.size * sizeof(IntType);
// performance crossover is less than using maximum available shared
// memory on most processors possibly due to decreasing occupancy
// 4096 inputs is a lot, most code will take the smem path
const int32 kMaxSmemBytesPerformance = 16384;
if (smem_usage < smem_max && smem_usage < kMaxSmemBytesPerformance)
split_v_kernel<T, IntType, true>
<<<config.block_count, config.thread_per_block, smem_usage,
gpu_device.stream()>>>(input_ptr, output_scan, total_rows,
total_cols, output_ptr_data);
else
split_v_kernel<T, IntType, false>
<<<config.block_count, config.thread_per_block, 0,
gpu_device.stream()>>>(input_ptr, output_scan, total_rows,
total_cols, output_ptr_data);
}
SplitVOpKernel_fixed<T><<<config.block_count, config.thread_per_block, 0,
gpu_device.stream()>>>(
input_ptr, total_rows, total_cols, output_ptr_data);
} else {
auto config = GetCuda2DLaunchConfig(total_cols, total_rows, gpu_device);
IntType smem_max = gpu_device.sharedMemPerBlock();
IntType smem_usage = output_scan.size * sizeof(IntType);
// performance crossover is less than using maximum available shared
// memory on most processors possibly due to decreasing occupancy
// 4096 inputs is a lot, most code will take the smem path
const int32 kMaxSmemBytesPerformance = 16384;
if (smem_usage < smem_max && smem_usage < kMaxSmemBytesPerformance)
split_v_kernel<T, IntType, true>
<<<config.block_count, config.thread_per_block, smem_usage,
gpu_device.stream()>>>(input_ptr, output_scan, total_rows,
total_cols, output_ptr_data);
else
split_v_kernel<T, IntType, false>
<<<config.block_count, config.thread_per_block, 0,
gpu_device.stream()>>>(input_ptr, output_scan, total_rows,
total_cols, output_ptr_data);
}
};
}
#define REGISTER_GPU_KERNEL(T) template struct SplitOpGPULaunch<T>;

View File

@ -0,0 +1,61 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_SPLIT_LIB_GPU_H_
#define TENSORFLOW_CORE_KERNELS_SPLIT_LIB_GPU_H_
#define EIGEN_USE_THREADS
#define EIGEN_USE_GPU
#include <memory>
#include <vector>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/kernels/cuda_device_array_gpu.h"
#include "tensorflow/core/kernels/split_lib.h"
namespace tensorflow {
template <typename T>
struct SplitOpGPULaunch {
void Run(const Eigen::GpuDevice& d, const T* input, int32 prefix_dim_size,
int32 split_dim_size, int32 suffix_dim_size,
const CudaDeviceArrayStruct<T*>& output_ptr_data);
};
template <typename T, typename IntType>
struct SplitVOpGPULaunch {
void Run(const Eigen::GpuDevice& d, bool fixed, const T* input,
int total_cols, int total_rows,
const CudaDeviceArrayStruct<IntType>& output_scan,
const CudaDeviceArrayStruct<T*>& output_ptr_data);
};
// Explicit instantiations in split_lib_gpu.cu.cc.
#define REGISTER_GPU_KERNEL(T) \
extern template struct SplitOpGPULaunch<T>; \
extern template struct SplitVOpGPULaunch<T, int32>; \
extern template struct SplitVOpGPULaunch<T, int64>;
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL);
TF_CALL_complex64(REGISTER_GPU_KERNEL);
TF_CALL_complex128(REGISTER_GPU_KERNEL);
TF_CALL_bfloat16(REGISTER_GPU_KERNEL);
#undef REGISTER_GPU_KERNEL
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_SPLIT_LIB_GPU_H_

View File

@ -30,6 +30,7 @@ limitations under the License.
#if GOOGLE_CUDA
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
#include "tensorflow/core/kernels/cuda_device_array.h"
#include "tensorflow/core/kernels/split_lib_gpu.h"
#include "tensorflow/core/platform/stream_executor.h"
#endif // GOOGLE_CUDA
@ -267,13 +268,6 @@ class SplitOpCPU : public SplitOpBase<CPUDevice, T> {
#if GOOGLE_CUDA
template <typename T>
struct SplitOpGPULaunch {
void Run(const Eigen::GpuDevice& d, const T* input, int32 prefix_dim_size,
int32 split_dim_size, int32 suffix_dim_size,
const CudaDeviceArrayStruct<T*>& output_ptr_data);
};
// Partial specialization for GPU
template <typename T>
class SplitOpGPU : public SplitOpBase<GPUDevice, T> {

View File

@ -36,6 +36,7 @@ limitations under the License.
#if GOOGLE_CUDA
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
#include "tensorflow/core/kernels/cuda_device_array.h"
#include "tensorflow/core/kernels/split_lib_gpu.h"
#include "tensorflow/core/platform/stream_executor.h"
#endif // GOOGLE_CUDA
@ -329,14 +330,6 @@ class SplitVOpCPU : public SplitVOpBase<CPUDevice, T, Tlen> {
#if GOOGLE_CUDA
template <typename T, typename IntType>
struct SplitVOpGPULaunch {
void Run(const Eigen::GpuDevice& d, bool fixed, const T* input,
int total_cols, int total_rows,
const CudaDeviceArrayStruct<IntType>& output_scan,
const CudaDeviceArrayStruct<T*>& output_ptr_data);
};
// Partial specialization for GPU
template <typename T, typename Tlen>
class SplitVOpGPU : public SplitVOpBase<GPUDevice, T, Tlen> {