addressing code-review comments
This commit is contained in:
parent
e762347e79
commit
5d1ccc1eee
@ -85,7 +85,6 @@ class CSRSparseMatrixGradTest(test.TestCase):
|
||||
return
|
||||
|
||||
if test.is_built_with_rocm():
|
||||
# sparse-matrix-add op is not yet supported on the ROCm platform
|
||||
self.skipTest("sparse-matrix-add op not supported on ROCm")
|
||||
|
||||
sparsify = lambda m: m * (m > 0)
|
||||
|
@ -433,7 +433,6 @@ class CSRSparseMatrixOpsTest(test.TestCase):
|
||||
return
|
||||
|
||||
if test.is_built_with_rocm():
|
||||
# sparse-matrix-add op is not yet supported on the ROCm platform
|
||||
self.skipTest("sparse-matrix-add op not supported on ROCm")
|
||||
|
||||
a_indices = np.array([[0, 0], [2, 3]])
|
||||
@ -474,7 +473,6 @@ class CSRSparseMatrixOpsTest(test.TestCase):
|
||||
return
|
||||
|
||||
if test.is_built_with_rocm():
|
||||
# sparse-matrix-add op is not yet supported on the ROCm platform
|
||||
self.skipTest("sparse-matrix-add op not supported on ROCm")
|
||||
|
||||
sparsify = lambda m: m * (m > 0)
|
||||
@ -520,7 +518,6 @@ class CSRSparseMatrixOpsTest(test.TestCase):
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testSparseMatrixMatMulConjugateOutput(self):
|
||||
if test.is_built_with_rocm():
|
||||
# complex types are not yet supported on the ROCm platform
|
||||
self.skipTest("complex type not supported on ROCm")
|
||||
|
||||
for shapes in [[(5, 6), (6, 1)], [(5, 6), (6, 2)]]:
|
||||
@ -552,6 +549,8 @@ class CSRSparseMatrixOpsTest(test.TestCase):
|
||||
|
||||
if test.is_built_with_rocm():
|
||||
# TODO(rocm): fix this
|
||||
# This test is currently failing on the ROCm platform
|
||||
# Ren-enable it once the fix is available
|
||||
self.skipTest("hipSPARSE all failure on the ROCm platform")
|
||||
|
||||
sparsify = lambda m: m * (m > 0)
|
||||
@ -612,6 +611,8 @@ class CSRSparseMatrixOpsTest(test.TestCase):
|
||||
|
||||
if test.is_built_with_rocm():
|
||||
# TODO(rocm): fix this
|
||||
# This test is currently failing on the ROCm platform
|
||||
# Ren-enable it once the fix is available
|
||||
self.skipTest("hipSPARSE all failure on the ROCm platform")
|
||||
|
||||
sparsify = lambda m: m * (m > 0)
|
||||
|
4
third_party/gpus/rocm_configure.bzl
vendored
4
third_party/gpus/rocm_configure.bzl
vendored
@ -238,7 +238,7 @@ def _rocm_include_path(repository_ctx, rocm_config):
|
||||
|
||||
return inc_dirs
|
||||
|
||||
def enable_rocm(repository_ctx):
|
||||
def _enable_rocm(repository_ctx):
|
||||
if "TF_NEED_ROCM" in repository_ctx.os.environ:
|
||||
enable_rocm = repository_ctx.os.environ["TF_NEED_ROCM"].strip()
|
||||
if enable_rocm == "1":
|
||||
@ -895,7 +895,7 @@ def _create_remote_rocm_repository(repository_ctx, remote_config_repo):
|
||||
|
||||
def _rocm_autoconf_impl(repository_ctx):
|
||||
"""Implementation of the rocm_autoconf repository rule."""
|
||||
if not enable_rocm(repository_ctx):
|
||||
if not _enable_rocm(repository_ctx):
|
||||
_create_dummy_repository(repository_ctx)
|
||||
elif _TF_ROCM_CONFIG_REPO in repository_ctx.os.environ:
|
||||
_create_remote_rocm_repository(
|
||||
|
Loading…
Reference in New Issue
Block a user