Split the tests so that it doesn't time out.

PiperOrigin-RevId: 178185460
This commit is contained in:
Yao Zhang 2017-12-06 18:40:17 -08:00 committed by TensorFlower Gardener
parent a51cc5801f
commit 8ad62af489
2 changed files with 35 additions and 12 deletions

View File

@ -2624,7 +2624,7 @@ cuda_py_test(
":nn_grad",
"//third_party/py/numpy",
],
shard_count = 4,
shard_count = 16,
)
cuda_py_test(

View File

@ -333,7 +333,7 @@ class BatchNormalizationTest(test.TestCase):
self.assertLess(err_grad_x_2, err_tolerance)
self.assertLess(err_grad_scale, err_tolerance)
def testInference(self):
def testInferenceShape1(self):
x_shape = [1, 1, 6, 1]
for dtype in [np.float16, np.float32]:
if test.is_gpu_available(cuda_only=True):
@ -344,6 +344,7 @@ class BatchNormalizationTest(test.TestCase):
self._test_inference(
x_shape, dtype, [1], np.float32, use_gpu=False, data_format='NHWC')
def testInferenceShape2(self):
x_shape = [1, 1, 6, 2]
if test.is_gpu_available(cuda_only=True):
for dtype in [np.float16, np.float32]:
@ -352,12 +353,14 @@ class BatchNormalizationTest(test.TestCase):
self._test_inference(
x_shape, dtype, [2], np.float32, use_gpu=False, data_format='NHWC')
def testInferenceShape3(self):
x_shape = [1, 2, 1, 6]
if test.is_gpu_available(cuda_only=True):
for dtype in [np.float16, np.float32]:
self._test_inference(
x_shape, dtype, [2], np.float32, use_gpu=True, data_format='NCHW')
def testInferenceShape4(self):
x_shape = [27, 131, 127, 6]
for dtype in [np.float16, np.float32]:
if test.is_gpu_available(cuda_only=True):
@ -368,7 +371,7 @@ class BatchNormalizationTest(test.TestCase):
self._test_inference(
x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC')
def testTraining(self):
def testTrainingShape1(self):
x_shape = [1, 1, 6, 1]
for dtype in [np.float16, np.float32]:
if test.is_gpu_available(cuda_only=True):
@ -379,6 +382,7 @@ class BatchNormalizationTest(test.TestCase):
self._test_training(
x_shape, dtype, [1], np.float32, use_gpu=False, data_format='NHWC')
def testTrainingShape2(self):
x_shape = [1, 1, 6, 2]
for dtype in [np.float16, np.float32]:
if test.is_gpu_available(cuda_only=True):
@ -387,12 +391,14 @@ class BatchNormalizationTest(test.TestCase):
self._test_training(
x_shape, dtype, [2], np.float32, use_gpu=False, data_format='NHWC')
def testTrainingShape3(self):
x_shape = [1, 2, 1, 6]
if test.is_gpu_available(cuda_only=True):
for dtype in [np.float16, np.float32]:
self._test_training(
x_shape, dtype, [2], np.float32, use_gpu=True, data_format='NCHW')
def testTrainingShape4(self):
x_shape = [27, 131, 127, 6]
for dtype in [np.float16, np.float32]:
if test.is_gpu_available(cuda_only=True):
@ -403,7 +409,7 @@ class BatchNormalizationTest(test.TestCase):
self._test_training(
x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC')
def testBatchNormGrad(self):
def testBatchNormGradShape1(self):
for is_training in [True, False]:
x_shape = [1, 1, 6, 1]
for dtype in [np.float16, np.float32]:
@ -430,6 +436,8 @@ class BatchNormalizationTest(test.TestCase):
data_format='NHWC',
is_training=is_training)
def testBatchNormGradShape2(self):
for is_training in [True, False]:
x_shape = [1, 1, 6, 2]
for dtype in [np.float16, np.float32]:
if test.is_gpu_available(cuda_only=True):
@ -448,6 +456,8 @@ class BatchNormalizationTest(test.TestCase):
data_format='NHWC',
is_training=is_training)
def testBatchNormGradShape3(self):
for is_training in [True, False]:
x_shape = [1, 2, 1, 6]
if test.is_gpu_available(cuda_only=True):
for dtype in [np.float16, np.float32]:
@ -459,6 +469,8 @@ class BatchNormalizationTest(test.TestCase):
data_format='NCHW',
is_training=is_training)
def testBatchNormGradShape4(self):
for is_training in [True, False]:
x_shape = [5, 7, 11, 4]
for dtype in [np.float16, np.float32]:
if test.is_gpu_available(cuda_only=True):
@ -515,26 +527,37 @@ class BatchNormalizationTest(test.TestCase):
is_training=is_training,
err_tolerance=err_tolerance)
def testBatchNormGradGrad(self):
configs = [{
def testBatchNormGradGradConfig1(self):
config = {
'shape': [2, 3, 4, 5],
'err_tolerance': 1e-2,
'dtype': np.float32,
}, {
}
self._testBatchNormGradGrad(config)
def testBatchNormGradGradConfig2(self):
config = {
'shape': [2, 3, 2, 2],
'err_tolerance': 1e-3,
'dtype': np.float32,
}, {
}
self._testBatchNormGradGrad(config)
def testBatchNormGradGradConfig3(self):
config = {
'shape': [2, 3, 4, 5],
'err_tolerance': 1e-2,
'dtype': np.float16,
}, {
}
self._testBatchNormGradGrad(config)
def testBatchNormGradGradConfig4(self):
config = {
'shape': [2, 3, 2, 2],
'err_tolerance': 2e-3,
'dtype': np.float16,
}]
for config in configs:
self._testBatchNormGradGrad(config)
}
self._testBatchNormGradGrad(config)
if __name__ == '__main__':