From 815fa1866caa88727d2d5288c028f0e610c4bda3 Mon Sep 17 00:00:00 2001 From: Eugene Kuznetsov Date: Wed, 15 Jan 2020 20:49:23 -0800 Subject: [PATCH 1/2] Complex sparse ops --- .../sparse/csr_sparse_matrix_to_dense_op.cc | 2 -- .../sparse/dense_to_csr_sparse_matrix_op.cc | 2 -- tensorflow/core/kernels/sparse/mat_mul_op.cc | 8 ++++++-- tensorflow/core/kernels/sparse/mul_op.cc | 2 -- .../kernels/sparse/sparse_matrix_components_op.cc | 4 ---- .../sparse_tensor_to_csr_sparse_matrix_op.cc | 2 -- .../csr_sparse_matrix_dense_mat_mul_grad_test.py | 5 +---- .../linalg/sparse/csr_sparse_matrix_ops_test.py | 15 +-------------- .../linalg/sparse/csr_sparse_matrix_test.py | 15 +++------------ 9 files changed, 11 insertions(+), 44 deletions(-) diff --git a/tensorflow/core/kernels/sparse/csr_sparse_matrix_to_dense_op.cc b/tensorflow/core/kernels/sparse/csr_sparse_matrix_to_dense_op.cc index 9e5a11c4aeb..364c2c07bd8 100644 --- a/tensorflow/core/kernels/sparse/csr_sparse_matrix_to_dense_op.cc +++ b/tensorflow/core/kernels/sparse/csr_sparse_matrix_to_dense_op.cc @@ -224,10 +224,8 @@ REGISTER_CPU(complex128) REGISTER_GPU(float) REGISTER_GPU(double) -#if GOOGLE_CUDA REGISTER_GPU(complex64) REGISTER_GPU(complex128) -#endif #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/kernels/sparse/dense_to_csr_sparse_matrix_op.cc b/tensorflow/core/kernels/sparse/dense_to_csr_sparse_matrix_op.cc index b42d315789b..f021a8f4df3 100644 --- a/tensorflow/core/kernels/sparse/dense_to_csr_sparse_matrix_op.cc +++ b/tensorflow/core/kernels/sparse/dense_to_csr_sparse_matrix_op.cc @@ -361,10 +361,8 @@ class DenseToCSRSparseMatrixGPUOp : public AsyncOpKernel { REGISTER_GPU(GPU, float) REGISTER_GPU(GPU, double) -#if GOOGLE_CUDA REGISTER_GPU(GPU, complex64) REGISTER_GPU(GPU, complex128) -#endif namespace functor { diff --git a/tensorflow/core/kernels/sparse/mat_mul_op.cc b/tensorflow/core/kernels/sparse/mat_mul_op.cc index a57d97b7a73..58a7d3e5a7e 100644 --- a/tensorflow/core/kernels/sparse/mat_mul_op.cc +++ b/tensorflow/core/kernels/sparse/mat_mul_op.cc @@ -538,8 +538,12 @@ class CSRMatMulGPUOp : public CSRMatMulOp { OP_REQUIRES_OK(ctx, ctx->allocate_output(0, c_shape, &c_t)); const GPUDevice& d = ctx->eigen_device(); - - if (b_outer_dim == 1) { + bool shortcut_ok = (b_outer_dim == 1); +#if TENSORFLOW_USE_ROCM + // ROCm hipsparse does not implement csrmv with transposed input a + shortcut_ok = shortcut_ok && !this->transpose_a_; +#endif + if (shortcut_ok) { // Call matrix-vector multiply if b is a vector. TTypes::ConstVec a_dense_shape_comp(a_dense_shape.data() + row_dim, 2); diff --git a/tensorflow/core/kernels/sparse/mul_op.cc b/tensorflow/core/kernels/sparse/mul_op.cc index f6cf369626c..33c3756ce58 100644 --- a/tensorflow/core/kernels/sparse/mul_op.cc +++ b/tensorflow/core/kernels/sparse/mul_op.cc @@ -107,10 +107,8 @@ class CSRMulOp : public OpKernel { REGISTER_GPU(float) REGISTER_GPU(double) -#if GOOGLE_CUDA REGISTER_GPU(complex64) REGISTER_GPU(complex128) -#endif #undef REGISTER_GPU diff --git a/tensorflow/core/kernels/sparse/sparse_matrix_components_op.cc b/tensorflow/core/kernels/sparse/sparse_matrix_components_op.cc index 9cbe88bde6c..59540f63846 100644 --- a/tensorflow/core/kernels/sparse/sparse_matrix_components_op.cc +++ b/tensorflow/core/kernels/sparse/sparse_matrix_components_op.cc @@ -120,10 +120,8 @@ REGISTER(CPU, complex128) REGISTER(GPU, float) REGISTER(GPU, double) -#if GOOGLE_CUDA REGISTER(GPU, complex64) REGISTER(GPU, complex128) -#endif #undef REGISTER @@ -141,10 +139,8 @@ namespace functor { DECLARE_GPU_SPEC(int32); DECLARE_GPU_SPEC(float); DECLARE_GPU_SPEC(double); -#if GOOGLE_CUDA DECLARE_GPU_SPEC(complex64); DECLARE_GPU_SPEC(complex128); -#endif #undef DECLARE_GPU_SPEC } // namespace functor diff --git a/tensorflow/core/kernels/sparse/sparse_tensor_to_csr_sparse_matrix_op.cc b/tensorflow/core/kernels/sparse/sparse_tensor_to_csr_sparse_matrix_op.cc index 47efd24f83a..efe6e3ed14f 100644 --- a/tensorflow/core/kernels/sparse/sparse_tensor_to_csr_sparse_matrix_op.cc +++ b/tensorflow/core/kernels/sparse/sparse_tensor_to_csr_sparse_matrix_op.cc @@ -327,10 +327,8 @@ extern template struct COOSparseMatrixToCSRSparseMatrix; REGISTER_GPU(float) REGISTER_GPU(double) -#if GOOGLE_CUDA REGISTER_GPU(complex64) REGISTER_GPU(complex128) -#endif #undef REGISTER_GPU diff --git a/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_dense_mat_mul_grad_test.py b/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_dense_mat_mul_grad_test.py index 5cd206ccbc1..4841c18a78c 100644 --- a/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_dense_mat_mul_grad_test.py +++ b/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_dense_mat_mul_grad_test.py @@ -106,10 +106,7 @@ class CSRSparseMatrixDenseMatMulGradTest(test.TestCase): # These tests are refactored from sparse_csr_matrix_grad_test to keep its size # "medium". -dtypes_to_test = [np.float32] -if not test.is_built_with_rocm: - # complex type is not supported on the ROCm platform - dtypes_to_test += [np.complex64] +dtypes_to_test = [np.float32, np.complex64] for dtype in dtypes_to_test: for (t_a, t_b, adj_a, adj_b, t_out, conj_out) in itertools.product(*(([False, True],) * 6)): diff --git a/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_ops_test.py b/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_ops_test.py index 51757802968..ac82f190db0 100644 --- a/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_ops_test.py +++ b/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_ops_test.py @@ -517,9 +517,6 @@ class CSRSparseMatrixOpsTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testSparseMatrixMatMulConjugateOutput(self): - if test.is_built_with_rocm(): - self.skipTest("complex type not supported on ROCm") - for shapes in [[(5, 6), (6, 1)], [(5, 6), (6, 2)]]: a_indices = np.array([[0, 0], [2, 3]]) a_values = np.array([1.0 + 1.j, 5.0 - 2.j]).astype(np.complex64) @@ -542,17 +539,7 @@ class CSRSparseMatrixOpsTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testLargeBatchSparseMatrixMatMul(self): - dtypes_to_test = [np.float32] - if not test.is_built_with_rocm(): - # complex types is not supported on the ROCm platform - dtypes_to_test += [np.complex64] - - if test.is_built_with_rocm(): - # TODO(rocm): fix this - # This test is currently failing on the ROCm platform - # Ren-enable it once the fix is available - self.skipTest("hipSPARSE all failure on the ROCm platform") - + dtypes_to_test = [np.float32, np.complex64] sparsify = lambda m: m * (m > 0) for dtype in dtypes_to_test: for (transpose_a, transpose_b) in ((False, False), (False, True), diff --git a/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_test.py b/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_test.py index 66077f5b2d2..35c706cb36a 100644 --- a/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_test.py +++ b/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_test.py @@ -154,10 +154,7 @@ class SparseMatrixMatmulTest(test.TestCase): sparsify = lambda m: m * (m > 0) dense_shape_a = [5, 13, 7] if transpose_a or adjoint_a else [5, 7, 13] dense_shape_b = [5, 15, 13] if transpose_b or adjoint_b else [5, 13, 15] - dtypes_to_test = [np.float32] - if not test.is_built_with_rocm(): - # complex type is not supported on the ROCm platform - dtypes_to_test += [np.complex64] + dtypes_to_test = [np.float32, np.complex64] for dtype in dtypes_to_test: a_mats = sparsify((np.random.randn(*dense_shape_a) + 1.j * np.random.randn(*dense_shape_a))).astype(dtype) @@ -198,10 +195,7 @@ class SparseMatrixMatmulTest(test.TestCase): sparsify = lambda m: m * (m > 0) dense_shape_a = [5, 13, 7] if transpose_a or adjoint_a else [5, 7, 13] dense_shape_b = [5, 15, 13] if transpose_b or adjoint_b else [5, 13, 15] - dtypes_to_test = [np.float32] - if not test.is_built_with_rocm(): - # complex type is not supported on the ROCm platform - dtypes_to_test += [np.complex64] + dtypes_to_test = [np.float32, np.complex64] for dtype in dtypes_to_test: a_mats = sparsify((np.random.randn(*dense_shape_a) + 1.j * np.random.randn(*dense_shape_a))).astype(dtype) @@ -239,10 +233,7 @@ class SparseMatrixMatmulTest(test.TestCase): sparsify = lambda m: m * (m > 0) dense_shape_a = [5, 13, 7] if transpose_a or adjoint_a else [5, 7, 13] dense_shape_b = [5, 15, 13] if transpose_b or adjoint_b else [5, 13, 15] - dtypes_to_test = [np.float32] - if not test.is_built_with_rocm(): - # complex type is not supported on the ROCm platform - dtypes_to_test += [np.complex64] + dtypes_to_test = [np.float32, np.complex64] for dtype in dtypes_to_test: a_mats = (np.random.randn(*dense_shape_a) + 1.j * np.random.randn(*dense_shape_a)).astype(dtype) From 15e47e640272ec89ec76098354707969dbcf3c67 Mon Sep 17 00:00:00 2001 From: Eugene Kuznetsov Date: Sat, 1 Feb 2020 01:59:15 -0800 Subject: [PATCH 2/2] Rename a variable --- tensorflow/core/kernels/sparse/mat_mul_op.cc | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tensorflow/core/kernels/sparse/mat_mul_op.cc b/tensorflow/core/kernels/sparse/mat_mul_op.cc index 58a7d3e5a7e..40fc4f27568 100644 --- a/tensorflow/core/kernels/sparse/mat_mul_op.cc +++ b/tensorflow/core/kernels/sparse/mat_mul_op.cc @@ -538,12 +538,13 @@ class CSRMatMulGPUOp : public CSRMatMulOp { OP_REQUIRES_OK(ctx, ctx->allocate_output(0, c_shape, &c_t)); const GPUDevice& d = ctx->eigen_device(); - bool shortcut_ok = (b_outer_dim == 1); + bool use_matrix_vector_multiply = (b_outer_dim == 1); #if TENSORFLOW_USE_ROCM // ROCm hipsparse does not implement csrmv with transposed input a - shortcut_ok = shortcut_ok && !this->transpose_a_; + use_matrix_vector_multiply = use_matrix_vector_multiply && + !this->transpose_a_; #endif - if (shortcut_ok) { + if (use_matrix_vector_multiply) { // Call matrix-vector multiply if b is a vector. TTypes::ConstVec a_dense_shape_comp(a_dense_shape.data() + row_dim, 2);