Enabling tests affected by support of complex GEMM and GEMV
This commit is contained in:
parent
c329f1c502
commit
543db6fc67
@ -3405,7 +3405,6 @@ tf_py_test(
|
|||||||
data = ["//tensorflow/python/kernel_tests/testdata:self_adjoint_eig_op_test_files"],
|
data = ["//tensorflow/python/kernel_tests/testdata:self_adjoint_eig_op_test_files"],
|
||||||
shard_count = 20,
|
shard_count = 20,
|
||||||
tags = [
|
tags = [
|
||||||
"no_rocm", # flaky test
|
|
||||||
"no_windows",
|
"no_windows",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
|
@ -262,10 +262,7 @@ class BatchMatMulBenchmark(test.Benchmark):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
dtypes_to_test = [np.float16, np.float32, np.float64, np.int32]
|
dtypes_to_test = [np.float16, np.float32, np.float64, np.int32, np.complex64, np.complex128]
|
||||||
if not test.is_built_with_rocm():
|
|
||||||
# ROCm does not support BLAS operations for complex types
|
|
||||||
dtypes_to_test += [np.complex64, np.complex128]
|
|
||||||
for dtype_ in dtypes_to_test:
|
for dtype_ in dtypes_to_test:
|
||||||
for adjoint_a_ in False, True:
|
for adjoint_a_ in False, True:
|
||||||
for adjoint_b_ in False, True:
|
for adjoint_b_ in False, True:
|
||||||
|
@ -183,10 +183,7 @@ def _GetEigTest(dtype_, shape_, compute_v_):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
dtypes_to_test = [dtypes_lib.float32, dtypes_lib.float64]
|
dtypes_to_test = [dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.complex64, dtypes_lib.complex128]
|
||||||
if not test.is_built_with_rocm():
|
|
||||||
# ROCm does not support BLAS operations for complex types
|
|
||||||
dtypes_to_test += [dtypes_lib.complex64, dtypes_lib.complex128]
|
|
||||||
for compute_v in True, False:
|
for compute_v in True, False:
|
||||||
for dtype in dtypes_to_test:
|
for dtype in dtypes_to_test:
|
||||||
for size in 1, 2, 5, 10:
|
for size in 1, 2, 5, 10:
|
||||||
|
@ -746,13 +746,6 @@ class ConvolutionDeltaOrthogonalInitializerTest(test.TestCase):
|
|||||||
else:
|
else:
|
||||||
shape = [4, 16, 16, 16, 64]
|
shape = [4, 16, 16, 16, 64]
|
||||||
convolution = convolutional.conv3d
|
convolution = convolutional.conv3d
|
||||||
|
|
||||||
if test.is_built_with_rocm():
|
|
||||||
# This subtest triggers a known bug in ROCm runtime code
|
|
||||||
# The bug has been fixed and will be available in ROCm 2.7
|
|
||||||
# Re-enable this test once ROCm 2.7 is released
|
|
||||||
continue
|
|
||||||
|
|
||||||
inputs = random_ops.random_normal(shape, dtype=dtype)
|
inputs = random_ops.random_normal(shape, dtype=dtype)
|
||||||
inputs_2norm = linalg_ops.norm(inputs)
|
inputs_2norm = linalg_ops.norm(inputs)
|
||||||
outputs = convolution(
|
outputs = convolution(
|
||||||
|
@ -141,8 +141,6 @@ class LinearOperatorAdjointTest(
|
|||||||
full_matrix2, adjoint=True, adjoint_arg=True).to_dense()))
|
full_matrix2, adjoint=True, adjoint_arg=True).to_dense()))
|
||||||
|
|
||||||
def test_matmul_adjoint_complex_operator(self):
|
def test_matmul_adjoint_complex_operator(self):
|
||||||
if test.is_built_with_rocm():
|
|
||||||
self.skipTest("ROCm does not support BLAS operations for complex types")
|
|
||||||
matrix1 = np.random.randn(4, 4) + 1j * np.random.randn(4, 4)
|
matrix1 = np.random.randn(4, 4) + 1j * np.random.randn(4, 4)
|
||||||
matrix2 = np.random.randn(4, 4) + 1j * np.random.randn(4, 4)
|
matrix2 = np.random.randn(4, 4) + 1j * np.random.randn(4, 4)
|
||||||
full_matrix1 = linalg.LinearOperatorFullMatrix(matrix1)
|
full_matrix1 = linalg.LinearOperatorFullMatrix(matrix1)
|
||||||
@ -201,7 +199,7 @@ class LinearOperatorAdjointTest(
|
|||||||
|
|
||||||
def test_solve_adjoint_complex_operator(self):
|
def test_solve_adjoint_complex_operator(self):
|
||||||
if test.is_built_with_rocm():
|
if test.is_built_with_rocm():
|
||||||
self.skipTest("ROCm does not support BLAS operations for complex types")
|
self.skipTest("ROCm does not support BLAS solve operations for complex types")
|
||||||
matrix1 = self.evaluate(linear_operator_test_util.random_tril_matrix(
|
matrix1 = self.evaluate(linear_operator_test_util.random_tril_matrix(
|
||||||
[4, 4], dtype=dtypes.complex128, force_well_conditioned=True) +
|
[4, 4], dtype=dtypes.complex128, force_well_conditioned=True) +
|
||||||
1j * linear_operator_test_util.random_tril_matrix(
|
1j * linear_operator_test_util.random_tril_matrix(
|
||||||
|
@ -357,11 +357,6 @@ class LinearOperatorCirculantTestNonHermitianSpectrum(
|
|||||||
self.evaluate(operator.assert_non_singular())
|
self.evaluate(operator.assert_non_singular())
|
||||||
|
|
||||||
def test_assert_non_singular_does_not_fail_for_non_singular_operator(self):
|
def test_assert_non_singular_does_not_fail_for_non_singular_operator(self):
|
||||||
|
|
||||||
if test.is_built_with_rocm():
|
|
||||||
# ROCm does not yet support BLAS operations with complex types.
|
|
||||||
self.skipTest("ROCm does not support BLAS operations for complex types")
|
|
||||||
|
|
||||||
spectrum = math_ops.cast([-3j, 4 + 0j, 2j + 2], dtypes.complex64)
|
spectrum = math_ops.cast([-3j, 4 + 0j, 2j + 2], dtypes.complex64)
|
||||||
operator = linalg.LinearOperatorCirculant(spectrum)
|
operator = linalg.LinearOperatorCirculant(spectrum)
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
@ -665,11 +660,6 @@ class LinearOperatorCirculant3DTest(test.TestCase):
|
|||||||
yield sess
|
yield sess
|
||||||
|
|
||||||
def test_real_spectrum_gives_self_adjoint_operator(self):
|
def test_real_spectrum_gives_self_adjoint_operator(self):
|
||||||
|
|
||||||
if test.is_built_with_rocm():
|
|
||||||
# ROCm does not yet support BLAS operations with complext types
|
|
||||||
self.skipTest("ROCm does not support BLAS operations for complex types")
|
|
||||||
|
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
# This is a real and hermitian spectrum.
|
# This is a real and hermitian spectrum.
|
||||||
spectrum = linear_operator_test_util.random_normal(
|
spectrum = linear_operator_test_util.random_normal(
|
||||||
@ -686,11 +676,6 @@ class LinearOperatorCirculant3DTest(test.TestCase):
|
|||||||
self.assertAllClose(matrix, matrix_h)
|
self.assertAllClose(matrix, matrix_h)
|
||||||
|
|
||||||
def test_defining_operator_using_real_convolution_kernel(self):
|
def test_defining_operator_using_real_convolution_kernel(self):
|
||||||
|
|
||||||
if test.is_built_with_rocm():
|
|
||||||
# ROCm does not yet support BLAS operations with complext types
|
|
||||||
self.skipTest("ROCm does not support BLAS operations for complex types")
|
|
||||||
|
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
convolution_kernel = linear_operator_test_util.random_normal(
|
convolution_kernel = linear_operator_test_util.random_normal(
|
||||||
shape=(2, 2, 3, 5), dtype=dtypes.float32)
|
shape=(2, 2, 3, 5), dtype=dtypes.float32)
|
||||||
@ -709,11 +694,6 @@ class LinearOperatorCirculant3DTest(test.TestCase):
|
|||||||
np.testing.assert_allclose(0, np.imag(matrix), atol=1e-5)
|
np.testing.assert_allclose(0, np.imag(matrix), atol=1e-5)
|
||||||
|
|
||||||
def test_defining_spd_operator_by_taking_real_part(self):
|
def test_defining_spd_operator_by_taking_real_part(self):
|
||||||
|
|
||||||
if test.is_built_with_rocm():
|
|
||||||
# ROCm does not yet support BLAS operations with complext types
|
|
||||||
self.skipTest("ROCm does not support BLAS operations for complex types")
|
|
||||||
|
|
||||||
with self.cached_session(): # Necessary for fft_kernel_label_map
|
with self.cached_session(): # Necessary for fft_kernel_label_map
|
||||||
# S is real and positive.
|
# S is real and positive.
|
||||||
s = linear_operator_test_util.random_uniform(
|
s = linear_operator_test_util.random_uniform(
|
||||||
|
@ -130,8 +130,6 @@ class LuOpTest(test.TestCase):
|
|||||||
for output_idx_type in (dtypes.int32, dtypes.int64):
|
for output_idx_type in (dtypes.int32, dtypes.int64):
|
||||||
self._verifyLu(data.astype(dtype), output_idx_type=output_idx_type)
|
self._verifyLu(data.astype(dtype), output_idx_type=output_idx_type)
|
||||||
|
|
||||||
if not test.is_built_with_rocm():
|
|
||||||
# ROCm does not support BLAS operations for complex types
|
|
||||||
for dtype in (np.complex64, np.complex128):
|
for dtype in (np.complex64, np.complex128):
|
||||||
for output_idx_type in (dtypes.int32, dtypes.int64):
|
for output_idx_type in (dtypes.int32, dtypes.int64):
|
||||||
complex_data = np.tril(1j * data, -1).astype(dtype)
|
complex_data = np.tril(1j * data, -1).astype(dtype)
|
||||||
@ -152,8 +150,6 @@ class LuOpTest(test.TestCase):
|
|||||||
# Make sure p_val is not the identity permutation.
|
# Make sure p_val is not the identity permutation.
|
||||||
self.assertNotAllClose(np.arange(3), p_val)
|
self.assertNotAllClose(np.arange(3), p_val)
|
||||||
|
|
||||||
if not test.is_built_with_rocm():
|
|
||||||
# ROCm does not support BLAS operations for complex types
|
|
||||||
for dtype in (np.complex64, np.complex128):
|
for dtype in (np.complex64, np.complex128):
|
||||||
complex_data = np.tril(1j * data, -1).astype(dtype)
|
complex_data = np.tril(1j * data, -1).astype(dtype)
|
||||||
complex_data += np.triu(-1j * data, 1).astype(dtype)
|
complex_data += np.triu(-1j * data, 1).astype(dtype)
|
||||||
@ -195,8 +191,6 @@ class LuOpTest(test.TestCase):
|
|||||||
matrices = np.random.rand(batch_size, 5, 5)
|
matrices = np.random.rand(batch_size, 5, 5)
|
||||||
self._verifyLu(matrices)
|
self._verifyLu(matrices)
|
||||||
|
|
||||||
if not test.is_built_with_rocm():
|
|
||||||
# ROCm does not support BLAS operations for complex types
|
|
||||||
# Generate random complex valued matrices.
|
# Generate random complex valued matrices.
|
||||||
np.random.seed(52)
|
np.random.seed(52)
|
||||||
matrices = np.random.rand(batch_size, 5,
|
matrices = np.random.rand(batch_size, 5,
|
||||||
@ -210,8 +204,6 @@ class LuOpTest(test.TestCase):
|
|||||||
data = np.random.rand(n, n)
|
data = np.random.rand(n, n)
|
||||||
self._verifyLu(data)
|
self._verifyLu(data)
|
||||||
|
|
||||||
if not test.is_built_with_rocm():
|
|
||||||
# ROCm does not support BLAS operations for complex types
|
|
||||||
# Generate random complex valued matrices.
|
# Generate random complex valued matrices.
|
||||||
np.random.seed(129)
|
np.random.seed(129)
|
||||||
data = np.random.rand(n, n) + 1j * np.random.rand(n, n)
|
data = np.random.rand(n, n) + 1j * np.random.rand(n, n)
|
||||||
|
@ -226,10 +226,7 @@ class MatMulInfixOperatorTest(test_lib.TestCase):
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
sizes = [1, 3, 5]
|
sizes = [1, 3, 5]
|
||||||
trans_options = [[False, False], [True, False], [False, True]]
|
trans_options = [[False, False], [True, False], [False, True]]
|
||||||
dtypes_to_test = [np.int32, np.int64, np.float16, np.float32, np.float64]
|
dtypes_to_test = [np.int32, np.int64, np.float16, np.float32, np.float64, np.complex64, np.complex128]
|
||||||
if not test_lib.is_built_with_rocm():
|
|
||||||
# ROCm does not support BLAS operations for complex types
|
|
||||||
dtypes_to_test += [np.complex64, np.complex128]
|
|
||||||
# TF2 does not support placeholders under eager so we skip it
|
# TF2 does not support placeholders under eager so we skip it
|
||||||
for use_static_shape in set([True, tf2.enabled()]):
|
for use_static_shape in set([True, tf2.enabled()]):
|
||||||
for dtype in dtypes_to_test:
|
for dtype in dtypes_to_test:
|
||||||
|
@ -91,8 +91,6 @@ class ExponentialOpTest(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testNonsymmetricComplex(self):
|
def testNonsymmetricComplex(self):
|
||||||
if test.is_built_with_rocm():
|
|
||||||
self.skipTest("ROCm does not support BLAS operations for complex types")
|
|
||||||
matrix1 = np.array([[1., 2.], [3., 4.]])
|
matrix1 = np.array([[1., 2.], [3., 4.]])
|
||||||
matrix2 = np.array([[1., 3.], [3., 5.]])
|
matrix2 = np.array([[1., 3.], [3., 5.]])
|
||||||
matrix1 = matrix1.astype(np.complex64)
|
matrix1 = matrix1.astype(np.complex64)
|
||||||
@ -114,8 +112,6 @@ class ExponentialOpTest(test.TestCase):
|
|||||||
self._verifyExponentialReal(self._makeBatch(matrix1, matrix2))
|
self._verifyExponentialReal(self._makeBatch(matrix1, matrix2))
|
||||||
|
|
||||||
def testSymmetricPositiveDefiniteComplex(self):
|
def testSymmetricPositiveDefiniteComplex(self):
|
||||||
if test.is_built_with_rocm():
|
|
||||||
self.skipTest("ROCm does not support BLAS operations for complex types")
|
|
||||||
matrix1 = np.array([[2., 1.], [1., 2.]])
|
matrix1 = np.array([[2., 1.], [1., 2.]])
|
||||||
matrix2 = np.array([[3., -1.], [-1., 3.]])
|
matrix2 = np.array([[3., -1.], [-1., 3.]])
|
||||||
matrix1 = matrix1.astype(np.complex64)
|
matrix1 = matrix1.astype(np.complex64)
|
||||||
|
@ -74,9 +74,6 @@ class InverseOpTest(test.TestCase):
|
|||||||
self._verifyInverseReal(matrix2)
|
self._verifyInverseReal(matrix2)
|
||||||
# A multidimensional batch of 2x2 matrices
|
# A multidimensional batch of 2x2 matrices
|
||||||
self._verifyInverseReal(self._makeBatch(matrix1, matrix2))
|
self._verifyInverseReal(self._makeBatch(matrix1, matrix2))
|
||||||
if not test.is_built_with_rocm():
|
|
||||||
# ROCm does not support BLAS operations for complex types
|
|
||||||
# Complex
|
|
||||||
matrix1 = matrix1.astype(np.complex64)
|
matrix1 = matrix1.astype(np.complex64)
|
||||||
matrix1 += 1j * matrix1
|
matrix1 += 1j * matrix1
|
||||||
matrix2 = matrix2.astype(np.complex64)
|
matrix2 = matrix2.astype(np.complex64)
|
||||||
@ -94,9 +91,6 @@ class InverseOpTest(test.TestCase):
|
|||||||
self._verifyInverseReal(matrix2)
|
self._verifyInverseReal(matrix2)
|
||||||
# A multidimensional batch of 2x2 matrices
|
# A multidimensional batch of 2x2 matrices
|
||||||
self._verifyInverseReal(self._makeBatch(matrix1, matrix2))
|
self._verifyInverseReal(self._makeBatch(matrix1, matrix2))
|
||||||
if not test.is_built_with_rocm():
|
|
||||||
# ROCm does not support BLAS operations for complex types
|
|
||||||
# Complex
|
|
||||||
matrix1 = matrix1.astype(np.complex64)
|
matrix1 = matrix1.astype(np.complex64)
|
||||||
matrix1 += 1j * matrix1
|
matrix1 += 1j * matrix1
|
||||||
matrix2 = matrix2.astype(np.complex64)
|
matrix2 = matrix2.astype(np.complex64)
|
||||||
|
@ -59,8 +59,6 @@ class LogarithmOpTest(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_v1_only("b/120545219")
|
||||||
def testNonsymmetric(self):
|
def testNonsymmetric(self):
|
||||||
if test.is_built_with_rocm():
|
|
||||||
self.skipTest("ROCm does not support BLAS operations for complex types")
|
|
||||||
# 2x2 matrices
|
# 2x2 matrices
|
||||||
matrix1 = np.array([[1., 2.], [3., 4.]])
|
matrix1 = np.array([[1., 2.], [3., 4.]])
|
||||||
matrix2 = np.array([[1., 3.], [3., 5.]])
|
matrix2 = np.array([[1., 3.], [3., 5.]])
|
||||||
@ -75,8 +73,6 @@ class LogarithmOpTest(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_v1_only("b/120545219")
|
||||||
def testSymmetricPositiveDefinite(self):
|
def testSymmetricPositiveDefinite(self):
|
||||||
if test.is_built_with_rocm():
|
|
||||||
self.skipTest("ROCm does not support BLAS operations for complex types")
|
|
||||||
# 2x2 matrices
|
# 2x2 matrices
|
||||||
matrix1 = np.array([[2., 1.], [1., 2.]])
|
matrix1 = np.array([[2., 1.], [1., 2.]])
|
||||||
matrix2 = np.array([[3., -1.], [-1., 3.]])
|
matrix2 = np.array([[3., -1.], [-1., 3.]])
|
||||||
@ -111,8 +107,6 @@ class LogarithmOpTest(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_v1_only("b/120545219")
|
||||||
def testRandomSmallAndLargeComplex64(self):
|
def testRandomSmallAndLargeComplex64(self):
|
||||||
if test.is_built_with_rocm():
|
|
||||||
self.skipTest("ROCm does not support BLAS operations for complex types")
|
|
||||||
np.random.seed(42)
|
np.random.seed(42)
|
||||||
for batch_dims in [(), (1,), (3,), (2, 2)]:
|
for batch_dims in [(), (1,), (3,), (2, 2)]:
|
||||||
for size in 8, 31, 32:
|
for size in 8, 31, 32:
|
||||||
@ -124,8 +118,6 @@ class LogarithmOpTest(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_v1_only("b/120545219")
|
||||||
def testRandomSmallAndLargeComplex128(self):
|
def testRandomSmallAndLargeComplex128(self):
|
||||||
if test.is_built_with_rocm():
|
|
||||||
self.skipTest("ROCm does not support BLAS operations for complex types")
|
|
||||||
np.random.seed(42)
|
np.random.seed(42)
|
||||||
for batch_dims in [(), (1,), (3,), (2, 2)]:
|
for batch_dims in [(), (1,), (3,), (2, 2)]:
|
||||||
for size in 8, 31, 32:
|
for size in 8, 31, 32:
|
||||||
|
@ -59,9 +59,6 @@ class SquareRootOpTest(test.TestCase):
|
|||||||
self._verifySquareRootReal(matrix1)
|
self._verifySquareRootReal(matrix1)
|
||||||
self._verifySquareRootReal(matrix2)
|
self._verifySquareRootReal(matrix2)
|
||||||
self._verifySquareRootReal(self._makeBatch(matrix1, matrix2))
|
self._verifySquareRootReal(self._makeBatch(matrix1, matrix2))
|
||||||
if not test.is_built_with_rocm():
|
|
||||||
# ROCm does not support BLAS operations for complex types
|
|
||||||
# Complex
|
|
||||||
matrix1 = matrix1.astype(np.complex64)
|
matrix1 = matrix1.astype(np.complex64)
|
||||||
matrix2 = matrix2.astype(np.complex64)
|
matrix2 = matrix2.astype(np.complex64)
|
||||||
matrix1 += 1j * matrix1
|
matrix1 += 1j * matrix1
|
||||||
|
@ -240,10 +240,7 @@ def _GetSelfAdjointEigGradTest(dtype_, shape_, compute_v_):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
dtypes_to_test = [dtypes_lib.float32, dtypes_lib.float64]
|
dtypes_to_test = [dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.complex64, dtypes_lib.complex128]
|
||||||
if not test.is_built_with_rocm():
|
|
||||||
# ROCm does not support BLAS operations for complex types
|
|
||||||
dtypes_to_test += [dtypes_lib.complex64, dtypes_lib.complex128]
|
|
||||||
for compute_v in True, False:
|
for compute_v in True, False:
|
||||||
for dtype in dtypes_to_test:
|
for dtype in dtypes_to_test:
|
||||||
for size in 1, 2, 5, 10:
|
for size in 1, 2, 5, 10:
|
||||||
|
@ -123,7 +123,6 @@ cuda_py_tests(
|
|||||||
srcs = ["spectral_ops_test.py"],
|
srcs = ["spectral_ops_test.py"],
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
tags = [
|
tags = [
|
||||||
"no_rocm",
|
|
||||||
"nomac",
|
"nomac",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
|
@ -370,10 +370,7 @@ class SVDBenchmark(test.Benchmark):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
dtypes_to_test = [np.float32, np.float64]
|
dtypes_to_test = [np.float32, np.float64, np.complex64, np.complex128]
|
||||||
if not test.is_built_with_rocm():
|
|
||||||
# ROCm does not support BLAS operations for complex types
|
|
||||||
dtypes_to_test += [np.complex64, np.complex128]
|
|
||||||
for compute_uv in False, True:
|
for compute_uv in False, True:
|
||||||
for full_matrices in False, True:
|
for full_matrices in False, True:
|
||||||
for dtype in dtypes_to_test:
|
for dtype in dtypes_to_test:
|
||||||
@ -392,7 +389,7 @@ if __name__ == "__main__":
|
|||||||
for compute_uv in False, True:
|
for compute_uv in False, True:
|
||||||
for full_matrices in False, True:
|
for full_matrices in False, True:
|
||||||
dtypes = ([np.float32, np.float64] + [np.complex64, np.complex128] *
|
dtypes = ([np.float32, np.float64] + [np.complex64, np.complex128] *
|
||||||
(not compute_uv) * (not test.is_built_with_rocm()))
|
(not compute_uv))
|
||||||
for dtype in dtypes:
|
for dtype in dtypes:
|
||||||
mat_shapes = [(10, 11), (11, 10), (11, 11), (2, 2, 2, 3)]
|
mat_shapes = [(10, 11), (11, 10), (11, 11), (2, 2, 2, 3)]
|
||||||
if not full_matrices or not compute_uv:
|
if not full_matrices or not compute_uv:
|
||||||
|
@ -221,10 +221,7 @@ def _get_tensordot_tests(dtype_, rank_a_, rank_b_, num_dims_, dynamic_shape_):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
dtypes_to_test = [np.float16, np.float32, np.float64]
|
dtypes_to_test = [np.float16, np.float32, np.float64, np.complex64, np.complex128]
|
||||||
if not test_lib.is_built_with_rocm():
|
|
||||||
# ROCm does not support BLAS operations for complex types
|
|
||||||
dtypes_to_test += [np.complex64, np.complex128]
|
|
||||||
for dtype in dtypes_to_test:
|
for dtype in dtypes_to_test:
|
||||||
for rank_a in 1, 2, 4, 5:
|
for rank_a in 1, 2, 4, 5:
|
||||||
for rank_b in 1, 2, 4, 5:
|
for rank_b in 1, 2, 4, 5:
|
||||||
|
@ -338,12 +338,6 @@ class EinsumTest(test.TestCase):
|
|||||||
self._check('ab...,b->ab...', (2, 3, 1, 1, 5), (3,))
|
self._check('ab...,b->ab...', (2, 3, 1, 1, 5), (3,))
|
||||||
|
|
||||||
def test_dtypes(self):
|
def test_dtypes(self):
|
||||||
dtypes = []
|
|
||||||
if test.is_built_with_rocm():
|
|
||||||
# This test triggers the BLAS op calls on the GPU
|
|
||||||
# ROCm does not support BLAS operations for complex types
|
|
||||||
dtypes = [np.float64, np.float32]
|
|
||||||
else:
|
|
||||||
dtypes = [np.float64, np.float32, np.complex64, np.complex128]
|
dtypes = [np.float64, np.float32, np.complex64, np.complex128]
|
||||||
for dtype in dtypes:
|
for dtype in dtypes:
|
||||||
self._check('ij,jk->ik', (2, 2), (2, 2), dtype=dtype)
|
self._check('ij,jk->ik', (2, 2), (2, 2), dtype=dtype)
|
||||||
|
Loading…
Reference in New Issue
Block a user