Merge pull request #35966 from ROCmSoftwarePlatform:google-upstream-sparse-complex

PiperOrigin-RevId: 303368854
Change-Id: If20570eeb12f957557bf0bdd005db6b7bf974f65
This commit is contained in:
TensorFlower Gardener 2020-03-27 11:24:26 -07:00
commit bffbe8736e
9 changed files with 12 additions and 44 deletions

View File

@ -224,10 +224,8 @@ REGISTER_CPU(complex128)
REGISTER_GPU(float) REGISTER_GPU(float)
REGISTER_GPU(double) REGISTER_GPU(double)
#if GOOGLE_CUDA
REGISTER_GPU(complex64) REGISTER_GPU(complex64)
REGISTER_GPU(complex128) REGISTER_GPU(complex128)
#endif
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -362,10 +362,8 @@ class DenseToCSRSparseMatrixGPUOp : public AsyncOpKernel {
REGISTER_GPU(GPU, float) REGISTER_GPU(GPU, float)
REGISTER_GPU(GPU, double) REGISTER_GPU(GPU, double)
#if GOOGLE_CUDA
REGISTER_GPU(GPU, complex64) REGISTER_GPU(GPU, complex64)
REGISTER_GPU(GPU, complex128) REGISTER_GPU(GPU, complex128)
#endif
namespace functor { namespace functor {

View File

@ -538,8 +538,13 @@ class CSRMatMulGPUOp : public CSRMatMulOp<GPUDevice, T> {
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, c_shape, &c_t)); OP_REQUIRES_OK(ctx, ctx->allocate_output(0, c_shape, &c_t));
const GPUDevice& d = ctx->eigen_device<GPUDevice>(); const GPUDevice& d = ctx->eigen_device<GPUDevice>();
bool use_matrix_vector_multiply = (b_outer_dim == 1);
if (b_outer_dim == 1) { #if TENSORFLOW_USE_ROCM
// ROCm hipsparse does not implement csrmv with transposed input a
use_matrix_vector_multiply =
use_matrix_vector_multiply && !this->transpose_a_;
#endif
if (use_matrix_vector_multiply) {
// Call matrix-vector multiply if b is a vector. // Call matrix-vector multiply if b is a vector.
TTypes<int64>::ConstVec a_dense_shape_comp(a_dense_shape.data() + row_dim, TTypes<int64>::ConstVec a_dense_shape_comp(a_dense_shape.data() + row_dim,
2); 2);

View File

@ -107,10 +107,8 @@ class CSRMulOp : public OpKernel {
REGISTER_GPU(float) REGISTER_GPU(float)
REGISTER_GPU(double) REGISTER_GPU(double)
#if GOOGLE_CUDA
REGISTER_GPU(complex64) REGISTER_GPU(complex64)
REGISTER_GPU(complex128) REGISTER_GPU(complex128)
#endif
#undef REGISTER_GPU #undef REGISTER_GPU

View File

@ -120,10 +120,8 @@ REGISTER(CPU, complex128)
REGISTER(GPU, float) REGISTER(GPU, float)
REGISTER(GPU, double) REGISTER(GPU, double)
#if GOOGLE_CUDA
REGISTER(GPU, complex64) REGISTER(GPU, complex64)
REGISTER(GPU, complex128) REGISTER(GPU, complex128)
#endif
#undef REGISTER #undef REGISTER
@ -141,10 +139,8 @@ namespace functor {
DECLARE_GPU_SPEC(int32); DECLARE_GPU_SPEC(int32);
DECLARE_GPU_SPEC(float); DECLARE_GPU_SPEC(float);
DECLARE_GPU_SPEC(double); DECLARE_GPU_SPEC(double);
#if GOOGLE_CUDA
DECLARE_GPU_SPEC(complex64); DECLARE_GPU_SPEC(complex64);
DECLARE_GPU_SPEC(complex128); DECLARE_GPU_SPEC(complex128);
#endif
#undef DECLARE_GPU_SPEC #undef DECLARE_GPU_SPEC
} // namespace functor } // namespace functor

View File

@ -328,10 +328,8 @@ extern template struct COOSparseMatrixToCSRSparseMatrix<GPUDevice>;
REGISTER_GPU(float) REGISTER_GPU(float)
REGISTER_GPU(double) REGISTER_GPU(double)
#if GOOGLE_CUDA
REGISTER_GPU(complex64) REGISTER_GPU(complex64)
REGISTER_GPU(complex128) REGISTER_GPU(complex128)
#endif
#undef REGISTER_GPU #undef REGISTER_GPU

View File

@ -106,10 +106,7 @@ class CSRSparseMatrixDenseMatMulGradTest(test.TestCase):
# These tests are refactored from sparse_csr_matrix_grad_test to keep its size # These tests are refactored from sparse_csr_matrix_grad_test to keep its size
# "medium". # "medium".
dtypes_to_test = [np.float32] dtypes_to_test = [np.float32, np.complex64]
if not test.is_built_with_rocm:
# complex type is not supported on the ROCm platform
dtypes_to_test += [np.complex64]
for dtype in dtypes_to_test: for dtype in dtypes_to_test:
for (t_a, t_b, adj_a, adj_b, t_out, for (t_a, t_b, adj_a, adj_b, t_out,
conj_out) in itertools.product(*(([False, True],) * 6)): conj_out) in itertools.product(*(([False, True],) * 6)):

View File

@ -517,9 +517,6 @@ class CSRSparseMatrixOpsTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testSparseMatrixMatMulConjugateOutput(self): def testSparseMatrixMatMulConjugateOutput(self):
if test.is_built_with_rocm():
self.skipTest("complex type not supported on ROCm")
for shapes in [[(5, 6), (6, 1)], [(5, 6), (6, 2)]]: for shapes in [[(5, 6), (6, 1)], [(5, 6), (6, 2)]]:
a_indices = np.array([[0, 0], [2, 3]]) a_indices = np.array([[0, 0], [2, 3]])
a_values = np.array([1.0 + 1.j, 5.0 - 2.j]).astype(np.complex64) a_values = np.array([1.0 + 1.j, 5.0 - 2.j]).astype(np.complex64)
@ -542,17 +539,7 @@ class CSRSparseMatrixOpsTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testLargeBatchSparseMatrixMatMul(self): def testLargeBatchSparseMatrixMatMul(self):
dtypes_to_test = [np.float32] dtypes_to_test = [np.float32, np.complex64]
if not test.is_built_with_rocm():
# complex types is not supported on the ROCm platform
dtypes_to_test += [np.complex64]
if test.is_built_with_rocm():
# TODO(rocm): fix this
# This test is currently failing on the ROCm platform
# Ren-enable it once the fix is available
self.skipTest("hipSPARSE all failure on the ROCm platform")
sparsify = lambda m: m * (m > 0) sparsify = lambda m: m * (m > 0)
for dtype in dtypes_to_test: for dtype in dtypes_to_test:
for (transpose_a, transpose_b) in ((False, False), (False, True), for (transpose_a, transpose_b) in ((False, False), (False, True),

View File

@ -154,10 +154,7 @@ class SparseMatrixMatmulTest(test.TestCase):
sparsify = lambda m: m * (m > 0) sparsify = lambda m: m * (m > 0)
dense_shape_a = [5, 13, 7] if transpose_a or adjoint_a else [5, 7, 13] dense_shape_a = [5, 13, 7] if transpose_a or adjoint_a else [5, 7, 13]
dense_shape_b = [5, 15, 13] if transpose_b or adjoint_b else [5, 13, 15] dense_shape_b = [5, 15, 13] if transpose_b or adjoint_b else [5, 13, 15]
dtypes_to_test = [np.float32] dtypes_to_test = [np.float32, np.complex64]
if not test.is_built_with_rocm():
# complex type is not supported on the ROCm platform
dtypes_to_test += [np.complex64]
for dtype in dtypes_to_test: for dtype in dtypes_to_test:
a_mats = sparsify((np.random.randn(*dense_shape_a) + a_mats = sparsify((np.random.randn(*dense_shape_a) +
1.j * np.random.randn(*dense_shape_a))).astype(dtype) 1.j * np.random.randn(*dense_shape_a))).astype(dtype)
@ -198,10 +195,7 @@ class SparseMatrixMatmulTest(test.TestCase):
sparsify = lambda m: m * (m > 0) sparsify = lambda m: m * (m > 0)
dense_shape_a = [5, 13, 7] if transpose_a or adjoint_a else [5, 7, 13] dense_shape_a = [5, 13, 7] if transpose_a or adjoint_a else [5, 7, 13]
dense_shape_b = [5, 15, 13] if transpose_b or adjoint_b else [5, 13, 15] dense_shape_b = [5, 15, 13] if transpose_b or adjoint_b else [5, 13, 15]
dtypes_to_test = [np.float32] dtypes_to_test = [np.float32, np.complex64]
if not test.is_built_with_rocm():
# complex type is not supported on the ROCm platform
dtypes_to_test += [np.complex64]
for dtype in dtypes_to_test: for dtype in dtypes_to_test:
a_mats = sparsify((np.random.randn(*dense_shape_a) + a_mats = sparsify((np.random.randn(*dense_shape_a) +
1.j * np.random.randn(*dense_shape_a))).astype(dtype) 1.j * np.random.randn(*dense_shape_a))).astype(dtype)
@ -239,10 +233,7 @@ class SparseMatrixMatmulTest(test.TestCase):
sparsify = lambda m: m * (m > 0) sparsify = lambda m: m * (m > 0)
dense_shape_a = [5, 13, 7] if transpose_a or adjoint_a else [5, 7, 13] dense_shape_a = [5, 13, 7] if transpose_a or adjoint_a else [5, 7, 13]
dense_shape_b = [5, 15, 13] if transpose_b or adjoint_b else [5, 13, 15] dense_shape_b = [5, 15, 13] if transpose_b or adjoint_b else [5, 13, 15]
dtypes_to_test = [np.float32] dtypes_to_test = [np.float32, np.complex64]
if not test.is_built_with_rocm():
# complex type is not supported on the ROCm platform
dtypes_to_test += [np.complex64]
for dtype in dtypes_to_test: for dtype in dtypes_to_test:
a_mats = (np.random.randn(*dense_shape_a) + a_mats = (np.random.randn(*dense_shape_a) +
1.j * np.random.randn(*dense_shape_a)).astype(dtype) 1.j * np.random.randn(*dense_shape_a)).astype(dtype)