[XLA] Implement the regularized incomplete beta function

An implementation involving continued fractions is used. The continued fraction
for this function can be seen at http://dlmf.nist.gov/8.17.v

PiperOrigin-RevId: 277327138
Change-Id: Icb8527af344b629806fd7e4880072d05e2530cd2
This commit is contained in:
David Majnemer 2019-10-29 11:27:54 -07:00 committed by TensorFlower Gardener
parent 2061ec8104
commit e5983061d3
9 changed files with 396 additions and 23 deletions

View File

@ -1022,6 +1022,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:extra_py_tests_deps",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",

View File

@ -20,6 +20,7 @@ from __future__ import print_function
from absl.testing import parameterized
import numpy as np
import scipy.special as sps
from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
@ -31,7 +32,7 @@ from tensorflow.python.platform import googletest
class TernaryOpsTest(xla_test.XLATestCase, parameterized.TestCase):
def _testTernary(self, op, a, b, c, expected):
def _testTernary(self, op, a, b, c, expected, rtol=1e-3, atol=1e-6):
with self.session() as session:
with self.test_scope():
pa = array_ops.placeholder(dtypes.as_dtype(a.dtype), a.shape, name="a")
@ -39,7 +40,7 @@ class TernaryOpsTest(xla_test.XLATestCase, parameterized.TestCase):
pc = array_ops.placeholder(dtypes.as_dtype(c.dtype), c.shape, name="c")
output = op(pa, pb, pc)
result = session.run(output, {pa: a, pb: b, pc: c})
self.assertAllClose(result, expected, rtol=1e-3)
self.assertAllClose(result, expected, rtol=rtol, atol=atol)
return result
@parameterized.parameters(
@ -210,6 +211,55 @@ class TernaryOpsTest(xla_test.XLATestCase, parameterized.TestCase):
upper,
expected=np.minimum(np.maximum(x, lower), upper))
def testBetaincSanity(self):
# This operation is only supported for float32 and float64.
for dtype in self.numeric_types & {np.float32, np.float64}:
# Sanity check a few identities:
# - betainc(a, b, 0) == 0
# - betainc(a, b, 1) == 1
# - betainc(a, 1, x) == x ** a
# Compare against the implementation in SciPy.
a = np.array([.3, .4, .2, .2], dtype=dtype)
b = np.array([1., 1., .4, .4], dtype=dtype)
x = np.array([.3, .4, .0, .1], dtype=dtype)
expected = sps.betainc(a, b, x)
self._testTernary(
math_ops.betainc, a, b, x, expected, rtol=5e-6, atol=6e-6)
@parameterized.parameters(
{
'sigma': 1e15,
'rtol': 1e-6,
'atol': 1e-6
},
{
'sigma': 30,
'rtol': 1e-6,
'atol': 2e-3
},
{
'sigma': 1e-8,
'rtol': 5e-4,
'atol': 3e-6
},
{
'sigma': 1e-16,
'rtol': 1e-6,
'atol': 2e-4
},
)
def testBetainc(self, sigma, rtol, atol):
# This operation is only supported for float32 and float64.
for dtype in self.numeric_types & {np.float32, np.float64}:
# Randomly generate a, b, x in the numerical domain of betainc.
# Compare against the implementation in SciPy.
a = np.abs(np.random.randn(10, 10) * sigma).astype(dtype) # in (0, infty)
b = np.abs(np.random.randn(10, 10) * sigma).astype(dtype) # in (0, infty)
x = np.random.rand(10, 10).astype(dtype) # in (0, 1)
expected = sps.betainc(a, b, x, dtype=dtype)
self._testTernary(
math_ops.betainc, a, b, x, expected, rtol=rtol, atol=atol)
if __name__ == "__main__":
googletest.main()

View File

@ -14,6 +14,7 @@ tf_kernel_library(
"batch_norm_op.cc",
"batchtospace_op.cc",
"bcast_ops.cc",
"beta_op.cc",
"bias_ops.cc",
"binary_ops.cc",
"broadcast_to_op.cc",

View File

@ -0,0 +1,81 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <limits>
#include "tensorflow/compiler/tf2xla/lib/broadcast.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/loops.h"
#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/status_macros.h"
namespace tensorflow {
namespace {
class BetaincOp : public XlaOpKernel {
public:
explicit BetaincOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
const TensorShape& a_shape = ctx->InputShape(0);
const TensorShape& b_shape = ctx->InputShape(1);
const TensorShape& x_shape = ctx->InputShape(2);
if (a_shape.dims() > 0 && b_shape.dims() > 0) {
OP_REQUIRES(ctx, a_shape == b_shape,
errors::InvalidArgument(
"Shapes of a and b are inconsistent: ",
a_shape.DebugString(), " vs. ", b_shape.DebugString()));
}
if (a_shape.dims() > 0 && x_shape.dims() > 0) {
OP_REQUIRES(ctx, a_shape == x_shape,
errors::InvalidArgument(
"Shapes of a and x are inconsistent: ",
a_shape.DebugString(), " vs. ", x_shape.DebugString()));
}
if (b_shape.dims() > 0 && x_shape.dims() > 0) {
OP_REQUIRES(ctx, b_shape == x_shape,
errors::InvalidArgument(
"Shapes of b and x are inconsistent: ",
b_shape.DebugString(), " vs. ", x_shape.DebugString()));
}
TensorShape merged_shape(a_shape);
if (b_shape.dims() > 0) merged_shape = b_shape;
if (x_shape.dims() > 0) merged_shape = x_shape;
auto builder = ctx->builder();
auto result =
builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
TF_ASSIGN_OR_RETURN(
auto a, BroadcastTo(ctx->Input(0), merged_shape.dim_sizes()));
TF_ASSIGN_OR_RETURN(
auto b, BroadcastTo(ctx->Input(1), merged_shape.dim_sizes()));
TF_ASSIGN_OR_RETURN(
auto x, BroadcastTo(ctx->Input(2), merged_shape.dim_sizes()));
return xla::RegularizedIncompleteBeta(a, b, x);
});
ctx->SetOutput(0, result);
}
};
REGISTER_XLA_OP(Name("Betainc"), BetaincOp);
} // namespace
} // namespace tensorflow

View File

@ -145,6 +145,7 @@ cc_library(
deps = [
":arithmetic",
":constants",
":loops",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla/client:xla_builder",

View File

@ -21,6 +21,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/loops.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@ -497,6 +499,30 @@ XlaOp Lgamma(XlaOp input) {
});
}
// Computes an approximation of the lbeta function which is equivalent to
// log(abs(Beta(a, b))) but avoids overflow by computing it with lgamma.
static XlaOp Lbeta(XlaOp a, XlaOp b) {
// Beta(a, b) can be computed using Gamma as per
// http://dlmf.nist.gov/5.12.E1 as follows:
// Beta(a, b) = (Gamma(a) * Gamma(b)) / Gamma(a + b)
//
// To avoid overflow, we compute in the log domain.
//
// As per http://dlmf.nist.gov/4.8.E2 we can transform:
// Log(a * b)
// into:
// Log(a) + Log(b)
//
// Likewise, per https://dlmf.nist.gov/4.8.E4, we can turn:
// Log(a - b)
// into:
// Log(a) - Log(b)
//
// This means that we can compute Log(Beta(a, b)) by:
// Log(Gamma(a)) + Log(Gamma(b)) - Log(Gamma(a + b))
return Lgamma(a) + Lgamma(b) - Lgamma(a + b);
}
// Compute the Digamma function using Lanczos' approximation from "A Precision
// Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
// series B. Vol. 1:
@ -1037,4 +1063,212 @@ XlaOp BesselI1e(XlaOp x) {
});
}
// I J Thompson and A R Barnett. 1986. Coulomb and Bessel functions of complex
// arguments and order. J. Comput. Phys. 64, 2 (June 1986), 490-509.
// DOI=http://dx.doi.org/10.1016/0021-9991(86)90046-X
static XlaOp LentzThompsonBarnettAlgorithm(
int64 num_iterations, double small, double threshold,
const ForEachIndexBodyFunction& nth_partial_numerator,
const ForEachIndexBodyFunction& nth_partial_denominator,
absl::Span<const XlaOp> inputs, absl::string_view name) {
auto& b = *inputs.front().builder();
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_RET_CHECK(num_iterations < INT32_MAX);
enum {
// Position in the evaluation.
kIterationIdx,
// Whether or not we have reached the desired tolerance.
kValuesUnconvergedIdx,
// Ratio between nth canonical numerator and the nth-1 canonical
// numerator.
kCIdx,
// Ratio between nth-1 canonical denominator and the nth canonical
// denominator.
kDIdx,
// Computed approximant in the evaluation.
kHIdx,
// Inputs follow all of the other state.
kFirstInputIdx,
};
auto while_cond_fn = [num_iterations](
absl::Span<const XlaOp> values,
XlaBuilder* cond_builder) -> StatusOr<XlaOp> {
auto iteration = values[kIterationIdx];
auto iterations_remain_cond =
Lt(iteration, ScalarLike(iteration, num_iterations));
auto values_unconverged_cond = values[kValuesUnconvergedIdx];
return And(iterations_remain_cond, values_unconverged_cond);
};
auto while_body_fn =
[small, threshold, &nth_partial_numerator, &nth_partial_denominator](
absl::Span<const XlaOp> values,
XlaBuilder* body_builder) -> StatusOr<std::vector<XlaOp>> {
XlaOp iteration = values[kIterationIdx];
TF_ASSIGN_OR_RETURN(
std::vector<XlaOp> partial_numerator,
nth_partial_numerator(iteration, values.subspan(kFirstInputIdx),
body_builder));
TF_RET_CHECK(partial_numerator.size() == 1);
TF_ASSIGN_OR_RETURN(
std::vector<XlaOp> partial_denominator,
nth_partial_denominator(iteration, values.subspan(kFirstInputIdx),
body_builder));
TF_RET_CHECK(partial_denominator.size() == 1);
auto c = partial_denominator[0] + partial_numerator[0] / values[kCIdx];
auto small_constant = FullLike(c, small);
c = Select(Lt(Abs(c), small_constant), small_constant, c);
auto d = partial_denominator[0] + partial_numerator[0] * values[kDIdx];
d = Select(Lt(Abs(d), small_constant), small_constant, d);
d = Reciprocal(d);
auto delta = c * d;
auto h = values[kHIdx] * delta;
std::vector<XlaOp> updated_values(values.size());
updated_values[kIterationIdx] = Add(iteration, ScalarLike(iteration, 1));
updated_values[kCIdx] = c;
updated_values[kDIdx] = d;
updated_values[kHIdx] = h;
std::copy(values.begin() + kFirstInputIdx, values.end(),
updated_values.begin() + kFirstInputIdx);
// If any values are greater than the tolerance, we have not converged.
auto tolerance_comparison =
Ge(Abs(Sub(delta, FullLike(delta, 1.0))), FullLike(delta, threshold));
updated_values[kValuesUnconvergedIdx] =
ReduceAll(tolerance_comparison, ConstantR0<bool>(body_builder, false),
CreateScalarOrComputation(PRED, body_builder));
return updated_values;
};
TF_ASSIGN_OR_RETURN(std::vector<XlaOp> partial_denominator,
nth_partial_denominator(Zero(&b, U32), inputs, &b));
TF_RET_CHECK(partial_denominator.size() == 1);
auto h = partial_denominator[0];
auto small_constant = FullLike(h, small);
h = Select(Lt(Abs(h), small_constant), small_constant, h);
std::vector<XlaOp> values(kFirstInputIdx + inputs.size());
values[kIterationIdx] = One(&b, U32);
values[kValuesUnconvergedIdx] = ConstantR0<bool>(&b, true);
values[kCIdx] = h;
values[kDIdx] = FullLike(h, 0.0);
values[kHIdx] = h;
std::copy(inputs.begin(), inputs.end(), values.begin() + kFirstInputIdx);
TF_ASSIGN_OR_RETURN(values, WhileLoopHelper(while_cond_fn, while_body_fn,
values, name, &b));
return values[kHIdx];
});
}
XlaOp RegularizedIncompleteBeta(XlaOp a, XlaOp b, XlaOp x) {
auto& builder = *x.builder();
return builder.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(Shape shape, builder.GetShape(a));
// The partial numerator for the incomplete beta function is given
// here: http://dlmf.nist.gov/8.17.E23 Note that there is a special
// case: the partial numerator for the first iteration is one.
auto NthPartialBetaincNumerator =
[&shape](XlaOp iteration, absl::Span<const XlaOp> inputs,
XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
auto a = inputs[0];
auto b = inputs[1];
auto x = inputs[2];
auto iteration_bcast = Broadcast(iteration, shape.dimensions());
auto iteration_is_even =
Eq(iteration_bcast % FullLike(iteration_bcast, 2),
FullLike(iteration_bcast, 0));
auto iteration_is_one = Eq(iteration_bcast, FullLike(iteration_bcast, 1));
auto iteration_minus_one = iteration_bcast - FullLike(iteration_bcast, 1);
auto m = iteration_minus_one / FullLike(iteration_minus_one, 2);
m = ConvertElementType(m, shape.element_type());
auto one = FullLike(a, 1.0);
auto two = FullLike(a, 2.0);
// Partial numerator terms.
auto even_numerator =
-(a + m) * (a + b + m) * x / ((a + two * m) * (a + two * m + one));
auto odd_numerator =
m * (b - m) * x / ((a + two * m - one) * (a + two * m));
auto one_numerator = ScalarLike(x, 1.0);
auto numerator = Select(iteration_is_even, even_numerator, odd_numerator);
return std::vector<XlaOp>{
Select(iteration_is_one, one_numerator, numerator)};
};
auto NthPartialBetaincDenominator =
[&shape](XlaOp iteration, absl::Span<const XlaOp> inputs,
XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
auto x = inputs[2];
auto iteration_bcast = Broadcast(iteration, shape.dimensions());
return std::vector<XlaOp>{
Select(Eq(iteration_bcast, ScalarLike(iteration_bcast, 0)),
ScalarLike(x, 0.0), ScalarLike(x, 1.0))};
};
// Determine if the inputs are out of range.
auto result_is_nan =
Or(Or(Or(Le(a, ScalarLike(a, 0.0)), Le(b, ScalarLike(b, 0.0))),
Lt(x, ScalarLike(x, 0.0))),
Gt(x, ScalarLike(x, 1.0)));
// The continued fraction will converge rapidly when x < (a+1)/(a+b+2)
// as per: http://dlmf.nist.gov/8.17.E23
//
// Otherwise, we can rewrite using the symmetry relation as per:
// http://dlmf.nist.gov/8.17.E4
auto converges_rapidly =
Lt(x, (a + FullLike(a, 1.0)) / (a + b + FullLike(b, 2.0)));
auto a_orig = a;
a = Select(converges_rapidly, a, b);
b = Select(converges_rapidly, b, a_orig);
x = Select(converges_rapidly, x, Sub(FullLike(x, 1.0), x));
XlaOp continued_fraction;
// Thresholds and iteration counts taken from Cephes.
if (shape.element_type() == F32) {
continued_fraction = LentzThompsonBarnettAlgorithm(
/*num_iterations=*/200,
/*small=*/std::numeric_limits<float>::epsilon() / 2.0f,
/*threshold=*/std::numeric_limits<float>::epsilon() / 2.0f,
/*nth_partial_numerator=*/NthPartialBetaincNumerator,
/*nth_partial_denominator=*/NthPartialBetaincDenominator, {a, b, x},
"Betainc");
} else {
TF_RET_CHECK(shape.element_type() == F64);
continued_fraction = LentzThompsonBarnettAlgorithm(
/*num_iterations=*/600,
/*small=*/std::numeric_limits<double>::epsilon() / 2.0f,
/*threshold=*/std::numeric_limits<double>::epsilon() / 2.0f,
/*nth_partial_numerator=*/NthPartialBetaincNumerator,
/*nth_partial_denominator=*/NthPartialBetaincDenominator, {a, b, x},
"Betainc");
}
// We want to compute the regularized complete beta function so we need to
// combine the continued fraction with a few more terms as well as dividing
// it by Beta(a, b). To avoid overflow, we compute in the log domain.
// See http://dlmf.nist.gov/8.17.E22 for an easier to read version of this
// formula.
auto lbeta = Lbeta(a, b);
auto result =
continued_fraction * Exp(Log(x) * a + Log1p(-x) * b - lbeta) / a;
result =
Select(result_is_nan, NanValue(&builder, shape.element_type()), result);
// We have an additional fixup to do if we are taking advantage of the
// symmetry relation.
return Select(converges_rapidly, result,
Sub(FullLike(result, 1.0), result));
});
}
} // namespace xla

View File

@ -108,6 +108,9 @@ XlaOp BesselI0e(XlaOp x);
// at x.
XlaOp BesselI1e(XlaOp x);
// Computes the Regularized Incomplete Beta function.
XlaOp RegularizedIncompleteBeta(XlaOp a, XlaOp b, XlaOp x);
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATH_H_

View File

@ -50,49 +50,51 @@ class BetaincTest(test.TestCase):
tf_out_t = math_ops.betainc(tf_a_s, tf_b_s, tf_x_s)
with self.cached_session():
tf_out = self.evaluate(tf_out_t)
scipy_out = special.betainc(a_s, b_s, x_s).astype(np_dt)
scipy_out = special.betainc(a_s, b_s, x_s, dtype=np_dt)
# the scipy version of betainc uses a double-only implementation.
# TODO(ebrevdo): identify reasons for (sometime) precision loss
# with doubles
tol = 1e-4 if dtype == dtypes.float32 else 5e-5
self.assertAllCloseAccordingToType(scipy_out, tf_out, rtol=tol, atol=0)
rtol = 1e-4 if dtype == dtypes.float32 else 5e-5
atol = 9e-6 if dtype == dtypes.float32 else 3e-6
self.assertAllCloseAccordingToType(
scipy_out, tf_out, rtol=rtol, atol=atol)
# Test out-of-range values (most should return nan output)
combinations = list(itertools.product([-1, 0, 0.5, 1.0, 1.5], repeat=3))
a_comb, b_comb, x_comb = np.asarray(list(zip(*combinations)), dtype=np_dt)
with self.cached_session():
tf_comb = math_ops.betainc(a_comb, b_comb, x_comb).eval()
scipy_comb = special.betainc(a_comb, b_comb, x_comb).astype(np_dt)
scipy_comb = special.betainc(a_comb, b_comb, x_comb, dtype=np_dt)
self.assertAllCloseAccordingToType(scipy_comb, tf_comb)
# Test broadcasting between scalars and other shapes
with self.cached_session():
self.assertAllCloseAccordingToType(
special.betainc(0.1, b_s, x_s).astype(np_dt),
special.betainc(0.1, b_s, x_s, dtype=np_dt),
math_ops.betainc(0.1, b_s, x_s).eval(),
rtol=tol,
atol=0)
rtol=rtol,
atol=atol)
self.assertAllCloseAccordingToType(
special.betainc(a_s, 0.1, x_s).astype(np_dt),
special.betainc(a_s, 0.1, x_s, dtype=np_dt),
math_ops.betainc(a_s, 0.1, x_s).eval(),
rtol=tol,
atol=0)
rtol=rtol,
atol=atol)
self.assertAllCloseAccordingToType(
special.betainc(a_s, b_s, 0.1).astype(np_dt),
special.betainc(a_s, b_s, 0.1, dtype=np_dt),
math_ops.betainc(a_s, b_s, 0.1).eval(),
rtol=tol,
atol=0)
rtol=rtol,
atol=atol)
self.assertAllCloseAccordingToType(
special.betainc(0.1, b_s, 0.1).astype(np_dt),
special.betainc(0.1, b_s, 0.1, dtype=np_dt),
math_ops.betainc(0.1, b_s, 0.1).eval(),
rtol=tol,
atol=0)
rtol=rtol,
atol=atol)
self.assertAllCloseAccordingToType(
special.betainc(0.1, 0.1, 0.1).astype(np_dt),
special.betainc(0.1, 0.1, 0.1, dtype=np_dt),
math_ops.betainc(0.1, 0.1, 0.1).eval(),
rtol=tol,
atol=0)
rtol=rtol,
atol=atol)
with self.assertRaisesRegexp(ValueError, "must be equal"):
math_ops.betainc(0.5, [0.5], [[0.5]])

View File

@ -322,7 +322,7 @@ class BetaTest(test.TestCase):
self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x)
if not stats:
return
self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=1e-4, atol=0)
self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=9e-3, atol=5e-6)
def testBetaLogCdf(self):
shape = (30, 40, 50)
@ -335,7 +335,7 @@ class BetaTest(test.TestCase):
self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x)
if not stats:
return
self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=1e-4, atol=0)
self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=3e-3, atol=2e-5)
def testBetaWithSoftplusConcentration(self):
a, b = -4.2, -9.1