From d72c60531a49c6e3f301c07cf6b7af7c12551a92 Mon Sep 17 00:00:00 2001 From: Deven Desai Date: Wed, 13 Mar 2019 17:18:04 +0000 Subject: [PATCH] Changing from Cuda* to Gpu*, the contents for tensforflow/core/kernels/cuda_device_array* files --- .../core/kernels/bucketize_op_gpu.cu.cc | 8 ++--- tensorflow/core/kernels/concat_lib_gpu.cc | 4 +-- tensorflow/core/kernels/concat_lib_gpu.h | 12 +++---- .../core/kernels/concat_lib_gpu_impl.cu.cc | 36 +++++++++---------- tensorflow/core/kernels/dynamic_stitch_op.cc | 12 +++---- .../core/kernels/dynamic_stitch_op_gpu.cu.cc | 16 ++++----- tensorflow/core/kernels/gpu_device_array.h | 18 +++++----- .../core/kernels/gpu_device_array_gpu.h | 16 ++++----- tensorflow/core/kernels/split_lib_gpu.cu.cc | 32 ++++++++--------- tensorflow/core/kernels/split_lib_gpu.h | 6 ++-- tensorflow/core/kernels/split_op.cc | 2 +- tensorflow/core/kernels/split_v_op.cc | 4 +-- 12 files changed, 83 insertions(+), 83 deletions(-) diff --git a/tensorflow/core/kernels/bucketize_op_gpu.cu.cc b/tensorflow/core/kernels/bucketize_op_gpu.cu.cc index 31c73fadf50..516468c768f 100644 --- a/tensorflow/core/kernels/bucketize_op_gpu.cu.cc +++ b/tensorflow/core/kernels/bucketize_op_gpu.cu.cc @@ -36,8 +36,8 @@ typedef Eigen::GpuDevice GPUDevice; template __global__ void BucketizeCustomKernel( const int32 size_in, const T* in, const int32 size_boundaries, - CudaDeviceArrayStruct boundaries_array, int32* out) { - const float* boundaries = GetCudaDeviceArrayOnDevice(&boundaries_array); + GpuDeviceArrayStruct boundaries_array, int32* out) { + const float* boundaries = GetGpuDeviceArrayOnDevice(&boundaries_array); extern __shared__ __align__(sizeof(float)) unsigned char shared_mem[]; float* shared_mem_boundaries = reinterpret_cast(shared_mem); @@ -85,8 +85,8 @@ struct BucketizeFunctor { typename TTypes::Tensor& output) { const GPUDevice& d = context->eigen_device(); - CudaDeviceArrayOnHost boundaries_array(context, - boundaries_vector.size()); + GpuDeviceArrayOnHost boundaries_array(context, + boundaries_vector.size()); TF_RETURN_IF_ERROR(boundaries_array.Init()); for (int i = 0; i < boundaries_vector.size(); ++i) { boundaries_array.Set(i, boundaries_vector[i]); diff --git a/tensorflow/core/kernels/concat_lib_gpu.cc b/tensorflow/core/kernels/concat_lib_gpu.cc index 3108558a220..a75d464c31d 100644 --- a/tensorflow/core/kernels/concat_lib_gpu.cc +++ b/tensorflow/core/kernels/concat_lib_gpu.cc @@ -38,14 +38,14 @@ void ConcatGPUCall( const std::vector::ConstMatrix>>& inputs_flat, typename TTypes::Tensor* output_flat) { - CudaDeviceArrayOnHost input_ptrs(c, inputs_flat.size()); + GpuDeviceArrayOnHost input_ptrs(c, inputs_flat.size()); OP_REQUIRES_OK(c, input_ptrs.Init()); for (int i = 0; i < inputs_flat.size(); ++i) { input_ptrs.Set(i, inputs_flat[i]->data()); } OP_REQUIRES_OK(c, input_ptrs.Finalize()); - CudaDeviceArrayOnHost output_scan(c, inputs_flat.size() + 1); + GpuDeviceArrayOnHost output_scan(c, inputs_flat.size() + 1); OP_REQUIRES_OK(c, output_scan.Init()); IntType scan = 0; output_scan.Set(0, scan); diff --git a/tensorflow/core/kernels/concat_lib_gpu.h b/tensorflow/core/kernels/concat_lib_gpu.h index 3fcecd754fe..2db66a7c5a8 100644 --- a/tensorflow/core/kernels/concat_lib_gpu.h +++ b/tensorflow/core/kernels/concat_lib_gpu.h @@ -38,8 +38,8 @@ void ConcatGPUSlice( template void ConcatGPUImpl(const Eigen::GpuDevice& d, - const CudaDeviceArrayStruct& input_ptrs, - const CudaDeviceArrayStruct& ptr_offsets, + const GpuDeviceArrayStruct& input_ptrs, + const GpuDeviceArrayStruct& ptr_offsets, bool same_size, int slice_size, typename TTypes::Matrix* output); @@ -57,13 +57,13 @@ void ConcatGPUImpl(const Eigen::GpuDevice& d, typename TTypes::Matrix* output); \ extern template void ConcatGPUImpl( \ const Eigen::GpuDevice& d, \ - const CudaDeviceArrayStruct& input_ptrs, \ - const CudaDeviceArrayStruct& ptr_offsets, bool fixed_size, \ + const GpuDeviceArrayStruct& input_ptrs, \ + const GpuDeviceArrayStruct& ptr_offsets, bool fixed_size, \ int split_size, typename TTypes::Matrix* output); \ extern template void ConcatGPUImpl( \ const Eigen::GpuDevice& d, \ - const CudaDeviceArrayStruct& input_ptrs, \ - const CudaDeviceArrayStruct& ptr_offsets, bool fixed_size, \ + const GpuDeviceArrayStruct& input_ptrs, \ + const GpuDeviceArrayStruct& ptr_offsets, bool fixed_size, \ int split_size, typename TTypes::Matrix* output); TF_CALL_GPU_NUMBER_TYPES(REGISTER); diff --git a/tensorflow/core/kernels/concat_lib_gpu_impl.cu.cc b/tensorflow/core/kernels/concat_lib_gpu_impl.cu.cc index 36f2a8ec3b7..e5a00c25cdd 100644 --- a/tensorflow/core/kernels/concat_lib_gpu_impl.cu.cc +++ b/tensorflow/core/kernels/concat_lib_gpu_impl.cu.cc @@ -35,9 +35,9 @@ namespace { template __global__ void concat_fixed_kernel( - CudaDeviceArrayStruct input_ptr_data, int split_size, + GpuDeviceArrayStruct input_ptr_data, int split_size, int total_rows, int total_cols, T* output) { - const T** input_ptrs = GetCudaDeviceArrayOnDevice(&input_ptr_data); + const T** input_ptrs = GetGpuDeviceArrayOnDevice(&input_ptr_data); IntType gidx = blockIdx.x * blockDim.x + threadIdx.x; for (; gidx < total_cols; gidx += blockDim.x * gridDim.x) { @@ -59,11 +59,11 @@ __global__ void concat_fixed_kernel( // cannot be in anonymous namespace due to extern shared memory template __global__ void concat_variable_kernel( - CudaDeviceArrayStruct input_ptr_data, - CudaDeviceArrayStruct output_scan, IntType total_rows, + GpuDeviceArrayStruct input_ptr_data, + GpuDeviceArrayStruct output_scan, IntType total_rows, IntType total_cols, T* output) { - const T** input_ptrs = GetCudaDeviceArrayOnDevice(&input_ptr_data); - IntType* col_scan = GetCudaDeviceArrayOnDevice(&output_scan); + const T** input_ptrs = GetGpuDeviceArrayOnDevice(&input_ptr_data); + IntType* col_scan = GetGpuDeviceArrayOnDevice(&output_scan); // do upper_bound on col to find which pointer we should be using IntType gidx = blockIdx.x * blockDim.x + threadIdx.x; @@ -136,8 +136,8 @@ void ConcatGPUSlice( template void ConcatGPUImpl(const Eigen::GpuDevice& gpu_device, - const CudaDeviceArrayStruct& input_ptrs, - const CudaDeviceArrayStruct& output_scan, + const GpuDeviceArrayStruct& input_ptrs, + const GpuDeviceArrayStruct& output_scan, bool fixed_size, int split_size, typename TTypes::Matrix* output) { auto config = GetCuda2DLaunchConfig(output->dimension(1), @@ -185,18 +185,18 @@ void ConcatGPUImpl(const Eigen::GpuDevice& gpu_device, inputs_flat, \ typename TTypes::Matrix* output); -#define REGISTER_GPU32(T) \ - template void ConcatGPUImpl( \ - const Eigen::GpuDevice& d, \ - const CudaDeviceArrayStruct& input_ptrs, \ - const CudaDeviceArrayStruct& ptr_offsets, bool fixed_size, \ +#define REGISTER_GPU32(T) \ + template void ConcatGPUImpl( \ + const Eigen::GpuDevice& d, \ + const GpuDeviceArrayStruct& input_ptrs, \ + const GpuDeviceArrayStruct& ptr_offsets, bool fixed_size, \ int split_size, typename TTypes::Matrix* output); -#define REGISTER_GPU64(T) \ - template void ConcatGPUImpl( \ - const Eigen::GpuDevice& d, \ - const CudaDeviceArrayStruct& input_ptrs, \ - const CudaDeviceArrayStruct& ptr_offsets, bool fixed_size, \ +#define REGISTER_GPU64(T) \ + template void ConcatGPUImpl( \ + const Eigen::GpuDevice& d, \ + const GpuDeviceArrayStruct& input_ptrs, \ + const GpuDeviceArrayStruct& ptr_offsets, bool fixed_size, \ int split_size, typename TTypes::Matrix* output); TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPUCONCAT32); diff --git a/tensorflow/core/kernels/dynamic_stitch_op.cc b/tensorflow/core/kernels/dynamic_stitch_op.cc index 8a5f0b570fb..471bd7fbb1c 100644 --- a/tensorflow/core/kernels/dynamic_stitch_op.cc +++ b/tensorflow/core/kernels/dynamic_stitch_op.cc @@ -138,15 +138,15 @@ class DynamicStitchOpImplBase : public OpKernel { template void DynamicStitchGPUImpl(const Eigen::GpuDevice& gpu_device, const int32 slice_size, const int32 first_dim_size, - const CudaDeviceArrayStruct& input_indices, - const CudaDeviceArrayStruct& input_ptrs, + const GpuDeviceArrayStruct& input_indices, + const GpuDeviceArrayStruct& input_ptrs, T* output); #define REGISTER_GPU(T) \ extern template void DynamicStitchGPUImpl( \ const Eigen::GpuDevice& gpu_device, const int32 slice_size, \ const int32 first_dim_size, \ - const CudaDeviceArrayStruct& input_indices, \ - const CudaDeviceArrayStruct& input_ptrs, T* output); + const GpuDeviceArrayStruct& input_indices, \ + const GpuDeviceArrayStruct& input_ptrs, T* output); TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU); TF_CALL_complex64(REGISTER_GPU); TF_CALL_complex128(REGISTER_GPU); @@ -185,8 +185,8 @@ class DynamicStitchOpGPU : public DynamicStitchOpImplBase { // implicitly using atomics to make sure the last index is the final // write. const int slice_size = merged->flat_outer_dims().dimension(1); - CudaDeviceArrayOnHost indices_flat(c, first_dim_size); - CudaDeviceArrayOnHost data_flat(c, data_elements_size); + GpuDeviceArrayOnHost indices_flat(c, first_dim_size); + GpuDeviceArrayOnHost data_flat(c, data_elements_size); OP_REQUIRES_OK(c, indices_flat.Init()); OP_REQUIRES_OK(c, data_flat.Init()); // initialize the indices_flat (-1 represents missing indices) diff --git a/tensorflow/core/kernels/dynamic_stitch_op_gpu.cu.cc b/tensorflow/core/kernels/dynamic_stitch_op_gpu.cu.cc index 0c60a405a16..dd8b3489614 100644 --- a/tensorflow/core/kernels/dynamic_stitch_op_gpu.cu.cc +++ b/tensorflow/core/kernels/dynamic_stitch_op_gpu.cu.cc @@ -31,11 +31,11 @@ namespace { template __global__ void DynamicStitchKernel(const int32 slice_size, const int32 output_size, - CudaDeviceArrayStruct input_indices, - CudaDeviceArrayStruct input_ptrs, + GpuDeviceArrayStruct input_indices, + GpuDeviceArrayStruct input_ptrs, T* output) { - int32* data_indices = GetCudaDeviceArrayOnDevice(&input_indices); - const T** data_ptrs = GetCudaDeviceArrayOnDevice(&input_ptrs); + int32* data_indices = GetGpuDeviceArrayOnDevice(&input_indices); + const T** data_ptrs = GetGpuDeviceArrayOnDevice(&input_ptrs); CUDA_1D_KERNEL_LOOP(output_index, output_size) { const int32 slice_id = output_index / slice_size; const int32 slice_offset = output_index % slice_size; @@ -51,8 +51,8 @@ __global__ void DynamicStitchKernel(const int32 slice_size, template void DynamicStitchGPUImpl(const Eigen::GpuDevice& gpu_device, const int32 slice_size, const int32 first_dim_size, - const CudaDeviceArrayStruct& input_indices, - const CudaDeviceArrayStruct& input_ptrs, + const GpuDeviceArrayStruct& input_indices, + const GpuDeviceArrayStruct& input_ptrs, T* output) { const int32 output_size = first_dim_size * slice_size; auto config = GetCudaLaunchConfig(output_size, gpu_device); @@ -67,8 +67,8 @@ void DynamicStitchGPUImpl(const Eigen::GpuDevice& gpu_device, template void DynamicStitchGPUImpl( \ const Eigen::GpuDevice& gpu_device, const int32 slice_size, \ const int32 first_dim_size, \ - const CudaDeviceArrayStruct& input_indices, \ - const CudaDeviceArrayStruct& input_ptrs, T* output); + const GpuDeviceArrayStruct& input_indices, \ + const GpuDeviceArrayStruct& input_ptrs, T* output); TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU); TF_CALL_complex64(REGISTER_GPU); diff --git a/tensorflow/core/kernels/gpu_device_array.h b/tensorflow/core/kernels/gpu_device_array.h index 62e39b6e75c..3961cee043b 100644 --- a/tensorflow/core/kernels/gpu_device_array.h +++ b/tensorflow/core/kernels/gpu_device_array.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_H_ #define TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_H_ -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" #include "tensorflow/core/framework/op_kernel.h" @@ -24,11 +24,11 @@ limitations under the License. namespace tensorflow { // Create an array of value on the host, to be sent to kernel using -// CudaDeviceArrayStruct. +// GpuDeviceArrayStruct. // // Usage: // int size = ...; -// CudaDeviceArrayOnHost ptrs(context, size); +// GpuDeviceArrayOnHost ptrs(context, size); // OP_REQUIRES_OK(ptrs.Init()); // for (int i = 0; i < size; ++i) { // ptrs.Set(i, ...); @@ -38,9 +38,9 @@ namespace tensorflow { // // ValueType must be memcopyable. template -class CudaDeviceArrayOnHost { +class GpuDeviceArrayOnHost { public: - CudaDeviceArrayOnHost(OpKernelContext* context, int32 size) + GpuDeviceArrayOnHost(OpKernelContext* context, int32 size) : context_(context), total_bytes_(static_cast(size) * sizeof(ValueType)) { data_.size = size; @@ -93,7 +93,7 @@ class CudaDeviceArrayOnHost { return Status::OK(); } - const CudaDeviceArrayStruct& data() const { + const GpuDeviceArrayStruct& data() const { // Ensure Finalize is called. DCHECK(inlined() || out_of_line_values_on_gpu_.IsInitialized()); return data_; @@ -105,16 +105,16 @@ class CudaDeviceArrayOnHost { OpKernelContext* const context_; const int64 total_bytes_; // total size of all pointers. ValueType* values_ = nullptr; - CudaDeviceArrayStruct data_; + GpuDeviceArrayStruct data_; Tensor out_of_line_values_on_host_; Tensor out_of_line_values_on_gpu_; - TF_DISALLOW_COPY_AND_ASSIGN(CudaDeviceArrayOnHost); + TF_DISALLOW_COPY_AND_ASSIGN(GpuDeviceArrayOnHost); }; } // namespace tensorflow -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #endif // TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_H_ diff --git a/tensorflow/core/kernels/gpu_device_array_gpu.h b/tensorflow/core/kernels/gpu_device_array_gpu.h index 64fa3cb806b..ca2051c70db 100644 --- a/tensorflow/core/kernels/gpu_device_array_gpu.h +++ b/tensorflow/core/kernels/gpu_device_array_gpu.h @@ -18,15 +18,15 @@ limitations under the License. #ifndef TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_GPU_H_ #define TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_GPU_H_ -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM namespace tensorflow { -static constexpr int kMaxInlineCudaPointers = 8; -// To decode on the device side, use GetCudaDeviceArrayOnDevice. -// To encode on the host side, use CudaDeviceArrayOnHost. +static constexpr int kMaxInlineGpuPointers = 8; +// To decode on the device side, use GetGpuDeviceArrayOnDevice. +// To encode on the host side, use GpuDeviceArrayOnHost. template -struct CudaDeviceArrayStruct { +struct GpuDeviceArrayStruct { int32 size; // used if size <= MaxInlineValues; ValueType inline_values[MaxInlineValues]; @@ -34,8 +34,8 @@ struct CudaDeviceArrayStruct { }; template -EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ValueType* GetCudaDeviceArrayOnDevice( - CudaDeviceArrayStruct* data) { +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ValueType* GetGpuDeviceArrayOnDevice( + GpuDeviceArrayStruct* data) { if (data->size <= MaxInlineValues) { return data->inline_values; } else { @@ -45,6 +45,6 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ValueType* GetCudaDeviceArrayOnDevice( } // namespace tensorflow -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #endif // TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_GPU_H_ diff --git a/tensorflow/core/kernels/split_lib_gpu.cu.cc b/tensorflow/core/kernels/split_lib_gpu.cu.cc index 3df7306c31b..0f0ea7c7c73 100644 --- a/tensorflow/core/kernels/split_lib_gpu.cu.cc +++ b/tensorflow/core/kernels/split_lib_gpu.cu.cc @@ -77,9 +77,9 @@ namespace { template __global__ void SplitOpKernel(const T* input, int32 prefix_dim_size, int32 split_dim_size, int32 suffix_dim_size, - CudaDeviceArrayStruct output_ptr_data) { + GpuDeviceArrayStruct output_ptr_data) { const int32 num_split = output_ptr_data.size; - T** output_ptrs = GetCudaDeviceArrayOnDevice(&output_ptr_data); + T** output_ptrs = GetGpuDeviceArrayOnDevice(&output_ptr_data); eigen_assert(blockDim.y == 1); eigen_assert(blockDim.z == 1); @@ -114,11 +114,11 @@ __global__ void SplitOpKernel(const T* input, int32 prefix_dim_size, // is reversed template __global__ void split_v_kernel(const T* input_ptr, - CudaDeviceArrayStruct output_scan, + GpuDeviceArrayStruct output_scan, IntType total_rows, IntType total_cols, - CudaDeviceArrayStruct output_ptr_data) { - T** output_ptrs = GetCudaDeviceArrayOnDevice(&output_ptr_data); - IntType* col_scan = GetCudaDeviceArrayOnDevice(&output_scan); + GpuDeviceArrayStruct output_ptr_data) { + T** output_ptrs = GetGpuDeviceArrayOnDevice(&output_ptr_data); + IntType* col_scan = GetGpuDeviceArrayOnDevice(&output_scan); // do upper_bound on col to find which pointer we should be using IntType gidx = blockIdx.x * blockDim.x + threadIdx.x; @@ -170,11 +170,11 @@ __global__ void split_v_kernel(const T* input_ptr, // different from the original split implementation due to 2D vs 3D // dimensions. This version is likely faster due to less integer math. template -__global__ void SplitVOpKernel_fixed( - const T* input, int32 prefix_dim_size, int32 suffix_dim_size, - CudaDeviceArrayStruct output_ptr_data) { +__global__ void SplitVOpKernel_fixed(const T* input, int32 prefix_dim_size, + int32 suffix_dim_size, + GpuDeviceArrayStruct output_ptr_data) { const int32 num_split = output_ptr_data.size; - T** output_ptrs = GetCudaDeviceArrayOnDevice(&output_ptr_data); + T** output_ptrs = GetGpuDeviceArrayOnDevice(&output_ptr_data); eigen_assert(blockDim.y == 1); eigen_assert(blockDim.z == 1); @@ -195,10 +195,10 @@ __global__ void SplitVOpKernel_fixed( } template -void SplitOpGPULaunch::Run( - const Eigen::GpuDevice& d, const T* input, int32 prefix_dim_size, - int32 split_dim_size, int32 suffix_dim_size, - const CudaDeviceArrayStruct& output_ptr_data) { +void SplitOpGPULaunch::Run(const Eigen::GpuDevice& d, const T* input, + int32 prefix_dim_size, int32 split_dim_size, + int32 suffix_dim_size, + const GpuDeviceArrayStruct& output_ptr_data) { CudaLaunchConfig config = GetCudaLaunchConfig( prefix_dim_size * split_dim_size * suffix_dim_size, d); @@ -212,8 +212,8 @@ template void SplitVOpGPULaunch::Run( const Eigen::GpuDevice& gpu_device, bool fixed_size, const T* input_ptr, int total_rows, int total_cols, - const CudaDeviceArrayStruct& output_scan, - const CudaDeviceArrayStruct& output_ptr_data) { + const GpuDeviceArrayStruct& output_scan, + const GpuDeviceArrayStruct& output_ptr_data) { if (fixed_size) { CudaLaunchConfig config = GetCudaLaunchConfig(total_rows * total_cols, gpu_device); diff --git a/tensorflow/core/kernels/split_lib_gpu.h b/tensorflow/core/kernels/split_lib_gpu.h index ff76d072319..20feb7df143 100644 --- a/tensorflow/core/kernels/split_lib_gpu.h +++ b/tensorflow/core/kernels/split_lib_gpu.h @@ -33,15 +33,15 @@ template struct SplitOpGPULaunch { void Run(const Eigen::GpuDevice& d, const T* input, int32 prefix_dim_size, int32 split_dim_size, int32 suffix_dim_size, - const CudaDeviceArrayStruct& output_ptr_data); + const GpuDeviceArrayStruct& output_ptr_data); }; template struct SplitVOpGPULaunch { void Run(const Eigen::GpuDevice& d, bool fixed, const T* input, int total_cols, int total_rows, - const CudaDeviceArrayStruct& output_scan, - const CudaDeviceArrayStruct& output_ptr_data); + const GpuDeviceArrayStruct& output_scan, + const GpuDeviceArrayStruct& output_ptr_data); }; // Explicit instantiations in split_lib_gpu.cu.cc. diff --git a/tensorflow/core/kernels/split_op.cc b/tensorflow/core/kernels/split_op.cc index 522419c294c..a419eedb398 100644 --- a/tensorflow/core/kernels/split_op.cc +++ b/tensorflow/core/kernels/split_op.cc @@ -302,7 +302,7 @@ class SplitOpGPU : public SplitOpBase { TensorShape output_shape(input_shape); output_shape.set_dim(split_dim, split_dim_output_size); - CudaDeviceArrayOnHost ptrs(context, num_split); + GpuDeviceArrayOnHost ptrs(context, num_split); OP_REQUIRES_OK(context, ptrs.Init()); for (int i = 0; i < num_split; ++i) { diff --git a/tensorflow/core/kernels/split_v_op.cc b/tensorflow/core/kernels/split_v_op.cc index 01cd8d81e9b..8e53089af0d 100644 --- a/tensorflow/core/kernels/split_v_op.cc +++ b/tensorflow/core/kernels/split_v_op.cc @@ -366,10 +366,10 @@ class SplitVOpGPU : public SplitVOpBase { // reshape to 2D if (num_split > 16) { - CudaDeviceArrayOnHost ptrs(context, num_split); + GpuDeviceArrayOnHost ptrs(context, num_split); OP_REQUIRES_OK(context, ptrs.Init()); - CudaDeviceArrayOnHost offsets(context, num_split + 1); + GpuDeviceArrayOnHost offsets(context, num_split + 1); OP_REQUIRES_OK(context, offsets.Init()); Tlen offset = 0;