Merge pull request #3858 from ibab/fix-prod-gradient-scalars

Fix reduce_prod gradient for scalar reduction indices params
This commit is contained in:
Benoit Steiner 2016-09-01 11:51:58 -07:00 committed by GitHub
commit e38975049b
2 changed files with 24 additions and 1 deletions

View File

@ -254,6 +254,9 @@ class SumReductionTest(tf.test.TestCase):
def testGradient4(self):
self._compareGradient([2, 3, 4, 2], [], None)
def testGradient5(self):
self._compareGradient([2, 3, 4, 2], [3, 4, 2], 0)
def testHighRank(self):
# Do a bunch of random high dimensional reductions
np.random.seed(42)
@ -379,6 +382,15 @@ class MeanReductionTest(tf.test.TestCase):
delta=1)
self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3)
su = tf.reduce_mean(t, 0)
jacob_t, jacob_n = tf.test.compute_gradient(t,
s,
su,
[3, 4, 2],
x_init_value=x,
delta=1)
self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3)
def testEmptyGradients(self):
with self.test_session():
x = tf.zeros([0, 3])
@ -463,6 +475,15 @@ class ProdReductionTest(tf.test.TestCase):
delta=1)
self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3)
su = tf.reduce_prod(t, 0)
jacob_t, jacob_n = tf.test.compute_gradient(t,
x.shape,
su,
[3, 4, 2],
x_init_value=x,
delta=1)
self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3)
def testGradientWithZeros(self):
s = [2, 3, 4, 2]
x = np.arange(1.0, 49.0).reshape(s).astype(np.float32) / 20.

View File

@ -116,6 +116,8 @@ def _ProdGrad(op, grad):
# cumprod operations.
input_shape = array_ops.shape(op.inputs[0])
# Reshape reduction indices for the case where the parameter is a scalar
reduction_indices = array_ops.reshape(op.inputs[1], [-1])
# Expand grad to full input shape
output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1])
@ -126,7 +128,7 @@ def _ProdGrad(op, grad):
# Pack all reduced dimensions into a single one, so we can perform the
# cumprod ops. If the reduction dims list is empty, it defaults to float32,
# so we need to cast here.
reduced = math_ops.cast(op.inputs[1], dtypes.int32)
reduced = math_ops.cast(reduction_indices, dtypes.int32)
idx = math_ops.range(0, array_ops.rank(op.inputs[0]))
other, _ = array_ops.listdiff(idx, reduced)
perm = array_ops.concat(0, [reduced, other])