updating the code after getting reviews about cuda test run

This commit is contained in:
mazharul 2020-12-21 19:52:06 -08:00
parent d47153ffe3
commit 5597539069
2 changed files with 5 additions and 27 deletions

View File

@ -43,7 +43,6 @@ from google.protobuf import text_format
from tensorflow.core.framework import graph_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python import _pywrap_utils
from tensorflow.python import tf2
from tensorflow.python.client import device_lib
from tensorflow.python.client import pywrap_tf_session
@ -1824,12 +1823,6 @@ def _disable_test(execute_func):
return disable_test_impl
# The description is just for documentation purposes.
def disable_nonAVX512f(description): # pylint: disable=unused-argument
"""Execute the test method only if avx512f is supported."""
execute_func = _pywrap_utils.IsBF16SupportedByOneDNNOnThisCPU()
return _disable_test(execute_func)
# The description is just for documentation purposes.
def disable_xla(description): # pylint: disable=unused-argument
"""Execute the test method only if xla is not enabled."""

View File

@ -50,6 +50,7 @@ from tensorflow.python.platform import sysconfig
from tensorflow.python.platform import test
from tensorflow.python.training import adam
from tensorflow.python.training import gradient_descent
from tensorflow.python.util import _pywrap_utils
def _input(shape):
@ -371,6 +372,10 @@ class AutoMixedPrecisionTest(test.TestCase, parameterized.TestCase):
self.skipTest('No GPU is available')
if mode == 'mkl' and not test_util.IsMklEnabled():
self.skipTest('MKL is not enabled')
# Test will fail on machines without AVX512f, e.g., Broadwell
isAVX512f = _pywrap_utils.IsBF16SupportedByOneDNNOnThisCPU()
if mode == 'mkl' and not isAVX512f:
self.skipTest('Skipping test due to non-AVX512f machine')
def _run_simple_loop_test(self, mode, inp, body, out):
"""Runs a test of a simple loop.
@ -428,7 +433,6 @@ class AutoMixedPrecisionTest(test.TestCase, parameterized.TestCase):
@parameterized.parameters(['cuda', 'mkl'])
@test_util.run_deprecated_v1
@test_util.disable_nonAVX512f('Test will fail on machines without AVX512f, e.g., Broadwell')
@test_util.disable_xla('This test does not pass with XLA')
def test_conv_bn(self, mode):
"""Test graph with convolution followed by batch norm."""
@ -460,7 +464,6 @@ class AutoMixedPrecisionTest(test.TestCase, parameterized.TestCase):
@parameterized.parameters(['cuda', 'mkl'])
@test_util.run_deprecated_v1
@test_util.disable_nonAVX512f('Test will fail on machines without AVX512f, e.g., Broadwell')
@test_util.disable_xla('This test does not pass with XLA')
def test_conv3d_bn(self, mode):
"""Test graph with convolution followed by batch norm."""
@ -486,7 +489,6 @@ class AutoMixedPrecisionTest(test.TestCase, parameterized.TestCase):
@parameterized.parameters(['cuda', 'mkl'])
@test_util.run_deprecated_v1
@test_util.disable_nonAVX512f('Test will fail on machines without AVX512f, e.g., Broadwell')
@test_util.disable_xla('This test does not pass with XLA')
def test_conv3d(self, mode):
"""Test grad ops with convolution3d graph."""
@ -517,7 +519,6 @@ class AutoMixedPrecisionTest(test.TestCase, parameterized.TestCase):
@parameterized.parameters(['cuda', 'mkl'])
@test_util.run_deprecated_v1
@test_util.disable_nonAVX512f('Test will fail on machines without AVX512f, e.g., Broadwell')
@test_util.disable_xla('This test does not pass with XLA')
def test_conv_bn_dropout(self, mode):
"""Test dropout precision of convolution batch norm graph."""
@ -578,7 +579,6 @@ class AutoMixedPrecisionTest(test.TestCase, parameterized.TestCase):
# TODO(benbarsdell): This test has not been tried with MKL.
@parameterized.parameters(['cuda'])
@test_util.run_deprecated_v1
@test_util.disable_nonAVX512f('Test will fail on machines without AVX512f, e.g., Broadwell')
@test_util.disable_xla('This test does not pass with XLA')
def test_depthwise_conv2d(self, mode):
"""Test grad ops with depthwise convolution2d graph."""
@ -614,7 +614,6 @@ class AutoMixedPrecisionTest(test.TestCase, parameterized.TestCase):
@parameterized.parameters(['cuda', 'mkl'])
@test_util.run_v1_only('b/138749235')
@test_util.disable_nonAVX512f('Test will fail on machines without AVX512f, e.g., Broadwell')
@test_util.disable_xla('This test does not pass with XLA')
def test_simple_loop(self, mode):
"""Test graph with while loop."""
@ -636,7 +635,6 @@ class AutoMixedPrecisionTest(test.TestCase, parameterized.TestCase):
@parameterized.parameters(['cuda', 'mkl'])
@test_util.run_v1_only('b/138749235')
@test_util.disable_nonAVX512f('Test will fail on machines without AVX512f, e.g., Broadwell')
@test_util.disable_xla('This test does not pass with XLA')
def test_loop_with_vars_intertwined(self, mode):
"""Test graph with intertwined while loops."""
@ -661,7 +659,6 @@ class AutoMixedPrecisionTest(test.TestCase, parameterized.TestCase):
@parameterized.parameters(['cuda'])
@test_util.run_deprecated_v1
@test_util.disable_nonAVX512f('Test will fail on machines without AVX512f, e.g., Broadwell')
@test_util.disable_xla('This test does not pass with XLA')
def test_multi_paths(self, mode):
"""Test graph with multiple paths."""
@ -691,7 +688,6 @@ class AutoMixedPrecisionTest(test.TestCase, parameterized.TestCase):
@parameterized.parameters(['cuda', 'mkl'])
@test_util.run_deprecated_v1
@test_util.disable_nonAVX512f('Test will fail on machines without AVX512f, e.g., Broadwell')
@test_util.disable_xla('This test does not pass with XLA')
def test_multi_paths_2(self, mode):
"""Test graph with multiple paths."""
@ -725,7 +721,6 @@ class AutoMixedPrecisionTest(test.TestCase, parameterized.TestCase):
@parameterized.parameters(['cuda']) # MKL doesn't support bf16 Sigmoid
@test_util.run_v1_only('b/138749235')
@test_util.disable_nonAVX512f('Test will fail on machines without AVX512f, e.g., Broadwell')
@test_util.disable_xla('This test does not pass with XLA')
def test_recurrent_lstm(self, mode):
"""Test graph with recurrent lstm."""
@ -753,63 +748,54 @@ class AutoMixedPrecisionTest(test.TestCase, parameterized.TestCase):
@parameterized.parameters(['cuda', 'mkl'])
@test_util.run_v1_only('v1 loop test')
@test_util.disable_nonAVX512f('Test will fail on machines without AVX512f, e.g., Broadwell')
@test_util.disable_xla('This test does not pass with XLA')
def test_propagation_through_simple_loop_1(self, mode):
self._run_simple_loop_test(mode, 'W', 'C', 'C')
@parameterized.parameters(['cuda', 'mkl'])
@test_util.run_v1_only('v1 loop test')
@test_util.disable_nonAVX512f('Test will fail on machines without AVX512f, e.g., Broadwell')
@test_util.disable_xla('This test does not pass with XLA')
def test_propagation_through_simple_loop_2(self, mode):
self._run_simple_loop_test(mode, 'C', 'C', 'W')
@parameterized.parameters(['cuda', 'mkl'])
@test_util.run_v1_only('v1 loop test')
@test_util.disable_nonAVX512f('Test will fail on machines without AVX512f, e.g., Broadwell')
@test_util.disable_xla('This test does not pass with XLA')
def test_propagation_through_simple_loop_3(self, mode):
self._run_simple_loop_test(mode, 'W', 'G', 'W')
@parameterized.parameters(['cuda', 'mkl'])
@test_util.run_v1_only('v1 loop test')
@test_util.disable_nonAVX512f('Test will fail on machines without AVX512f, e.g., Broadwell')
@test_util.disable_xla('This test does not pass with XLA')
def test_propagation_through_simple_loop_4(self, mode):
self._run_simple_loop_test(mode, 'W', 'gbg', 'W')
@parameterized.parameters(['cuda', 'mkl'])
@test_util.run_v1_only('b/138749235')
@test_util.disable_nonAVX512f('Test will fail on machines without AVX512f, e.g., Broadwell')
@test_util.disable_xla('This test does not pass with XLA')
def test_propagation_through_simple_loop_5(self, mode):
self._run_simple_loop_test(mode, 'b', 'gWC', 'c')
@parameterized.parameters(['cuda', 'mkl'])
@test_util.run_v1_only('b/138749235')
@test_util.disable_nonAVX512f('Test will fail on machines without AVX512f, e.g., Broadwell')
@test_util.disable_xla('This test does not pass with XLA')
def test_propagation_through_simple_loop_6(self, mode):
self._run_simple_loop_test(mode, 'b', 'CWCG', 'C')
@parameterized.parameters(['cuda', 'mkl'])
@test_util.run_v1_only('b/138749235')
@test_util.disable_nonAVX512f('Test will fail on machines without AVX512f, e.g., Broadwell')
@test_util.disable_xla('This test does not pass with XLA')
def test_propagation_through_simple_loop_7(self, mode):
self._run_simple_loop_test(mode, 'C', 'GWCG', 'C')
@parameterized.parameters(['cuda', 'mkl'])
@test_util.run_v1_only('b/138749235')
@test_util.disable_nonAVX512f('Test will fail on machines without AVX512f, e.g., Broadwell')
@test_util.disable_xla('This test does not pass with XLA')
def test_propagation_through_simple_loop_8(self, mode):
self._run_simple_loop_test(mode, 'C', 'CgbgWC', 'g')
@parameterized.parameters(['cuda', 'mkl'])
@test_util.run_deprecated_v1
@test_util.disable_nonAVX512f('Test will fail on machines without AVX512f, e.g., Broadwell')
@test_util.disable_xla('This test does not pass with XLA')
def test_noninlined_funcdef(self, mode):
"""Test graph with non-inlined function subgraph.
@ -838,7 +824,6 @@ class AutoMixedPrecisionTest(test.TestCase, parameterized.TestCase):
@parameterized.parameters(['cuda', 'mkl'])
@test_util.run_deprecated_v1
@test_util.disable_nonAVX512f('Test will fail on machines without AVX512f, e.g., Broadwell')
@test_util.disable_xla('This test does not pass with XLA')
def test_ingraph_train_loop(self, mode):
"""Tests a graph containing a while loop around a training update.