Fix a bug in the robust gradient for div. Add robust gradient for real_div.
PiperOrigin-RevId: 238719379
This commit is contained in:
parent
c0886e70d1
commit
7ea3243fac
@ -1139,6 +1139,7 @@ class SingularGradientOpTest(test.TestCase):
|
||||
(gen_math_ops.acos, (1.,)),
|
||||
(gen_math_ops.atan2, (0., 0.)),
|
||||
(gen_math_ops.div, (1., 0.)),
|
||||
(gen_math_ops.real_div, (1., 0.)),
|
||||
(math_ops.pow, (0., -1.)),
|
||||
]
|
||||
for op, singularity in ops_and_singularity:
|
||||
|
@ -177,8 +177,8 @@ def _ProdGrad(op, grad):
|
||||
left = math_ops.cumprod(reshaped, axis=0, exclusive=True)
|
||||
right = math_ops.cumprod(reshaped, axis=0, exclusive=True, reverse=True)
|
||||
# For complex inputs, the gradient is in the conjugate direction.
|
||||
y = array_ops.reshape(math_ops.conj(left) * math_ops.conj(right),
|
||||
permuted_shape)
|
||||
y = array_ops.reshape(
|
||||
math_ops.conj(left) * math_ops.conj(right), permuted_shape)
|
||||
|
||||
# Invert the transpose and reshape operations.
|
||||
# Make sure to set the statically known shape information through a reshape.
|
||||
@ -261,8 +261,8 @@ def _SegmentMinOrMaxGrad(op, grad):
|
||||
# Get the number of selected (minimum or maximum) elements in each segment.
|
||||
gathered_outputs = array_ops.gather(op.outputs[0], op.inputs[1])
|
||||
is_selected = math_ops.equal(op.inputs[0], gathered_outputs)
|
||||
num_selected = math_ops.segment_sum(math_ops.cast(is_selected, grad.dtype),
|
||||
op.inputs[1])
|
||||
num_selected = math_ops.segment_sum(
|
||||
math_ops.cast(is_selected, grad.dtype), op.inputs[1])
|
||||
# Compute the gradient for each segment. The gradient for the ith segment is
|
||||
# divided evenly among the selected elements in that segment.
|
||||
weighted_grads = math_ops.divide(grad, num_selected)
|
||||
@ -282,9 +282,13 @@ def _SegmentMaxGrad(op, grad):
|
||||
return _SegmentMinOrMaxGrad(op, grad)
|
||||
|
||||
|
||||
def _GatherDropNegatives(params, ids, zero_clipped_indices=None,
|
||||
def _GatherDropNegatives(params,
|
||||
ids,
|
||||
zero_clipped_indices=None,
|
||||
is_positive=None):
|
||||
""" Helper function for unsorted segment ops. Gathers params for
|
||||
""" Helper function for unsorted segment ops.
|
||||
|
||||
Gathers params for
|
||||
positive segment ids and gathers 0 for inputs with negative segment id.
|
||||
Also returns the clipped indices and a boolean mask with the same shape
|
||||
as ids where a positive id is masked as true. With this, the latter two
|
||||
@ -300,8 +304,8 @@ def _GatherDropNegatives(params, ids, zero_clipped_indices=None,
|
||||
# todo(philjd): remove this if tf.where supports broadcasting (#9284)
|
||||
for _ in range(gathered.shape.ndims - is_positive.shape.ndims):
|
||||
is_positive = array_ops.expand_dims(is_positive, -1)
|
||||
is_positive = (is_positive &
|
||||
array_ops.ones_like(gathered, dtype=dtypes.bool))
|
||||
is_positive = (
|
||||
is_positive & array_ops.ones_like(gathered, dtype=dtypes.bool))
|
||||
# replace gathered params of negative indices with 0
|
||||
zero_slice = array_ops.zeros_like(gathered)
|
||||
return (array_ops.where(is_positive, gathered, zero_slice),
|
||||
@ -321,8 +325,7 @@ def _UnsortedSegmentMinOrMaxGrad(op, grad):
|
||||
# divided evenly among the selected elements in that segment.
|
||||
weighted_grads = math_ops.divide(grad, num_selected)
|
||||
gathered_grads, _, _ = _GatherDropNegatives(weighted_grads, None,
|
||||
zero_clipped_indices,
|
||||
is_positive)
|
||||
zero_clipped_indices, is_positive)
|
||||
zeros = array_ops.zeros_like(gathered_grads)
|
||||
return array_ops.where(is_selected, gathered_grads, zeros), None, None
|
||||
|
||||
@ -348,6 +351,7 @@ def _UnsortedSegmentMinGrad(op, grad):
|
||||
@ops.RegisterGradient("UnsortedSegmentProd")
|
||||
def _UnsortedSegmentProdGrad(op, grad):
|
||||
""" Gradient for UnsortedSegmentProd.
|
||||
|
||||
The gradient can be expressed for each segment by dividing the segment's
|
||||
product by each element of the segment input tensor, but this approach can't
|
||||
deal with zeros in the input.
|
||||
@ -368,19 +372,18 @@ def _UnsortedSegmentProdGrad(op, grad):
|
||||
math_ops.cast(is_zero, dtype=dtypes.int32), op.inputs[1], op.inputs[2])
|
||||
# handle case 3 and set the gradient to 0 for segments with more than one
|
||||
# 0 as input
|
||||
grad = array_ops.where(math_ops.greater(num_zeros, 1),
|
||||
array_ops.zeros_like(grad), grad)
|
||||
grad = array_ops.where(
|
||||
math_ops.greater(num_zeros, 1), array_ops.zeros_like(grad), grad)
|
||||
# replace all zeros with ones and compute the unsorted_segment_prod
|
||||
non_zero_data = array_ops.where(is_zero, array_ops.ones_like(op.inputs[0]),
|
||||
op.inputs[0])
|
||||
non_zero_prod = gen_math_ops.unsorted_segment_prod(
|
||||
non_zero_data, op.inputs[1], op.inputs[2])
|
||||
non_zero_prod = gen_math_ops.unsorted_segment_prod(non_zero_data,
|
||||
op.inputs[1], op.inputs[2])
|
||||
# clip the indices for gather to be positive
|
||||
zero_clipped_indices = math_ops.maximum(op.inputs[1],
|
||||
array_ops.zeros_like(op.inputs[1]))
|
||||
gathered_prod = array_ops.gather(op.outputs[0], zero_clipped_indices)
|
||||
gathered_non_zero_prod = array_ops.gather(non_zero_prod,
|
||||
zero_clipped_indices)
|
||||
gathered_non_zero_prod = array_ops.gather(non_zero_prod, zero_clipped_indices)
|
||||
prod_divided_by_el = gathered_prod / op.inputs[0] # May contain nan/inf.
|
||||
# Now fetch the individual results for segments containing 0 and those that
|
||||
# don't. is_zero will also fetch results for entries with negative index
|
||||
@ -714,8 +717,8 @@ def _IgammaGrad(op, grad):
|
||||
partial_a = gen_math_ops.igamma_grad_a(a, x)
|
||||
# Perform operations in log space before summing, because Gamma(a)
|
||||
# and Gamma'(a) can grow large.
|
||||
partial_x = math_ops.exp(-x + (a - 1) * math_ops.log(x)
|
||||
- math_ops.lgamma(a))
|
||||
partial_x = math_ops.exp(-x + (a - 1) * math_ops.log(x) -
|
||||
math_ops.lgamma(a))
|
||||
return (array_ops.reshape(math_ops.reduce_sum(partial_a * grad, ra), sa),
|
||||
array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx))
|
||||
|
||||
@ -983,8 +986,7 @@ def _MulNoNanGrad(op, grad):
|
||||
y = op.inputs[1]
|
||||
if (isinstance(grad, ops.Tensor) and
|
||||
_ShapesFullySpecifiedAndEqual(x, y, grad)):
|
||||
return gen_math_ops.mul_no_nan(grad, y), gen_math_ops.mul_no_nan(
|
||||
x, grad)
|
||||
return gen_math_ops.mul_no_nan(grad, y), gen_math_ops.mul_no_nan(x, grad)
|
||||
assert x.dtype.base_dtype == y.dtype.base_dtype, (x.dtype, " vs. ", y.dtype)
|
||||
sx = array_ops.shape(x)
|
||||
sy = array_ops.shape(y)
|
||||
@ -992,8 +994,7 @@ def _MulNoNanGrad(op, grad):
|
||||
return (array_ops.reshape(
|
||||
math_ops.reduce_sum(gen_math_ops.mul_no_nan(grad, y), rx), sx),
|
||||
array_ops.reshape(
|
||||
math_ops.reduce_sum(gen_math_ops.mul_no_nan(x, grad), ry),
|
||||
sy))
|
||||
math_ops.reduce_sum(gen_math_ops.mul_no_nan(x, grad), ry), sy))
|
||||
|
||||
|
||||
@ops.RegisterGradient("Div")
|
||||
@ -1007,13 +1008,19 @@ def _DivGrad(op, grad):
|
||||
x = math_ops.conj(x)
|
||||
y = math_ops.conj(y)
|
||||
if compat.forward_compatible(2019, 4, 7):
|
||||
div_op = math_ops.div_no_nan
|
||||
return (array_ops.reshape(
|
||||
math_ops.reduce_sum(math_ops.div_no_nan(grad, y), rx), sx),
|
||||
array_ops.reshape(
|
||||
math_ops.reduce_sum(
|
||||
math_ops.mul_no_nan(
|
||||
math_ops.divide(math_ops.divide(-x, y), y), grad), ry),
|
||||
sy))
|
||||
else:
|
||||
div_op = math_ops.divide
|
||||
return (array_ops.reshape(math_ops.reduce_sum(div_op(grad, y), rx), sx),
|
||||
array_ops.reshape(
|
||||
math_ops.reduce_sum(grad * div_op(math_ops.divide(-x, y), y), ry),
|
||||
sy))
|
||||
return (array_ops.reshape(
|
||||
math_ops.reduce_sum(math_ops.divide(grad, y), rx), sx),
|
||||
array_ops.reshape(
|
||||
math_ops.reduce_sum(
|
||||
grad * math_ops.divide(math_ops.divide(-x, y), y), ry), sy))
|
||||
|
||||
|
||||
@ops.RegisterGradient("FloorDiv")
|
||||
@ -1053,11 +1060,21 @@ def _RealDivGrad(op, grad):
|
||||
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
|
||||
x = math_ops.conj(x)
|
||||
y = math_ops.conj(y)
|
||||
return (array_ops.reshape(
|
||||
math_ops.reduce_sum(math_ops.realdiv(grad, y), rx), sx),
|
||||
array_ops.reshape(
|
||||
math_ops.reduce_sum(
|
||||
grad * math_ops.realdiv(math_ops.realdiv(-x, y), y), ry), sy))
|
||||
if compat.forward_compatible(2019, 4, 7):
|
||||
return (array_ops.reshape(
|
||||
math_ops.reduce_sum(math_ops.div_no_nan(grad, y), rx), sx),
|
||||
array_ops.reshape(
|
||||
math_ops.reduce_sum(
|
||||
math_ops.mul_no_nan(
|
||||
math_ops.realdiv(math_ops.realdiv(-x, y), y), grad),
|
||||
ry), sy))
|
||||
else:
|
||||
return (array_ops.reshape(
|
||||
math_ops.reduce_sum(math_ops.realdiv(grad, y), rx), sx),
|
||||
array_ops.reshape(
|
||||
math_ops.reduce_sum(
|
||||
grad * math_ops.realdiv(math_ops.realdiv(-x, y), y), ry),
|
||||
sy))
|
||||
|
||||
|
||||
@ops.RegisterGradient("DivNoNan")
|
||||
@ -1359,8 +1376,8 @@ def _ComplexAbsGrad(op, grad):
|
||||
"""Returns the gradient of ComplexAbs."""
|
||||
# TODO(b/27786104): The cast to complex could be removed once arithmetic
|
||||
# supports mixtures of complex64 and real values.
|
||||
return (math_ops.complex(grad, array_ops.zeros_like(grad)) * math_ops.sign(
|
||||
op.inputs[0]))
|
||||
return (math_ops.complex(grad, array_ops.zeros_like(grad)) *
|
||||
math_ops.sign(op.inputs[0]))
|
||||
|
||||
|
||||
@ops.RegisterGradient("Cast")
|
||||
|
Loading…
x
Reference in New Issue
Block a user