Use parameterized testcase instead of an in-test for loop.
This breaks up the test into smaller tests, which Bazel can more intelligently separate (and shard!). This should help with timeouts. PiperOrigin-RevId: 266458623
This commit is contained in:
parent
39684f22f7
commit
6168bd4cc2
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
@ -31,10 +32,10 @@ from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class ExtractVolumePatchesGradTest(test.TestCase):
|
||||
class ExtractVolumePatchesGradTest(test.TestCase, parameterized.TestCase):
|
||||
"""Gradient-checking for ExtractVolumePatches op."""
|
||||
|
||||
_TEST_CASES = [
|
||||
@parameterized.parameters([
|
||||
{
|
||||
'in_shape': [2, 5, 5, 5, 3],
|
||||
'ksizes': [1, 1, 1, 1, 1],
|
||||
@ -55,31 +56,28 @@ class ExtractVolumePatchesGradTest(test.TestCase):
|
||||
'ksizes': [1, 2, 3, 2, 1],
|
||||
'strides': [1, 2, 4, 3, 1],
|
||||
},
|
||||
]
|
||||
|
||||
])
|
||||
@test_util.run_deprecated_v1
|
||||
def testGradient(self):
|
||||
def testGradient(self, in_shape, ksizes, strides):
|
||||
# Set graph seed for determinism.
|
||||
random_seed = 42
|
||||
random_seed_lib.set_random_seed(random_seed)
|
||||
|
||||
with self.cached_session():
|
||||
for test_case in self._TEST_CASES:
|
||||
np.random.seed(random_seed)
|
||||
in_shape = test_case['in_shape']
|
||||
in_val = constant_op.constant(
|
||||
np.random.random(in_shape), dtype=dtypes.float32)
|
||||
np.random.seed(random_seed)
|
||||
in_val = constant_op.constant(
|
||||
np.random.random(in_shape), dtype=dtypes.float32)
|
||||
|
||||
for padding in ['VALID', 'SAME']:
|
||||
out_val = array_ops.extract_volume_patches(
|
||||
in_val, test_case['ksizes'], test_case['strides'], padding)
|
||||
out_shape = out_val.get_shape().as_list()
|
||||
for padding in ['VALID', 'SAME']:
|
||||
out_val = array_ops.extract_volume_patches(
|
||||
in_val, ksizes, strides, padding)
|
||||
out_shape = out_val.get_shape().as_list()
|
||||
|
||||
err = gradient_checker.compute_gradient_error(in_val, in_shape,
|
||||
out_val, out_shape)
|
||||
err = gradient_checker.compute_gradient_error(in_val, in_shape,
|
||||
out_val, out_shape)
|
||||
|
||||
print('extract_volume_patches gradient err: %.4e' % err)
|
||||
self.assertLess(err, 1e-4)
|
||||
print('extract_volume_patches gradient err: %.4e' % err)
|
||||
self.assertLess(err, 1e-4)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testConstructGradientWithLargeVolumess(self):
|
||||
|
Loading…
Reference in New Issue
Block a user