addressing code-review comments

This commit is contained in:
Deven Desai 2019-12-05 03:20:49 +00:00
parent e762347e79
commit 5d1ccc1eee
3 changed files with 6 additions and 6 deletions

View File

@ -85,7 +85,6 @@ class CSRSparseMatrixGradTest(test.TestCase):
return
if test.is_built_with_rocm():
# sparse-matrix-add op is not yet supported on the ROCm platform
self.skipTest("sparse-matrix-add op not supported on ROCm")
sparsify = lambda m: m * (m > 0)

View File

@ -433,7 +433,6 @@ class CSRSparseMatrixOpsTest(test.TestCase):
return
if test.is_built_with_rocm():
# sparse-matrix-add op is not yet supported on the ROCm platform
self.skipTest("sparse-matrix-add op not supported on ROCm")
a_indices = np.array([[0, 0], [2, 3]])
@ -474,7 +473,6 @@ class CSRSparseMatrixOpsTest(test.TestCase):
return
if test.is_built_with_rocm():
# sparse-matrix-add op is not yet supported on the ROCm platform
self.skipTest("sparse-matrix-add op not supported on ROCm")
sparsify = lambda m: m * (m > 0)
@ -520,7 +518,6 @@ class CSRSparseMatrixOpsTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testSparseMatrixMatMulConjugateOutput(self):
if test.is_built_with_rocm():
# complex types are not yet supported on the ROCm platform
self.skipTest("complex type not supported on ROCm")
for shapes in [[(5, 6), (6, 1)], [(5, 6), (6, 2)]]:
@ -552,6 +549,8 @@ class CSRSparseMatrixOpsTest(test.TestCase):
if test.is_built_with_rocm():
# TODO(rocm): fix this
# This test is currently failing on the ROCm platform
# Ren-enable it once the fix is available
self.skipTest("hipSPARSE all failure on the ROCm platform")
sparsify = lambda m: m * (m > 0)
@ -612,6 +611,8 @@ class CSRSparseMatrixOpsTest(test.TestCase):
if test.is_built_with_rocm():
# TODO(rocm): fix this
# This test is currently failing on the ROCm platform
# Ren-enable it once the fix is available
self.skipTest("hipSPARSE all failure on the ROCm platform")
sparsify = lambda m: m * (m > 0)

View File

@ -238,7 +238,7 @@ def _rocm_include_path(repository_ctx, rocm_config):
return inc_dirs
def enable_rocm(repository_ctx):
def _enable_rocm(repository_ctx):
if "TF_NEED_ROCM" in repository_ctx.os.environ:
enable_rocm = repository_ctx.os.environ["TF_NEED_ROCM"].strip()
if enable_rocm == "1":
@ -895,7 +895,7 @@ def _create_remote_rocm_repository(repository_ctx, remote_config_repo):
def _rocm_autoconf_impl(repository_ctx):
"""Implementation of the rocm_autoconf repository rule."""
if not enable_rocm(repository_ctx):
if not _enable_rocm(repository_ctx):
_create_dummy_repository(repository_ctx)
elif _TF_ROCM_CONFIG_REPO in repository_ctx.os.environ:
_create_remote_rocm_repository(