Split the tests so that it doesn't time out.
PiperOrigin-RevId: 178185460
This commit is contained in:
parent
a51cc5801f
commit
8ad62af489
@ -2624,7 +2624,7 @@ cuda_py_test(
|
||||
":nn_grad",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
shard_count = 4,
|
||||
shard_count = 16,
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
|
@ -333,7 +333,7 @@ class BatchNormalizationTest(test.TestCase):
|
||||
self.assertLess(err_grad_x_2, err_tolerance)
|
||||
self.assertLess(err_grad_scale, err_tolerance)
|
||||
|
||||
def testInference(self):
|
||||
def testInferenceShape1(self):
|
||||
x_shape = [1, 1, 6, 1]
|
||||
for dtype in [np.float16, np.float32]:
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
@ -344,6 +344,7 @@ class BatchNormalizationTest(test.TestCase):
|
||||
self._test_inference(
|
||||
x_shape, dtype, [1], np.float32, use_gpu=False, data_format='NHWC')
|
||||
|
||||
def testInferenceShape2(self):
|
||||
x_shape = [1, 1, 6, 2]
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
for dtype in [np.float16, np.float32]:
|
||||
@ -352,12 +353,14 @@ class BatchNormalizationTest(test.TestCase):
|
||||
self._test_inference(
|
||||
x_shape, dtype, [2], np.float32, use_gpu=False, data_format='NHWC')
|
||||
|
||||
def testInferenceShape3(self):
|
||||
x_shape = [1, 2, 1, 6]
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
for dtype in [np.float16, np.float32]:
|
||||
self._test_inference(
|
||||
x_shape, dtype, [2], np.float32, use_gpu=True, data_format='NCHW')
|
||||
|
||||
def testInferenceShape4(self):
|
||||
x_shape = [27, 131, 127, 6]
|
||||
for dtype in [np.float16, np.float32]:
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
@ -368,7 +371,7 @@ class BatchNormalizationTest(test.TestCase):
|
||||
self._test_inference(
|
||||
x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC')
|
||||
|
||||
def testTraining(self):
|
||||
def testTrainingShape1(self):
|
||||
x_shape = [1, 1, 6, 1]
|
||||
for dtype in [np.float16, np.float32]:
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
@ -379,6 +382,7 @@ class BatchNormalizationTest(test.TestCase):
|
||||
self._test_training(
|
||||
x_shape, dtype, [1], np.float32, use_gpu=False, data_format='NHWC')
|
||||
|
||||
def testTrainingShape2(self):
|
||||
x_shape = [1, 1, 6, 2]
|
||||
for dtype in [np.float16, np.float32]:
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
@ -387,12 +391,14 @@ class BatchNormalizationTest(test.TestCase):
|
||||
self._test_training(
|
||||
x_shape, dtype, [2], np.float32, use_gpu=False, data_format='NHWC')
|
||||
|
||||
def testTrainingShape3(self):
|
||||
x_shape = [1, 2, 1, 6]
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
for dtype in [np.float16, np.float32]:
|
||||
self._test_training(
|
||||
x_shape, dtype, [2], np.float32, use_gpu=True, data_format='NCHW')
|
||||
|
||||
def testTrainingShape4(self):
|
||||
x_shape = [27, 131, 127, 6]
|
||||
for dtype in [np.float16, np.float32]:
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
@ -403,7 +409,7 @@ class BatchNormalizationTest(test.TestCase):
|
||||
self._test_training(
|
||||
x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC')
|
||||
|
||||
def testBatchNormGrad(self):
|
||||
def testBatchNormGradShape1(self):
|
||||
for is_training in [True, False]:
|
||||
x_shape = [1, 1, 6, 1]
|
||||
for dtype in [np.float16, np.float32]:
|
||||
@ -430,6 +436,8 @@ class BatchNormalizationTest(test.TestCase):
|
||||
data_format='NHWC',
|
||||
is_training=is_training)
|
||||
|
||||
def testBatchNormGradShape2(self):
|
||||
for is_training in [True, False]:
|
||||
x_shape = [1, 1, 6, 2]
|
||||
for dtype in [np.float16, np.float32]:
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
@ -448,6 +456,8 @@ class BatchNormalizationTest(test.TestCase):
|
||||
data_format='NHWC',
|
||||
is_training=is_training)
|
||||
|
||||
def testBatchNormGradShape3(self):
|
||||
for is_training in [True, False]:
|
||||
x_shape = [1, 2, 1, 6]
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
for dtype in [np.float16, np.float32]:
|
||||
@ -459,6 +469,8 @@ class BatchNormalizationTest(test.TestCase):
|
||||
data_format='NCHW',
|
||||
is_training=is_training)
|
||||
|
||||
def testBatchNormGradShape4(self):
|
||||
for is_training in [True, False]:
|
||||
x_shape = [5, 7, 11, 4]
|
||||
for dtype in [np.float16, np.float32]:
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
@ -515,26 +527,37 @@ class BatchNormalizationTest(test.TestCase):
|
||||
is_training=is_training,
|
||||
err_tolerance=err_tolerance)
|
||||
|
||||
def testBatchNormGradGrad(self):
|
||||
configs = [{
|
||||
def testBatchNormGradGradConfig1(self):
|
||||
config = {
|
||||
'shape': [2, 3, 4, 5],
|
||||
'err_tolerance': 1e-2,
|
||||
'dtype': np.float32,
|
||||
}, {
|
||||
}
|
||||
self._testBatchNormGradGrad(config)
|
||||
|
||||
def testBatchNormGradGradConfig2(self):
|
||||
config = {
|
||||
'shape': [2, 3, 2, 2],
|
||||
'err_tolerance': 1e-3,
|
||||
'dtype': np.float32,
|
||||
}, {
|
||||
}
|
||||
self._testBatchNormGradGrad(config)
|
||||
|
||||
def testBatchNormGradGradConfig3(self):
|
||||
config = {
|
||||
'shape': [2, 3, 4, 5],
|
||||
'err_tolerance': 1e-2,
|
||||
'dtype': np.float16,
|
||||
}, {
|
||||
}
|
||||
self._testBatchNormGradGrad(config)
|
||||
|
||||
def testBatchNormGradGradConfig4(self):
|
||||
config = {
|
||||
'shape': [2, 3, 2, 2],
|
||||
'err_tolerance': 2e-3,
|
||||
'dtype': np.float16,
|
||||
}]
|
||||
for config in configs:
|
||||
self._testBatchNormGradGrad(config)
|
||||
}
|
||||
self._testBatchNormGradGrad(config)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Loading…
Reference in New Issue
Block a user