updating the code after getting reviews about cuda test run
This commit is contained in:
parent
d47153ffe3
commit
5597539069
@ -43,7 +43,6 @@ from google.protobuf import text_format
|
||||
|
||||
from tensorflow.core.framework import graph_pb2
|
||||
from tensorflow.core.protobuf import rewriter_config_pb2
|
||||
from tensorflow.python import _pywrap_utils
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.client import device_lib
|
||||
from tensorflow.python.client import pywrap_tf_session
|
||||
@ -1824,12 +1823,6 @@ def _disable_test(execute_func):
|
||||
return disable_test_impl
|
||||
|
||||
|
||||
# The description is just for documentation purposes.
|
||||
def disable_nonAVX512f(description): # pylint: disable=unused-argument
|
||||
"""Execute the test method only if avx512f is supported."""
|
||||
execute_func = _pywrap_utils.IsBF16SupportedByOneDNNOnThisCPU()
|
||||
return _disable_test(execute_func)
|
||||
|
||||
# The description is just for documentation purposes.
|
||||
def disable_xla(description): # pylint: disable=unused-argument
|
||||
"""Execute the test method only if xla is not enabled."""
|
||||
|
@ -50,6 +50,7 @@ from tensorflow.python.platform import sysconfig
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import adam
|
||||
from tensorflow.python.training import gradient_descent
|
||||
from tensorflow.python.util import _pywrap_utils
|
||||
|
||||
|
||||
def _input(shape):
|
||||
@ -371,6 +372,10 @@ class AutoMixedPrecisionTest(test.TestCase, parameterized.TestCase):
|
||||
self.skipTest('No GPU is available')
|
||||
if mode == 'mkl' and not test_util.IsMklEnabled():
|
||||
self.skipTest('MKL is not enabled')
|
||||
# Test will fail on machines without AVX512f, e.g., Broadwell
|
||||
isAVX512f = _pywrap_utils.IsBF16SupportedByOneDNNOnThisCPU()
|
||||
if mode == 'mkl' and not isAVX512f:
|
||||
self.skipTest('Skipping test due to non-AVX512f machine')
|
||||
|
||||
def _run_simple_loop_test(self, mode, inp, body, out):
|
||||
"""Runs a test of a simple loop.
|
||||
@ -428,7 +433,6 @@ class AutoMixedPrecisionTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters(['cuda', 'mkl'])
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_nonAVX512f('Test will fail on machines without AVX512f, e.g., Broadwell')
|
||||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_conv_bn(self, mode):
|
||||
"""Test graph with convolution followed by batch norm."""
|
||||
@ -460,7 +464,6 @@ class AutoMixedPrecisionTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters(['cuda', 'mkl'])
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_nonAVX512f('Test will fail on machines without AVX512f, e.g., Broadwell')
|
||||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_conv3d_bn(self, mode):
|
||||
"""Test graph with convolution followed by batch norm."""
|
||||
@ -486,7 +489,6 @@ class AutoMixedPrecisionTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters(['cuda', 'mkl'])
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_nonAVX512f('Test will fail on machines without AVX512f, e.g., Broadwell')
|
||||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_conv3d(self, mode):
|
||||
"""Test grad ops with convolution3d graph."""
|
||||
@ -517,7 +519,6 @@ class AutoMixedPrecisionTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters(['cuda', 'mkl'])
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_nonAVX512f('Test will fail on machines without AVX512f, e.g., Broadwell')
|
||||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_conv_bn_dropout(self, mode):
|
||||
"""Test dropout precision of convolution batch norm graph."""
|
||||
@ -578,7 +579,6 @@ class AutoMixedPrecisionTest(test.TestCase, parameterized.TestCase):
|
||||
# TODO(benbarsdell): This test has not been tried with MKL.
|
||||
@parameterized.parameters(['cuda'])
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_nonAVX512f('Test will fail on machines without AVX512f, e.g., Broadwell')
|
||||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_depthwise_conv2d(self, mode):
|
||||
"""Test grad ops with depthwise convolution2d graph."""
|
||||
@ -614,7 +614,6 @@ class AutoMixedPrecisionTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters(['cuda', 'mkl'])
|
||||
@test_util.run_v1_only('b/138749235')
|
||||
@test_util.disable_nonAVX512f('Test will fail on machines without AVX512f, e.g., Broadwell')
|
||||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_simple_loop(self, mode):
|
||||
"""Test graph with while loop."""
|
||||
@ -636,7 +635,6 @@ class AutoMixedPrecisionTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters(['cuda', 'mkl'])
|
||||
@test_util.run_v1_only('b/138749235')
|
||||
@test_util.disable_nonAVX512f('Test will fail on machines without AVX512f, e.g., Broadwell')
|
||||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_loop_with_vars_intertwined(self, mode):
|
||||
"""Test graph with intertwined while loops."""
|
||||
@ -661,7 +659,6 @@ class AutoMixedPrecisionTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters(['cuda'])
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_nonAVX512f('Test will fail on machines without AVX512f, e.g., Broadwell')
|
||||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_multi_paths(self, mode):
|
||||
"""Test graph with multiple paths."""
|
||||
@ -691,7 +688,6 @@ class AutoMixedPrecisionTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters(['cuda', 'mkl'])
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_nonAVX512f('Test will fail on machines without AVX512f, e.g., Broadwell')
|
||||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_multi_paths_2(self, mode):
|
||||
"""Test graph with multiple paths."""
|
||||
@ -725,7 +721,6 @@ class AutoMixedPrecisionTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters(['cuda']) # MKL doesn't support bf16 Sigmoid
|
||||
@test_util.run_v1_only('b/138749235')
|
||||
@test_util.disable_nonAVX512f('Test will fail on machines without AVX512f, e.g., Broadwell')
|
||||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_recurrent_lstm(self, mode):
|
||||
"""Test graph with recurrent lstm."""
|
||||
@ -753,63 +748,54 @@ class AutoMixedPrecisionTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters(['cuda', 'mkl'])
|
||||
@test_util.run_v1_only('v1 loop test')
|
||||
@test_util.disable_nonAVX512f('Test will fail on machines without AVX512f, e.g., Broadwell')
|
||||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_propagation_through_simple_loop_1(self, mode):
|
||||
self._run_simple_loop_test(mode, 'W', 'C', 'C')
|
||||
|
||||
@parameterized.parameters(['cuda', 'mkl'])
|
||||
@test_util.run_v1_only('v1 loop test')
|
||||
@test_util.disable_nonAVX512f('Test will fail on machines without AVX512f, e.g., Broadwell')
|
||||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_propagation_through_simple_loop_2(self, mode):
|
||||
self._run_simple_loop_test(mode, 'C', 'C', 'W')
|
||||
|
||||
@parameterized.parameters(['cuda', 'mkl'])
|
||||
@test_util.run_v1_only('v1 loop test')
|
||||
@test_util.disable_nonAVX512f('Test will fail on machines without AVX512f, e.g., Broadwell')
|
||||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_propagation_through_simple_loop_3(self, mode):
|
||||
self._run_simple_loop_test(mode, 'W', 'G', 'W')
|
||||
|
||||
@parameterized.parameters(['cuda', 'mkl'])
|
||||
@test_util.run_v1_only('v1 loop test')
|
||||
@test_util.disable_nonAVX512f('Test will fail on machines without AVX512f, e.g., Broadwell')
|
||||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_propagation_through_simple_loop_4(self, mode):
|
||||
self._run_simple_loop_test(mode, 'W', 'gbg', 'W')
|
||||
|
||||
@parameterized.parameters(['cuda', 'mkl'])
|
||||
@test_util.run_v1_only('b/138749235')
|
||||
@test_util.disable_nonAVX512f('Test will fail on machines without AVX512f, e.g., Broadwell')
|
||||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_propagation_through_simple_loop_5(self, mode):
|
||||
self._run_simple_loop_test(mode, 'b', 'gWC', 'c')
|
||||
|
||||
@parameterized.parameters(['cuda', 'mkl'])
|
||||
@test_util.run_v1_only('b/138749235')
|
||||
@test_util.disable_nonAVX512f('Test will fail on machines without AVX512f, e.g., Broadwell')
|
||||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_propagation_through_simple_loop_6(self, mode):
|
||||
self._run_simple_loop_test(mode, 'b', 'CWCG', 'C')
|
||||
|
||||
@parameterized.parameters(['cuda', 'mkl'])
|
||||
@test_util.run_v1_only('b/138749235')
|
||||
@test_util.disable_nonAVX512f('Test will fail on machines without AVX512f, e.g., Broadwell')
|
||||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_propagation_through_simple_loop_7(self, mode):
|
||||
self._run_simple_loop_test(mode, 'C', 'GWCG', 'C')
|
||||
|
||||
@parameterized.parameters(['cuda', 'mkl'])
|
||||
@test_util.run_v1_only('b/138749235')
|
||||
@test_util.disable_nonAVX512f('Test will fail on machines without AVX512f, e.g., Broadwell')
|
||||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_propagation_through_simple_loop_8(self, mode):
|
||||
self._run_simple_loop_test(mode, 'C', 'CgbgWC', 'g')
|
||||
|
||||
@parameterized.parameters(['cuda', 'mkl'])
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_nonAVX512f('Test will fail on machines without AVX512f, e.g., Broadwell')
|
||||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_noninlined_funcdef(self, mode):
|
||||
"""Test graph with non-inlined function subgraph.
|
||||
@ -838,7 +824,6 @@ class AutoMixedPrecisionTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters(['cuda', 'mkl'])
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_nonAVX512f('Test will fail on machines without AVX512f, e.g., Broadwell')
|
||||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_ingraph_train_loop(self, mode):
|
||||
"""Tests a graph containing a while loop around a training update.
|
||||
|
Loading…
Reference in New Issue
Block a user