[ROCm] Adding/Removing no_rocm tag to/from tests
Adding/Removing no_rocm tag to/from tests/subtests that are currently failing/passing on the ROCm platform
This commit is contained in:
parent
13178c55c7
commit
ff523f35ce
@ -87,6 +87,12 @@ class MatrixTriangularSolveOpTest(xla_test.XLATestCase):
|
||||
self._VerifyTriangularSolveCombo(a.astype(dtype), b.astype(dtype))
|
||||
|
||||
def testBasicComplexDtypes(self):
|
||||
|
||||
if xla_test.test.is_built_with_rocm():
|
||||
# The folowing subtest invokes the call to "BlasTrsm"
|
||||
# That operation is currently not supported on the ROCm platform
|
||||
self.skipTest("BlasTrsm op for complex types is not supported in ROCm")
|
||||
|
||||
rng = np.random.RandomState(0)
|
||||
a = np.tril(rng.randn(5, 5) + rng.randn(5, 5) * 1j)
|
||||
b = rng.randn(5, 7) + rng.randn(5, 7) * 1j
|
||||
|
||||
@ -567,7 +567,10 @@ cc_library(
|
||||
xla_test(
|
||||
name = "logdet_test",
|
||||
srcs = ["logdet_test.cc"],
|
||||
tags = ["optonly"],
|
||||
tags = [
|
||||
"no_rocm",
|
||||
"optonly",
|
||||
],
|
||||
deps = [
|
||||
":logdet",
|
||||
":matrix",
|
||||
|
||||
@ -58,8 +58,10 @@ tf_cc_test(
|
||||
srcs = [
|
||||
"gemm_rewrite_test.cc",
|
||||
],
|
||||
# TODO(b/148106593): Re-enable this test in OSS.
|
||||
tags = ["no_oss"] + tf_cuda_tests_tags(),
|
||||
tags = tf_cuda_tests_tags() + [
|
||||
"no_oss", # TODO(b/148106593): Re-enable this test in OSS.
|
||||
"no_rocm",
|
||||
],
|
||||
deps = [
|
||||
":gpu_codegen_test",
|
||||
"//tensorflow/compiler/xla:debug_options_flags",
|
||||
@ -141,7 +143,7 @@ tf_cc_test(
|
||||
srcs = [
|
||||
"tree_reduction_rewriter_test.cc",
|
||||
],
|
||||
tags = tf_cuda_tests_tags(),
|
||||
tags = tf_cuda_tests_tags() + ["no_rocm"],
|
||||
deps = [
|
||||
":gpu_codegen_test",
|
||||
"//tensorflow/compiler/xla:debug_options_flags",
|
||||
|
||||
@ -42,8 +42,10 @@ cc_library(
|
||||
tf_cc_test(
|
||||
name = "conv_emitter_test",
|
||||
srcs = ["conv_emitter_test.cc"],
|
||||
# TODO(b/148143101): Test should pass in OSS.
|
||||
tags = ["no_oss"],
|
||||
tags = [
|
||||
"no_oss", # TODO(b/148143101): Test should pass in OSS.
|
||||
"no_rocm",
|
||||
],
|
||||
deps = [
|
||||
":conv_emitter",
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
|
||||
@ -1057,7 +1057,6 @@ cuda_py_test(
|
||||
shard_count = 5,
|
||||
tags = [
|
||||
"multi_and_single_gpu",
|
||||
"no_rocm",
|
||||
"no_windows_gpu", # TODO(b/130551176)
|
||||
"noguitar",
|
||||
],
|
||||
|
||||
@ -364,7 +364,10 @@ tf_py_test(
|
||||
srcs = ["integration_test.py"],
|
||||
python_version = "PY3",
|
||||
shard_count = 16,
|
||||
tags = ["notsan"],
|
||||
tags = [
|
||||
"no_rocm",
|
||||
"notsan",
|
||||
],
|
||||
deps = [
|
||||
":keras",
|
||||
"//tensorflow/python:client_testlib",
|
||||
@ -978,6 +981,7 @@ tf_py_test(
|
||||
python_version = "PY3",
|
||||
shard_count = 8,
|
||||
tags = [
|
||||
"no_rocm",
|
||||
"notsan", # b/67509773
|
||||
],
|
||||
deps = [
|
||||
|
||||
@ -345,6 +345,7 @@ tf_py_test(
|
||||
srcs = ["recurrent_test.py"],
|
||||
python_version = "PY3",
|
||||
shard_count = 10,
|
||||
tags = ["no_rocm"],
|
||||
deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python/keras",
|
||||
@ -386,7 +387,6 @@ cuda_py_test(
|
||||
srcs = ["lstm_v2_test.py"],
|
||||
python_version = "PY3",
|
||||
shard_count = 12,
|
||||
tags = ["no_rocm"],
|
||||
deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python/keras",
|
||||
@ -401,7 +401,6 @@ cuda_py_test(
|
||||
srcs = ["gru_v2_test.py"],
|
||||
python_version = "PY3",
|
||||
shard_count = 12,
|
||||
tags = ["no_rocm"],
|
||||
deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python/keras",
|
||||
|
||||
@ -50,6 +50,9 @@ class BaseFFTOpsTest(test.TestCase):
|
||||
|
||||
def _compare_forward(self, x, rank, fft_length=None, use_placeholder=False,
|
||||
rtol=1e-4, atol=1e-4):
|
||||
if test.is_built_with_rocm() and x.dtype in (np.complex64, np.complex128):
|
||||
self.skipTest("Complex datatype not yet supported in ROCm.")
|
||||
return
|
||||
x_np = self._np_fft(x, rank, fft_length)
|
||||
if use_placeholder:
|
||||
x_ph = array_ops.placeholder(dtype=dtypes.as_dtype(x.dtype))
|
||||
@ -61,6 +64,9 @@ class BaseFFTOpsTest(test.TestCase):
|
||||
|
||||
def _compare_backward(self, x, rank, fft_length=None, use_placeholder=False,
|
||||
rtol=1e-4, atol=1e-4):
|
||||
if test.is_built_with_rocm() and x.dtype in (np.complex64, np.complex128):
|
||||
self.skipTest("Complex datatype not yet supported in ROCm.")
|
||||
return
|
||||
x_np = self._np_ifft(x, rank, fft_length)
|
||||
if use_placeholder:
|
||||
x_ph = array_ops.placeholder(dtype=dtypes.as_dtype(x.dtype))
|
||||
@ -78,6 +84,9 @@ class BaseFFTOpsTest(test.TestCase):
|
||||
|
||||
def _check_grad_complex(self, func, x, y, result_is_complex=True,
|
||||
rtol=1e-2, atol=1e-2):
|
||||
if test.is_built_with_rocm():
|
||||
self.skipTest("Complex datatype not yet supported in ROCm.")
|
||||
return
|
||||
with self.cached_session(use_gpu=True):
|
||||
def f(inx, iny):
|
||||
inx.set_shape(x.shape)
|
||||
@ -174,6 +183,9 @@ class FFTOpsTest(BaseFFTOpsTest, parameterized.TestCase):
|
||||
itertools.product(VALID_FFT_RANKS, range(3),
|
||||
(np.complex64, np.complex128)))
|
||||
def test_basic(self, rank, extra_dims, np_type):
|
||||
if test.is_built_with_rocm():
|
||||
self.skipTest("Complex datatype not yet supported in ROCm.")
|
||||
return
|
||||
dims = rank + extra_dims
|
||||
tol = 1e-4 if np_type == np.complex64 else 1e-8
|
||||
self._compare(
|
||||
|
||||
@ -127,7 +127,6 @@ cuda_py_test(
|
||||
name = "xla_control_flow_ops_test",
|
||||
srcs = ["xla_control_flow_ops_test.py"],
|
||||
tags = [
|
||||
"no_rocm",
|
||||
# XLA is not enabled by default on Mac or Windows.
|
||||
"no_mac",
|
||||
"no_windows",
|
||||
|
||||
@ -267,6 +267,7 @@ py_test(
|
||||
srcs = ["test_file_v2_0.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["no_rocm"],
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
],
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user