Merge pull request #35966 from ROCmSoftwarePlatform:google-upstream-sparse-complex
PiperOrigin-RevId: 303368854 Change-Id: If20570eeb12f957557bf0bdd005db6b7bf974f65
This commit is contained in:
commit
bffbe8736e
@ -224,10 +224,8 @@ REGISTER_CPU(complex128)
|
||||
|
||||
REGISTER_GPU(float)
|
||||
REGISTER_GPU(double)
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER_GPU(complex64)
|
||||
REGISTER_GPU(complex128)
|
||||
#endif
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
|
@ -362,10 +362,8 @@ class DenseToCSRSparseMatrixGPUOp : public AsyncOpKernel {
|
||||
|
||||
REGISTER_GPU(GPU, float)
|
||||
REGISTER_GPU(GPU, double)
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER_GPU(GPU, complex64)
|
||||
REGISTER_GPU(GPU, complex128)
|
||||
#endif
|
||||
|
||||
namespace functor {
|
||||
|
||||
|
@ -538,8 +538,13 @@ class CSRMatMulGPUOp : public CSRMatMulOp<GPUDevice, T> {
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, c_shape, &c_t));
|
||||
|
||||
const GPUDevice& d = ctx->eigen_device<GPUDevice>();
|
||||
|
||||
if (b_outer_dim == 1) {
|
||||
bool use_matrix_vector_multiply = (b_outer_dim == 1);
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
// ROCm hipsparse does not implement csrmv with transposed input a
|
||||
use_matrix_vector_multiply =
|
||||
use_matrix_vector_multiply && !this->transpose_a_;
|
||||
#endif
|
||||
if (use_matrix_vector_multiply) {
|
||||
// Call matrix-vector multiply if b is a vector.
|
||||
TTypes<int64>::ConstVec a_dense_shape_comp(a_dense_shape.data() + row_dim,
|
||||
2);
|
||||
|
@ -107,10 +107,8 @@ class CSRMulOp : public OpKernel {
|
||||
|
||||
REGISTER_GPU(float)
|
||||
REGISTER_GPU(double)
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER_GPU(complex64)
|
||||
REGISTER_GPU(complex128)
|
||||
#endif
|
||||
|
||||
#undef REGISTER_GPU
|
||||
|
||||
|
@ -120,10 +120,8 @@ REGISTER(CPU, complex128)
|
||||
|
||||
REGISTER(GPU, float)
|
||||
REGISTER(GPU, double)
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER(GPU, complex64)
|
||||
REGISTER(GPU, complex128)
|
||||
#endif
|
||||
|
||||
#undef REGISTER
|
||||
|
||||
@ -141,10 +139,8 @@ namespace functor {
|
||||
DECLARE_GPU_SPEC(int32);
|
||||
DECLARE_GPU_SPEC(float);
|
||||
DECLARE_GPU_SPEC(double);
|
||||
#if GOOGLE_CUDA
|
||||
DECLARE_GPU_SPEC(complex64);
|
||||
DECLARE_GPU_SPEC(complex128);
|
||||
#endif
|
||||
|
||||
#undef DECLARE_GPU_SPEC
|
||||
} // namespace functor
|
||||
|
@ -328,10 +328,8 @@ extern template struct COOSparseMatrixToCSRSparseMatrix<GPUDevice>;
|
||||
|
||||
REGISTER_GPU(float)
|
||||
REGISTER_GPU(double)
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER_GPU(complex64)
|
||||
REGISTER_GPU(complex128)
|
||||
#endif
|
||||
|
||||
#undef REGISTER_GPU
|
||||
|
||||
|
@ -106,10 +106,7 @@ class CSRSparseMatrixDenseMatMulGradTest(test.TestCase):
|
||||
|
||||
# These tests are refactored from sparse_csr_matrix_grad_test to keep its size
|
||||
# "medium".
|
||||
dtypes_to_test = [np.float32]
|
||||
if not test.is_built_with_rocm:
|
||||
# complex type is not supported on the ROCm platform
|
||||
dtypes_to_test += [np.complex64]
|
||||
dtypes_to_test = [np.float32, np.complex64]
|
||||
for dtype in dtypes_to_test:
|
||||
for (t_a, t_b, adj_a, adj_b, t_out,
|
||||
conj_out) in itertools.product(*(([False, True],) * 6)):
|
||||
|
@ -517,9 +517,6 @@ class CSRSparseMatrixOpsTest(test.TestCase):
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testSparseMatrixMatMulConjugateOutput(self):
|
||||
if test.is_built_with_rocm():
|
||||
self.skipTest("complex type not supported on ROCm")
|
||||
|
||||
for shapes in [[(5, 6), (6, 1)], [(5, 6), (6, 2)]]:
|
||||
a_indices = np.array([[0, 0], [2, 3]])
|
||||
a_values = np.array([1.0 + 1.j, 5.0 - 2.j]).astype(np.complex64)
|
||||
@ -542,17 +539,7 @@ class CSRSparseMatrixOpsTest(test.TestCase):
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testLargeBatchSparseMatrixMatMul(self):
|
||||
dtypes_to_test = [np.float32]
|
||||
if not test.is_built_with_rocm():
|
||||
# complex types is not supported on the ROCm platform
|
||||
dtypes_to_test += [np.complex64]
|
||||
|
||||
if test.is_built_with_rocm():
|
||||
# TODO(rocm): fix this
|
||||
# This test is currently failing on the ROCm platform
|
||||
# Ren-enable it once the fix is available
|
||||
self.skipTest("hipSPARSE all failure on the ROCm platform")
|
||||
|
||||
dtypes_to_test = [np.float32, np.complex64]
|
||||
sparsify = lambda m: m * (m > 0)
|
||||
for dtype in dtypes_to_test:
|
||||
for (transpose_a, transpose_b) in ((False, False), (False, True),
|
||||
|
@ -154,10 +154,7 @@ class SparseMatrixMatmulTest(test.TestCase):
|
||||
sparsify = lambda m: m * (m > 0)
|
||||
dense_shape_a = [5, 13, 7] if transpose_a or adjoint_a else [5, 7, 13]
|
||||
dense_shape_b = [5, 15, 13] if transpose_b or adjoint_b else [5, 13, 15]
|
||||
dtypes_to_test = [np.float32]
|
||||
if not test.is_built_with_rocm():
|
||||
# complex type is not supported on the ROCm platform
|
||||
dtypes_to_test += [np.complex64]
|
||||
dtypes_to_test = [np.float32, np.complex64]
|
||||
for dtype in dtypes_to_test:
|
||||
a_mats = sparsify((np.random.randn(*dense_shape_a) +
|
||||
1.j * np.random.randn(*dense_shape_a))).astype(dtype)
|
||||
@ -198,10 +195,7 @@ class SparseMatrixMatmulTest(test.TestCase):
|
||||
sparsify = lambda m: m * (m > 0)
|
||||
dense_shape_a = [5, 13, 7] if transpose_a or adjoint_a else [5, 7, 13]
|
||||
dense_shape_b = [5, 15, 13] if transpose_b or adjoint_b else [5, 13, 15]
|
||||
dtypes_to_test = [np.float32]
|
||||
if not test.is_built_with_rocm():
|
||||
# complex type is not supported on the ROCm platform
|
||||
dtypes_to_test += [np.complex64]
|
||||
dtypes_to_test = [np.float32, np.complex64]
|
||||
for dtype in dtypes_to_test:
|
||||
a_mats = sparsify((np.random.randn(*dense_shape_a) +
|
||||
1.j * np.random.randn(*dense_shape_a))).astype(dtype)
|
||||
@ -239,10 +233,7 @@ class SparseMatrixMatmulTest(test.TestCase):
|
||||
sparsify = lambda m: m * (m > 0)
|
||||
dense_shape_a = [5, 13, 7] if transpose_a or adjoint_a else [5, 7, 13]
|
||||
dense_shape_b = [5, 15, 13] if transpose_b or adjoint_b else [5, 13, 15]
|
||||
dtypes_to_test = [np.float32]
|
||||
if not test.is_built_with_rocm():
|
||||
# complex type is not supported on the ROCm platform
|
||||
dtypes_to_test += [np.complex64]
|
||||
dtypes_to_test = [np.float32, np.complex64]
|
||||
for dtype in dtypes_to_test:
|
||||
a_mats = (np.random.randn(*dense_shape_a) +
|
||||
1.j * np.random.randn(*dense_shape_a)).astype(dtype)
|
||||
|
Loading…
Reference in New Issue
Block a user