Make betainc derivative more numerically stable by using log1p (and handling NaN's in the case when a = 1 or b = 1).
PiperOrigin-RevId: 289918587 Change-Id: I266fd10c0e34ad30291d061636926c945fe0f824
This commit is contained in:
parent
18369ac065
commit
86fa42f516
@ -944,8 +944,10 @@ def _BetaincGrad(op, grad):
|
||||
log_beta = (
|
||||
gen_math_ops.lgamma(a) + gen_math_ops.lgamma(b) -
|
||||
gen_math_ops.lgamma(a + b))
|
||||
partial_x = math_ops.exp((b - 1) * math_ops.log(1 - x) +
|
||||
(a - 1) * math_ops.log(x) - log_beta)
|
||||
# We use xlog1py and xlogy since the derivatives should tend to
|
||||
# zero one one of the tails when a is 1. or b is 1.
|
||||
partial_x = math_ops.exp(math_ops.xlog1py(b - 1, -x) +
|
||||
math_ops.xlogy(a - 1, x) - log_beta)
|
||||
|
||||
# TODO(b/36815900): Mark None return values as NotImplemented
|
||||
if compat.forward_compatible(2020, 3, 14):
|
||||
|
Loading…
Reference in New Issue
Block a user