Use parameterized testcase instead of an in-test for loop.
This breaks up the test into smaller tests, which Bazel can more intelligently separate (and shard!). This should help with timeouts. PiperOrigin-RevId: 266458623
This commit is contained in:
parent
39684f22f7
commit
6168bd4cc2
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
@ -31,10 +32,10 @@ from tensorflow.python.ops import variable_scope
|
|||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
class ExtractVolumePatchesGradTest(test.TestCase):
|
class ExtractVolumePatchesGradTest(test.TestCase, parameterized.TestCase):
|
||||||
"""Gradient-checking for ExtractVolumePatches op."""
|
"""Gradient-checking for ExtractVolumePatches op."""
|
||||||
|
|
||||||
_TEST_CASES = [
|
@parameterized.parameters([
|
||||||
{
|
{
|
||||||
'in_shape': [2, 5, 5, 5, 3],
|
'in_shape': [2, 5, 5, 5, 3],
|
||||||
'ksizes': [1, 1, 1, 1, 1],
|
'ksizes': [1, 1, 1, 1, 1],
|
||||||
@ -55,24 +56,21 @@ class ExtractVolumePatchesGradTest(test.TestCase):
|
|||||||
'ksizes': [1, 2, 3, 2, 1],
|
'ksizes': [1, 2, 3, 2, 1],
|
||||||
'strides': [1, 2, 4, 3, 1],
|
'strides': [1, 2, 4, 3, 1],
|
||||||
},
|
},
|
||||||
]
|
])
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testGradient(self):
|
def testGradient(self, in_shape, ksizes, strides):
|
||||||
# Set graph seed for determinism.
|
# Set graph seed for determinism.
|
||||||
random_seed = 42
|
random_seed = 42
|
||||||
random_seed_lib.set_random_seed(random_seed)
|
random_seed_lib.set_random_seed(random_seed)
|
||||||
|
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
for test_case in self._TEST_CASES:
|
|
||||||
np.random.seed(random_seed)
|
np.random.seed(random_seed)
|
||||||
in_shape = test_case['in_shape']
|
|
||||||
in_val = constant_op.constant(
|
in_val = constant_op.constant(
|
||||||
np.random.random(in_shape), dtype=dtypes.float32)
|
np.random.random(in_shape), dtype=dtypes.float32)
|
||||||
|
|
||||||
for padding in ['VALID', 'SAME']:
|
for padding in ['VALID', 'SAME']:
|
||||||
out_val = array_ops.extract_volume_patches(
|
out_val = array_ops.extract_volume_patches(
|
||||||
in_val, test_case['ksizes'], test_case['strides'], padding)
|
in_val, ksizes, strides, padding)
|
||||||
out_shape = out_val.get_shape().as_list()
|
out_shape = out_val.get_shape().as_list()
|
||||||
|
|
||||||
err = gradient_checker.compute_gradient_error(in_val, in_shape,
|
err = gradient_checker.compute_gradient_error(in_val, in_shape,
|
||||||
|
Loading…
Reference in New Issue
Block a user