Adding ROCm support for the CSR Sparse Matrix Ops

This commit is contained in:
Deven Desai 2019-11-20 15:45:31 +00:00
parent 5ad7620d6f
commit 2e1cdaa4b6
19 changed files with 221 additions and 136 deletions

View File

@ -2,6 +2,7 @@
load(
"//tensorflow:tensorflow.bzl",
"if_cuda_or_rocm",
"tf_cc_test",
"tf_kernel_library",
)
@ -77,7 +78,7 @@ tf_kernel_library(
"//tensorflow/core/kernels:scatter_nd_op",
"//tensorflow/core/kernels:slice_op",
"//tensorflow/core/kernels:transpose_functor",
] + if_cuda([
] + if_cuda_or_rocm([
"//tensorflow/core/kernels:cuda_solvers",
"//tensorflow/core/kernels:cuda_sparse",
]),

View File

@ -15,7 +15,7 @@ limitations under the License.
#define EIGEN_USE_THREADS
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
#endif
@ -31,7 +31,7 @@ limitations under the License.
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
#include "tensorflow/core/kernels/fill_functor.h"
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/kernels/cuda_sparse.h"
#endif
@ -233,8 +233,10 @@ class CSRAddOp : public OpKernel {
REGISTER_GPU(float)
REGISTER_GPU(double)
#if GOOGLE_CUDA
REGISTER_GPU(complex64)
REGISTER_GPU(complex128)
#endif
#undef REGISTER_GPU
@ -246,7 +248,7 @@ REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(
#undef REGISTER
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
namespace functor {
template <typename T>
struct CSRSparseMatrixAdd<GPUDevice, T>
@ -337,6 +339,6 @@ struct CSRSparseMatrixAdd<GPUDevice, T>
} // namespace functor
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} // namespace tensorflow

View File

@ -15,7 +15,7 @@ limitations under the License.
#define EIGEN_USE_THREADS
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
#endif
@ -31,7 +31,7 @@ limitations under the License.
#include "tensorflow/core/kernels/sparse/kernels.h"
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/kernels/cuda_sparse.h"
#endif
@ -92,12 +92,12 @@ REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(
CONJ_VARIANT_UNARY_OP, DEVICE_CPU, CSRSparseMatrix,
(CSRSparseMatrixUnaryHelper<CPUDevice, CSRSparseMatrixConjFunctor>));
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(
CONJ_VARIANT_UNARY_OP, DEVICE_GPU, CSRSparseMatrix,
(CSRSparseMatrixUnaryHelper<GPUDevice, CSRSparseMatrixConjFunctor>));
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} // namespace tensorflow

View File

@ -15,7 +15,7 @@ limitations under the License.
#define EIGEN_USE_THREADS
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
#endif
@ -33,7 +33,7 @@ limitations under the License.
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
#include "tensorflow/core/util/work_sharder.h"
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/kernels/cuda_sparse.h"
#endif
@ -220,19 +220,21 @@ REGISTER_CPU(double)
REGISTER_CPU(complex64)
REGISTER_CPU(complex128)
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
REGISTER_GPU(float)
REGISTER_GPU(double)
#if GOOGLE_CUDA
REGISTER_GPU(complex64)
REGISTER_GPU(complex128)
#endif
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#undef REGISTER_CPU
#undef REGISTER_GPU
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
namespace functor {
template <>
@ -256,6 +258,6 @@ extern template struct CSRSparseMatrixToCOOSparseMatrix<GPUDevice>;
} // namespace functor
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} // namespace tensorflow

View File

@ -15,7 +15,7 @@ limitations under the License.
#define EIGEN_USE_THREADS
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
#endif
@ -31,7 +31,7 @@ limitations under the License.
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
#include "tensorflow/core/util/work_sharder.h"
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/kernels/cuda_sparse.h"
#endif
@ -205,18 +205,20 @@ class CSRSparseMatrixToSparseTensorGPUOp : public OpKernel {
.HostMemory("dense_shape"), \
CSRSparseMatrixToSparseTensorGPUOp<GPUDevice, T>);
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
REGISTER_GPU(float)
REGISTER_GPU(double)
#if GOOGLE_CUDA
REGISTER_GPU(complex64)
REGISTER_GPU(complex128)
#endif
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#undef REGISTER_GPU
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
namespace functor {
template <>
@ -240,7 +242,7 @@ extern template struct CSRSparseMatrixToCOOSparseMatrix<GPUDevice>;
} // namespace functor
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define REGISTER_CPU(T) \
REGISTER_KERNEL_BUILDER(Name("CSRSparseMatrixToSparseTensor") \

View File

@ -15,7 +15,7 @@ limitations under the License.
#define EIGEN_USE_THREADS
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
#endif
@ -32,13 +32,18 @@ limitations under the License.
#include "tensorflow/core/kernels/sparse/kernels.h"
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
#include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/kernels/cuda_sparse.h"
#include "tensorflow/stream_executor/cuda/cuda_activation.h"
#endif
#if GOOGLE_CUDA
#include "tensorflow/stream_executor/cuda/cuda_activation.h"
using ::perftools::gputools::cuda::ScopedActivateExecutorContext;
#elif TENSORFLOW_USE_ROCM
#include "tensorflow/stream_executor/rocm/rocm_activation.h"
using ::perftools::gputools::rocm::ScopedActivateExecutorContext;
#endif
namespace tensorflow {
@ -138,7 +143,7 @@ REGISTER_CPU(complex128)
#undef REGISTER_CPU
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
template <typename Device, typename T>
class DenseToCSRSparseMatrixGPUOp : public AsyncOpKernel {
@ -356,8 +361,10 @@ class DenseToCSRSparseMatrixGPUOp : public AsyncOpKernel {
REGISTER_GPU(GPU, float)
REGISTER_GPU(GPU, double)
#if GOOGLE_CUDA
REGISTER_GPU(GPU, complex64)
REGISTER_GPU(GPU, complex128)
#endif
namespace functor {
@ -391,7 +398,7 @@ extern template struct COOSparseMatrixToCSRSparseMatrix<GPUDevice>;
} // namespace functor
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#undef REGISTER_GPU

View File

@ -13,15 +13,19 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#if GOOGLE_CUDA
#include "third_party/cub/device/device_histogram.cuh"
#include "third_party/cub/iterator/counting_input_iterator.cuh"
#include "third_party/cub/iterator/transform_input_iterator.cuh"
#include "third_party/gpus/cuda/include/cusparse.h"
#elif TENSORFLOW_USE_ROCM
#include "rocm/include/hipcub/hipcub.hpp"
#endif
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/cuda_sparse.h"
@ -32,6 +36,12 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"
#if GOOGLE_CUDA
namespace gpuprim = ::cub;
#elif TENSORFLOW_USE_ROCM
namespace gpuprim = ::hipcub;
#endif
namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice;
@ -65,9 +75,9 @@ Status CalculateNNZPerBatchMatrixFromIndices<GPUDevice>::operator()(
DCHECK_EQ(indices.dimension(1), 3); // batch, row, col
const int rank = indices.dimension(1);
cub::CountingInputIterator<int> row_counter(0);
cub::TransformInputIterator<int, StridedDataReader,
cub::CountingInputIterator<int>>
gpuprim::CountingInputIterator<int> row_counter(0);
gpuprim::TransformInputIterator<int, StridedDataReader,
gpuprim::CountingInputIterator<int>>
indices_first_column(row_counter,
StridedDataReader(indices.data(), rank));
@ -76,7 +86,7 @@ Status CalculateNNZPerBatchMatrixFromIndices<GPUDevice>::operator()(
DCHECK_NE(indices.data(), nullptr);
DCHECK_NE(nnz_per_batch.data(), nullptr);
auto first_success = cub::DeviceHistogram::HistogramEven(
auto first_success = gpuprim::DeviceHistogram::HistogramEven(
/*d_temp_storage*/ nullptr,
/*temp_storage_bytes&*/ temp_storage_bytes,
/*d_samples*/ indices_first_column,
@ -87,12 +97,12 @@ Status CalculateNNZPerBatchMatrixFromIndices<GPUDevice>::operator()(
/*num_samples*/ total_nnz,
/*stream*/ cu_stream);
if (first_success != cudaSuccess) {
if (first_success != gpuSuccess) {
return errors::Internal(
"SparseTensorToCSRSparseMatrix: Could not launch "
"cub::DeviceHistogram::HistogramEven "
"gpuprim::DeviceHistogram::HistogramEven "
"to calculate temp_storage_bytes, status: ",
cudaGetErrorString(first_success));
GpuGetErrorString(first_success));
}
Tensor temp_storage;
@ -100,7 +110,7 @@ Status CalculateNNZPerBatchMatrixFromIndices<GPUDevice>::operator()(
DT_INT8, TensorShape({static_cast<int64>(temp_storage_bytes)}),
&temp_storage));
DCHECK_NE(temp_storage.flat<int8>().data(), nullptr);
auto second_success = cub::DeviceHistogram::HistogramEven(
auto second_success = gpuprim::DeviceHistogram::HistogramEven(
/*d_temp_storage*/ temp_storage.flat<int8>().data(),
/*temp_storage_bytes&*/ temp_storage_bytes,
/*d_samples*/ indices_first_column,
@ -111,12 +121,12 @@ Status CalculateNNZPerBatchMatrixFromIndices<GPUDevice>::operator()(
/*num_samples*/ total_nnz,
/*stream*/ cu_stream);
if (second_success != cudaSuccess) {
if (second_success != gpuSuccess) {
return errors::Internal(
"SparseTensorToCSRSparseMatrix: Could not launch "
"cub::DeviceHistogram::HistogramEven "
"gpuprim::DeviceHistogram::HistogramEven "
"to count nnz entries per batch. temp_storage_bytes: ",
temp_storage_bytes, ", status: ", cudaGetErrorString(second_success));
temp_storage_bytes, ", status: ", GpuGetErrorString(second_success));
}
return Status::OK();
@ -128,11 +138,11 @@ template <>
Status CSRSparseMatrixToCOOSparseMatrix<GPUDevice>::operator()(
OpKernelContext* c, TTypes<const int>::UnalignedVec csr_row_ptr,
TTypes<int>::UnalignedVec coo_row_ind) {
GpuSparse cuda_sparse(c);
GpuSparse gpu_sparse(c);
const int nnz = coo_row_ind.size();
TF_RETURN_IF_ERROR(cuda_sparse.Initialize());
TF_RETURN_IF_ERROR(gpu_sparse.Initialize());
const int m = csr_row_ptr.size() - 1; // rows
return cuda_sparse.Csr2coo(csr_row_ptr.data(), nnz, m, coo_row_ind.data());
return gpu_sparse.Csr2coo(csr_row_ptr.data(), nnz, m, coo_row_ind.data());
}
template <int stride>
@ -140,7 +150,7 @@ __global__ void SparseTensorToCOOMatrixKernel(const int64* indices,
int* coo_rows_out,
int* coo_cols_out, int size) {
const int offset = (stride == 3) ? 1 : 0;
CUDA_1D_KERNEL_LOOP(i, size) {
GPU_1D_KERNEL_LOOP(i, size) {
coo_rows_out[i] = static_cast<int>(ldg(indices + i * stride + offset));
coo_cols_out[i] = static_cast<int>(ldg(indices + i * stride + offset + 1));
}
@ -157,20 +167,22 @@ void SparseTensorToCOOSparseMatrix<GPUDevice>::operator()(
const int size = coo_row_ind.dimension(0);
GpuLaunchConfig config = GetGpuLaunchConfig(size, d);
if (stride == 2) {
SparseTensorToCOOMatrixKernel<2>
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
indices.data(), coo_row_ind.data(), coo_col_ind.data(), size);
TF_CHECK_OK(GpuLaunchKernel(SparseTensorToCOOMatrixKernel<2>,
config.block_count, config.thread_per_block, 0,
d.stream(), indices.data(), coo_row_ind.data(),
coo_col_ind.data(), size));
} else {
SparseTensorToCOOMatrixKernel<3>
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
indices.data(), coo_row_ind.data(), coo_col_ind.data(), size);
TF_CHECK_OK(GpuLaunchKernel(SparseTensorToCOOMatrixKernel<3>,
config.block_count, config.thread_per_block, 0,
d.stream(), indices.data(), coo_row_ind.data(),
coo_col_ind.data(), size));
}
}
__global__ void COOMatrixToSparseTensorKernel2D(const int* coo_rows,
const int* coo_cols,
int64* indices_out, int size) {
CUDA_1D_KERNEL_LOOP(i, size) {
GPU_1D_KERNEL_LOOP(i, size) {
indices_out[i * 2] = static_cast<int64>(ldg(coo_rows + i));
indices_out[i * 2 + 1] = static_cast<int64>(ldg(coo_cols + i));
}
@ -203,7 +215,7 @@ __global__ void COOMatrixToSparseTensorKernel3D(
}
__syncthreads();
CUDA_1D_KERNEL_LOOP(i, size) {
GPU_1D_KERNEL_LOOP(i, size) {
// TODO(ebrevdo): Consider special casing batch_size <= 3,
// alternatively doing linear instead of binary search. Requires
// some benchmarks.
@ -231,9 +243,10 @@ Status COOSparseMatrixToSparseTensor<GPUDevice>::operator()(
DCHECK_EQ(size, indices.dimension(0));
if (ndims == 2) {
GpuLaunchConfig config = GetGpuLaunchConfig(size, d);
COOMatrixToSparseTensorKernel2D<<<config.block_count,
config.thread_per_block, 0, d.stream()>>>(
coo_row_ind.data(), coo_col_ind.data(), indices.data(), size);
TF_CHECK_OK(GpuLaunchKernel(COOMatrixToSparseTensorKernel2D,
config.block_count, config.thread_per_block, 0,
d.stream(), coo_row_ind.data(),
coo_col_ind.data(), indices.data(), size));
return Status::OK();
} else {
const int batch_size = host_dense_shape(0);
@ -246,11 +259,11 @@ Status COOSparseMatrixToSparseTensor<GPUDevice>::operator()(
GpuLaunchConfig config = GetGpuLaunchConfig(size, d);
// shared memory stores the batch pointers.
const size_t shared_memory_size = sizeof(int) * (batch_size + 1);
COOMatrixToSparseTensorKernel3D<<<config.block_count,
config.thread_per_block,
shared_memory_size, d.stream()>>>(
coo_row_ind.data(), coo_col_ind.data(), indices.data(),
batch_ptr_copy.data(), batch_size, size);
TF_CHECK_OK(
GpuLaunchKernel(COOMatrixToSparseTensorKernel3D, config.block_count,
config.thread_per_block, shared_memory_size, d.stream(),
coo_row_ind.data(), coo_col_ind.data(), indices.data(),
batch_ptr_copy.data(), batch_size, size));
return Status::OK();
}
}
@ -274,7 +287,7 @@ __global__ void CSRSparseMatrixBatchMulVecKernel3D(
}
__syncthreads();
CUDA_1D_KERNEL_LOOP(i, total_nnz) {
GPU_1D_KERNEL_LOOP(i, total_nnz) {
const int b = BinarySearchRange(local_batch_ptr, batch_size, i);
c_values[i] = ldg(a_values + i) * local_batch_values[b];
}
@ -316,10 +329,10 @@ Status CSRSparseMatrixBatchMulVecImpl(OpKernelContext* ctx,
const size_t shared_memory_size =
(sizeof(int) * (batch_size + 1) // local batch_pointers.
+ sizeof(T) * batch_size); // local copy of b.
CSRSparseMatrixBatchMulVecKernel3D<T>
<<<config.block_count, config.thread_per_block, shared_memory_size,
d.stream()>>>(a_values.data(), b.data(), c_values.data(),
batch_ptr_copy.data(), batch_size, total_nnz);
TF_CHECK_OK(GpuLaunchKernel(
CSRSparseMatrixBatchMulVecKernel3D<T>, config.block_count,
config.thread_per_block, shared_memory_size, d.stream(), a_values.data(),
b.data(), c_values.data(), batch_ptr_copy.data(), batch_size, total_nnz));
return Status::OK();
}
@ -374,7 +387,7 @@ __global__ void CSRSparseMatrixSoftmaxKernel2D(const int rows,
// algorithm to distribute the work in case the row sizes are
// uneven:
// http://images.nvidia.com/events/sc15/pdfs/sc15-Merge-Based-Parallel-Sparse-Matrix-Vector-Multiplication-merrill.pdf
CUDA_1D_KERNEL_LOOP(row, rows) {
GPU_1D_KERNEL_LOOP(row, rows) {
CalculateRowSoftmax(ldg(row_ptr + row), ldg(row_ptr + row + 1), logits,
softmax);
}
@ -382,7 +395,7 @@ __global__ void CSRSparseMatrixSoftmaxKernel2D(const int rows,
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void CopyFromGpuDeviceArrayToLocal(
GpuDeviceArrayStruct<int> cuda_ptr_s, int* local_ptr, int length) {
#ifdef __CUDA_ARCH__
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
const int* cuda_ptr = GetGpuDeviceArrayOnDevice(&cuda_ptr_s);
for (int i = threadIdx.x; i < length; i += blockDim.x) {
local_ptr[i] = cuda_ptr[i];
@ -404,7 +417,7 @@ __global__ void CSRSparseMatrixSoftmaxKernel3D(
CopyFromGpuDeviceArrayToLocal(std::move(batch_ptr_s), local_batch_ptr,
batch_size + 1);
CUDA_1D_KERNEL_LOOP(i, size) {
GPU_1D_KERNEL_LOOP(i, size) {
const int batch = i / rows;
const int row = i % rows;
const int batch_offset = local_batch_ptr[batch];
@ -431,10 +444,10 @@ Status CSRSparseMatrixSoftmaxGPUImpl(OpKernelContext* ctx,
const int rows = host_dense_shape(0);
DCHECK_EQ(rows, row_ptr.size() - 1);
GpuLaunchConfig config = GetGpuLaunchConfig(rows /*size*/, d);
CSRSparseMatrixSoftmaxKernel2D<T>
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
rows /*size*/, row_ptr.data(), logits_values.data(),
softmax_values.data());
TF_CHECK_OK(GpuLaunchKernel(CSRSparseMatrixSoftmaxKernel2D<T>,
config.block_count, config.thread_per_block, 0,
d.stream(), rows /*size*/, row_ptr.data(),
logits_values.data(), softmax_values.data()));
} else {
const int batch_size = host_dense_shape(0);
const int rows = host_dense_shape(1);
@ -452,10 +465,11 @@ Status CSRSparseMatrixSoftmaxGPUImpl(OpKernelContext* ctx,
GpuLaunchConfig config = GetGpuLaunchConfig(size, d);
// shared memory stores the batch pointers.
const size_t shared_memory_size = sizeof(int) * (batch_size + 1);
CSRSparseMatrixSoftmaxKernel3D<T>
<<<config.block_count, config.thread_per_block, shared_memory_size,
d.stream()>>>(size, rows, batch_ptr_copy.data(), row_ptr.data(),
logits_values.data(), softmax_values.data());
TF_CHECK_OK(GpuLaunchKernel(CSRSparseMatrixSoftmaxKernel3D<T>,
config.block_count, config.thread_per_block,
shared_memory_size, d.stream(), size, rows,
batch_ptr_copy.data(), row_ptr.data(),
logits_values.data(), softmax_values.data()));
}
return Status::OK();
@ -549,7 +563,7 @@ __global__ void CSRSparseMatrixSoftmaxGradKernel2D(
// algorithm to distribute the work in case the row sizes are
// uneven:
// http://images.nvidia.com/events/sc15/pdfs/sc15-Merge-Based-Parallel-Sparse-Matrix-Vector-Multiplication-merrill.pdf
CUDA_1D_KERNEL_LOOP(row, rows) {
GPU_1D_KERNEL_LOOP(row, rows) {
CalculateRowSoftmaxGrad(
ldg(softmax_row_ptr + row) /*softmax_begin*/,
ldg(softmax_row_ptr + row + 1) /*softmax_end*/, softmax_col_ind,
@ -579,7 +593,7 @@ __global__ void CSRSparseMatrixSoftmaxGradKernel3D(
#define SOFTMAX_BATCH_PTR(i) local_batch_ptr[i];
#define GRAD_SOFTMAX_BATCH_PTR(i) local_batch_ptr[batch_size + 1 + i];
CUDA_1D_KERNEL_LOOP(i, size) {
GPU_1D_KERNEL_LOOP(i, size) {
const int batch = i / rows;
const int row = i % rows;
const int softmax_batch_offset = SOFTMAX_BATCH_PTR(batch);
@ -625,12 +639,12 @@ Status CSRSparseMatrixSoftmaxGradGPUImpl(
DCHECK_EQ(rows + 1, softmax_row_ptr.size());
DCHECK_EQ(rows + 1, grad_softmax_row_ptr.size());
GpuLaunchConfig config = GetGpuLaunchConfig(rows /*size*/, d);
CSRSparseMatrixSoftmaxGradKernel2D<T>
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
rows /*size*/, softmax_row_ptr.data(), softmax_col_ind.data(),
softmax_values.data(), grad_softmax_row_ptr.data(),
grad_softmax_col_ind.data(), grad_softmax_values.data(),
gradient_values.data());
TF_CHECK_OK(GpuLaunchKernel(
CSRSparseMatrixSoftmaxGradKernel2D<T>, config.block_count,
config.thread_per_block, 0, d.stream(), rows /*size*/,
softmax_row_ptr.data(), softmax_col_ind.data(), softmax_values.data(),
grad_softmax_row_ptr.data(), grad_softmax_col_ind.data(),
grad_softmax_values.data(), gradient_values.data()));
} else {
const int batch_size = host_dense_shape(0);
const int rows = host_dense_shape(1);
@ -656,13 +670,13 @@ Status CSRSparseMatrixSoftmaxGradGPUImpl(
// shared memory stores two copies of batch pointers: one for the
// softmax CSR matrix, one for the grad_softmax CSR matrix.
const size_t shared_memory_size = 2 * sizeof(int) * (batch_size + 1);
CSRSparseMatrixSoftmaxGradKernel3D<T>
<<<config.block_count, config.thread_per_block, shared_memory_size,
d.stream()>>>(size, rows, softmax_and_grad_batch_ptr_copy.data(),
softmax_row_ptr.data(), softmax_col_ind.data(),
softmax_values.data(), grad_softmax_row_ptr.data(),
grad_softmax_col_ind.data(),
grad_softmax_values.data(), gradient_values.data());
TF_CHECK_OK(GpuLaunchKernel(
CSRSparseMatrixSoftmaxGradKernel3D<T>, config.block_count,
config.thread_per_block, shared_memory_size, d.stream(), size, rows,
softmax_and_grad_batch_ptr_copy.data(), softmax_row_ptr.data(),
softmax_col_ind.data(), softmax_values.data(),
grad_softmax_row_ptr.data(), grad_softmax_col_ind.data(),
grad_softmax_values.data(), gradient_values.data()));
}
return Status::OK();
@ -687,4 +701,4 @@ DEFINE_SOFTMAX_GRAD_GPU(double);
} // namespace tensorflow
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -15,7 +15,7 @@ limitations under the License.
#define EIGEN_USE_THREADS
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
#endif
@ -36,7 +36,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/platform/threadpool.h"
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/kernels/cuda_sparse.h"
#endif
@ -694,7 +694,7 @@ REGISTER_CPU(complex128)
#undef REGISTER_CPU
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define REGISTER_GPU(T) \
REGISTER_KERNEL_BUILDER( \
@ -703,14 +703,16 @@ REGISTER_CPU(complex128)
REGISTER_GPU(float)
REGISTER_GPU(double)
#if GOOGLE_CUDA
REGISTER_GPU(complex64)
REGISTER_GPU(complex128)
#endif
#undef REGISTER_GPU
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
namespace functor {
@ -741,11 +743,16 @@ class CSRSparseMatrixMatMul<GPUDevice, T> {
// transA must be non-transpose if transB is transpose (cusparse
// limitation).
#if GOOGLE_CUDA
const gpusparseOperation_t transA = CUSPARSE_OPERATION_NON_TRANSPOSE;
#elif TENSORFLOW_USE_ROCM
const gpusparseOperation_t transA = HIPSPARSE_OPERATION_NON_TRANSPOSE;
#endif
// transB: b is row-major, and cusparse requires col-major b (or
// equivalently transB == transpose). this version is actually more
// efficient.
#if GOOGLE_CUDA
const gpusparseOperation_t transB = CUSPARSE_OPERATION_TRANSPOSE;
gpusparseMatDescr_t descrA;
@ -754,6 +761,16 @@ class CSRSparseMatrixMatMul<GPUDevice, T> {
cusparseSetMatType(descrA, CUSPARSE_MATRIX_TYPE_GENERAL));
TF_RETURN_IF_GPUSPARSE_ERROR(
cusparseSetMatIndexBase(descrA, CUSPARSE_INDEX_BASE_ZERO));
#elif TENSORFLOW_USE_ROCM
const gpusparseOperation_t transB = HIPSPARSE_OPERATION_TRANSPOSE;
gpusparseMatDescr_t descrA;
TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseCreateMatDescr(&descrA));
TF_RETURN_IF_GPUSPARSE_ERROR(
hipsparseSetMatType(descrA, HIPSPARSE_MATRIX_TYPE_GENERAL));
TF_RETURN_IF_GPUSPARSE_ERROR(
hipsparseSetMatIndexBase(descrA, HIPSPARSE_INDEX_BASE_ZERO));
#endif
// A is (m, k), Bt is (ldb, k) and Ct is (ldc, n)
const int k = b.dimension(0);
@ -816,11 +833,19 @@ class CSRSparseMatrixMatVec<GPUDevice, T> {
const T beta = 0;
gpusparseMatDescr_t descrA;
#if GOOGLE_CUDA
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descrA));
TF_RETURN_IF_GPUSPARSE_ERROR(
cusparseSetMatType(descrA, CUSPARSE_MATRIX_TYPE_GENERAL));
TF_RETURN_IF_GPUSPARSE_ERROR(
cusparseSetMatIndexBase(descrA, CUSPARSE_INDEX_BASE_ZERO));
#elif TENSORFLOW_USE_ROCM
TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseCreateMatDescr(&descrA));
TF_RETURN_IF_GPUSPARSE_ERROR(
hipsparseSetMatType(descrA, HIPSPARSE_MATRIX_TYPE_GENERAL));
TF_RETURN_IF_GPUSPARSE_ERROR(
hipsparseSetMatIndexBase(descrA, HIPSPARSE_INDEX_BASE_ZERO));
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
const int m = a.dense_shape_host(0);
const int n = a.dense_shape_host(1);
@ -841,6 +866,6 @@ class CSRSparseMatrixMatVec<GPUDevice, T> {
} // namespace functor
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} // namespace tensorflow

View File

@ -15,7 +15,7 @@ limitations under the License.
#define EIGEN_USE_THREADS
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
#endif
@ -28,7 +28,7 @@ limitations under the License.
#include "tensorflow/core/kernels/sparse/kernels.h"
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/cuda_sparse.h"
#endif
@ -101,22 +101,24 @@ class CSRMulOp : public OpKernel {
Name("SparseMatrixMul").Device(DEVICE_##DEV).TypeConstraint<T>("T"), \
CSRMulOp<DEV##Device, T>);
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define REGISTER_GPU(T) REGISTER(GPU, T)
REGISTER_GPU(float)
REGISTER_GPU(double)
#if GOOGLE_CUDA
REGISTER_GPU(complex64)
REGISTER_GPU(complex128)
#endif
#undef REGISTER_GPU
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#undef REGISTER
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
namespace functor {
@ -159,13 +161,15 @@ class CSRSparseMatrixMulScalar<GPUDevice, T> {
DECLARE_GPU_SPEC(float);
DECLARE_GPU_SPEC(double);
#if GOOGLE_CUDA
DECLARE_GPU_SPEC(std::complex<float>);
DECLARE_GPU_SPEC(std::complex<double>);
#endif
#undef DECLARE_GPU_SPEC
} // namespace functor
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} // namespace tensorflow

View File

@ -15,7 +15,7 @@ limitations under the License.
#define EIGEN_USE_THREADS
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
#endif
@ -28,7 +28,7 @@ limitations under the License.
#include "tensorflow/core/kernels/sparse/kernels.h"
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/kernels/cuda_sparse.h"
#endif
@ -67,11 +67,11 @@ class CSRNNZOp : public OpKernel {
REGISTER(CPU)
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
REGISTER(GPU)
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#undef REGISTER

View File

@ -19,7 +19,7 @@ limitations under the License.
#define EIGEN_USE_THREADS
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/cuda_sparse.h"
#define EIGEN_USE_GPU
#endif
@ -84,7 +84,7 @@ class CSRSoftmaxOp : public OpKernel {
}
};
#ifdef GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define REGISTER(DEV, T) \
REGISTER_KERNEL_BUILDER(Name("SparseMatrixSoftmax") \
.Device(DEVICE_##DEV) \
@ -110,7 +110,7 @@ DECLARE_GPU_SPEC(double);
#undef DECLARE_GPU_SPEC
} // namespace functor
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
template <typename Device, typename T>
class CSRSoftmaxGradOp : public OpKernel {
@ -193,7 +193,7 @@ class CSRSoftmaxGradOp : public OpKernel {
}
};
#ifdef GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define REGISTER(DEV, T) \
REGISTER_KERNEL_BUILDER(Name("SparseMatrixSoftmaxGrad") \
.Device(DEVICE_##DEV) \
@ -220,6 +220,6 @@ DECLARE_GPU_SPEC(double);
#undef DECLARE_GPU_SPEC
} // namespace functor
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} // namespace tensorflow

View File

@ -15,7 +15,7 @@ limitations under the License.
#define EIGEN_USE_THREADS
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
#endif
@ -35,7 +35,7 @@ limitations under the License.
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
#include "tensorflow/core/util/work_sharder.h"
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/kernels/cuda_sparse.h"
#endif
@ -500,22 +500,24 @@ REGISTER_CPU(complex128)
.TypeConstraint<T>("type"), \
CSRSparseMatMulGPUOp<DEV##Device, T>);
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define REGISTER_GPU(T) REGISTER(GPU, T)
REGISTER_GPU(float)
REGISTER_GPU(double)
#if GOOGLE_CUDA
REGISTER_GPU(complex64)
REGISTER_GPU(complex128)
#endif // GOOGLE_CUDA
#undef REGISTER_GPU
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#undef REGISTER
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
namespace functor {
template <typename T>
struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
@ -529,11 +531,19 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
adjoint_a_(adjoint_a),
transpose_b_(transpose_b) {
// TODO(ebrevdo): Figure out why transposed implementations crash cuSparse.
#if GOOGLE_CUDA
transA_ = transpose_a ? (adjoint_a ? CUSPARSE_OPERATION_TRANSPOSE
: CUSPARSE_OPERATION_CONJUGATE_TRANSPOSE)
: CUSPARSE_OPERATION_NON_TRANSPOSE;
transB_ = transpose_b ? CUSPARSE_OPERATION_TRANSPOSE
: CUSPARSE_OPERATION_NON_TRANSPOSE;
#elif TENSORFLOW_USE_ROCM
transA_ = transpose_a ? (adjoint_a ? HIPSPARSE_OPERATION_TRANSPOSE
: HIPSPARSE_OPERATION_CONJUGATE_TRANSPOSE)
: HIPSPARSE_OPERATION_NON_TRANSPOSE;
transB_ = transpose_b ? HIPSPARSE_OPERATION_TRANSPOSE
: HIPSPARSE_OPERATION_NON_TRANSPOSE;
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
}
Status Initialize() {
@ -646,6 +656,6 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
} // namespace functor
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} // namespace tensorflow

View File

@ -15,7 +15,7 @@ limitations under the License.
#define EIGEN_USE_THREADS
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
#endif

View File

@ -18,7 +18,7 @@ limitations under the License.
#define EIGEN_USE_THREADS
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
#endif

View File

@ -15,7 +15,7 @@ limitations under the License.
#define EIGEN_USE_THREADS
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
#endif
@ -29,7 +29,7 @@ limitations under the License.
#include "tensorflow/core/kernels/sparse/kernels.h"
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/kernels/cuda_sparse.h"
#endif
@ -116,12 +116,14 @@ REGISTER(CPU, double)
REGISTER(CPU, complex64)
REGISTER(CPU, complex128)
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
REGISTER(GPU, float)
REGISTER(GPU, double)
#if GOOGLE_CUDA
REGISTER(GPU, complex64)
REGISTER(GPU, complex128)
#endif
#undef REGISTER
@ -139,12 +141,14 @@ namespace functor {
DECLARE_GPU_SPEC(int32);
DECLARE_GPU_SPEC(float);
DECLARE_GPU_SPEC(double);
#if GOOGLE_CUDA
DECLARE_GPU_SPEC(complex64);
DECLARE_GPU_SPEC(complex128);
#endif
#undef DECLARE_GPU_SPEC
} // namespace functor
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} // namespace tensorflow

View File

@ -15,7 +15,7 @@ limitations under the License.
#define EIGEN_USE_THREADS
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
#endif
@ -30,13 +30,18 @@ limitations under the License.
#include "tensorflow/core/kernels/sparse/kernels.h"
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
#include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/kernels/cuda_sparse.h"
#include "tensorflow/stream_executor/cuda/cuda_activation.h"
#endif
#if GOOGLE_CUDA
#include "tensorflow/stream_executor/cuda/cuda_activation.h"
using ::perftools::gputools::cuda::ScopedActivateExecutorContext;
#elif TENSORFLOW_USE_ROCM
#include "tensorflow/stream_executor/rocm/rocm_activation.h"
using ::perftools::gputools::rocm::ScopedActivateExecutorContext;
#endif
namespace tensorflow {
@ -104,7 +109,7 @@ class SparseTensorToCSRSparseMatrixCPUOp : public OpKernel {
}
};
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
template <typename Device, typename T>
class SparseTensorToCSRSparseMatrixGPUOp : public AsyncOpKernel {
@ -322,12 +327,14 @@ extern template struct COOSparseMatrixToCSRSparseMatrix<GPUDevice>;
REGISTER_GPU(float)
REGISTER_GPU(double)
#if GOOGLE_CUDA
REGISTER_GPU(complex64)
REGISTER_GPU(complex128)
#endif
#undef REGISTER_GPU
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define REGISTER_CPU(T) \
REGISTER_KERNEL_BUILDER(Name("SparseTensorToCSRSparseMatrix") \

View File

@ -19,7 +19,7 @@ limitations under the License.
#define EIGEN_USE_THREADS
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/cuda_sparse.h"
#define EIGEN_USE_GPU
#endif
@ -132,9 +132,12 @@ REGISTER_TRANSPOSE(CPU, double)
REGISTER_TRANSPOSE(CPU, complex64)
REGISTER_TRANSPOSE(CPU, complex128)
#ifdef GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
REGISTER_TRANSPOSE(GPU, float)
REGISTER_TRANSPOSE(GPU, double)
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#if GOOGLE_CUDA
REGISTER_TRANSPOSE(GPU, complex64)
REGISTER_TRANSPOSE(GPU, complex128)
#endif // GOOGLE_CUDA
@ -250,7 +253,7 @@ struct CSRSparseMatrixTransposeComponent<CPUDevice, T> {
}
};
#ifdef GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
template <typename T>
struct CSRSparseMatrixTransposeComponent<GPUDevice, T> {
@ -259,7 +262,11 @@ struct CSRSparseMatrixTransposeComponent<GPUDevice, T> {
TF_RETURN_IF_ERROR(ValidateTransposeInputs(x, *y));
GpuSparse cuda_sparse(ctx);
TF_RETURN_IF_ERROR(cuda_sparse.Initialize());
#if GOOGLE_CUDA
const gpusparseAction_t copyValues = CUSPARSE_ACTION_NUMERIC;
#elif TENSORFLOW_USE_ROCM
const gpusparseAction_t copyValues = HIPSPARSE_ACTION_NUMERIC;
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
const int rank = x.dense_shape_host.size();
const int m = x.row_ptr.size() - 1;
const int n = x.dense_shape_host(rank - 1);
@ -279,7 +286,7 @@ struct CSRSparseMatrixTransposeComponent<GPUDevice, T> {
return Status::OK();
}
};
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} // namespace functor
} // namespace tensorflow

View File

@ -15,7 +15,7 @@ limitations under the License.
#define EIGEN_USE_THREADS
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
#endif
@ -74,7 +74,7 @@ Status CSRSparseMatrixZerosLikeHelper(OpKernelContext* ctx,
} // namespace
#ifdef GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define REGISTER(DEV) \
REGISTER_KERNEL_BUILDER(Name("SparseMatrixZeros") \
.Device(DEVICE_##DEV) \
@ -88,6 +88,6 @@ REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(
CSRSparseMatrixZerosLikeHelper<GPUDevice>);
#undef REGISTER
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} // namespace tensorflow

View File

@ -18,7 +18,7 @@ limitations under the License.
#define EIGEN_USE_THREADS
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
#endif