Use parameterized testcase instead of an in-test for loop.

This breaks up the test into smaller tests, which Bazel can more intelligently separate (and shard!). This should help with timeouts.

PiperOrigin-RevId: 266458623
This commit is contained in:
Revan Sopher 2019-08-30 14:40:05 -07:00 committed by TensorFlower Gardener
parent 39684f22f7
commit 6168bd4cc2

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python.framework import constant_op
@ -31,10 +32,10 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
class ExtractVolumePatchesGradTest(test.TestCase):
class ExtractVolumePatchesGradTest(test.TestCase, parameterized.TestCase):
"""Gradient-checking for ExtractVolumePatches op."""
_TEST_CASES = [
@parameterized.parameters([
{
'in_shape': [2, 5, 5, 5, 3],
'ksizes': [1, 1, 1, 1, 1],
@ -55,31 +56,28 @@ class ExtractVolumePatchesGradTest(test.TestCase):
'ksizes': [1, 2, 3, 2, 1],
'strides': [1, 2, 4, 3, 1],
},
]
])
@test_util.run_deprecated_v1
def testGradient(self):
def testGradient(self, in_shape, ksizes, strides):
# Set graph seed for determinism.
random_seed = 42
random_seed_lib.set_random_seed(random_seed)
with self.cached_session():
for test_case in self._TEST_CASES:
np.random.seed(random_seed)
in_shape = test_case['in_shape']
in_val = constant_op.constant(
np.random.random(in_shape), dtype=dtypes.float32)
np.random.seed(random_seed)
in_val = constant_op.constant(
np.random.random(in_shape), dtype=dtypes.float32)
for padding in ['VALID', 'SAME']:
out_val = array_ops.extract_volume_patches(
in_val, test_case['ksizes'], test_case['strides'], padding)
out_shape = out_val.get_shape().as_list()
for padding in ['VALID', 'SAME']:
out_val = array_ops.extract_volume_patches(
in_val, ksizes, strides, padding)
out_shape = out_val.get_shape().as_list()
err = gradient_checker.compute_gradient_error(in_val, in_shape,
out_val, out_shape)
err = gradient_checker.compute_gradient_error(in_val, in_shape,
out_val, out_shape)
print('extract_volume_patches gradient err: %.4e' % err)
self.assertLess(err, 1e-4)
print('extract_volume_patches gradient err: %.4e' % err)
self.assertLess(err, 1e-4)
@test_util.run_deprecated_v1
def testConstructGradientWithLargeVolumess(self):