Merge pull request #35966 from ROCmSoftwarePlatform:google-upstream-sparse-complex
PiperOrigin-RevId: 303368854 Change-Id: If20570eeb12f957557bf0bdd005db6b7bf974f65
This commit is contained in:
commit
bffbe8736e
@ -224,10 +224,8 @@ REGISTER_CPU(complex128)
|
|||||||
|
|
||||||
REGISTER_GPU(float)
|
REGISTER_GPU(float)
|
||||||
REGISTER_GPU(double)
|
REGISTER_GPU(double)
|
||||||
#if GOOGLE_CUDA
|
|
||||||
REGISTER_GPU(complex64)
|
REGISTER_GPU(complex64)
|
||||||
REGISTER_GPU(complex128)
|
REGISTER_GPU(complex128)
|
||||||
#endif
|
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
|
@ -362,10 +362,8 @@ class DenseToCSRSparseMatrixGPUOp : public AsyncOpKernel {
|
|||||||
|
|
||||||
REGISTER_GPU(GPU, float)
|
REGISTER_GPU(GPU, float)
|
||||||
REGISTER_GPU(GPU, double)
|
REGISTER_GPU(GPU, double)
|
||||||
#if GOOGLE_CUDA
|
|
||||||
REGISTER_GPU(GPU, complex64)
|
REGISTER_GPU(GPU, complex64)
|
||||||
REGISTER_GPU(GPU, complex128)
|
REGISTER_GPU(GPU, complex128)
|
||||||
#endif
|
|
||||||
|
|
||||||
namespace functor {
|
namespace functor {
|
||||||
|
|
||||||
|
@ -538,8 +538,13 @@ class CSRMatMulGPUOp : public CSRMatMulOp<GPUDevice, T> {
|
|||||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, c_shape, &c_t));
|
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, c_shape, &c_t));
|
||||||
|
|
||||||
const GPUDevice& d = ctx->eigen_device<GPUDevice>();
|
const GPUDevice& d = ctx->eigen_device<GPUDevice>();
|
||||||
|
bool use_matrix_vector_multiply = (b_outer_dim == 1);
|
||||||
if (b_outer_dim == 1) {
|
#if TENSORFLOW_USE_ROCM
|
||||||
|
// ROCm hipsparse does not implement csrmv with transposed input a
|
||||||
|
use_matrix_vector_multiply =
|
||||||
|
use_matrix_vector_multiply && !this->transpose_a_;
|
||||||
|
#endif
|
||||||
|
if (use_matrix_vector_multiply) {
|
||||||
// Call matrix-vector multiply if b is a vector.
|
// Call matrix-vector multiply if b is a vector.
|
||||||
TTypes<int64>::ConstVec a_dense_shape_comp(a_dense_shape.data() + row_dim,
|
TTypes<int64>::ConstVec a_dense_shape_comp(a_dense_shape.data() + row_dim,
|
||||||
2);
|
2);
|
||||||
|
@ -107,10 +107,8 @@ class CSRMulOp : public OpKernel {
|
|||||||
|
|
||||||
REGISTER_GPU(float)
|
REGISTER_GPU(float)
|
||||||
REGISTER_GPU(double)
|
REGISTER_GPU(double)
|
||||||
#if GOOGLE_CUDA
|
|
||||||
REGISTER_GPU(complex64)
|
REGISTER_GPU(complex64)
|
||||||
REGISTER_GPU(complex128)
|
REGISTER_GPU(complex128)
|
||||||
#endif
|
|
||||||
|
|
||||||
#undef REGISTER_GPU
|
#undef REGISTER_GPU
|
||||||
|
|
||||||
|
@ -120,10 +120,8 @@ REGISTER(CPU, complex128)
|
|||||||
|
|
||||||
REGISTER(GPU, float)
|
REGISTER(GPU, float)
|
||||||
REGISTER(GPU, double)
|
REGISTER(GPU, double)
|
||||||
#if GOOGLE_CUDA
|
|
||||||
REGISTER(GPU, complex64)
|
REGISTER(GPU, complex64)
|
||||||
REGISTER(GPU, complex128)
|
REGISTER(GPU, complex128)
|
||||||
#endif
|
|
||||||
|
|
||||||
#undef REGISTER
|
#undef REGISTER
|
||||||
|
|
||||||
@ -141,10 +139,8 @@ namespace functor {
|
|||||||
DECLARE_GPU_SPEC(int32);
|
DECLARE_GPU_SPEC(int32);
|
||||||
DECLARE_GPU_SPEC(float);
|
DECLARE_GPU_SPEC(float);
|
||||||
DECLARE_GPU_SPEC(double);
|
DECLARE_GPU_SPEC(double);
|
||||||
#if GOOGLE_CUDA
|
|
||||||
DECLARE_GPU_SPEC(complex64);
|
DECLARE_GPU_SPEC(complex64);
|
||||||
DECLARE_GPU_SPEC(complex128);
|
DECLARE_GPU_SPEC(complex128);
|
||||||
#endif
|
|
||||||
|
|
||||||
#undef DECLARE_GPU_SPEC
|
#undef DECLARE_GPU_SPEC
|
||||||
} // namespace functor
|
} // namespace functor
|
||||||
|
@ -328,10 +328,8 @@ extern template struct COOSparseMatrixToCSRSparseMatrix<GPUDevice>;
|
|||||||
|
|
||||||
REGISTER_GPU(float)
|
REGISTER_GPU(float)
|
||||||
REGISTER_GPU(double)
|
REGISTER_GPU(double)
|
||||||
#if GOOGLE_CUDA
|
|
||||||
REGISTER_GPU(complex64)
|
REGISTER_GPU(complex64)
|
||||||
REGISTER_GPU(complex128)
|
REGISTER_GPU(complex128)
|
||||||
#endif
|
|
||||||
|
|
||||||
#undef REGISTER_GPU
|
#undef REGISTER_GPU
|
||||||
|
|
||||||
|
@ -106,10 +106,7 @@ class CSRSparseMatrixDenseMatMulGradTest(test.TestCase):
|
|||||||
|
|
||||||
# These tests are refactored from sparse_csr_matrix_grad_test to keep its size
|
# These tests are refactored from sparse_csr_matrix_grad_test to keep its size
|
||||||
# "medium".
|
# "medium".
|
||||||
dtypes_to_test = [np.float32]
|
dtypes_to_test = [np.float32, np.complex64]
|
||||||
if not test.is_built_with_rocm:
|
|
||||||
# complex type is not supported on the ROCm platform
|
|
||||||
dtypes_to_test += [np.complex64]
|
|
||||||
for dtype in dtypes_to_test:
|
for dtype in dtypes_to_test:
|
||||||
for (t_a, t_b, adj_a, adj_b, t_out,
|
for (t_a, t_b, adj_a, adj_b, t_out,
|
||||||
conj_out) in itertools.product(*(([False, True],) * 6)):
|
conj_out) in itertools.product(*(([False, True],) * 6)):
|
||||||
|
@ -517,9 +517,6 @@ class CSRSparseMatrixOpsTest(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def testSparseMatrixMatMulConjugateOutput(self):
|
def testSparseMatrixMatMulConjugateOutput(self):
|
||||||
if test.is_built_with_rocm():
|
|
||||||
self.skipTest("complex type not supported on ROCm")
|
|
||||||
|
|
||||||
for shapes in [[(5, 6), (6, 1)], [(5, 6), (6, 2)]]:
|
for shapes in [[(5, 6), (6, 1)], [(5, 6), (6, 2)]]:
|
||||||
a_indices = np.array([[0, 0], [2, 3]])
|
a_indices = np.array([[0, 0], [2, 3]])
|
||||||
a_values = np.array([1.0 + 1.j, 5.0 - 2.j]).astype(np.complex64)
|
a_values = np.array([1.0 + 1.j, 5.0 - 2.j]).astype(np.complex64)
|
||||||
@ -542,17 +539,7 @@ class CSRSparseMatrixOpsTest(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def testLargeBatchSparseMatrixMatMul(self):
|
def testLargeBatchSparseMatrixMatMul(self):
|
||||||
dtypes_to_test = [np.float32]
|
dtypes_to_test = [np.float32, np.complex64]
|
||||||
if not test.is_built_with_rocm():
|
|
||||||
# complex types is not supported on the ROCm platform
|
|
||||||
dtypes_to_test += [np.complex64]
|
|
||||||
|
|
||||||
if test.is_built_with_rocm():
|
|
||||||
# TODO(rocm): fix this
|
|
||||||
# This test is currently failing on the ROCm platform
|
|
||||||
# Ren-enable it once the fix is available
|
|
||||||
self.skipTest("hipSPARSE all failure on the ROCm platform")
|
|
||||||
|
|
||||||
sparsify = lambda m: m * (m > 0)
|
sparsify = lambda m: m * (m > 0)
|
||||||
for dtype in dtypes_to_test:
|
for dtype in dtypes_to_test:
|
||||||
for (transpose_a, transpose_b) in ((False, False), (False, True),
|
for (transpose_a, transpose_b) in ((False, False), (False, True),
|
||||||
|
@ -154,10 +154,7 @@ class SparseMatrixMatmulTest(test.TestCase):
|
|||||||
sparsify = lambda m: m * (m > 0)
|
sparsify = lambda m: m * (m > 0)
|
||||||
dense_shape_a = [5, 13, 7] if transpose_a or adjoint_a else [5, 7, 13]
|
dense_shape_a = [5, 13, 7] if transpose_a or adjoint_a else [5, 7, 13]
|
||||||
dense_shape_b = [5, 15, 13] if transpose_b or adjoint_b else [5, 13, 15]
|
dense_shape_b = [5, 15, 13] if transpose_b or adjoint_b else [5, 13, 15]
|
||||||
dtypes_to_test = [np.float32]
|
dtypes_to_test = [np.float32, np.complex64]
|
||||||
if not test.is_built_with_rocm():
|
|
||||||
# complex type is not supported on the ROCm platform
|
|
||||||
dtypes_to_test += [np.complex64]
|
|
||||||
for dtype in dtypes_to_test:
|
for dtype in dtypes_to_test:
|
||||||
a_mats = sparsify((np.random.randn(*dense_shape_a) +
|
a_mats = sparsify((np.random.randn(*dense_shape_a) +
|
||||||
1.j * np.random.randn(*dense_shape_a))).astype(dtype)
|
1.j * np.random.randn(*dense_shape_a))).astype(dtype)
|
||||||
@ -198,10 +195,7 @@ class SparseMatrixMatmulTest(test.TestCase):
|
|||||||
sparsify = lambda m: m * (m > 0)
|
sparsify = lambda m: m * (m > 0)
|
||||||
dense_shape_a = [5, 13, 7] if transpose_a or adjoint_a else [5, 7, 13]
|
dense_shape_a = [5, 13, 7] if transpose_a or adjoint_a else [5, 7, 13]
|
||||||
dense_shape_b = [5, 15, 13] if transpose_b or adjoint_b else [5, 13, 15]
|
dense_shape_b = [5, 15, 13] if transpose_b or adjoint_b else [5, 13, 15]
|
||||||
dtypes_to_test = [np.float32]
|
dtypes_to_test = [np.float32, np.complex64]
|
||||||
if not test.is_built_with_rocm():
|
|
||||||
# complex type is not supported on the ROCm platform
|
|
||||||
dtypes_to_test += [np.complex64]
|
|
||||||
for dtype in dtypes_to_test:
|
for dtype in dtypes_to_test:
|
||||||
a_mats = sparsify((np.random.randn(*dense_shape_a) +
|
a_mats = sparsify((np.random.randn(*dense_shape_a) +
|
||||||
1.j * np.random.randn(*dense_shape_a))).astype(dtype)
|
1.j * np.random.randn(*dense_shape_a))).astype(dtype)
|
||||||
@ -239,10 +233,7 @@ class SparseMatrixMatmulTest(test.TestCase):
|
|||||||
sparsify = lambda m: m * (m > 0)
|
sparsify = lambda m: m * (m > 0)
|
||||||
dense_shape_a = [5, 13, 7] if transpose_a or adjoint_a else [5, 7, 13]
|
dense_shape_a = [5, 13, 7] if transpose_a or adjoint_a else [5, 7, 13]
|
||||||
dense_shape_b = [5, 15, 13] if transpose_b or adjoint_b else [5, 13, 15]
|
dense_shape_b = [5, 15, 13] if transpose_b or adjoint_b else [5, 13, 15]
|
||||||
dtypes_to_test = [np.float32]
|
dtypes_to_test = [np.float32, np.complex64]
|
||||||
if not test.is_built_with_rocm():
|
|
||||||
# complex type is not supported on the ROCm platform
|
|
||||||
dtypes_to_test += [np.complex64]
|
|
||||||
for dtype in dtypes_to_test:
|
for dtype in dtypes_to_test:
|
||||||
a_mats = (np.random.randn(*dense_shape_a) +
|
a_mats = (np.random.randn(*dense_shape_a) +
|
||||||
1.j * np.random.randn(*dense_shape_a)).astype(dtype)
|
1.j * np.random.randn(*dense_shape_a)).astype(dtype)
|
||||||
|
Loading…
Reference in New Issue
Block a user