Add SegmentProdGrad

Fix typo
This commit is contained in:
Tzu-Wei Sung 2020-08-10 09:41:27 -07:00
parent 9df7b750eb
commit c5b15cbca8
2 changed files with 79 additions and 0 deletions

View File

@ -414,6 +414,48 @@ def _SegmentMaxGrad(op, grad):
return _SegmentMinOrMaxGrad(op, grad)
@ops.RegisterGradient("SegmentProd")
def _SegmentProdGrad(op, grad):
"""Gradient for SegmentProd.
The gradient can be expressed for each segment by dividing the segment's
product by each element of the segment input tensor, but this approach can't
deal with zeros in the input.
Unlike reduce_prod we can't use cumsum here as individual segments may have
a different number of elements. Therefore we consider three cases:
1) A segment input contains no zeros and we can safely divide by the input
tensor.
2) A segment contains exactly one zero. Then the gradient of each input of
the segment is zero except for the 0-input, there the gradient is
the product of the remaining segment entries.
3) A segment contains at least two zeros. The gradient is zero for all
segment inputs.
"""
data = op.inputs[0]
segment_ids = op.inputs[1]
is_zero = math_ops.equal(data, 0)
num_zeros = gen_math_ops.segment_sum(
math_ops.cast(is_zero, dtype=dtypes.int32), segment_ids)
# handle case 3 and set the gradient to 0 for segments with more than one
# 0 as input
grad = array_ops.where_v2(
math_ops.greater(num_zeros, 1), array_ops.zeros_like(grad), grad)
# replace all zeros with ones and compute the segment_prod
non_zero_data = array_ops.where_v2(is_zero,
array_ops.ones_like(data),
data)
non_zero_prod = gen_math_ops.segment_prod(non_zero_data, segment_ids)
gathered_prod = array_ops.gather(op.outputs[0], segment_ids)
gathered_non_zero_prod = array_ops.gather(non_zero_prod, segment_ids)
prod_divided_by_el = gathered_prod / data # May contain nan/inf.
# Now fetch the individual results for segments containing 0 and those that
# don't.
partial_derivative = array_ops.where_v2(is_zero, gathered_non_zero_prod,
prod_divided_by_el)
gathered_grad = array_ops.gather(grad, segment_ids)
return gathered_grad * partial_derivative, None
def _GatherDropNegatives(params,
ids,
zero_clipped_indices=None,

View File

@ -374,6 +374,43 @@ class SegmentMinOrMaxGradientTest(test.TestCase):
self.assertLess(error, 1e-4)
@test_util.run_all_in_graph_and_eager_modes
class SegmentProdGradientTest(test.TestCase):
def _run_gradient_check(self, data, segment_ids):
def _segment_prod(x):
return math_ops.segment_prod(x, segment_ids)
err = gradient_checker_v2.max_error(
*gradient_checker_v2.compute_gradient(_segment_prod, [data]))
self.assertLess(err, 2e-4)
def testSegmentProdGradientWithoutOverlap(self):
data = constant_op.constant([[1, 2, 3, 4], [4, 3, 2, 1], [5, 6, 7, 8]],
dtype=dtypes.float32)
segment_ids = constant_op.constant([0, 1, 2], dtype=dtypes.int64)
self._run_gradient_check(data, segment_ids)
def testSegmentProdGradientWithoutZeros(self):
data = constant_op.constant([[1, 2, 3, 4], [4, 3, 2, 1], [5, 6, 7, 8]],
dtype=dtypes.float32)
segment_ids = constant_op.constant([0, 0, 1], dtype=dtypes.int64)
self._run_gradient_check(data, segment_ids)
def testSegmentProdGradientWithZeros(self):
data = constant_op.constant([[0, 2, 3, 4], [0, 0, 2, 0], [5, 0, 7, 0]],
dtype=dtypes.float32)
segment_ids = constant_op.constant([0, 0, 1], dtype=dtypes.int64)
self._run_gradient_check(data, segment_ids)
def testSegmentProdGradientWithEmptySegment(self):
data = constant_op.constant([[1, 2, 3, 4], [4, 3, 2, 1], [5, 6, 7, 8]],
dtype=dtypes.float32)
segment_ids = constant_op.constant([0, 0, 2], dtype=dtypes.int64)
self._run_gradient_check(data, segment_ids)
class FloorModGradientTest(test.TestCase):
@test_util.run_deprecated_v1