Merge pull request #44298 from shauidu:Fix_CumprodGrad
PiperOrigin-RevId: 339212445 Change-Id: I7b5e44ea2b1fdf50e021722ab9ed93d049e144de
This commit is contained in:
commit
4d6ea6bd4d
@ -1940,11 +1940,10 @@ def _CumprodGrad(op, grad):
|
|||||||
exclusive = op.get_attr("exclusive")
|
exclusive = op.get_attr("exclusive")
|
||||||
reverse = op.get_attr("reverse")
|
reverse = op.get_attr("reverse")
|
||||||
|
|
||||||
# TODO This fails when x contains 0 and should be fixed
|
|
||||||
prod = math_ops.cumprod(x, axis, exclusive=exclusive, reverse=reverse)
|
prod = math_ops.cumprod(x, axis, exclusive=exclusive, reverse=reverse)
|
||||||
out = math_ops.cumsum(
|
out = math_ops.cumsum(
|
||||||
prod * grad, axis, exclusive=exclusive, reverse=not reverse)
|
prod * grad, axis, exclusive=exclusive, reverse=not reverse)
|
||||||
return [out / x, None]
|
return [math_ops.div_no_nan(out, x), None]
|
||||||
|
|
||||||
|
|
||||||
@ops.RegisterGradient("CumulativeLogsumexp")
|
@ops.RegisterGradient("CumulativeLogsumexp")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user