Merge pull request #3858 from ibab/fix-prod-gradient-scalars
Fix reduce_prod gradient for scalar reduction indices params
This commit is contained in:
commit
e38975049b
@ -254,6 +254,9 @@ class SumReductionTest(tf.test.TestCase):
|
||||
def testGradient4(self):
|
||||
self._compareGradient([2, 3, 4, 2], [], None)
|
||||
|
||||
def testGradient5(self):
|
||||
self._compareGradient([2, 3, 4, 2], [3, 4, 2], 0)
|
||||
|
||||
def testHighRank(self):
|
||||
# Do a bunch of random high dimensional reductions
|
||||
np.random.seed(42)
|
||||
@ -379,6 +382,15 @@ class MeanReductionTest(tf.test.TestCase):
|
||||
delta=1)
|
||||
self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3)
|
||||
|
||||
su = tf.reduce_mean(t, 0)
|
||||
jacob_t, jacob_n = tf.test.compute_gradient(t,
|
||||
s,
|
||||
su,
|
||||
[3, 4, 2],
|
||||
x_init_value=x,
|
||||
delta=1)
|
||||
self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3)
|
||||
|
||||
def testEmptyGradients(self):
|
||||
with self.test_session():
|
||||
x = tf.zeros([0, 3])
|
||||
@ -463,6 +475,15 @@ class ProdReductionTest(tf.test.TestCase):
|
||||
delta=1)
|
||||
self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3)
|
||||
|
||||
su = tf.reduce_prod(t, 0)
|
||||
jacob_t, jacob_n = tf.test.compute_gradient(t,
|
||||
x.shape,
|
||||
su,
|
||||
[3, 4, 2],
|
||||
x_init_value=x,
|
||||
delta=1)
|
||||
self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3)
|
||||
|
||||
def testGradientWithZeros(self):
|
||||
s = [2, 3, 4, 2]
|
||||
x = np.arange(1.0, 49.0).reshape(s).astype(np.float32) / 20.
|
||||
|
@ -116,6 +116,8 @@ def _ProdGrad(op, grad):
|
||||
# cumprod operations.
|
||||
|
||||
input_shape = array_ops.shape(op.inputs[0])
|
||||
# Reshape reduction indices for the case where the parameter is a scalar
|
||||
reduction_indices = array_ops.reshape(op.inputs[1], [-1])
|
||||
|
||||
# Expand grad to full input shape
|
||||
output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1])
|
||||
@ -126,7 +128,7 @@ def _ProdGrad(op, grad):
|
||||
# Pack all reduced dimensions into a single one, so we can perform the
|
||||
# cumprod ops. If the reduction dims list is empty, it defaults to float32,
|
||||
# so we need to cast here.
|
||||
reduced = math_ops.cast(op.inputs[1], dtypes.int32)
|
||||
reduced = math_ops.cast(reduction_indices, dtypes.int32)
|
||||
idx = math_ops.range(0, array_ops.rank(op.inputs[0]))
|
||||
other, _ = array_ops.listdiff(idx, reduced)
|
||||
perm = array_ops.concat(0, [reduced, other])
|
||||
|
Loading…
Reference in New Issue
Block a user