Add 'test_util.IsMklEnabled()' to guard tests to only mkl.

This commit is contained in:
wenxizhu 2019-07-08 15:47:58 +08:00
parent b47c1aafac
commit b1fca4d98c

View File

@ -32,6 +32,7 @@ from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import nn_ops from tensorflow.python.ops import nn_ops
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.framework import test_util
def GetTestConfigs(): def GetTestConfigs():
@ -220,6 +221,7 @@ class Conv3DTest(test.TestCase):
expected=expected_output) expected=expected_output)
def testConv3D1x1x1Filter2x1x1Dilation(self): def testConv3D1x1x1Filter2x1x1Dilation(self):
if test.is_gpu_available(cuda_only=True) or test_util.IsMklEnabled():
self._VerifyDilatedConvValues( self._VerifyDilatedConvValues(
tensor_in_sizes=[1, 3, 6, 1, 1], tensor_in_sizes=[1, 3, 6, 1, 1],
filter_in_sizes=[1, 1, 1, 1, 1], filter_in_sizes=[1, 1, 1, 1, 1],
@ -244,6 +246,7 @@ class Conv3DTest(test.TestCase):
expected=expected_output) expected=expected_output)
def testConv3D2x2x2Filter1x2x1Dilation(self): def testConv3D2x2x2Filter1x2x1Dilation(self):
if test.is_gpu_available(cuda_only=True) or test_util.IsMklEnabled():
self._VerifyDilatedConvValues( self._VerifyDilatedConvValues(
tensor_in_sizes=[1, 4, 6, 3, 1], tensor_in_sizes=[1, 4, 6, 3, 1],
filter_in_sizes=[2, 2, 2, 1, 1], filter_in_sizes=[2, 2, 2, 1, 1],