Add missing declarations for explicit instantiations in concat_lib and split_lib, and add new headers concat_lib_gpu.h and split_lib_gpu.h to contain them (and the declaration of the primary templates).
The current behaviour (using externally defined instantiations without having seen a declaration of those external instantiations) is undesirable and effectively deprecated, and is warned about by -Wundefined-func-template. PiperOrigin-RevId: 237158146
This commit is contained in:
parent
58e052bd77
commit
87cd62e4d1
@ -230,6 +230,7 @@ tf_kernel_library(
|
|||||||
gpu_srcs = [
|
gpu_srcs = [
|
||||||
"concat_lib_gpu_impl.cu.cc",
|
"concat_lib_gpu_impl.cu.cc",
|
||||||
"concat_lib.h",
|
"concat_lib.h",
|
||||||
|
"concat_lib_gpu.h",
|
||||||
"cuda_device_array.h",
|
"cuda_device_array.h",
|
||||||
"cuda_device_array_gpu.h",
|
"cuda_device_array_gpu.h",
|
||||||
],
|
],
|
||||||
@ -607,6 +608,7 @@ tf_kernel_library(
|
|||||||
gpu_srcs = [
|
gpu_srcs = [
|
||||||
"split_lib_gpu.cu.cc",
|
"split_lib_gpu.cu.cc",
|
||||||
"split_lib.h",
|
"split_lib.h",
|
||||||
|
"split_lib_gpu.h",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":cuda_device_array",
|
":cuda_device_array",
|
||||||
@ -618,9 +620,7 @@ tf_kernel_library(
|
|||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "split_lib_hdrs",
|
name = "split_lib_hdrs",
|
||||||
hdrs = [
|
hdrs = ["split_lib.h"],
|
||||||
"split_lib.h",
|
|
||||||
],
|
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/core:framework_lite",
|
"//tensorflow/core:framework_lite",
|
||||||
"//third_party/eigen3",
|
"//third_party/eigen3",
|
||||||
|
@ -54,6 +54,24 @@ void ConcatGPU(
|
|||||||
inputs_flat,
|
inputs_flat,
|
||||||
Tensor* output, typename TTypes<T, 2>::Tensor* output_flat);
|
Tensor* output, typename TTypes<T, 2>::Tensor* output_flat);
|
||||||
|
|
||||||
|
// Explicit instantiations in concat_lib_gpu.cc.
|
||||||
|
#define REGISTER(T) \
|
||||||
|
extern template void ConcatGPU<T>( \
|
||||||
|
OpKernelContext * c, \
|
||||||
|
const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>& \
|
||||||
|
inputs_flat, \
|
||||||
|
Tensor* output, typename TTypes<T, 2>::Tensor* output_flat);
|
||||||
|
|
||||||
|
TF_CALL_GPU_NUMBER_TYPES(REGISTER);
|
||||||
|
TF_CALL_complex64(REGISTER);
|
||||||
|
TF_CALL_complex128(REGISTER);
|
||||||
|
TF_CALL_int32(REGISTER); // Needed for TensorLists.
|
||||||
|
TF_CALL_int64(REGISTER);
|
||||||
|
TF_CALL_int16(REGISTER);
|
||||||
|
TF_CALL_bfloat16(REGISTER);
|
||||||
|
TF_CALL_bool(REGISTER);
|
||||||
|
TF_CALL_uint8(REGISTER);
|
||||||
|
#undef REGISTER
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
#ifdef TENSORFLOW_USE_SYCL
|
#ifdef TENSORFLOW_USE_SYCL
|
||||||
|
@ -26,24 +26,10 @@ limitations under the License.
|
|||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
|
|
||||||
|
#include "tensorflow/core/kernels/concat_lib_gpu.h"
|
||||||
#include "tensorflow/core/kernels/cuda_device_array.h"
|
#include "tensorflow/core/kernels/cuda_device_array.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
template <typename T, typename IntType>
|
|
||||||
void ConcatGPUSlice(
|
|
||||||
const Eigen::GpuDevice& gpu_device,
|
|
||||||
const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>&
|
|
||||||
inputs_flat,
|
|
||||||
typename TTypes<T, 2>::Matrix* output);
|
|
||||||
|
|
||||||
template <typename T, typename IntType>
|
|
||||||
void ConcatGPUImpl(const Eigen::GpuDevice& d,
|
|
||||||
const CudaDeviceArrayStruct<const T*>& input_ptrs,
|
|
||||||
const CudaDeviceArrayStruct<IntType>& ptr_offsets,
|
|
||||||
bool same_size, int slice_size,
|
|
||||||
typename TTypes<T, 2>::Matrix* output);
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
template <typename T, typename IntType>
|
template <typename T, typename IntType>
|
||||||
|
82
tensorflow/core/kernels/concat_lib_gpu.h
Normal file
82
tensorflow/core/kernels/concat_lib_gpu.h
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_CORE_KERNELS_CONCAT_LIB_GPU_H_
|
||||||
|
#define TENSORFLOW_CORE_KERNELS_CONCAT_LIB_GPU_H_
|
||||||
|
|
||||||
|
#define EIGEN_USE_THREADS
|
||||||
|
#define EIGEN_USE_GPU
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
|
#include "tensorflow/core/kernels/concat_lib.h"
|
||||||
|
#include "tensorflow/core/kernels/cuda_device_array_gpu.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
template <typename T, typename IntType>
|
||||||
|
void ConcatGPUSlice(
|
||||||
|
const Eigen::GpuDevice& gpu_device,
|
||||||
|
const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>&
|
||||||
|
inputs_flat,
|
||||||
|
typename TTypes<T, 2>::Matrix* output);
|
||||||
|
|
||||||
|
template <typename T, typename IntType>
|
||||||
|
void ConcatGPUImpl(const Eigen::GpuDevice& d,
|
||||||
|
const CudaDeviceArrayStruct<const T*>& input_ptrs,
|
||||||
|
const CudaDeviceArrayStruct<IntType>& ptr_offsets,
|
||||||
|
bool same_size, int slice_size,
|
||||||
|
typename TTypes<T, 2>::Matrix* output);
|
||||||
|
|
||||||
|
// Explicit instantiations in concat_lib_gpu_impl.cu.cc.
|
||||||
|
#define REGISTER(T) \
|
||||||
|
extern template void ConcatGPUSlice<T, int32>( \
|
||||||
|
const Eigen::GpuDevice& gpu_device, \
|
||||||
|
const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>& \
|
||||||
|
inputs_flat, \
|
||||||
|
typename TTypes<T, 2>::Matrix* output); \
|
||||||
|
extern template void ConcatGPUSlice<T, int64>( \
|
||||||
|
const Eigen::GpuDevice& gpu_device, \
|
||||||
|
const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>& \
|
||||||
|
inputs_flat, \
|
||||||
|
typename TTypes<T, 2>::Matrix* output); \
|
||||||
|
extern template void ConcatGPUImpl<T, int32>( \
|
||||||
|
const Eigen::GpuDevice& d, \
|
||||||
|
const CudaDeviceArrayStruct<const T*>& input_ptrs, \
|
||||||
|
const CudaDeviceArrayStruct<int32>& ptr_offsets, bool fixed_size, \
|
||||||
|
int split_size, typename TTypes<T, 2>::Matrix* output); \
|
||||||
|
extern template void ConcatGPUImpl<T, int64>( \
|
||||||
|
const Eigen::GpuDevice& d, \
|
||||||
|
const CudaDeviceArrayStruct<const T*>& input_ptrs, \
|
||||||
|
const CudaDeviceArrayStruct<int64>& ptr_offsets, bool fixed_size, \
|
||||||
|
int split_size, typename TTypes<T, 2>::Matrix* output);
|
||||||
|
|
||||||
|
TF_CALL_GPU_NUMBER_TYPES(REGISTER);
|
||||||
|
TF_CALL_complex64(REGISTER);
|
||||||
|
TF_CALL_complex128(REGISTER);
|
||||||
|
TF_CALL_int32(REGISTER); // Needed for TensorLists.
|
||||||
|
TF_CALL_int64(REGISTER);
|
||||||
|
TF_CALL_int16(REGISTER);
|
||||||
|
TF_CALL_bfloat16(REGISTER);
|
||||||
|
TF_CALL_bool(REGISTER);
|
||||||
|
TF_CALL_uint8(REGISTER);
|
||||||
|
#undef REGISTER
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CORE_KERNELS_CONCAT_LIB_GPU_H_
|
@ -23,6 +23,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/bfloat16.h"
|
#include "tensorflow/core/framework/bfloat16.h"
|
||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
#include "tensorflow/core/framework/tensor_types.h"
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
|
#include "tensorflow/core/kernels/concat_lib_gpu.h"
|
||||||
#include "tensorflow/core/kernels/cuda_device_array_gpu.h"
|
#include "tensorflow/core/kernels/cuda_device_array_gpu.h"
|
||||||
#include "tensorflow/core/util/cuda_kernel_helper.h"
|
#include "tensorflow/core/util/cuda_kernel_helper.h"
|
||||||
|
|
||||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
#include "tensorflow/core/framework/tensor_types.h"
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
#include "tensorflow/core/kernels/cuda_device_array_gpu.h"
|
#include "tensorflow/core/kernels/cuda_device_array_gpu.h"
|
||||||
|
#include "tensorflow/core/kernels/split_lib_gpu.h"
|
||||||
#include "tensorflow/core/util/cuda_kernel_helper.h"
|
#include "tensorflow/core/util/cuda_kernel_helper.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -192,54 +193,52 @@ __global__ void SplitVOpKernel_fixed(
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct SplitOpGPULaunch {
|
void SplitOpGPULaunch<T>::Run(
|
||||||
void Run(const Eigen::GpuDevice& d, const T* input, int32 prefix_dim_size,
|
const Eigen::GpuDevice& d, const T* input, int32 prefix_dim_size,
|
||||||
int32 split_dim_size, int32 suffix_dim_size,
|
int32 split_dim_size, int32 suffix_dim_size,
|
||||||
const CudaDeviceArrayStruct<T*>& output_ptr_data) {
|
const CudaDeviceArrayStruct<T*>& output_ptr_data) {
|
||||||
CudaLaunchConfig config = GetCudaLaunchConfig(
|
CudaLaunchConfig config = GetCudaLaunchConfig(
|
||||||
prefix_dim_size * split_dim_size * suffix_dim_size, d);
|
prefix_dim_size * split_dim_size * suffix_dim_size, d);
|
||||||
|
|
||||||
TF_CHECK_OK(CudaLaunchKernel(SplitOpKernel<T>, config.block_count,
|
TF_CHECK_OK(CudaLaunchKernel(SplitOpKernel<T>, config.block_count,
|
||||||
config.thread_per_block, 0, d.stream(), input,
|
config.thread_per_block, 0, d.stream(), input,
|
||||||
prefix_dim_size, split_dim_size,
|
prefix_dim_size, split_dim_size, suffix_dim_size,
|
||||||
suffix_dim_size, output_ptr_data));
|
output_ptr_data));
|
||||||
}
|
}
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T, typename IntType>
|
template <typename T, typename IntType>
|
||||||
struct SplitVOpGPULaunch {
|
void SplitVOpGPULaunch<T, IntType>::Run(
|
||||||
void Run(const Eigen::GpuDevice& gpu_device, bool fixed_size,
|
const Eigen::GpuDevice& gpu_device, bool fixed_size, const T* input_ptr,
|
||||||
const T* input_ptr, int total_rows, int total_cols,
|
int total_rows, int total_cols,
|
||||||
const CudaDeviceArrayStruct<IntType>& output_scan,
|
const CudaDeviceArrayStruct<IntType>& output_scan,
|
||||||
const CudaDeviceArrayStruct<T*>& output_ptr_data) {
|
const CudaDeviceArrayStruct<T*>& output_ptr_data) {
|
||||||
if (fixed_size) {
|
if (fixed_size) {
|
||||||
CudaLaunchConfig config =
|
CudaLaunchConfig config =
|
||||||
GetCudaLaunchConfig(total_rows * total_cols, gpu_device);
|
GetCudaLaunchConfig(total_rows * total_cols, gpu_device);
|
||||||
|
|
||||||
SplitVOpKernel_fixed<T><<<config.block_count, config.thread_per_block, 0,
|
SplitVOpKernel_fixed<T><<<config.block_count, config.thread_per_block, 0,
|
||||||
gpu_device.stream()>>>(
|
gpu_device.stream()>>>(
|
||||||
input_ptr, total_rows, total_cols, output_ptr_data);
|
input_ptr, total_rows, total_cols, output_ptr_data);
|
||||||
} else {
|
} else {
|
||||||
auto config = GetCuda2DLaunchConfig(total_cols, total_rows, gpu_device);
|
auto config = GetCuda2DLaunchConfig(total_cols, total_rows, gpu_device);
|
||||||
IntType smem_max = gpu_device.sharedMemPerBlock();
|
IntType smem_max = gpu_device.sharedMemPerBlock();
|
||||||
IntType smem_usage = output_scan.size * sizeof(IntType);
|
IntType smem_usage = output_scan.size * sizeof(IntType);
|
||||||
// performance crossover is less than using maximum available shared
|
// performance crossover is less than using maximum available shared
|
||||||
// memory on most processors possibly due to decreasing occupancy
|
// memory on most processors possibly due to decreasing occupancy
|
||||||
// 4096 inputs is a lot, most code will take the smem path
|
// 4096 inputs is a lot, most code will take the smem path
|
||||||
const int32 kMaxSmemBytesPerformance = 16384;
|
const int32 kMaxSmemBytesPerformance = 16384;
|
||||||
if (smem_usage < smem_max && smem_usage < kMaxSmemBytesPerformance)
|
if (smem_usage < smem_max && smem_usage < kMaxSmemBytesPerformance)
|
||||||
split_v_kernel<T, IntType, true>
|
split_v_kernel<T, IntType, true>
|
||||||
<<<config.block_count, config.thread_per_block, smem_usage,
|
<<<config.block_count, config.thread_per_block, smem_usage,
|
||||||
gpu_device.stream()>>>(input_ptr, output_scan, total_rows,
|
gpu_device.stream()>>>(input_ptr, output_scan, total_rows,
|
||||||
total_cols, output_ptr_data);
|
total_cols, output_ptr_data);
|
||||||
else
|
else
|
||||||
split_v_kernel<T, IntType, false>
|
split_v_kernel<T, IntType, false>
|
||||||
<<<config.block_count, config.thread_per_block, 0,
|
<<<config.block_count, config.thread_per_block, 0,
|
||||||
gpu_device.stream()>>>(input_ptr, output_scan, total_rows,
|
gpu_device.stream()>>>(input_ptr, output_scan, total_rows,
|
||||||
total_cols, output_ptr_data);
|
total_cols, output_ptr_data);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
};
|
}
|
||||||
|
|
||||||
#define REGISTER_GPU_KERNEL(T) template struct SplitOpGPULaunch<T>;
|
#define REGISTER_GPU_KERNEL(T) template struct SplitOpGPULaunch<T>;
|
||||||
|
|
||||||
|
61
tensorflow/core/kernels/split_lib_gpu.h
Normal file
61
tensorflow/core/kernels/split_lib_gpu.h
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_CORE_KERNELS_SPLIT_LIB_GPU_H_
|
||||||
|
#define TENSORFLOW_CORE_KERNELS_SPLIT_LIB_GPU_H_
|
||||||
|
|
||||||
|
#define EIGEN_USE_THREADS
|
||||||
|
#define EIGEN_USE_GPU
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
|
#include "tensorflow/core/kernels/cuda_device_array_gpu.h"
|
||||||
|
#include "tensorflow/core/kernels/split_lib.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct SplitOpGPULaunch {
|
||||||
|
void Run(const Eigen::GpuDevice& d, const T* input, int32 prefix_dim_size,
|
||||||
|
int32 split_dim_size, int32 suffix_dim_size,
|
||||||
|
const CudaDeviceArrayStruct<T*>& output_ptr_data);
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, typename IntType>
|
||||||
|
struct SplitVOpGPULaunch {
|
||||||
|
void Run(const Eigen::GpuDevice& d, bool fixed, const T* input,
|
||||||
|
int total_cols, int total_rows,
|
||||||
|
const CudaDeviceArrayStruct<IntType>& output_scan,
|
||||||
|
const CudaDeviceArrayStruct<T*>& output_ptr_data);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Explicit instantiations in split_lib_gpu.cu.cc.
|
||||||
|
#define REGISTER_GPU_KERNEL(T) \
|
||||||
|
extern template struct SplitOpGPULaunch<T>; \
|
||||||
|
extern template struct SplitVOpGPULaunch<T, int32>; \
|
||||||
|
extern template struct SplitVOpGPULaunch<T, int64>;
|
||||||
|
|
||||||
|
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL);
|
||||||
|
TF_CALL_complex64(REGISTER_GPU_KERNEL);
|
||||||
|
TF_CALL_complex128(REGISTER_GPU_KERNEL);
|
||||||
|
TF_CALL_bfloat16(REGISTER_GPU_KERNEL);
|
||||||
|
#undef REGISTER_GPU_KERNEL
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CORE_KERNELS_SPLIT_LIB_GPU_H_
|
@ -30,6 +30,7 @@ limitations under the License.
|
|||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
|
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
|
||||||
#include "tensorflow/core/kernels/cuda_device_array.h"
|
#include "tensorflow/core/kernels/cuda_device_array.h"
|
||||||
|
#include "tensorflow/core/kernels/split_lib_gpu.h"
|
||||||
#include "tensorflow/core/platform/stream_executor.h"
|
#include "tensorflow/core/platform/stream_executor.h"
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
@ -267,13 +268,6 @@ class SplitOpCPU : public SplitOpBase<CPUDevice, T> {
|
|||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
struct SplitOpGPULaunch {
|
|
||||||
void Run(const Eigen::GpuDevice& d, const T* input, int32 prefix_dim_size,
|
|
||||||
int32 split_dim_size, int32 suffix_dim_size,
|
|
||||||
const CudaDeviceArrayStruct<T*>& output_ptr_data);
|
|
||||||
};
|
|
||||||
|
|
||||||
// Partial specialization for GPU
|
// Partial specialization for GPU
|
||||||
template <typename T>
|
template <typename T>
|
||||||
class SplitOpGPU : public SplitOpBase<GPUDevice, T> {
|
class SplitOpGPU : public SplitOpBase<GPUDevice, T> {
|
||||||
|
@ -36,6 +36,7 @@ limitations under the License.
|
|||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
|
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
|
||||||
#include "tensorflow/core/kernels/cuda_device_array.h"
|
#include "tensorflow/core/kernels/cuda_device_array.h"
|
||||||
|
#include "tensorflow/core/kernels/split_lib_gpu.h"
|
||||||
#include "tensorflow/core/platform/stream_executor.h"
|
#include "tensorflow/core/platform/stream_executor.h"
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
@ -329,14 +330,6 @@ class SplitVOpCPU : public SplitVOpBase<CPUDevice, T, Tlen> {
|
|||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
|
|
||||||
template <typename T, typename IntType>
|
|
||||||
struct SplitVOpGPULaunch {
|
|
||||||
void Run(const Eigen::GpuDevice& d, bool fixed, const T* input,
|
|
||||||
int total_cols, int total_rows,
|
|
||||||
const CudaDeviceArrayStruct<IntType>& output_scan,
|
|
||||||
const CudaDeviceArrayStruct<T*>& output_ptr_data);
|
|
||||||
};
|
|
||||||
|
|
||||||
// Partial specialization for GPU
|
// Partial specialization for GPU
|
||||||
template <typename T, typename Tlen>
|
template <typename T, typename Tlen>
|
||||||
class SplitVOpGPU : public SplitVOpBase<GPUDevice, T, Tlen> {
|
class SplitVOpGPU : public SplitVOpBase<GPUDevice, T, Tlen> {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user