Merge pull request #26678 from ROCmSoftwarePlatform:google_upstream_gpu_device_array
PiperOrigin-RevId: 238488799
This commit is contained in:
commit
dafd11de8d
@ -36,8 +36,8 @@ typedef Eigen::GpuDevice GPUDevice;
|
||||
template <typename T, bool useSharedMem>
|
||||
__global__ void BucketizeCustomKernel(
|
||||
const int32 size_in, const T* in, const int32 size_boundaries,
|
||||
CudaDeviceArrayStruct<float> boundaries_array, int32* out) {
|
||||
const float* boundaries = GetCudaDeviceArrayOnDevice(&boundaries_array);
|
||||
GpuDeviceArrayStruct<float> boundaries_array, int32* out) {
|
||||
const float* boundaries = GetGpuDeviceArrayOnDevice(&boundaries_array);
|
||||
|
||||
extern __shared__ __align__(sizeof(float)) unsigned char shared_mem[];
|
||||
float* shared_mem_boundaries = reinterpret_cast<float*>(shared_mem);
|
||||
@ -85,8 +85,8 @@ struct BucketizeFunctor<GPUDevice, T> {
|
||||
typename TTypes<int32, 1>::Tensor& output) {
|
||||
const GPUDevice& d = context->eigen_device<GPUDevice>();
|
||||
|
||||
CudaDeviceArrayOnHost<float> boundaries_array(context,
|
||||
boundaries_vector.size());
|
||||
GpuDeviceArrayOnHost<float> boundaries_array(context,
|
||||
boundaries_vector.size());
|
||||
TF_RETURN_IF_ERROR(boundaries_array.Init());
|
||||
for (int i = 0; i < boundaries_vector.size(); ++i) {
|
||||
boundaries_array.Set(i, boundaries_vector[i]);
|
||||
|
@ -38,14 +38,14 @@ void ConcatGPUCall(
|
||||
const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>&
|
||||
inputs_flat,
|
||||
typename TTypes<T, 2>::Tensor* output_flat) {
|
||||
CudaDeviceArrayOnHost<const T*> input_ptrs(c, inputs_flat.size());
|
||||
GpuDeviceArrayOnHost<const T*> input_ptrs(c, inputs_flat.size());
|
||||
OP_REQUIRES_OK(c, input_ptrs.Init());
|
||||
for (int i = 0; i < inputs_flat.size(); ++i) {
|
||||
input_ptrs.Set(i, inputs_flat[i]->data());
|
||||
}
|
||||
OP_REQUIRES_OK(c, input_ptrs.Finalize());
|
||||
|
||||
CudaDeviceArrayOnHost<IntType> output_scan(c, inputs_flat.size() + 1);
|
||||
GpuDeviceArrayOnHost<IntType> output_scan(c, inputs_flat.size() + 1);
|
||||
OP_REQUIRES_OK(c, output_scan.Init());
|
||||
IntType scan = 0;
|
||||
output_scan.Set(0, scan);
|
||||
|
@ -38,8 +38,8 @@ void ConcatGPUSlice(
|
||||
|
||||
template <typename T, typename IntType>
|
||||
void ConcatGPUImpl(const Eigen::GpuDevice& d,
|
||||
const CudaDeviceArrayStruct<const T*>& input_ptrs,
|
||||
const CudaDeviceArrayStruct<IntType>& ptr_offsets,
|
||||
const GpuDeviceArrayStruct<const T*>& input_ptrs,
|
||||
const GpuDeviceArrayStruct<IntType>& ptr_offsets,
|
||||
bool same_size, int slice_size,
|
||||
typename TTypes<T, 2>::Matrix* output);
|
||||
|
||||
@ -57,13 +57,13 @@ void ConcatGPUImpl(const Eigen::GpuDevice& d,
|
||||
typename TTypes<T, 2>::Matrix* output); \
|
||||
extern template void ConcatGPUImpl<T, int32>( \
|
||||
const Eigen::GpuDevice& d, \
|
||||
const CudaDeviceArrayStruct<const T*>& input_ptrs, \
|
||||
const CudaDeviceArrayStruct<int32>& ptr_offsets, bool fixed_size, \
|
||||
const GpuDeviceArrayStruct<const T*>& input_ptrs, \
|
||||
const GpuDeviceArrayStruct<int32>& ptr_offsets, bool fixed_size, \
|
||||
int split_size, typename TTypes<T, 2>::Matrix* output); \
|
||||
extern template void ConcatGPUImpl<T, int64>( \
|
||||
const Eigen::GpuDevice& d, \
|
||||
const CudaDeviceArrayStruct<const T*>& input_ptrs, \
|
||||
const CudaDeviceArrayStruct<int64>& ptr_offsets, bool fixed_size, \
|
||||
const GpuDeviceArrayStruct<const T*>& input_ptrs, \
|
||||
const GpuDeviceArrayStruct<int64>& ptr_offsets, bool fixed_size, \
|
||||
int split_size, typename TTypes<T, 2>::Matrix* output);
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER);
|
||||
|
@ -35,9 +35,9 @@ namespace {
|
||||
|
||||
template <typename T, typename IntType>
|
||||
__global__ void concat_fixed_kernel(
|
||||
CudaDeviceArrayStruct<const T*> input_ptr_data, int split_size,
|
||||
GpuDeviceArrayStruct<const T*> input_ptr_data, int split_size,
|
||||
int total_rows, int total_cols, T* output) {
|
||||
const T** input_ptrs = GetCudaDeviceArrayOnDevice(&input_ptr_data);
|
||||
const T** input_ptrs = GetGpuDeviceArrayOnDevice(&input_ptr_data);
|
||||
IntType gidx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
for (; gidx < total_cols; gidx += blockDim.x * gridDim.x) {
|
||||
@ -59,11 +59,11 @@ __global__ void concat_fixed_kernel(
|
||||
// cannot be in anonymous namespace due to extern shared memory
|
||||
template <typename T, typename IntType, bool useSmem>
|
||||
__global__ void concat_variable_kernel(
|
||||
CudaDeviceArrayStruct<const T*> input_ptr_data,
|
||||
CudaDeviceArrayStruct<IntType> output_scan, IntType total_rows,
|
||||
GpuDeviceArrayStruct<const T*> input_ptr_data,
|
||||
GpuDeviceArrayStruct<IntType> output_scan, IntType total_rows,
|
||||
IntType total_cols, T* output) {
|
||||
const T** input_ptrs = GetCudaDeviceArrayOnDevice(&input_ptr_data);
|
||||
IntType* col_scan = GetCudaDeviceArrayOnDevice(&output_scan);
|
||||
const T** input_ptrs = GetGpuDeviceArrayOnDevice(&input_ptr_data);
|
||||
IntType* col_scan = GetGpuDeviceArrayOnDevice(&output_scan);
|
||||
|
||||
// do upper_bound on col to find which pointer we should be using
|
||||
IntType gidx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
@ -136,8 +136,8 @@ void ConcatGPUSlice(
|
||||
|
||||
template <typename T, typename IntType>
|
||||
void ConcatGPUImpl(const Eigen::GpuDevice& gpu_device,
|
||||
const CudaDeviceArrayStruct<const T*>& input_ptrs,
|
||||
const CudaDeviceArrayStruct<IntType>& output_scan,
|
||||
const GpuDeviceArrayStruct<const T*>& input_ptrs,
|
||||
const GpuDeviceArrayStruct<IntType>& output_scan,
|
||||
bool fixed_size, int split_size,
|
||||
typename TTypes<T, 2>::Matrix* output) {
|
||||
auto config = GetCuda2DLaunchConfig(output->dimension(1),
|
||||
@ -185,18 +185,18 @@ void ConcatGPUImpl(const Eigen::GpuDevice& gpu_device,
|
||||
inputs_flat, \
|
||||
typename TTypes<T, 2>::Matrix* output);
|
||||
|
||||
#define REGISTER_GPU32(T) \
|
||||
template void ConcatGPUImpl<T, int32>( \
|
||||
const Eigen::GpuDevice& d, \
|
||||
const CudaDeviceArrayStruct<const T*>& input_ptrs, \
|
||||
const CudaDeviceArrayStruct<int32>& ptr_offsets, bool fixed_size, \
|
||||
#define REGISTER_GPU32(T) \
|
||||
template void ConcatGPUImpl<T, int32>( \
|
||||
const Eigen::GpuDevice& d, \
|
||||
const GpuDeviceArrayStruct<const T*>& input_ptrs, \
|
||||
const GpuDeviceArrayStruct<int32>& ptr_offsets, bool fixed_size, \
|
||||
int split_size, typename TTypes<T, 2>::Matrix* output);
|
||||
|
||||
#define REGISTER_GPU64(T) \
|
||||
template void ConcatGPUImpl<T, int64>( \
|
||||
const Eigen::GpuDevice& d, \
|
||||
const CudaDeviceArrayStruct<const T*>& input_ptrs, \
|
||||
const CudaDeviceArrayStruct<int64>& ptr_offsets, bool fixed_size, \
|
||||
#define REGISTER_GPU64(T) \
|
||||
template void ConcatGPUImpl<T, int64>( \
|
||||
const Eigen::GpuDevice& d, \
|
||||
const GpuDeviceArrayStruct<const T*>& input_ptrs, \
|
||||
const GpuDeviceArrayStruct<int64>& ptr_offsets, bool fixed_size, \
|
||||
int split_size, typename TTypes<T, 2>::Matrix* output);
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPUCONCAT32);
|
||||
|
@ -138,15 +138,15 @@ class DynamicStitchOpImplBase : public OpKernel {
|
||||
template <typename T>
|
||||
void DynamicStitchGPUImpl(const Eigen::GpuDevice& gpu_device,
|
||||
const int32 slice_size, const int32 first_dim_size,
|
||||
const CudaDeviceArrayStruct<int>& input_indices,
|
||||
const CudaDeviceArrayStruct<const T*>& input_ptrs,
|
||||
const GpuDeviceArrayStruct<int>& input_indices,
|
||||
const GpuDeviceArrayStruct<const T*>& input_ptrs,
|
||||
T* output);
|
||||
#define REGISTER_GPU(T) \
|
||||
extern template void DynamicStitchGPUImpl( \
|
||||
const Eigen::GpuDevice& gpu_device, const int32 slice_size, \
|
||||
const int32 first_dim_size, \
|
||||
const CudaDeviceArrayStruct<int32>& input_indices, \
|
||||
const CudaDeviceArrayStruct<const T*>& input_ptrs, T* output);
|
||||
const GpuDeviceArrayStruct<int32>& input_indices, \
|
||||
const GpuDeviceArrayStruct<const T*>& input_ptrs, T* output);
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
|
||||
TF_CALL_complex64(REGISTER_GPU);
|
||||
TF_CALL_complex128(REGISTER_GPU);
|
||||
@ -185,8 +185,8 @@ class DynamicStitchOpGPU : public DynamicStitchOpImplBase<T> {
|
||||
// implicitly using atomics to make sure the last index is the final
|
||||
// write.
|
||||
const int slice_size = merged->flat_outer_dims<T>().dimension(1);
|
||||
CudaDeviceArrayOnHost<int32> indices_flat(c, first_dim_size);
|
||||
CudaDeviceArrayOnHost<const T*> data_flat(c, data_elements_size);
|
||||
GpuDeviceArrayOnHost<int32> indices_flat(c, first_dim_size);
|
||||
GpuDeviceArrayOnHost<const T*> data_flat(c, data_elements_size);
|
||||
OP_REQUIRES_OK(c, indices_flat.Init());
|
||||
OP_REQUIRES_OK(c, data_flat.Init());
|
||||
// initialize the indices_flat (-1 represents missing indices)
|
||||
|
@ -31,11 +31,11 @@ namespace {
|
||||
template <typename T>
|
||||
__global__ void DynamicStitchKernel(const int32 slice_size,
|
||||
const int32 output_size,
|
||||
CudaDeviceArrayStruct<int32> input_indices,
|
||||
CudaDeviceArrayStruct<const T*> input_ptrs,
|
||||
GpuDeviceArrayStruct<int32> input_indices,
|
||||
GpuDeviceArrayStruct<const T*> input_ptrs,
|
||||
T* output) {
|
||||
int32* data_indices = GetCudaDeviceArrayOnDevice(&input_indices);
|
||||
const T** data_ptrs = GetCudaDeviceArrayOnDevice(&input_ptrs);
|
||||
int32* data_indices = GetGpuDeviceArrayOnDevice(&input_indices);
|
||||
const T** data_ptrs = GetGpuDeviceArrayOnDevice(&input_ptrs);
|
||||
CUDA_1D_KERNEL_LOOP(output_index, output_size) {
|
||||
const int32 slice_id = output_index / slice_size;
|
||||
const int32 slice_offset = output_index % slice_size;
|
||||
@ -51,8 +51,8 @@ __global__ void DynamicStitchKernel(const int32 slice_size,
|
||||
template <typename T>
|
||||
void DynamicStitchGPUImpl(const Eigen::GpuDevice& gpu_device,
|
||||
const int32 slice_size, const int32 first_dim_size,
|
||||
const CudaDeviceArrayStruct<int>& input_indices,
|
||||
const CudaDeviceArrayStruct<const T*>& input_ptrs,
|
||||
const GpuDeviceArrayStruct<int>& input_indices,
|
||||
const GpuDeviceArrayStruct<const T*>& input_ptrs,
|
||||
T* output) {
|
||||
const int32 output_size = first_dim_size * slice_size;
|
||||
auto config = GetCudaLaunchConfig(output_size, gpu_device);
|
||||
@ -67,8 +67,8 @@ void DynamicStitchGPUImpl(const Eigen::GpuDevice& gpu_device,
|
||||
template void DynamicStitchGPUImpl( \
|
||||
const Eigen::GpuDevice& gpu_device, const int32 slice_size, \
|
||||
const int32 first_dim_size, \
|
||||
const CudaDeviceArrayStruct<int32>& input_indices, \
|
||||
const CudaDeviceArrayStruct<const T*>& input_ptrs, T* output);
|
||||
const GpuDeviceArrayStruct<int32>& input_indices, \
|
||||
const GpuDeviceArrayStruct<const T*>& input_ptrs, T* output);
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
|
||||
TF_CALL_complex64(REGISTER_GPU);
|
||||
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_H_
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
@ -24,11 +24,11 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
|
||||
// Create an array of value on the host, to be sent to kernel using
|
||||
// CudaDeviceArrayStruct.
|
||||
// GpuDeviceArrayStruct.
|
||||
//
|
||||
// Usage:
|
||||
// int size = ...;
|
||||
// CudaDeviceArrayOnHost ptrs(context, size);
|
||||
// GpuDeviceArrayOnHost ptrs(context, size);
|
||||
// OP_REQUIRES_OK(ptrs.Init());
|
||||
// for (int i = 0; i < size; ++i) {
|
||||
// ptrs.Set(i, ...);
|
||||
@ -38,9 +38,9 @@ namespace tensorflow {
|
||||
//
|
||||
// ValueType must be memcopyable.
|
||||
template <typename ValueType, int MaxInlineValues = 8>
|
||||
class CudaDeviceArrayOnHost {
|
||||
class GpuDeviceArrayOnHost {
|
||||
public:
|
||||
CudaDeviceArrayOnHost(OpKernelContext* context, int32 size)
|
||||
GpuDeviceArrayOnHost(OpKernelContext* context, int32 size)
|
||||
: context_(context),
|
||||
total_bytes_(static_cast<int64>(size) * sizeof(ValueType)) {
|
||||
data_.size = size;
|
||||
@ -93,7 +93,7 @@ class CudaDeviceArrayOnHost {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const CudaDeviceArrayStruct<ValueType, MaxInlineValues>& data() const {
|
||||
const GpuDeviceArrayStruct<ValueType, MaxInlineValues>& data() const {
|
||||
// Ensure Finalize is called.
|
||||
DCHECK(inlined() || out_of_line_values_on_gpu_.IsInitialized());
|
||||
return data_;
|
||||
@ -105,16 +105,16 @@ class CudaDeviceArrayOnHost {
|
||||
OpKernelContext* const context_;
|
||||
const int64 total_bytes_; // total size of all pointers.
|
||||
ValueType* values_ = nullptr;
|
||||
CudaDeviceArrayStruct<ValueType, MaxInlineValues> data_;
|
||||
GpuDeviceArrayStruct<ValueType, MaxInlineValues> data_;
|
||||
|
||||
Tensor out_of_line_values_on_host_;
|
||||
Tensor out_of_line_values_on_gpu_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(CudaDeviceArrayOnHost);
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(GpuDeviceArrayOnHost);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_H_
|
||||
|
@ -18,15 +18,15 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_GPU_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_GPU_H_
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
static constexpr int kMaxInlineCudaPointers = 8;
|
||||
// To decode on the device side, use GetCudaDeviceArrayOnDevice.
|
||||
// To encode on the host side, use CudaDeviceArrayOnHost.
|
||||
static constexpr int kMaxInlineGpuPointers = 8;
|
||||
// To decode on the device side, use GetGpuDeviceArrayOnDevice.
|
||||
// To encode on the host side, use GpuDeviceArrayOnHost.
|
||||
template <typename ValueType, int MaxInlineValues = 8>
|
||||
struct CudaDeviceArrayStruct {
|
||||
struct GpuDeviceArrayStruct {
|
||||
int32 size;
|
||||
// used if size <= MaxInlineValues;
|
||||
ValueType inline_values[MaxInlineValues];
|
||||
@ -34,8 +34,8 @@ struct CudaDeviceArrayStruct {
|
||||
};
|
||||
|
||||
template <typename ValueType, int MaxInlineValues = 8>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ValueType* GetCudaDeviceArrayOnDevice(
|
||||
CudaDeviceArrayStruct<ValueType, MaxInlineValues>* data) {
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ValueType* GetGpuDeviceArrayOnDevice(
|
||||
GpuDeviceArrayStruct<ValueType, MaxInlineValues>* data) {
|
||||
if (data->size <= MaxInlineValues) {
|
||||
return data->inline_values;
|
||||
} else {
|
||||
@ -45,6 +45,6 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ValueType* GetCudaDeviceArrayOnDevice(
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_GPU_H_
|
||||
|
@ -77,9 +77,9 @@ namespace {
|
||||
template <typename T>
|
||||
__global__ void SplitOpKernel(const T* input, int32 prefix_dim_size,
|
||||
int32 split_dim_size, int32 suffix_dim_size,
|
||||
CudaDeviceArrayStruct<T*> output_ptr_data) {
|
||||
GpuDeviceArrayStruct<T*> output_ptr_data) {
|
||||
const int32 num_split = output_ptr_data.size;
|
||||
T** output_ptrs = GetCudaDeviceArrayOnDevice(&output_ptr_data);
|
||||
T** output_ptrs = GetGpuDeviceArrayOnDevice(&output_ptr_data);
|
||||
|
||||
eigen_assert(blockDim.y == 1);
|
||||
eigen_assert(blockDim.z == 1);
|
||||
@ -114,11 +114,11 @@ __global__ void SplitOpKernel(const T* input, int32 prefix_dim_size,
|
||||
// is reversed
|
||||
template <typename T, typename IntType, bool useSmem>
|
||||
__global__ void split_v_kernel(const T* input_ptr,
|
||||
CudaDeviceArrayStruct<IntType> output_scan,
|
||||
GpuDeviceArrayStruct<IntType> output_scan,
|
||||
IntType total_rows, IntType total_cols,
|
||||
CudaDeviceArrayStruct<T*> output_ptr_data) {
|
||||
T** output_ptrs = GetCudaDeviceArrayOnDevice(&output_ptr_data);
|
||||
IntType* col_scan = GetCudaDeviceArrayOnDevice(&output_scan);
|
||||
GpuDeviceArrayStruct<T*> output_ptr_data) {
|
||||
T** output_ptrs = GetGpuDeviceArrayOnDevice(&output_ptr_data);
|
||||
IntType* col_scan = GetGpuDeviceArrayOnDevice(&output_scan);
|
||||
|
||||
// do upper_bound on col to find which pointer we should be using
|
||||
IntType gidx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
@ -170,11 +170,11 @@ __global__ void split_v_kernel(const T* input_ptr,
|
||||
// different from the original split implementation due to 2D vs 3D
|
||||
// dimensions. This version is likely faster due to less integer math.
|
||||
template <typename T>
|
||||
__global__ void SplitVOpKernel_fixed(
|
||||
const T* input, int32 prefix_dim_size, int32 suffix_dim_size,
|
||||
CudaDeviceArrayStruct<T*> output_ptr_data) {
|
||||
__global__ void SplitVOpKernel_fixed(const T* input, int32 prefix_dim_size,
|
||||
int32 suffix_dim_size,
|
||||
GpuDeviceArrayStruct<T*> output_ptr_data) {
|
||||
const int32 num_split = output_ptr_data.size;
|
||||
T** output_ptrs = GetCudaDeviceArrayOnDevice(&output_ptr_data);
|
||||
T** output_ptrs = GetGpuDeviceArrayOnDevice(&output_ptr_data);
|
||||
|
||||
eigen_assert(blockDim.y == 1);
|
||||
eigen_assert(blockDim.z == 1);
|
||||
@ -195,10 +195,10 @@ __global__ void SplitVOpKernel_fixed(
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void SplitOpGPULaunch<T>::Run(
|
||||
const Eigen::GpuDevice& d, const T* input, int32 prefix_dim_size,
|
||||
int32 split_dim_size, int32 suffix_dim_size,
|
||||
const CudaDeviceArrayStruct<T*>& output_ptr_data) {
|
||||
void SplitOpGPULaunch<T>::Run(const Eigen::GpuDevice& d, const T* input,
|
||||
int32 prefix_dim_size, int32 split_dim_size,
|
||||
int32 suffix_dim_size,
|
||||
const GpuDeviceArrayStruct<T*>& output_ptr_data) {
|
||||
CudaLaunchConfig config = GetCudaLaunchConfig(
|
||||
prefix_dim_size * split_dim_size * suffix_dim_size, d);
|
||||
|
||||
@ -212,8 +212,8 @@ template <typename T, typename IntType>
|
||||
void SplitVOpGPULaunch<T, IntType>::Run(
|
||||
const Eigen::GpuDevice& gpu_device, bool fixed_size, const T* input_ptr,
|
||||
int total_rows, int total_cols,
|
||||
const CudaDeviceArrayStruct<IntType>& output_scan,
|
||||
const CudaDeviceArrayStruct<T*>& output_ptr_data) {
|
||||
const GpuDeviceArrayStruct<IntType>& output_scan,
|
||||
const GpuDeviceArrayStruct<T*>& output_ptr_data) {
|
||||
if (fixed_size) {
|
||||
CudaLaunchConfig config =
|
||||
GetCudaLaunchConfig(total_rows * total_cols, gpu_device);
|
||||
|
@ -33,15 +33,15 @@ template <typename T>
|
||||
struct SplitOpGPULaunch {
|
||||
void Run(const Eigen::GpuDevice& d, const T* input, int32 prefix_dim_size,
|
||||
int32 split_dim_size, int32 suffix_dim_size,
|
||||
const CudaDeviceArrayStruct<T*>& output_ptr_data);
|
||||
const GpuDeviceArrayStruct<T*>& output_ptr_data);
|
||||
};
|
||||
|
||||
template <typename T, typename IntType>
|
||||
struct SplitVOpGPULaunch {
|
||||
void Run(const Eigen::GpuDevice& d, bool fixed, const T* input,
|
||||
int total_cols, int total_rows,
|
||||
const CudaDeviceArrayStruct<IntType>& output_scan,
|
||||
const CudaDeviceArrayStruct<T*>& output_ptr_data);
|
||||
const GpuDeviceArrayStruct<IntType>& output_scan,
|
||||
const GpuDeviceArrayStruct<T*>& output_ptr_data);
|
||||
};
|
||||
|
||||
// Explicit instantiations in split_lib_gpu.cu.cc.
|
||||
|
@ -302,7 +302,7 @@ class SplitOpGPU : public SplitOpBase<GPUDevice, T> {
|
||||
TensorShape output_shape(input_shape);
|
||||
output_shape.set_dim(split_dim, split_dim_output_size);
|
||||
|
||||
CudaDeviceArrayOnHost<T*> ptrs(context, num_split);
|
||||
GpuDeviceArrayOnHost<T*> ptrs(context, num_split);
|
||||
OP_REQUIRES_OK(context, ptrs.Init());
|
||||
|
||||
for (int i = 0; i < num_split; ++i) {
|
||||
|
@ -366,10 +366,10 @@ class SplitVOpGPU : public SplitVOpBase<GPUDevice, T, Tlen> {
|
||||
// reshape to 2D
|
||||
|
||||
if (num_split > 16) {
|
||||
CudaDeviceArrayOnHost<T*> ptrs(context, num_split);
|
||||
GpuDeviceArrayOnHost<T*> ptrs(context, num_split);
|
||||
OP_REQUIRES_OK(context, ptrs.Init());
|
||||
|
||||
CudaDeviceArrayOnHost<Tlen> offsets(context, num_split + 1);
|
||||
GpuDeviceArrayOnHost<Tlen> offsets(context, num_split + 1);
|
||||
OP_REQUIRES_OK(context, offsets.Init());
|
||||
|
||||
Tlen offset = 0;
|
||||
|
Loading…
Reference in New Issue
Block a user