Merge pull request #29815 from lioutasb:extract_image_patches_varsize_grad_fix

PiperOrigin-RevId: 254241100
This commit is contained in:
TensorFlower Gardener 2019-06-20 12:26:32 -07:00
commit a88ea5a35c
2 changed files with 51 additions and 9 deletions

View File

@ -100,8 +100,6 @@ class ExtractImagePatchesGradTest(test.TestCase):
err = gradient_checker.compute_gradient_error(in_val, in_shape,
out_val, out_shape)
print('extract_image_patches gradient err: %.4e' % err)
self.assertLess(err, 1e-4)
@test_util.run_deprecated_v1
@ -124,6 +122,52 @@ class ExtractImagePatchesGradTest(test.TestCase):
# Won't time out.
self.assertIsNotNone(gradients)
def _VariableShapeGradient(self, test_shape_pattern):
"""Use test_shape_pattern to infer which dimensions are of
variable size.
"""
# Set graph seed for determinism.
random_seed = 42
random_seed_lib.set_random_seed(random_seed)
with self.test_session():
for test_case in self._TEST_CASES:
np.random.seed(random_seed)
in_shape = test_case['in_shape']
test_shape = [
x if x is None else y for x, y in zip(test_shape_pattern, in_shape)
]
in_val = array_ops.placeholder(shape=test_shape, dtype=dtypes.float32)
feed_dict = {in_val: np.random.random(in_shape)}
for padding in ['VALID', 'SAME']:
out_val = array_ops.extract_image_patches(in_val, test_case['ksizes'],
test_case['strides'],
test_case['rates'], padding)
out_val_tmp = out_val.eval(feed_dict=feed_dict)
out_shape = out_val_tmp.shape
err = gradient_checker.compute_gradient_error(in_val, in_shape,
out_val, out_shape)
self.assertLess(err, 1e-4)
@test_util.run_deprecated_v1
def test_BxxC_Gradient(self):
self._VariableShapeGradient([-1, None, None, -1])
@test_util.run_deprecated_v1
def test_xHWx_Gradient(self):
self._VariableShapeGradient([None, -1, -1, None])
@test_util.run_deprecated_v1
def test_BHWC_Gradient(self):
self._VariableShapeGradient([-1, -1, -1, -1])
@test_util.run_deprecated_v1
def test_AllNone_Gradient(self):
self._VariableShapeGradient([None, None, None, None])
if __name__ == '__main__':
test.main()

View File

@ -831,12 +831,9 @@ def _QuantizeAndDequantizeV3Grad(_, grad):
@ops.RegisterGradient("ExtractImagePatches")
def _ExtractImagePatchesGrad(op, grad):
batch_size, rows_in, cols_in, channels = [
dim.value for dim in op.inputs[0].shape.dims
]
input_bhwc = array_ops.shape(op.inputs[0])
batch_size = input_bhwc[0]
channels = input_bhwc[3]
input_bhwc = array_ops.shape(op.inputs[0], out_type=dtypes.int64)
batch_size, rows_in, cols_in, channels = input_bhwc[0], input_bhwc[1], \
input_bhwc[2], input_bhwc[3]
# Create indices matrix for input tensor.
# Note that 0 is preserved for padding location,
@ -853,7 +850,8 @@ def _ExtractImagePatchesGrad(op, grad):
op.get_attr("padding"))
# Create indices matrix for output tensor.
_, rows_out, cols_out, _ = [dim.value for dim in op.outputs[0].shape.dims]
output_bhwc = array_ops.shape(op.outputs[0], out_type=dtypes.int64)
rows_out, cols_out = output_bhwc[1], output_bhwc[2]
_, ksize_r, ksize_c, _ = op.get_attr("ksizes")
# Indices for output start from 0.
output_indices_num = rows_out * cols_out * ksize_r * ksize_c