Merge pull request #29815 from lioutasb:extract_image_patches_varsize_grad_fix
PiperOrigin-RevId: 254241100
This commit is contained in:
commit
a88ea5a35c
@ -100,8 +100,6 @@ class ExtractImagePatchesGradTest(test.TestCase):
|
||||
|
||||
err = gradient_checker.compute_gradient_error(in_val, in_shape,
|
||||
out_val, out_shape)
|
||||
|
||||
print('extract_image_patches gradient err: %.4e' % err)
|
||||
self.assertLess(err, 1e-4)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@ -124,6 +122,52 @@ class ExtractImagePatchesGradTest(test.TestCase):
|
||||
# Won't time out.
|
||||
self.assertIsNotNone(gradients)
|
||||
|
||||
def _VariableShapeGradient(self, test_shape_pattern):
|
||||
"""Use test_shape_pattern to infer which dimensions are of
|
||||
|
||||
variable size.
|
||||
"""
|
||||
# Set graph seed for determinism.
|
||||
random_seed = 42
|
||||
random_seed_lib.set_random_seed(random_seed)
|
||||
|
||||
with self.test_session():
|
||||
for test_case in self._TEST_CASES:
|
||||
np.random.seed(random_seed)
|
||||
in_shape = test_case['in_shape']
|
||||
test_shape = [
|
||||
x if x is None else y for x, y in zip(test_shape_pattern, in_shape)
|
||||
]
|
||||
in_val = array_ops.placeholder(shape=test_shape, dtype=dtypes.float32)
|
||||
|
||||
feed_dict = {in_val: np.random.random(in_shape)}
|
||||
for padding in ['VALID', 'SAME']:
|
||||
out_val = array_ops.extract_image_patches(in_val, test_case['ksizes'],
|
||||
test_case['strides'],
|
||||
test_case['rates'], padding)
|
||||
out_val_tmp = out_val.eval(feed_dict=feed_dict)
|
||||
out_shape = out_val_tmp.shape
|
||||
|
||||
err = gradient_checker.compute_gradient_error(in_val, in_shape,
|
||||
out_val, out_shape)
|
||||
self.assertLess(err, 1e-4)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_BxxC_Gradient(self):
|
||||
self._VariableShapeGradient([-1, None, None, -1])
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_xHWx_Gradient(self):
|
||||
self._VariableShapeGradient([None, -1, -1, None])
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_BHWC_Gradient(self):
|
||||
self._VariableShapeGradient([-1, -1, -1, -1])
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_AllNone_Gradient(self):
|
||||
self._VariableShapeGradient([None, None, None, None])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -831,12 +831,9 @@ def _QuantizeAndDequantizeV3Grad(_, grad):
|
||||
|
||||
@ops.RegisterGradient("ExtractImagePatches")
|
||||
def _ExtractImagePatchesGrad(op, grad):
|
||||
batch_size, rows_in, cols_in, channels = [
|
||||
dim.value for dim in op.inputs[0].shape.dims
|
||||
]
|
||||
input_bhwc = array_ops.shape(op.inputs[0])
|
||||
batch_size = input_bhwc[0]
|
||||
channels = input_bhwc[3]
|
||||
input_bhwc = array_ops.shape(op.inputs[0], out_type=dtypes.int64)
|
||||
batch_size, rows_in, cols_in, channels = input_bhwc[0], input_bhwc[1], \
|
||||
input_bhwc[2], input_bhwc[3]
|
||||
|
||||
# Create indices matrix for input tensor.
|
||||
# Note that 0 is preserved for padding location,
|
||||
@ -853,7 +850,8 @@ def _ExtractImagePatchesGrad(op, grad):
|
||||
op.get_attr("padding"))
|
||||
|
||||
# Create indices matrix for output tensor.
|
||||
_, rows_out, cols_out, _ = [dim.value for dim in op.outputs[0].shape.dims]
|
||||
output_bhwc = array_ops.shape(op.outputs[0], out_type=dtypes.int64)
|
||||
rows_out, cols_out = output_bhwc[1], output_bhwc[2]
|
||||
_, ksize_r, ksize_c, _ = op.get_attr("ksizes")
|
||||
# Indices for output start from 0.
|
||||
output_indices_num = rows_out * cols_out * ksize_r * ksize_c
|
||||
|
Loading…
Reference in New Issue
Block a user