Changing from Cuda* to Gpu*, the contents for tensforflow/core/kernels/cuda_device_array* files

This commit is contained in:
Deven Desai 2019-03-13 17:18:04 +00:00
parent 88b8eebcf2
commit d72c60531a
12 changed files with 83 additions and 83 deletions

View File

@ -36,8 +36,8 @@ typedef Eigen::GpuDevice GPUDevice;
template <typename T, bool useSharedMem> template <typename T, bool useSharedMem>
__global__ void BucketizeCustomKernel( __global__ void BucketizeCustomKernel(
const int32 size_in, const T* in, const int32 size_boundaries, const int32 size_in, const T* in, const int32 size_boundaries,
CudaDeviceArrayStruct<float> boundaries_array, int32* out) { GpuDeviceArrayStruct<float> boundaries_array, int32* out) {
const float* boundaries = GetCudaDeviceArrayOnDevice(&boundaries_array); const float* boundaries = GetGpuDeviceArrayOnDevice(&boundaries_array);
extern __shared__ __align__(sizeof(float)) unsigned char shared_mem[]; extern __shared__ __align__(sizeof(float)) unsigned char shared_mem[];
float* shared_mem_boundaries = reinterpret_cast<float*>(shared_mem); float* shared_mem_boundaries = reinterpret_cast<float*>(shared_mem);
@ -85,7 +85,7 @@ struct BucketizeFunctor<GPUDevice, T> {
typename TTypes<int32, 1>::Tensor& output) { typename TTypes<int32, 1>::Tensor& output) {
const GPUDevice& d = context->eigen_device<GPUDevice>(); const GPUDevice& d = context->eigen_device<GPUDevice>();
CudaDeviceArrayOnHost<float> boundaries_array(context, GpuDeviceArrayOnHost<float> boundaries_array(context,
boundaries_vector.size()); boundaries_vector.size());
TF_RETURN_IF_ERROR(boundaries_array.Init()); TF_RETURN_IF_ERROR(boundaries_array.Init());
for (int i = 0; i < boundaries_vector.size(); ++i) { for (int i = 0; i < boundaries_vector.size(); ++i) {

View File

@ -38,14 +38,14 @@ void ConcatGPUCall(
const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>& const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>&
inputs_flat, inputs_flat,
typename TTypes<T, 2>::Tensor* output_flat) { typename TTypes<T, 2>::Tensor* output_flat) {
CudaDeviceArrayOnHost<const T*> input_ptrs(c, inputs_flat.size()); GpuDeviceArrayOnHost<const T*> input_ptrs(c, inputs_flat.size());
OP_REQUIRES_OK(c, input_ptrs.Init()); OP_REQUIRES_OK(c, input_ptrs.Init());
for (int i = 0; i < inputs_flat.size(); ++i) { for (int i = 0; i < inputs_flat.size(); ++i) {
input_ptrs.Set(i, inputs_flat[i]->data()); input_ptrs.Set(i, inputs_flat[i]->data());
} }
OP_REQUIRES_OK(c, input_ptrs.Finalize()); OP_REQUIRES_OK(c, input_ptrs.Finalize());
CudaDeviceArrayOnHost<IntType> output_scan(c, inputs_flat.size() + 1); GpuDeviceArrayOnHost<IntType> output_scan(c, inputs_flat.size() + 1);
OP_REQUIRES_OK(c, output_scan.Init()); OP_REQUIRES_OK(c, output_scan.Init());
IntType scan = 0; IntType scan = 0;
output_scan.Set(0, scan); output_scan.Set(0, scan);

View File

@ -38,8 +38,8 @@ void ConcatGPUSlice(
template <typename T, typename IntType> template <typename T, typename IntType>
void ConcatGPUImpl(const Eigen::GpuDevice& d, void ConcatGPUImpl(const Eigen::GpuDevice& d,
const CudaDeviceArrayStruct<const T*>& input_ptrs, const GpuDeviceArrayStruct<const T*>& input_ptrs,
const CudaDeviceArrayStruct<IntType>& ptr_offsets, const GpuDeviceArrayStruct<IntType>& ptr_offsets,
bool same_size, int slice_size, bool same_size, int slice_size,
typename TTypes<T, 2>::Matrix* output); typename TTypes<T, 2>::Matrix* output);
@ -57,13 +57,13 @@ void ConcatGPUImpl(const Eigen::GpuDevice& d,
typename TTypes<T, 2>::Matrix* output); \ typename TTypes<T, 2>::Matrix* output); \
extern template void ConcatGPUImpl<T, int32>( \ extern template void ConcatGPUImpl<T, int32>( \
const Eigen::GpuDevice& d, \ const Eigen::GpuDevice& d, \
const CudaDeviceArrayStruct<const T*>& input_ptrs, \ const GpuDeviceArrayStruct<const T*>& input_ptrs, \
const CudaDeviceArrayStruct<int32>& ptr_offsets, bool fixed_size, \ const GpuDeviceArrayStruct<int32>& ptr_offsets, bool fixed_size, \
int split_size, typename TTypes<T, 2>::Matrix* output); \ int split_size, typename TTypes<T, 2>::Matrix* output); \
extern template void ConcatGPUImpl<T, int64>( \ extern template void ConcatGPUImpl<T, int64>( \
const Eigen::GpuDevice& d, \ const Eigen::GpuDevice& d, \
const CudaDeviceArrayStruct<const T*>& input_ptrs, \ const GpuDeviceArrayStruct<const T*>& input_ptrs, \
const CudaDeviceArrayStruct<int64>& ptr_offsets, bool fixed_size, \ const GpuDeviceArrayStruct<int64>& ptr_offsets, bool fixed_size, \
int split_size, typename TTypes<T, 2>::Matrix* output); int split_size, typename TTypes<T, 2>::Matrix* output);
TF_CALL_GPU_NUMBER_TYPES(REGISTER); TF_CALL_GPU_NUMBER_TYPES(REGISTER);

View File

@ -35,9 +35,9 @@ namespace {
template <typename T, typename IntType> template <typename T, typename IntType>
__global__ void concat_fixed_kernel( __global__ void concat_fixed_kernel(
CudaDeviceArrayStruct<const T*> input_ptr_data, int split_size, GpuDeviceArrayStruct<const T*> input_ptr_data, int split_size,
int total_rows, int total_cols, T* output) { int total_rows, int total_cols, T* output) {
const T** input_ptrs = GetCudaDeviceArrayOnDevice(&input_ptr_data); const T** input_ptrs = GetGpuDeviceArrayOnDevice(&input_ptr_data);
IntType gidx = blockIdx.x * blockDim.x + threadIdx.x; IntType gidx = blockIdx.x * blockDim.x + threadIdx.x;
for (; gidx < total_cols; gidx += blockDim.x * gridDim.x) { for (; gidx < total_cols; gidx += blockDim.x * gridDim.x) {
@ -59,11 +59,11 @@ __global__ void concat_fixed_kernel(
// cannot be in anonymous namespace due to extern shared memory // cannot be in anonymous namespace due to extern shared memory
template <typename T, typename IntType, bool useSmem> template <typename T, typename IntType, bool useSmem>
__global__ void concat_variable_kernel( __global__ void concat_variable_kernel(
CudaDeviceArrayStruct<const T*> input_ptr_data, GpuDeviceArrayStruct<const T*> input_ptr_data,
CudaDeviceArrayStruct<IntType> output_scan, IntType total_rows, GpuDeviceArrayStruct<IntType> output_scan, IntType total_rows,
IntType total_cols, T* output) { IntType total_cols, T* output) {
const T** input_ptrs = GetCudaDeviceArrayOnDevice(&input_ptr_data); const T** input_ptrs = GetGpuDeviceArrayOnDevice(&input_ptr_data);
IntType* col_scan = GetCudaDeviceArrayOnDevice(&output_scan); IntType* col_scan = GetGpuDeviceArrayOnDevice(&output_scan);
// do upper_bound on col to find which pointer we should be using // do upper_bound on col to find which pointer we should be using
IntType gidx = blockIdx.x * blockDim.x + threadIdx.x; IntType gidx = blockIdx.x * blockDim.x + threadIdx.x;
@ -136,8 +136,8 @@ void ConcatGPUSlice(
template <typename T, typename IntType> template <typename T, typename IntType>
void ConcatGPUImpl(const Eigen::GpuDevice& gpu_device, void ConcatGPUImpl(const Eigen::GpuDevice& gpu_device,
const CudaDeviceArrayStruct<const T*>& input_ptrs, const GpuDeviceArrayStruct<const T*>& input_ptrs,
const CudaDeviceArrayStruct<IntType>& output_scan, const GpuDeviceArrayStruct<IntType>& output_scan,
bool fixed_size, int split_size, bool fixed_size, int split_size,
typename TTypes<T, 2>::Matrix* output) { typename TTypes<T, 2>::Matrix* output) {
auto config = GetCuda2DLaunchConfig(output->dimension(1), auto config = GetCuda2DLaunchConfig(output->dimension(1),
@ -188,15 +188,15 @@ void ConcatGPUImpl(const Eigen::GpuDevice& gpu_device,
#define REGISTER_GPU32(T) \ #define REGISTER_GPU32(T) \
template void ConcatGPUImpl<T, int32>( \ template void ConcatGPUImpl<T, int32>( \
const Eigen::GpuDevice& d, \ const Eigen::GpuDevice& d, \
const CudaDeviceArrayStruct<const T*>& input_ptrs, \ const GpuDeviceArrayStruct<const T*>& input_ptrs, \
const CudaDeviceArrayStruct<int32>& ptr_offsets, bool fixed_size, \ const GpuDeviceArrayStruct<int32>& ptr_offsets, bool fixed_size, \
int split_size, typename TTypes<T, 2>::Matrix* output); int split_size, typename TTypes<T, 2>::Matrix* output);
#define REGISTER_GPU64(T) \ #define REGISTER_GPU64(T) \
template void ConcatGPUImpl<T, int64>( \ template void ConcatGPUImpl<T, int64>( \
const Eigen::GpuDevice& d, \ const Eigen::GpuDevice& d, \
const CudaDeviceArrayStruct<const T*>& input_ptrs, \ const GpuDeviceArrayStruct<const T*>& input_ptrs, \
const CudaDeviceArrayStruct<int64>& ptr_offsets, bool fixed_size, \ const GpuDeviceArrayStruct<int64>& ptr_offsets, bool fixed_size, \
int split_size, typename TTypes<T, 2>::Matrix* output); int split_size, typename TTypes<T, 2>::Matrix* output);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPUCONCAT32); TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPUCONCAT32);

View File

@ -138,15 +138,15 @@ class DynamicStitchOpImplBase : public OpKernel {
template <typename T> template <typename T>
void DynamicStitchGPUImpl(const Eigen::GpuDevice& gpu_device, void DynamicStitchGPUImpl(const Eigen::GpuDevice& gpu_device,
const int32 slice_size, const int32 first_dim_size, const int32 slice_size, const int32 first_dim_size,
const CudaDeviceArrayStruct<int>& input_indices, const GpuDeviceArrayStruct<int>& input_indices,
const CudaDeviceArrayStruct<const T*>& input_ptrs, const GpuDeviceArrayStruct<const T*>& input_ptrs,
T* output); T* output);
#define REGISTER_GPU(T) \ #define REGISTER_GPU(T) \
extern template void DynamicStitchGPUImpl( \ extern template void DynamicStitchGPUImpl( \
const Eigen::GpuDevice& gpu_device, const int32 slice_size, \ const Eigen::GpuDevice& gpu_device, const int32 slice_size, \
const int32 first_dim_size, \ const int32 first_dim_size, \
const CudaDeviceArrayStruct<int32>& input_indices, \ const GpuDeviceArrayStruct<int32>& input_indices, \
const CudaDeviceArrayStruct<const T*>& input_ptrs, T* output); const GpuDeviceArrayStruct<const T*>& input_ptrs, T* output);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU); TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
TF_CALL_complex64(REGISTER_GPU); TF_CALL_complex64(REGISTER_GPU);
TF_CALL_complex128(REGISTER_GPU); TF_CALL_complex128(REGISTER_GPU);
@ -185,8 +185,8 @@ class DynamicStitchOpGPU : public DynamicStitchOpImplBase<T> {
// implicitly using atomics to make sure the last index is the final // implicitly using atomics to make sure the last index is the final
// write. // write.
const int slice_size = merged->flat_outer_dims<T>().dimension(1); const int slice_size = merged->flat_outer_dims<T>().dimension(1);
CudaDeviceArrayOnHost<int32> indices_flat(c, first_dim_size); GpuDeviceArrayOnHost<int32> indices_flat(c, first_dim_size);
CudaDeviceArrayOnHost<const T*> data_flat(c, data_elements_size); GpuDeviceArrayOnHost<const T*> data_flat(c, data_elements_size);
OP_REQUIRES_OK(c, indices_flat.Init()); OP_REQUIRES_OK(c, indices_flat.Init());
OP_REQUIRES_OK(c, data_flat.Init()); OP_REQUIRES_OK(c, data_flat.Init());
// initialize the indices_flat (-1 represents missing indices) // initialize the indices_flat (-1 represents missing indices)

View File

@ -31,11 +31,11 @@ namespace {
template <typename T> template <typename T>
__global__ void DynamicStitchKernel(const int32 slice_size, __global__ void DynamicStitchKernel(const int32 slice_size,
const int32 output_size, const int32 output_size,
CudaDeviceArrayStruct<int32> input_indices, GpuDeviceArrayStruct<int32> input_indices,
CudaDeviceArrayStruct<const T*> input_ptrs, GpuDeviceArrayStruct<const T*> input_ptrs,
T* output) { T* output) {
int32* data_indices = GetCudaDeviceArrayOnDevice(&input_indices); int32* data_indices = GetGpuDeviceArrayOnDevice(&input_indices);
const T** data_ptrs = GetCudaDeviceArrayOnDevice(&input_ptrs); const T** data_ptrs = GetGpuDeviceArrayOnDevice(&input_ptrs);
CUDA_1D_KERNEL_LOOP(output_index, output_size) { CUDA_1D_KERNEL_LOOP(output_index, output_size) {
const int32 slice_id = output_index / slice_size; const int32 slice_id = output_index / slice_size;
const int32 slice_offset = output_index % slice_size; const int32 slice_offset = output_index % slice_size;
@ -51,8 +51,8 @@ __global__ void DynamicStitchKernel(const int32 slice_size,
template <typename T> template <typename T>
void DynamicStitchGPUImpl(const Eigen::GpuDevice& gpu_device, void DynamicStitchGPUImpl(const Eigen::GpuDevice& gpu_device,
const int32 slice_size, const int32 first_dim_size, const int32 slice_size, const int32 first_dim_size,
const CudaDeviceArrayStruct<int>& input_indices, const GpuDeviceArrayStruct<int>& input_indices,
const CudaDeviceArrayStruct<const T*>& input_ptrs, const GpuDeviceArrayStruct<const T*>& input_ptrs,
T* output) { T* output) {
const int32 output_size = first_dim_size * slice_size; const int32 output_size = first_dim_size * slice_size;
auto config = GetCudaLaunchConfig(output_size, gpu_device); auto config = GetCudaLaunchConfig(output_size, gpu_device);
@ -67,8 +67,8 @@ void DynamicStitchGPUImpl(const Eigen::GpuDevice& gpu_device,
template void DynamicStitchGPUImpl( \ template void DynamicStitchGPUImpl( \
const Eigen::GpuDevice& gpu_device, const int32 slice_size, \ const Eigen::GpuDevice& gpu_device, const int32 slice_size, \
const int32 first_dim_size, \ const int32 first_dim_size, \
const CudaDeviceArrayStruct<int32>& input_indices, \ const GpuDeviceArrayStruct<int32>& input_indices, \
const CudaDeviceArrayStruct<const T*>& input_ptrs, T* output); const GpuDeviceArrayStruct<const T*>& input_ptrs, T* output);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU); TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
TF_CALL_complex64(REGISTER_GPU); TF_CALL_complex64(REGISTER_GPU);

View File

@ -15,7 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_H_ #ifndef TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_H_
#define TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_H_ #define TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_H_
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
@ -24,11 +24,11 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
// Create an array of value on the host, to be sent to kernel using // Create an array of value on the host, to be sent to kernel using
// CudaDeviceArrayStruct. // GpuDeviceArrayStruct.
// //
// Usage: // Usage:
// int size = ...; // int size = ...;
// CudaDeviceArrayOnHost ptrs(context, size); // GpuDeviceArrayOnHost ptrs(context, size);
// OP_REQUIRES_OK(ptrs.Init()); // OP_REQUIRES_OK(ptrs.Init());
// for (int i = 0; i < size; ++i) { // for (int i = 0; i < size; ++i) {
// ptrs.Set(i, ...); // ptrs.Set(i, ...);
@ -38,9 +38,9 @@ namespace tensorflow {
// //
// ValueType must be memcopyable. // ValueType must be memcopyable.
template <typename ValueType, int MaxInlineValues = 8> template <typename ValueType, int MaxInlineValues = 8>
class CudaDeviceArrayOnHost { class GpuDeviceArrayOnHost {
public: public:
CudaDeviceArrayOnHost(OpKernelContext* context, int32 size) GpuDeviceArrayOnHost(OpKernelContext* context, int32 size)
: context_(context), : context_(context),
total_bytes_(static_cast<int64>(size) * sizeof(ValueType)) { total_bytes_(static_cast<int64>(size) * sizeof(ValueType)) {
data_.size = size; data_.size = size;
@ -93,7 +93,7 @@ class CudaDeviceArrayOnHost {
return Status::OK(); return Status::OK();
} }
const CudaDeviceArrayStruct<ValueType, MaxInlineValues>& data() const { const GpuDeviceArrayStruct<ValueType, MaxInlineValues>& data() const {
// Ensure Finalize is called. // Ensure Finalize is called.
DCHECK(inlined() || out_of_line_values_on_gpu_.IsInitialized()); DCHECK(inlined() || out_of_line_values_on_gpu_.IsInitialized());
return data_; return data_;
@ -105,16 +105,16 @@ class CudaDeviceArrayOnHost {
OpKernelContext* const context_; OpKernelContext* const context_;
const int64 total_bytes_; // total size of all pointers. const int64 total_bytes_; // total size of all pointers.
ValueType* values_ = nullptr; ValueType* values_ = nullptr;
CudaDeviceArrayStruct<ValueType, MaxInlineValues> data_; GpuDeviceArrayStruct<ValueType, MaxInlineValues> data_;
Tensor out_of_line_values_on_host_; Tensor out_of_line_values_on_host_;
Tensor out_of_line_values_on_gpu_; Tensor out_of_line_values_on_gpu_;
TF_DISALLOW_COPY_AND_ASSIGN(CudaDeviceArrayOnHost); TF_DISALLOW_COPY_AND_ASSIGN(GpuDeviceArrayOnHost);
}; };
} // namespace tensorflow } // namespace tensorflow
#endif // GOOGLE_CUDA #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#endif // TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_H_ #endif // TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_H_

View File

@ -18,15 +18,15 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_GPU_H_ #ifndef TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_GPU_H_
#define TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_GPU_H_ #define TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_GPU_H_
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
namespace tensorflow { namespace tensorflow {
static constexpr int kMaxInlineCudaPointers = 8; static constexpr int kMaxInlineGpuPointers = 8;
// To decode on the device side, use GetCudaDeviceArrayOnDevice. // To decode on the device side, use GetGpuDeviceArrayOnDevice.
// To encode on the host side, use CudaDeviceArrayOnHost. // To encode on the host side, use GpuDeviceArrayOnHost.
template <typename ValueType, int MaxInlineValues = 8> template <typename ValueType, int MaxInlineValues = 8>
struct CudaDeviceArrayStruct { struct GpuDeviceArrayStruct {
int32 size; int32 size;
// used if size <= MaxInlineValues; // used if size <= MaxInlineValues;
ValueType inline_values[MaxInlineValues]; ValueType inline_values[MaxInlineValues];
@ -34,8 +34,8 @@ struct CudaDeviceArrayStruct {
}; };
template <typename ValueType, int MaxInlineValues = 8> template <typename ValueType, int MaxInlineValues = 8>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ValueType* GetCudaDeviceArrayOnDevice( EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ValueType* GetGpuDeviceArrayOnDevice(
CudaDeviceArrayStruct<ValueType, MaxInlineValues>* data) { GpuDeviceArrayStruct<ValueType, MaxInlineValues>* data) {
if (data->size <= MaxInlineValues) { if (data->size <= MaxInlineValues) {
return data->inline_values; return data->inline_values;
} else { } else {
@ -45,6 +45,6 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ValueType* GetCudaDeviceArrayOnDevice(
} // namespace tensorflow } // namespace tensorflow
#endif // GOOGLE_CUDA #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#endif // TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_GPU_H_ #endif // TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_GPU_H_

View File

@ -77,9 +77,9 @@ namespace {
template <typename T> template <typename T>
__global__ void SplitOpKernel(const T* input, int32 prefix_dim_size, __global__ void SplitOpKernel(const T* input, int32 prefix_dim_size,
int32 split_dim_size, int32 suffix_dim_size, int32 split_dim_size, int32 suffix_dim_size,
CudaDeviceArrayStruct<T*> output_ptr_data) { GpuDeviceArrayStruct<T*> output_ptr_data) {
const int32 num_split = output_ptr_data.size; const int32 num_split = output_ptr_data.size;
T** output_ptrs = GetCudaDeviceArrayOnDevice(&output_ptr_data); T** output_ptrs = GetGpuDeviceArrayOnDevice(&output_ptr_data);
eigen_assert(blockDim.y == 1); eigen_assert(blockDim.y == 1);
eigen_assert(blockDim.z == 1); eigen_assert(blockDim.z == 1);
@ -114,11 +114,11 @@ __global__ void SplitOpKernel(const T* input, int32 prefix_dim_size,
// is reversed // is reversed
template <typename T, typename IntType, bool useSmem> template <typename T, typename IntType, bool useSmem>
__global__ void split_v_kernel(const T* input_ptr, __global__ void split_v_kernel(const T* input_ptr,
CudaDeviceArrayStruct<IntType> output_scan, GpuDeviceArrayStruct<IntType> output_scan,
IntType total_rows, IntType total_cols, IntType total_rows, IntType total_cols,
CudaDeviceArrayStruct<T*> output_ptr_data) { GpuDeviceArrayStruct<T*> output_ptr_data) {
T** output_ptrs = GetCudaDeviceArrayOnDevice(&output_ptr_data); T** output_ptrs = GetGpuDeviceArrayOnDevice(&output_ptr_data);
IntType* col_scan = GetCudaDeviceArrayOnDevice(&output_scan); IntType* col_scan = GetGpuDeviceArrayOnDevice(&output_scan);
// do upper_bound on col to find which pointer we should be using // do upper_bound on col to find which pointer we should be using
IntType gidx = blockIdx.x * blockDim.x + threadIdx.x; IntType gidx = blockIdx.x * blockDim.x + threadIdx.x;
@ -170,11 +170,11 @@ __global__ void split_v_kernel(const T* input_ptr,
// different from the original split implementation due to 2D vs 3D // different from the original split implementation due to 2D vs 3D
// dimensions. This version is likely faster due to less integer math. // dimensions. This version is likely faster due to less integer math.
template <typename T> template <typename T>
__global__ void SplitVOpKernel_fixed( __global__ void SplitVOpKernel_fixed(const T* input, int32 prefix_dim_size,
const T* input, int32 prefix_dim_size, int32 suffix_dim_size, int32 suffix_dim_size,
CudaDeviceArrayStruct<T*> output_ptr_data) { GpuDeviceArrayStruct<T*> output_ptr_data) {
const int32 num_split = output_ptr_data.size; const int32 num_split = output_ptr_data.size;
T** output_ptrs = GetCudaDeviceArrayOnDevice(&output_ptr_data); T** output_ptrs = GetGpuDeviceArrayOnDevice(&output_ptr_data);
eigen_assert(blockDim.y == 1); eigen_assert(blockDim.y == 1);
eigen_assert(blockDim.z == 1); eigen_assert(blockDim.z == 1);
@ -195,10 +195,10 @@ __global__ void SplitVOpKernel_fixed(
} }
template <typename T> template <typename T>
void SplitOpGPULaunch<T>::Run( void SplitOpGPULaunch<T>::Run(const Eigen::GpuDevice& d, const T* input,
const Eigen::GpuDevice& d, const T* input, int32 prefix_dim_size, int32 prefix_dim_size, int32 split_dim_size,
int32 split_dim_size, int32 suffix_dim_size, int32 suffix_dim_size,
const CudaDeviceArrayStruct<T*>& output_ptr_data) { const GpuDeviceArrayStruct<T*>& output_ptr_data) {
CudaLaunchConfig config = GetCudaLaunchConfig( CudaLaunchConfig config = GetCudaLaunchConfig(
prefix_dim_size * split_dim_size * suffix_dim_size, d); prefix_dim_size * split_dim_size * suffix_dim_size, d);
@ -212,8 +212,8 @@ template <typename T, typename IntType>
void SplitVOpGPULaunch<T, IntType>::Run( void SplitVOpGPULaunch<T, IntType>::Run(
const Eigen::GpuDevice& gpu_device, bool fixed_size, const T* input_ptr, const Eigen::GpuDevice& gpu_device, bool fixed_size, const T* input_ptr,
int total_rows, int total_cols, int total_rows, int total_cols,
const CudaDeviceArrayStruct<IntType>& output_scan, const GpuDeviceArrayStruct<IntType>& output_scan,
const CudaDeviceArrayStruct<T*>& output_ptr_data) { const GpuDeviceArrayStruct<T*>& output_ptr_data) {
if (fixed_size) { if (fixed_size) {
CudaLaunchConfig config = CudaLaunchConfig config =
GetCudaLaunchConfig(total_rows * total_cols, gpu_device); GetCudaLaunchConfig(total_rows * total_cols, gpu_device);

View File

@ -33,15 +33,15 @@ template <typename T>
struct SplitOpGPULaunch { struct SplitOpGPULaunch {
void Run(const Eigen::GpuDevice& d, const T* input, int32 prefix_dim_size, void Run(const Eigen::GpuDevice& d, const T* input, int32 prefix_dim_size,
int32 split_dim_size, int32 suffix_dim_size, int32 split_dim_size, int32 suffix_dim_size,
const CudaDeviceArrayStruct<T*>& output_ptr_data); const GpuDeviceArrayStruct<T*>& output_ptr_data);
}; };
template <typename T, typename IntType> template <typename T, typename IntType>
struct SplitVOpGPULaunch { struct SplitVOpGPULaunch {
void Run(const Eigen::GpuDevice& d, bool fixed, const T* input, void Run(const Eigen::GpuDevice& d, bool fixed, const T* input,
int total_cols, int total_rows, int total_cols, int total_rows,
const CudaDeviceArrayStruct<IntType>& output_scan, const GpuDeviceArrayStruct<IntType>& output_scan,
const CudaDeviceArrayStruct<T*>& output_ptr_data); const GpuDeviceArrayStruct<T*>& output_ptr_data);
}; };
// Explicit instantiations in split_lib_gpu.cu.cc. // Explicit instantiations in split_lib_gpu.cu.cc.

View File

@ -302,7 +302,7 @@ class SplitOpGPU : public SplitOpBase<GPUDevice, T> {
TensorShape output_shape(input_shape); TensorShape output_shape(input_shape);
output_shape.set_dim(split_dim, split_dim_output_size); output_shape.set_dim(split_dim, split_dim_output_size);
CudaDeviceArrayOnHost<T*> ptrs(context, num_split); GpuDeviceArrayOnHost<T*> ptrs(context, num_split);
OP_REQUIRES_OK(context, ptrs.Init()); OP_REQUIRES_OK(context, ptrs.Init());
for (int i = 0; i < num_split; ++i) { for (int i = 0; i < num_split; ++i) {

View File

@ -366,10 +366,10 @@ class SplitVOpGPU : public SplitVOpBase<GPUDevice, T, Tlen> {
// reshape to 2D // reshape to 2D
if (num_split > 16) { if (num_split > 16) {
CudaDeviceArrayOnHost<T*> ptrs(context, num_split); GpuDeviceArrayOnHost<T*> ptrs(context, num_split);
OP_REQUIRES_OK(context, ptrs.Init()); OP_REQUIRES_OK(context, ptrs.Init());
CudaDeviceArrayOnHost<Tlen> offsets(context, num_split + 1); GpuDeviceArrayOnHost<Tlen> offsets(context, num_split + 1);
OP_REQUIRES_OK(context, offsets.Init()); OP_REQUIRES_OK(context, offsets.Init());
Tlen offset = 0; Tlen offset = 0;