Adding ROCm support for the CSR Sparse Matrix Ops
This commit is contained in:
parent
5ad7620d6f
commit
2e1cdaa4b6
@ -2,6 +2,7 @@
|
||||
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"if_cuda_or_rocm",
|
||||
"tf_cc_test",
|
||||
"tf_kernel_library",
|
||||
)
|
||||
@ -77,7 +78,7 @@ tf_kernel_library(
|
||||
"//tensorflow/core/kernels:scatter_nd_op",
|
||||
"//tensorflow/core/kernels:slice_op",
|
||||
"//tensorflow/core/kernels:transpose_functor",
|
||||
] + if_cuda([
|
||||
] + if_cuda_or_rocm([
|
||||
"//tensorflow/core/kernels:cuda_solvers",
|
||||
"//tensorflow/core/kernels:cuda_sparse",
|
||||
]),
|
||||
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#define EIGEN_USE_GPU
|
||||
#endif
|
||||
|
||||
@ -31,7 +31,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
|
||||
#include "tensorflow/core/kernels/fill_functor.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#include "tensorflow/core/kernels/cuda_solvers.h"
|
||||
#include "tensorflow/core/kernels/cuda_sparse.h"
|
||||
#endif
|
||||
@ -233,8 +233,10 @@ class CSRAddOp : public OpKernel {
|
||||
|
||||
REGISTER_GPU(float)
|
||||
REGISTER_GPU(double)
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER_GPU(complex64)
|
||||
REGISTER_GPU(complex128)
|
||||
#endif
|
||||
|
||||
#undef REGISTER_GPU
|
||||
|
||||
@ -246,7 +248,7 @@ REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(
|
||||
|
||||
#undef REGISTER
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
namespace functor {
|
||||
template <typename T>
|
||||
struct CSRSparseMatrixAdd<GPUDevice, T>
|
||||
@ -337,6 +339,6 @@ struct CSRSparseMatrixAdd<GPUDevice, T>
|
||||
|
||||
} // namespace functor
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#define EIGEN_USE_GPU
|
||||
#endif
|
||||
|
||||
@ -31,7 +31,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/sparse/kernels.h"
|
||||
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#include "tensorflow/core/kernels/cuda_solvers.h"
|
||||
#include "tensorflow/core/kernels/cuda_sparse.h"
|
||||
#endif
|
||||
@ -92,12 +92,12 @@ REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(
|
||||
CONJ_VARIANT_UNARY_OP, DEVICE_CPU, CSRSparseMatrix,
|
||||
(CSRSparseMatrixUnaryHelper<CPUDevice, CSRSparseMatrixConjFunctor>));
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(
|
||||
CONJ_VARIANT_UNARY_OP, DEVICE_GPU, CSRSparseMatrix,
|
||||
(CSRSparseMatrixUnaryHelper<GPUDevice, CSRSparseMatrixConjFunctor>));
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#define EIGEN_USE_GPU
|
||||
#endif
|
||||
|
||||
@ -33,7 +33,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
|
||||
#include "tensorflow/core/util/work_sharder.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#include "tensorflow/core/kernels/cuda_solvers.h"
|
||||
#include "tensorflow/core/kernels/cuda_sparse.h"
|
||||
#endif
|
||||
@ -220,19 +220,21 @@ REGISTER_CPU(double)
|
||||
REGISTER_CPU(complex64)
|
||||
REGISTER_CPU(complex128)
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
REGISTER_GPU(float)
|
||||
REGISTER_GPU(double)
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER_GPU(complex64)
|
||||
REGISTER_GPU(complex128)
|
||||
#endif
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#undef REGISTER_CPU
|
||||
#undef REGISTER_GPU
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
namespace functor {
|
||||
template <>
|
||||
@ -256,6 +258,6 @@ extern template struct CSRSparseMatrixToCOOSparseMatrix<GPUDevice>;
|
||||
|
||||
} // namespace functor
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#define EIGEN_USE_GPU
|
||||
#endif
|
||||
|
||||
@ -31,7 +31,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
|
||||
#include "tensorflow/core/util/work_sharder.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#include "tensorflow/core/kernels/cuda_solvers.h"
|
||||
#include "tensorflow/core/kernels/cuda_sparse.h"
|
||||
#endif
|
||||
@ -205,18 +205,20 @@ class CSRSparseMatrixToSparseTensorGPUOp : public OpKernel {
|
||||
.HostMemory("dense_shape"), \
|
||||
CSRSparseMatrixToSparseTensorGPUOp<GPUDevice, T>);
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
REGISTER_GPU(float)
|
||||
REGISTER_GPU(double)
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER_GPU(complex64)
|
||||
REGISTER_GPU(complex128)
|
||||
#endif
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#undef REGISTER_GPU
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
namespace functor {
|
||||
template <>
|
||||
@ -240,7 +242,7 @@ extern template struct CSRSparseMatrixToCOOSparseMatrix<GPUDevice>;
|
||||
|
||||
} // namespace functor
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#define REGISTER_CPU(T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("CSRSparseMatrixToSparseTensor") \
|
||||
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#define EIGEN_USE_GPU
|
||||
#endif
|
||||
|
||||
@ -32,13 +32,18 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/sparse/kernels.h"
|
||||
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
|
||||
#include "tensorflow/core/kernels/cuda_solvers.h"
|
||||
#include "tensorflow/core/kernels/cuda_sparse.h"
|
||||
#include "tensorflow/stream_executor/cuda/cuda_activation.h"
|
||||
#endif
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#include "tensorflow/stream_executor/cuda/cuda_activation.h"
|
||||
using ::perftools::gputools::cuda::ScopedActivateExecutorContext;
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
#include "tensorflow/stream_executor/rocm/rocm_activation.h"
|
||||
using ::perftools::gputools::rocm::ScopedActivateExecutorContext;
|
||||
#endif
|
||||
|
||||
namespace tensorflow {
|
||||
@ -138,7 +143,7 @@ REGISTER_CPU(complex128)
|
||||
|
||||
#undef REGISTER_CPU
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
template <typename Device, typename T>
|
||||
class DenseToCSRSparseMatrixGPUOp : public AsyncOpKernel {
|
||||
@ -356,8 +361,10 @@ class DenseToCSRSparseMatrixGPUOp : public AsyncOpKernel {
|
||||
|
||||
REGISTER_GPU(GPU, float)
|
||||
REGISTER_GPU(GPU, double)
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER_GPU(GPU, complex64)
|
||||
REGISTER_GPU(GPU, complex128)
|
||||
#endif
|
||||
|
||||
namespace functor {
|
||||
|
||||
@ -391,7 +398,7 @@ extern template struct COOSparseMatrixToCSRSparseMatrix<GPUDevice>;
|
||||
|
||||
} // namespace functor
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#undef REGISTER_GPU
|
||||
|
||||
|
@ -13,15 +13,19 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#if GOOGLE_CUDA
|
||||
#include "third_party/cub/device/device_histogram.cuh"
|
||||
#include "third_party/cub/iterator/counting_input_iterator.cuh"
|
||||
#include "third_party/cub/iterator/transform_input_iterator.cuh"
|
||||
#include "third_party/gpus/cuda/include/cusparse.h"
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
#include "rocm/include/hipcub/hipcub.hpp"
|
||||
#endif
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/kernels/cuda_sparse.h"
|
||||
@ -32,6 +36,12 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
namespace gpuprim = ::cub;
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
namespace gpuprim = ::hipcub;
|
||||
#endif
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
@ -65,9 +75,9 @@ Status CalculateNNZPerBatchMatrixFromIndices<GPUDevice>::operator()(
|
||||
DCHECK_EQ(indices.dimension(1), 3); // batch, row, col
|
||||
|
||||
const int rank = indices.dimension(1);
|
||||
cub::CountingInputIterator<int> row_counter(0);
|
||||
cub::TransformInputIterator<int, StridedDataReader,
|
||||
cub::CountingInputIterator<int>>
|
||||
gpuprim::CountingInputIterator<int> row_counter(0);
|
||||
gpuprim::TransformInputIterator<int, StridedDataReader,
|
||||
gpuprim::CountingInputIterator<int>>
|
||||
indices_first_column(row_counter,
|
||||
StridedDataReader(indices.data(), rank));
|
||||
|
||||
@ -76,7 +86,7 @@ Status CalculateNNZPerBatchMatrixFromIndices<GPUDevice>::operator()(
|
||||
DCHECK_NE(indices.data(), nullptr);
|
||||
DCHECK_NE(nnz_per_batch.data(), nullptr);
|
||||
|
||||
auto first_success = cub::DeviceHistogram::HistogramEven(
|
||||
auto first_success = gpuprim::DeviceHistogram::HistogramEven(
|
||||
/*d_temp_storage*/ nullptr,
|
||||
/*temp_storage_bytes&*/ temp_storage_bytes,
|
||||
/*d_samples*/ indices_first_column,
|
||||
@ -87,12 +97,12 @@ Status CalculateNNZPerBatchMatrixFromIndices<GPUDevice>::operator()(
|
||||
/*num_samples*/ total_nnz,
|
||||
/*stream*/ cu_stream);
|
||||
|
||||
if (first_success != cudaSuccess) {
|
||||
if (first_success != gpuSuccess) {
|
||||
return errors::Internal(
|
||||
"SparseTensorToCSRSparseMatrix: Could not launch "
|
||||
"cub::DeviceHistogram::HistogramEven "
|
||||
"gpuprim::DeviceHistogram::HistogramEven "
|
||||
"to calculate temp_storage_bytes, status: ",
|
||||
cudaGetErrorString(first_success));
|
||||
GpuGetErrorString(first_success));
|
||||
}
|
||||
|
||||
Tensor temp_storage;
|
||||
@ -100,7 +110,7 @@ Status CalculateNNZPerBatchMatrixFromIndices<GPUDevice>::operator()(
|
||||
DT_INT8, TensorShape({static_cast<int64>(temp_storage_bytes)}),
|
||||
&temp_storage));
|
||||
DCHECK_NE(temp_storage.flat<int8>().data(), nullptr);
|
||||
auto second_success = cub::DeviceHistogram::HistogramEven(
|
||||
auto second_success = gpuprim::DeviceHistogram::HistogramEven(
|
||||
/*d_temp_storage*/ temp_storage.flat<int8>().data(),
|
||||
/*temp_storage_bytes&*/ temp_storage_bytes,
|
||||
/*d_samples*/ indices_first_column,
|
||||
@ -111,12 +121,12 @@ Status CalculateNNZPerBatchMatrixFromIndices<GPUDevice>::operator()(
|
||||
/*num_samples*/ total_nnz,
|
||||
/*stream*/ cu_stream);
|
||||
|
||||
if (second_success != cudaSuccess) {
|
||||
if (second_success != gpuSuccess) {
|
||||
return errors::Internal(
|
||||
"SparseTensorToCSRSparseMatrix: Could not launch "
|
||||
"cub::DeviceHistogram::HistogramEven "
|
||||
"gpuprim::DeviceHistogram::HistogramEven "
|
||||
"to count nnz entries per batch. temp_storage_bytes: ",
|
||||
temp_storage_bytes, ", status: ", cudaGetErrorString(second_success));
|
||||
temp_storage_bytes, ", status: ", GpuGetErrorString(second_success));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
@ -128,11 +138,11 @@ template <>
|
||||
Status CSRSparseMatrixToCOOSparseMatrix<GPUDevice>::operator()(
|
||||
OpKernelContext* c, TTypes<const int>::UnalignedVec csr_row_ptr,
|
||||
TTypes<int>::UnalignedVec coo_row_ind) {
|
||||
GpuSparse cuda_sparse(c);
|
||||
GpuSparse gpu_sparse(c);
|
||||
const int nnz = coo_row_ind.size();
|
||||
TF_RETURN_IF_ERROR(cuda_sparse.Initialize());
|
||||
TF_RETURN_IF_ERROR(gpu_sparse.Initialize());
|
||||
const int m = csr_row_ptr.size() - 1; // rows
|
||||
return cuda_sparse.Csr2coo(csr_row_ptr.data(), nnz, m, coo_row_ind.data());
|
||||
return gpu_sparse.Csr2coo(csr_row_ptr.data(), nnz, m, coo_row_ind.data());
|
||||
}
|
||||
|
||||
template <int stride>
|
||||
@ -140,7 +150,7 @@ __global__ void SparseTensorToCOOMatrixKernel(const int64* indices,
|
||||
int* coo_rows_out,
|
||||
int* coo_cols_out, int size) {
|
||||
const int offset = (stride == 3) ? 1 : 0;
|
||||
CUDA_1D_KERNEL_LOOP(i, size) {
|
||||
GPU_1D_KERNEL_LOOP(i, size) {
|
||||
coo_rows_out[i] = static_cast<int>(ldg(indices + i * stride + offset));
|
||||
coo_cols_out[i] = static_cast<int>(ldg(indices + i * stride + offset + 1));
|
||||
}
|
||||
@ -157,20 +167,22 @@ void SparseTensorToCOOSparseMatrix<GPUDevice>::operator()(
|
||||
const int size = coo_row_ind.dimension(0);
|
||||
GpuLaunchConfig config = GetGpuLaunchConfig(size, d);
|
||||
if (stride == 2) {
|
||||
SparseTensorToCOOMatrixKernel<2>
|
||||
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
|
||||
indices.data(), coo_row_ind.data(), coo_col_ind.data(), size);
|
||||
TF_CHECK_OK(GpuLaunchKernel(SparseTensorToCOOMatrixKernel<2>,
|
||||
config.block_count, config.thread_per_block, 0,
|
||||
d.stream(), indices.data(), coo_row_ind.data(),
|
||||
coo_col_ind.data(), size));
|
||||
} else {
|
||||
SparseTensorToCOOMatrixKernel<3>
|
||||
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
|
||||
indices.data(), coo_row_ind.data(), coo_col_ind.data(), size);
|
||||
TF_CHECK_OK(GpuLaunchKernel(SparseTensorToCOOMatrixKernel<3>,
|
||||
config.block_count, config.thread_per_block, 0,
|
||||
d.stream(), indices.data(), coo_row_ind.data(),
|
||||
coo_col_ind.data(), size));
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void COOMatrixToSparseTensorKernel2D(const int* coo_rows,
|
||||
const int* coo_cols,
|
||||
int64* indices_out, int size) {
|
||||
CUDA_1D_KERNEL_LOOP(i, size) {
|
||||
GPU_1D_KERNEL_LOOP(i, size) {
|
||||
indices_out[i * 2] = static_cast<int64>(ldg(coo_rows + i));
|
||||
indices_out[i * 2 + 1] = static_cast<int64>(ldg(coo_cols + i));
|
||||
}
|
||||
@ -203,7 +215,7 @@ __global__ void COOMatrixToSparseTensorKernel3D(
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
CUDA_1D_KERNEL_LOOP(i, size) {
|
||||
GPU_1D_KERNEL_LOOP(i, size) {
|
||||
// TODO(ebrevdo): Consider special casing batch_size <= 3,
|
||||
// alternatively doing linear instead of binary search. Requires
|
||||
// some benchmarks.
|
||||
@ -231,9 +243,10 @@ Status COOSparseMatrixToSparseTensor<GPUDevice>::operator()(
|
||||
DCHECK_EQ(size, indices.dimension(0));
|
||||
if (ndims == 2) {
|
||||
GpuLaunchConfig config = GetGpuLaunchConfig(size, d);
|
||||
COOMatrixToSparseTensorKernel2D<<<config.block_count,
|
||||
config.thread_per_block, 0, d.stream()>>>(
|
||||
coo_row_ind.data(), coo_col_ind.data(), indices.data(), size);
|
||||
TF_CHECK_OK(GpuLaunchKernel(COOMatrixToSparseTensorKernel2D,
|
||||
config.block_count, config.thread_per_block, 0,
|
||||
d.stream(), coo_row_ind.data(),
|
||||
coo_col_ind.data(), indices.data(), size));
|
||||
return Status::OK();
|
||||
} else {
|
||||
const int batch_size = host_dense_shape(0);
|
||||
@ -246,11 +259,11 @@ Status COOSparseMatrixToSparseTensor<GPUDevice>::operator()(
|
||||
GpuLaunchConfig config = GetGpuLaunchConfig(size, d);
|
||||
// shared memory stores the batch pointers.
|
||||
const size_t shared_memory_size = sizeof(int) * (batch_size + 1);
|
||||
COOMatrixToSparseTensorKernel3D<<<config.block_count,
|
||||
config.thread_per_block,
|
||||
shared_memory_size, d.stream()>>>(
|
||||
coo_row_ind.data(), coo_col_ind.data(), indices.data(),
|
||||
batch_ptr_copy.data(), batch_size, size);
|
||||
TF_CHECK_OK(
|
||||
GpuLaunchKernel(COOMatrixToSparseTensorKernel3D, config.block_count,
|
||||
config.thread_per_block, shared_memory_size, d.stream(),
|
||||
coo_row_ind.data(), coo_col_ind.data(), indices.data(),
|
||||
batch_ptr_copy.data(), batch_size, size));
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
@ -274,7 +287,7 @@ __global__ void CSRSparseMatrixBatchMulVecKernel3D(
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
CUDA_1D_KERNEL_LOOP(i, total_nnz) {
|
||||
GPU_1D_KERNEL_LOOP(i, total_nnz) {
|
||||
const int b = BinarySearchRange(local_batch_ptr, batch_size, i);
|
||||
c_values[i] = ldg(a_values + i) * local_batch_values[b];
|
||||
}
|
||||
@ -316,10 +329,10 @@ Status CSRSparseMatrixBatchMulVecImpl(OpKernelContext* ctx,
|
||||
const size_t shared_memory_size =
|
||||
(sizeof(int) * (batch_size + 1) // local batch_pointers.
|
||||
+ sizeof(T) * batch_size); // local copy of b.
|
||||
CSRSparseMatrixBatchMulVecKernel3D<T>
|
||||
<<<config.block_count, config.thread_per_block, shared_memory_size,
|
||||
d.stream()>>>(a_values.data(), b.data(), c_values.data(),
|
||||
batch_ptr_copy.data(), batch_size, total_nnz);
|
||||
TF_CHECK_OK(GpuLaunchKernel(
|
||||
CSRSparseMatrixBatchMulVecKernel3D<T>, config.block_count,
|
||||
config.thread_per_block, shared_memory_size, d.stream(), a_values.data(),
|
||||
b.data(), c_values.data(), batch_ptr_copy.data(), batch_size, total_nnz));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
@ -374,7 +387,7 @@ __global__ void CSRSparseMatrixSoftmaxKernel2D(const int rows,
|
||||
// algorithm to distribute the work in case the row sizes are
|
||||
// uneven:
|
||||
// http://images.nvidia.com/events/sc15/pdfs/sc15-Merge-Based-Parallel-Sparse-Matrix-Vector-Multiplication-merrill.pdf
|
||||
CUDA_1D_KERNEL_LOOP(row, rows) {
|
||||
GPU_1D_KERNEL_LOOP(row, rows) {
|
||||
CalculateRowSoftmax(ldg(row_ptr + row), ldg(row_ptr + row + 1), logits,
|
||||
softmax);
|
||||
}
|
||||
@ -382,7 +395,7 @@ __global__ void CSRSparseMatrixSoftmaxKernel2D(const int rows,
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void CopyFromGpuDeviceArrayToLocal(
|
||||
GpuDeviceArrayStruct<int> cuda_ptr_s, int* local_ptr, int length) {
|
||||
#ifdef __CUDA_ARCH__
|
||||
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
|
||||
const int* cuda_ptr = GetGpuDeviceArrayOnDevice(&cuda_ptr_s);
|
||||
for (int i = threadIdx.x; i < length; i += blockDim.x) {
|
||||
local_ptr[i] = cuda_ptr[i];
|
||||
@ -404,7 +417,7 @@ __global__ void CSRSparseMatrixSoftmaxKernel3D(
|
||||
CopyFromGpuDeviceArrayToLocal(std::move(batch_ptr_s), local_batch_ptr,
|
||||
batch_size + 1);
|
||||
|
||||
CUDA_1D_KERNEL_LOOP(i, size) {
|
||||
GPU_1D_KERNEL_LOOP(i, size) {
|
||||
const int batch = i / rows;
|
||||
const int row = i % rows;
|
||||
const int batch_offset = local_batch_ptr[batch];
|
||||
@ -431,10 +444,10 @@ Status CSRSparseMatrixSoftmaxGPUImpl(OpKernelContext* ctx,
|
||||
const int rows = host_dense_shape(0);
|
||||
DCHECK_EQ(rows, row_ptr.size() - 1);
|
||||
GpuLaunchConfig config = GetGpuLaunchConfig(rows /*size*/, d);
|
||||
CSRSparseMatrixSoftmaxKernel2D<T>
|
||||
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
|
||||
rows /*size*/, row_ptr.data(), logits_values.data(),
|
||||
softmax_values.data());
|
||||
TF_CHECK_OK(GpuLaunchKernel(CSRSparseMatrixSoftmaxKernel2D<T>,
|
||||
config.block_count, config.thread_per_block, 0,
|
||||
d.stream(), rows /*size*/, row_ptr.data(),
|
||||
logits_values.data(), softmax_values.data()));
|
||||
} else {
|
||||
const int batch_size = host_dense_shape(0);
|
||||
const int rows = host_dense_shape(1);
|
||||
@ -452,10 +465,11 @@ Status CSRSparseMatrixSoftmaxGPUImpl(OpKernelContext* ctx,
|
||||
GpuLaunchConfig config = GetGpuLaunchConfig(size, d);
|
||||
// shared memory stores the batch pointers.
|
||||
const size_t shared_memory_size = sizeof(int) * (batch_size + 1);
|
||||
CSRSparseMatrixSoftmaxKernel3D<T>
|
||||
<<<config.block_count, config.thread_per_block, shared_memory_size,
|
||||
d.stream()>>>(size, rows, batch_ptr_copy.data(), row_ptr.data(),
|
||||
logits_values.data(), softmax_values.data());
|
||||
TF_CHECK_OK(GpuLaunchKernel(CSRSparseMatrixSoftmaxKernel3D<T>,
|
||||
config.block_count, config.thread_per_block,
|
||||
shared_memory_size, d.stream(), size, rows,
|
||||
batch_ptr_copy.data(), row_ptr.data(),
|
||||
logits_values.data(), softmax_values.data()));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
@ -549,7 +563,7 @@ __global__ void CSRSparseMatrixSoftmaxGradKernel2D(
|
||||
// algorithm to distribute the work in case the row sizes are
|
||||
// uneven:
|
||||
// http://images.nvidia.com/events/sc15/pdfs/sc15-Merge-Based-Parallel-Sparse-Matrix-Vector-Multiplication-merrill.pdf
|
||||
CUDA_1D_KERNEL_LOOP(row, rows) {
|
||||
GPU_1D_KERNEL_LOOP(row, rows) {
|
||||
CalculateRowSoftmaxGrad(
|
||||
ldg(softmax_row_ptr + row) /*softmax_begin*/,
|
||||
ldg(softmax_row_ptr + row + 1) /*softmax_end*/, softmax_col_ind,
|
||||
@ -579,7 +593,7 @@ __global__ void CSRSparseMatrixSoftmaxGradKernel3D(
|
||||
#define SOFTMAX_BATCH_PTR(i) local_batch_ptr[i];
|
||||
#define GRAD_SOFTMAX_BATCH_PTR(i) local_batch_ptr[batch_size + 1 + i];
|
||||
|
||||
CUDA_1D_KERNEL_LOOP(i, size) {
|
||||
GPU_1D_KERNEL_LOOP(i, size) {
|
||||
const int batch = i / rows;
|
||||
const int row = i % rows;
|
||||
const int softmax_batch_offset = SOFTMAX_BATCH_PTR(batch);
|
||||
@ -625,12 +639,12 @@ Status CSRSparseMatrixSoftmaxGradGPUImpl(
|
||||
DCHECK_EQ(rows + 1, softmax_row_ptr.size());
|
||||
DCHECK_EQ(rows + 1, grad_softmax_row_ptr.size());
|
||||
GpuLaunchConfig config = GetGpuLaunchConfig(rows /*size*/, d);
|
||||
CSRSparseMatrixSoftmaxGradKernel2D<T>
|
||||
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
|
||||
rows /*size*/, softmax_row_ptr.data(), softmax_col_ind.data(),
|
||||
softmax_values.data(), grad_softmax_row_ptr.data(),
|
||||
grad_softmax_col_ind.data(), grad_softmax_values.data(),
|
||||
gradient_values.data());
|
||||
TF_CHECK_OK(GpuLaunchKernel(
|
||||
CSRSparseMatrixSoftmaxGradKernel2D<T>, config.block_count,
|
||||
config.thread_per_block, 0, d.stream(), rows /*size*/,
|
||||
softmax_row_ptr.data(), softmax_col_ind.data(), softmax_values.data(),
|
||||
grad_softmax_row_ptr.data(), grad_softmax_col_ind.data(),
|
||||
grad_softmax_values.data(), gradient_values.data()));
|
||||
} else {
|
||||
const int batch_size = host_dense_shape(0);
|
||||
const int rows = host_dense_shape(1);
|
||||
@ -656,13 +670,13 @@ Status CSRSparseMatrixSoftmaxGradGPUImpl(
|
||||
// shared memory stores two copies of batch pointers: one for the
|
||||
// softmax CSR matrix, one for the grad_softmax CSR matrix.
|
||||
const size_t shared_memory_size = 2 * sizeof(int) * (batch_size + 1);
|
||||
CSRSparseMatrixSoftmaxGradKernel3D<T>
|
||||
<<<config.block_count, config.thread_per_block, shared_memory_size,
|
||||
d.stream()>>>(size, rows, softmax_and_grad_batch_ptr_copy.data(),
|
||||
softmax_row_ptr.data(), softmax_col_ind.data(),
|
||||
softmax_values.data(), grad_softmax_row_ptr.data(),
|
||||
grad_softmax_col_ind.data(),
|
||||
grad_softmax_values.data(), gradient_values.data());
|
||||
TF_CHECK_OK(GpuLaunchKernel(
|
||||
CSRSparseMatrixSoftmaxGradKernel3D<T>, config.block_count,
|
||||
config.thread_per_block, shared_memory_size, d.stream(), size, rows,
|
||||
softmax_and_grad_batch_ptr_copy.data(), softmax_row_ptr.data(),
|
||||
softmax_col_ind.data(), softmax_values.data(),
|
||||
grad_softmax_row_ptr.data(), grad_softmax_col_ind.data(),
|
||||
grad_softmax_values.data(), gradient_values.data()));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
@ -687,4 +701,4 @@ DEFINE_SOFTMAX_GRAD_GPU(double);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#define EIGEN_USE_GPU
|
||||
#endif
|
||||
|
||||
@ -36,7 +36,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/platform/threadpool.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#include "tensorflow/core/kernels/cuda_solvers.h"
|
||||
#include "tensorflow/core/kernels/cuda_sparse.h"
|
||||
#endif
|
||||
@ -694,7 +694,7 @@ REGISTER_CPU(complex128)
|
||||
|
||||
#undef REGISTER_CPU
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#define REGISTER_GPU(T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
@ -703,14 +703,16 @@ REGISTER_CPU(complex128)
|
||||
|
||||
REGISTER_GPU(float)
|
||||
REGISTER_GPU(double)
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER_GPU(complex64)
|
||||
REGISTER_GPU(complex128)
|
||||
#endif
|
||||
|
||||
#undef REGISTER_GPU
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
namespace functor {
|
||||
|
||||
@ -741,11 +743,16 @@ class CSRSparseMatrixMatMul<GPUDevice, T> {
|
||||
|
||||
// transA must be non-transpose if transB is transpose (cusparse
|
||||
// limitation).
|
||||
#if GOOGLE_CUDA
|
||||
const gpusparseOperation_t transA = CUSPARSE_OPERATION_NON_TRANSPOSE;
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
const gpusparseOperation_t transA = HIPSPARSE_OPERATION_NON_TRANSPOSE;
|
||||
#endif
|
||||
|
||||
// transB: b is row-major, and cusparse requires col-major b (or
|
||||
// equivalently transB == transpose). this version is actually more
|
||||
// efficient.
|
||||
#if GOOGLE_CUDA
|
||||
const gpusparseOperation_t transB = CUSPARSE_OPERATION_TRANSPOSE;
|
||||
|
||||
gpusparseMatDescr_t descrA;
|
||||
@ -754,6 +761,16 @@ class CSRSparseMatrixMatMul<GPUDevice, T> {
|
||||
cusparseSetMatType(descrA, CUSPARSE_MATRIX_TYPE_GENERAL));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
cusparseSetMatIndexBase(descrA, CUSPARSE_INDEX_BASE_ZERO));
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
const gpusparseOperation_t transB = HIPSPARSE_OPERATION_TRANSPOSE;
|
||||
|
||||
gpusparseMatDescr_t descrA;
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseCreateMatDescr(&descrA));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
hipsparseSetMatType(descrA, HIPSPARSE_MATRIX_TYPE_GENERAL));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
hipsparseSetMatIndexBase(descrA, HIPSPARSE_INDEX_BASE_ZERO));
|
||||
#endif
|
||||
|
||||
// A is (m, k), Bt is (ldb, k) and Ct is (ldc, n)
|
||||
const int k = b.dimension(0);
|
||||
@ -816,11 +833,19 @@ class CSRSparseMatrixMatVec<GPUDevice, T> {
|
||||
const T beta = 0;
|
||||
|
||||
gpusparseMatDescr_t descrA;
|
||||
#if GOOGLE_CUDA
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descrA));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
cusparseSetMatType(descrA, CUSPARSE_MATRIX_TYPE_GENERAL));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
cusparseSetMatIndexBase(descrA, CUSPARSE_INDEX_BASE_ZERO));
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseCreateMatDescr(&descrA));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
hipsparseSetMatType(descrA, HIPSPARSE_MATRIX_TYPE_GENERAL));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
hipsparseSetMatIndexBase(descrA, HIPSPARSE_INDEX_BASE_ZERO));
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
const int m = a.dense_shape_host(0);
|
||||
const int n = a.dense_shape_host(1);
|
||||
@ -841,6 +866,6 @@ class CSRSparseMatrixMatVec<GPUDevice, T> {
|
||||
|
||||
} // namespace functor
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#define EIGEN_USE_GPU
|
||||
#endif
|
||||
|
||||
@ -28,7 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/sparse/kernels.h"
|
||||
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#include "tensorflow/core/kernels/cuda_sparse.h"
|
||||
#endif
|
||||
|
||||
@ -101,22 +101,24 @@ class CSRMulOp : public OpKernel {
|
||||
Name("SparseMatrixMul").Device(DEVICE_##DEV).TypeConstraint<T>("T"), \
|
||||
CSRMulOp<DEV##Device, T>);
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#define REGISTER_GPU(T) REGISTER(GPU, T)
|
||||
|
||||
REGISTER_GPU(float)
|
||||
REGISTER_GPU(double)
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER_GPU(complex64)
|
||||
REGISTER_GPU(complex128)
|
||||
#endif
|
||||
|
||||
#undef REGISTER_GPU
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#undef REGISTER
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
namespace functor {
|
||||
|
||||
@ -159,13 +161,15 @@ class CSRSparseMatrixMulScalar<GPUDevice, T> {
|
||||
|
||||
DECLARE_GPU_SPEC(float);
|
||||
DECLARE_GPU_SPEC(double);
|
||||
#if GOOGLE_CUDA
|
||||
DECLARE_GPU_SPEC(std::complex<float>);
|
||||
DECLARE_GPU_SPEC(std::complex<double>);
|
||||
#endif
|
||||
|
||||
#undef DECLARE_GPU_SPEC
|
||||
|
||||
} // namespace functor
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#define EIGEN_USE_GPU
|
||||
#endif
|
||||
|
||||
@ -28,7 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/sparse/kernels.h"
|
||||
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#include "tensorflow/core/kernels/cuda_solvers.h"
|
||||
#include "tensorflow/core/kernels/cuda_sparse.h"
|
||||
#endif
|
||||
@ -67,11 +67,11 @@ class CSRNNZOp : public OpKernel {
|
||||
|
||||
REGISTER(CPU)
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
REGISTER(GPU)
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#undef REGISTER
|
||||
|
||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#include "tensorflow/core/kernels/cuda_sparse.h"
|
||||
#define EIGEN_USE_GPU
|
||||
#endif
|
||||
@ -84,7 +84,7 @@ class CSRSoftmaxOp : public OpKernel {
|
||||
}
|
||||
};
|
||||
|
||||
#ifdef GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#define REGISTER(DEV, T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("SparseMatrixSoftmax") \
|
||||
.Device(DEVICE_##DEV) \
|
||||
@ -110,7 +110,7 @@ DECLARE_GPU_SPEC(double);
|
||||
#undef DECLARE_GPU_SPEC
|
||||
} // namespace functor
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
template <typename Device, typename T>
|
||||
class CSRSoftmaxGradOp : public OpKernel {
|
||||
@ -193,7 +193,7 @@ class CSRSoftmaxGradOp : public OpKernel {
|
||||
}
|
||||
};
|
||||
|
||||
#ifdef GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#define REGISTER(DEV, T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("SparseMatrixSoftmaxGrad") \
|
||||
.Device(DEVICE_##DEV) \
|
||||
@ -220,6 +220,6 @@ DECLARE_GPU_SPEC(double);
|
||||
#undef DECLARE_GPU_SPEC
|
||||
} // namespace functor
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#define EIGEN_USE_GPU
|
||||
#endif
|
||||
|
||||
@ -35,7 +35,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
|
||||
#include "tensorflow/core/util/work_sharder.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#include "tensorflow/core/kernels/cuda_solvers.h"
|
||||
#include "tensorflow/core/kernels/cuda_sparse.h"
|
||||
#endif
|
||||
@ -500,22 +500,24 @@ REGISTER_CPU(complex128)
|
||||
.TypeConstraint<T>("type"), \
|
||||
CSRSparseMatMulGPUOp<DEV##Device, T>);
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#define REGISTER_GPU(T) REGISTER(GPU, T)
|
||||
|
||||
REGISTER_GPU(float)
|
||||
REGISTER_GPU(double)
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER_GPU(complex64)
|
||||
REGISTER_GPU(complex128)
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
#undef REGISTER_GPU
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#undef REGISTER
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
namespace functor {
|
||||
template <typename T>
|
||||
struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
|
||||
@ -529,11 +531,19 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
|
||||
adjoint_a_(adjoint_a),
|
||||
transpose_b_(transpose_b) {
|
||||
// TODO(ebrevdo): Figure out why transposed implementations crash cuSparse.
|
||||
#if GOOGLE_CUDA
|
||||
transA_ = transpose_a ? (adjoint_a ? CUSPARSE_OPERATION_TRANSPOSE
|
||||
: CUSPARSE_OPERATION_CONJUGATE_TRANSPOSE)
|
||||
: CUSPARSE_OPERATION_NON_TRANSPOSE;
|
||||
transB_ = transpose_b ? CUSPARSE_OPERATION_TRANSPOSE
|
||||
: CUSPARSE_OPERATION_NON_TRANSPOSE;
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
transA_ = transpose_a ? (adjoint_a ? HIPSPARSE_OPERATION_TRANSPOSE
|
||||
: HIPSPARSE_OPERATION_CONJUGATE_TRANSPOSE)
|
||||
: HIPSPARSE_OPERATION_NON_TRANSPOSE;
|
||||
transB_ = transpose_b ? HIPSPARSE_OPERATION_TRANSPOSE
|
||||
: HIPSPARSE_OPERATION_NON_TRANSPOSE;
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
}
|
||||
|
||||
Status Initialize() {
|
||||
@ -646,6 +656,6 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
|
||||
|
||||
} // namespace functor
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#define EIGEN_USE_GPU
|
||||
#endif
|
||||
|
||||
|
@ -18,7 +18,7 @@ limitations under the License.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#define EIGEN_USE_GPU
|
||||
#endif
|
||||
|
||||
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#define EIGEN_USE_GPU
|
||||
#endif
|
||||
|
||||
@ -29,7 +29,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/sparse/kernels.h"
|
||||
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#include "tensorflow/core/kernels/cuda_solvers.h"
|
||||
#include "tensorflow/core/kernels/cuda_sparse.h"
|
||||
#endif
|
||||
@ -116,12 +116,14 @@ REGISTER(CPU, double)
|
||||
REGISTER(CPU, complex64)
|
||||
REGISTER(CPU, complex128)
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
REGISTER(GPU, float)
|
||||
REGISTER(GPU, double)
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER(GPU, complex64)
|
||||
REGISTER(GPU, complex128)
|
||||
#endif
|
||||
|
||||
#undef REGISTER
|
||||
|
||||
@ -139,12 +141,14 @@ namespace functor {
|
||||
DECLARE_GPU_SPEC(int32);
|
||||
DECLARE_GPU_SPEC(float);
|
||||
DECLARE_GPU_SPEC(double);
|
||||
#if GOOGLE_CUDA
|
||||
DECLARE_GPU_SPEC(complex64);
|
||||
DECLARE_GPU_SPEC(complex128);
|
||||
#endif
|
||||
|
||||
#undef DECLARE_GPU_SPEC
|
||||
} // namespace functor
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#define EIGEN_USE_GPU
|
||||
#endif
|
||||
|
||||
@ -30,13 +30,18 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/sparse/kernels.h"
|
||||
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
|
||||
#include "tensorflow/core/kernels/cuda_solvers.h"
|
||||
#include "tensorflow/core/kernels/cuda_sparse.h"
|
||||
#include "tensorflow/stream_executor/cuda/cuda_activation.h"
|
||||
#endif
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#include "tensorflow/stream_executor/cuda/cuda_activation.h"
|
||||
using ::perftools::gputools::cuda::ScopedActivateExecutorContext;
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
#include "tensorflow/stream_executor/rocm/rocm_activation.h"
|
||||
using ::perftools::gputools::rocm::ScopedActivateExecutorContext;
|
||||
#endif
|
||||
|
||||
namespace tensorflow {
|
||||
@ -104,7 +109,7 @@ class SparseTensorToCSRSparseMatrixCPUOp : public OpKernel {
|
||||
}
|
||||
};
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
template <typename Device, typename T>
|
||||
class SparseTensorToCSRSparseMatrixGPUOp : public AsyncOpKernel {
|
||||
@ -322,12 +327,14 @@ extern template struct COOSparseMatrixToCSRSparseMatrix<GPUDevice>;
|
||||
|
||||
REGISTER_GPU(float)
|
||||
REGISTER_GPU(double)
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER_GPU(complex64)
|
||||
REGISTER_GPU(complex128)
|
||||
#endif
|
||||
|
||||
#undef REGISTER_GPU
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#define REGISTER_CPU(T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("SparseTensorToCSRSparseMatrix") \
|
||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#include "tensorflow/core/kernels/cuda_sparse.h"
|
||||
#define EIGEN_USE_GPU
|
||||
#endif
|
||||
@ -132,9 +132,12 @@ REGISTER_TRANSPOSE(CPU, double)
|
||||
REGISTER_TRANSPOSE(CPU, complex64)
|
||||
REGISTER_TRANSPOSE(CPU, complex128)
|
||||
|
||||
#ifdef GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
REGISTER_TRANSPOSE(GPU, float)
|
||||
REGISTER_TRANSPOSE(GPU, double)
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER_TRANSPOSE(GPU, complex64)
|
||||
REGISTER_TRANSPOSE(GPU, complex128)
|
||||
#endif // GOOGLE_CUDA
|
||||
@ -250,7 +253,7 @@ struct CSRSparseMatrixTransposeComponent<CPUDevice, T> {
|
||||
}
|
||||
};
|
||||
|
||||
#ifdef GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
template <typename T>
|
||||
struct CSRSparseMatrixTransposeComponent<GPUDevice, T> {
|
||||
@ -259,7 +262,11 @@ struct CSRSparseMatrixTransposeComponent<GPUDevice, T> {
|
||||
TF_RETURN_IF_ERROR(ValidateTransposeInputs(x, *y));
|
||||
GpuSparse cuda_sparse(ctx);
|
||||
TF_RETURN_IF_ERROR(cuda_sparse.Initialize());
|
||||
#if GOOGLE_CUDA
|
||||
const gpusparseAction_t copyValues = CUSPARSE_ACTION_NUMERIC;
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
const gpusparseAction_t copyValues = HIPSPARSE_ACTION_NUMERIC;
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
const int rank = x.dense_shape_host.size();
|
||||
const int m = x.row_ptr.size() - 1;
|
||||
const int n = x.dense_shape_host(rank - 1);
|
||||
@ -279,7 +286,7 @@ struct CSRSparseMatrixTransposeComponent<GPUDevice, T> {
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
} // namespace functor
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#define EIGEN_USE_GPU
|
||||
#endif
|
||||
|
||||
@ -74,7 +74,7 @@ Status CSRSparseMatrixZerosLikeHelper(OpKernelContext* ctx,
|
||||
|
||||
} // namespace
|
||||
|
||||
#ifdef GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#define REGISTER(DEV) \
|
||||
REGISTER_KERNEL_BUILDER(Name("SparseMatrixZeros") \
|
||||
.Device(DEVICE_##DEV) \
|
||||
@ -88,6 +88,6 @@ REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(
|
||||
CSRSparseMatrixZerosLikeHelper<GPUDevice>);
|
||||
|
||||
#undef REGISTER
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -18,7 +18,7 @@ limitations under the License.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#define EIGEN_USE_GPU
|
||||
#endif
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user