Merge pull request #44298 from shauidu:Fix_CumprodGrad
PiperOrigin-RevId: 339212445 Change-Id: I7b5e44ea2b1fdf50e021722ab9ed93d049e144de
This commit is contained in:
commit
4d6ea6bd4d
@ -1940,11 +1940,10 @@ def _CumprodGrad(op, grad):
|
||||
exclusive = op.get_attr("exclusive")
|
||||
reverse = op.get_attr("reverse")
|
||||
|
||||
# TODO This fails when x contains 0 and should be fixed
|
||||
prod = math_ops.cumprod(x, axis, exclusive=exclusive, reverse=reverse)
|
||||
out = math_ops.cumsum(
|
||||
prod * grad, axis, exclusive=exclusive, reverse=not reverse)
|
||||
return [out / x, None]
|
||||
return [math_ops.div_no_nan(out, x), None]
|
||||
|
||||
|
||||
@ops.RegisterGradient("CumulativeLogsumexp")
|
||||
|
Loading…
x
Reference in New Issue
Block a user