Fix a bug in the robust gradient for div. Add robust gradient for real_div.

PiperOrigin-RevId: 238719379
This commit is contained in:
A. Unique TensorFlower 2019-03-15 15:13:05 -07:00 committed by TensorFlower Gardener
parent c0886e70d1
commit 7ea3243fac
2 changed files with 53 additions and 35 deletions

View File

@ -1139,6 +1139,7 @@ class SingularGradientOpTest(test.TestCase):
(gen_math_ops.acos, (1.,)),
(gen_math_ops.atan2, (0., 0.)),
(gen_math_ops.div, (1., 0.)),
(gen_math_ops.real_div, (1., 0.)),
(math_ops.pow, (0., -1.)),
]
for op, singularity in ops_and_singularity:

View File

@ -177,8 +177,8 @@ def _ProdGrad(op, grad):
left = math_ops.cumprod(reshaped, axis=0, exclusive=True)
right = math_ops.cumprod(reshaped, axis=0, exclusive=True, reverse=True)
# For complex inputs, the gradient is in the conjugate direction.
y = array_ops.reshape(math_ops.conj(left) * math_ops.conj(right),
permuted_shape)
y = array_ops.reshape(
math_ops.conj(left) * math_ops.conj(right), permuted_shape)
# Invert the transpose and reshape operations.
# Make sure to set the statically known shape information through a reshape.
@ -261,8 +261,8 @@ def _SegmentMinOrMaxGrad(op, grad):
# Get the number of selected (minimum or maximum) elements in each segment.
gathered_outputs = array_ops.gather(op.outputs[0], op.inputs[1])
is_selected = math_ops.equal(op.inputs[0], gathered_outputs)
num_selected = math_ops.segment_sum(math_ops.cast(is_selected, grad.dtype),
op.inputs[1])
num_selected = math_ops.segment_sum(
math_ops.cast(is_selected, grad.dtype), op.inputs[1])
# Compute the gradient for each segment. The gradient for the ith segment is
# divided evenly among the selected elements in that segment.
weighted_grads = math_ops.divide(grad, num_selected)
@ -282,9 +282,13 @@ def _SegmentMaxGrad(op, grad):
return _SegmentMinOrMaxGrad(op, grad)
def _GatherDropNegatives(params, ids, zero_clipped_indices=None,
def _GatherDropNegatives(params,
ids,
zero_clipped_indices=None,
is_positive=None):
""" Helper function for unsorted segment ops. Gathers params for
""" Helper function for unsorted segment ops.
Gathers params for
positive segment ids and gathers 0 for inputs with negative segment id.
Also returns the clipped indices and a boolean mask with the same shape
as ids where a positive id is masked as true. With this, the latter two
@ -300,8 +304,8 @@ def _GatherDropNegatives(params, ids, zero_clipped_indices=None,
# todo(philjd): remove this if tf.where supports broadcasting (#9284)
for _ in range(gathered.shape.ndims - is_positive.shape.ndims):
is_positive = array_ops.expand_dims(is_positive, -1)
is_positive = (is_positive &
array_ops.ones_like(gathered, dtype=dtypes.bool))
is_positive = (
is_positive & array_ops.ones_like(gathered, dtype=dtypes.bool))
# replace gathered params of negative indices with 0
zero_slice = array_ops.zeros_like(gathered)
return (array_ops.where(is_positive, gathered, zero_slice),
@ -321,8 +325,7 @@ def _UnsortedSegmentMinOrMaxGrad(op, grad):
# divided evenly among the selected elements in that segment.
weighted_grads = math_ops.divide(grad, num_selected)
gathered_grads, _, _ = _GatherDropNegatives(weighted_grads, None,
zero_clipped_indices,
is_positive)
zero_clipped_indices, is_positive)
zeros = array_ops.zeros_like(gathered_grads)
return array_ops.where(is_selected, gathered_grads, zeros), None, None
@ -348,6 +351,7 @@ def _UnsortedSegmentMinGrad(op, grad):
@ops.RegisterGradient("UnsortedSegmentProd")
def _UnsortedSegmentProdGrad(op, grad):
""" Gradient for UnsortedSegmentProd.
The gradient can be expressed for each segment by dividing the segment's
product by each element of the segment input tensor, but this approach can't
deal with zeros in the input.
@ -368,19 +372,18 @@ def _UnsortedSegmentProdGrad(op, grad):
math_ops.cast(is_zero, dtype=dtypes.int32), op.inputs[1], op.inputs[2])
# handle case 3 and set the gradient to 0 for segments with more than one
# 0 as input
grad = array_ops.where(math_ops.greater(num_zeros, 1),
array_ops.zeros_like(grad), grad)
grad = array_ops.where(
math_ops.greater(num_zeros, 1), array_ops.zeros_like(grad), grad)
# replace all zeros with ones and compute the unsorted_segment_prod
non_zero_data = array_ops.where(is_zero, array_ops.ones_like(op.inputs[0]),
op.inputs[0])
non_zero_prod = gen_math_ops.unsorted_segment_prod(
non_zero_data, op.inputs[1], op.inputs[2])
non_zero_prod = gen_math_ops.unsorted_segment_prod(non_zero_data,
op.inputs[1], op.inputs[2])
# clip the indices for gather to be positive
zero_clipped_indices = math_ops.maximum(op.inputs[1],
array_ops.zeros_like(op.inputs[1]))
gathered_prod = array_ops.gather(op.outputs[0], zero_clipped_indices)
gathered_non_zero_prod = array_ops.gather(non_zero_prod,
zero_clipped_indices)
gathered_non_zero_prod = array_ops.gather(non_zero_prod, zero_clipped_indices)
prod_divided_by_el = gathered_prod / op.inputs[0] # May contain nan/inf.
# Now fetch the individual results for segments containing 0 and those that
# don't. is_zero will also fetch results for entries with negative index
@ -714,8 +717,8 @@ def _IgammaGrad(op, grad):
partial_a = gen_math_ops.igamma_grad_a(a, x)
# Perform operations in log space before summing, because Gamma(a)
# and Gamma'(a) can grow large.
partial_x = math_ops.exp(-x + (a - 1) * math_ops.log(x)
- math_ops.lgamma(a))
partial_x = math_ops.exp(-x + (a - 1) * math_ops.log(x) -
math_ops.lgamma(a))
return (array_ops.reshape(math_ops.reduce_sum(partial_a * grad, ra), sa),
array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx))
@ -983,8 +986,7 @@ def _MulNoNanGrad(op, grad):
y = op.inputs[1]
if (isinstance(grad, ops.Tensor) and
_ShapesFullySpecifiedAndEqual(x, y, grad)):
return gen_math_ops.mul_no_nan(grad, y), gen_math_ops.mul_no_nan(
x, grad)
return gen_math_ops.mul_no_nan(grad, y), gen_math_ops.mul_no_nan(x, grad)
assert x.dtype.base_dtype == y.dtype.base_dtype, (x.dtype, " vs. ", y.dtype)
sx = array_ops.shape(x)
sy = array_ops.shape(y)
@ -992,8 +994,7 @@ def _MulNoNanGrad(op, grad):
return (array_ops.reshape(
math_ops.reduce_sum(gen_math_ops.mul_no_nan(grad, y), rx), sx),
array_ops.reshape(
math_ops.reduce_sum(gen_math_ops.mul_no_nan(x, grad), ry),
sy))
math_ops.reduce_sum(gen_math_ops.mul_no_nan(x, grad), ry), sy))
@ops.RegisterGradient("Div")
@ -1007,13 +1008,19 @@ def _DivGrad(op, grad):
x = math_ops.conj(x)
y = math_ops.conj(y)
if compat.forward_compatible(2019, 4, 7):
div_op = math_ops.div_no_nan
return (array_ops.reshape(
math_ops.reduce_sum(math_ops.div_no_nan(grad, y), rx), sx),
array_ops.reshape(
math_ops.reduce_sum(
math_ops.mul_no_nan(
math_ops.divide(math_ops.divide(-x, y), y), grad), ry),
sy))
else:
div_op = math_ops.divide
return (array_ops.reshape(math_ops.reduce_sum(div_op(grad, y), rx), sx),
array_ops.reshape(
math_ops.reduce_sum(grad * div_op(math_ops.divide(-x, y), y), ry),
sy))
return (array_ops.reshape(
math_ops.reduce_sum(math_ops.divide(grad, y), rx), sx),
array_ops.reshape(
math_ops.reduce_sum(
grad * math_ops.divide(math_ops.divide(-x, y), y), ry), sy))
@ops.RegisterGradient("FloorDiv")
@ -1053,11 +1060,21 @@ def _RealDivGrad(op, grad):
rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
x = math_ops.conj(x)
y = math_ops.conj(y)
return (array_ops.reshape(
math_ops.reduce_sum(math_ops.realdiv(grad, y), rx), sx),
array_ops.reshape(
math_ops.reduce_sum(
grad * math_ops.realdiv(math_ops.realdiv(-x, y), y), ry), sy))
if compat.forward_compatible(2019, 4, 7):
return (array_ops.reshape(
math_ops.reduce_sum(math_ops.div_no_nan(grad, y), rx), sx),
array_ops.reshape(
math_ops.reduce_sum(
math_ops.mul_no_nan(
math_ops.realdiv(math_ops.realdiv(-x, y), y), grad),
ry), sy))
else:
return (array_ops.reshape(
math_ops.reduce_sum(math_ops.realdiv(grad, y), rx), sx),
array_ops.reshape(
math_ops.reduce_sum(
grad * math_ops.realdiv(math_ops.realdiv(-x, y), y), ry),
sy))
@ops.RegisterGradient("DivNoNan")
@ -1359,8 +1376,8 @@ def _ComplexAbsGrad(op, grad):
"""Returns the gradient of ComplexAbs."""
# TODO(b/27786104): The cast to complex could be removed once arithmetic
# supports mixtures of complex64 and real values.
return (math_ops.complex(grad, array_ops.zeros_like(grad)) * math_ops.sign(
op.inputs[0]))
return (math_ops.complex(grad, array_ops.zeros_like(grad)) *
math_ops.sign(op.inputs[0]))
@ops.RegisterGradient("Cast")