Merge pull request #29815 from lioutasb:extract_image_patches_varsize_grad_fix
PiperOrigin-RevId: 254241100
This commit is contained in:
commit
a88ea5a35c
tensorflow/python
@ -100,8 +100,6 @@ class ExtractImagePatchesGradTest(test.TestCase):
|
|||||||
|
|
||||||
err = gradient_checker.compute_gradient_error(in_val, in_shape,
|
err = gradient_checker.compute_gradient_error(in_val, in_shape,
|
||||||
out_val, out_shape)
|
out_val, out_shape)
|
||||||
|
|
||||||
print('extract_image_patches gradient err: %.4e' % err)
|
|
||||||
self.assertLess(err, 1e-4)
|
self.assertLess(err, 1e-4)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
@ -124,6 +122,52 @@ class ExtractImagePatchesGradTest(test.TestCase):
|
|||||||
# Won't time out.
|
# Won't time out.
|
||||||
self.assertIsNotNone(gradients)
|
self.assertIsNotNone(gradients)
|
||||||
|
|
||||||
|
def _VariableShapeGradient(self, test_shape_pattern):
|
||||||
|
"""Use test_shape_pattern to infer which dimensions are of
|
||||||
|
|
||||||
|
variable size.
|
||||||
|
"""
|
||||||
|
# Set graph seed for determinism.
|
||||||
|
random_seed = 42
|
||||||
|
random_seed_lib.set_random_seed(random_seed)
|
||||||
|
|
||||||
|
with self.test_session():
|
||||||
|
for test_case in self._TEST_CASES:
|
||||||
|
np.random.seed(random_seed)
|
||||||
|
in_shape = test_case['in_shape']
|
||||||
|
test_shape = [
|
||||||
|
x if x is None else y for x, y in zip(test_shape_pattern, in_shape)
|
||||||
|
]
|
||||||
|
in_val = array_ops.placeholder(shape=test_shape, dtype=dtypes.float32)
|
||||||
|
|
||||||
|
feed_dict = {in_val: np.random.random(in_shape)}
|
||||||
|
for padding in ['VALID', 'SAME']:
|
||||||
|
out_val = array_ops.extract_image_patches(in_val, test_case['ksizes'],
|
||||||
|
test_case['strides'],
|
||||||
|
test_case['rates'], padding)
|
||||||
|
out_val_tmp = out_val.eval(feed_dict=feed_dict)
|
||||||
|
out_shape = out_val_tmp.shape
|
||||||
|
|
||||||
|
err = gradient_checker.compute_gradient_error(in_val, in_shape,
|
||||||
|
out_val, out_shape)
|
||||||
|
self.assertLess(err, 1e-4)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def test_BxxC_Gradient(self):
|
||||||
|
self._VariableShapeGradient([-1, None, None, -1])
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def test_xHWx_Gradient(self):
|
||||||
|
self._VariableShapeGradient([None, -1, -1, None])
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def test_BHWC_Gradient(self):
|
||||||
|
self._VariableShapeGradient([-1, -1, -1, -1])
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def test_AllNone_Gradient(self):
|
||||||
|
self._VariableShapeGradient([None, None, None, None])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -831,12 +831,9 @@ def _QuantizeAndDequantizeV3Grad(_, grad):
|
|||||||
|
|
||||||
@ops.RegisterGradient("ExtractImagePatches")
|
@ops.RegisterGradient("ExtractImagePatches")
|
||||||
def _ExtractImagePatchesGrad(op, grad):
|
def _ExtractImagePatchesGrad(op, grad):
|
||||||
batch_size, rows_in, cols_in, channels = [
|
input_bhwc = array_ops.shape(op.inputs[0], out_type=dtypes.int64)
|
||||||
dim.value for dim in op.inputs[0].shape.dims
|
batch_size, rows_in, cols_in, channels = input_bhwc[0], input_bhwc[1], \
|
||||||
]
|
input_bhwc[2], input_bhwc[3]
|
||||||
input_bhwc = array_ops.shape(op.inputs[0])
|
|
||||||
batch_size = input_bhwc[0]
|
|
||||||
channels = input_bhwc[3]
|
|
||||||
|
|
||||||
# Create indices matrix for input tensor.
|
# Create indices matrix for input tensor.
|
||||||
# Note that 0 is preserved for padding location,
|
# Note that 0 is preserved for padding location,
|
||||||
@ -853,7 +850,8 @@ def _ExtractImagePatchesGrad(op, grad):
|
|||||||
op.get_attr("padding"))
|
op.get_attr("padding"))
|
||||||
|
|
||||||
# Create indices matrix for output tensor.
|
# Create indices matrix for output tensor.
|
||||||
_, rows_out, cols_out, _ = [dim.value for dim in op.outputs[0].shape.dims]
|
output_bhwc = array_ops.shape(op.outputs[0], out_type=dtypes.int64)
|
||||||
|
rows_out, cols_out = output_bhwc[1], output_bhwc[2]
|
||||||
_, ksize_r, ksize_c, _ = op.get_attr("ksizes")
|
_, ksize_r, ksize_c, _ = op.get_attr("ksizes")
|
||||||
# Indices for output start from 0.
|
# Indices for output start from 0.
|
||||||
output_indices_num = rows_out * cols_out * ksize_r * ksize_c
|
output_indices_num = rows_out * cols_out * ksize_r * ksize_c
|
||||||
|
Loading…
Reference in New Issue
Block a user