[XLA] Implement the regularized incomplete beta function
An implementation involving continued fractions is used. The continued fraction for this function can be seen at http://dlmf.nist.gov/8.17.v PiperOrigin-RevId: 277327138 Change-Id: Icb8527af344b629806fd7e4880072d05e2530cd2
This commit is contained in:
parent
2061ec8104
commit
e5983061d3
@ -1022,6 +1022,7 @@ tf_xla_py_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":xla_test",
|
":xla_test",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
|
"//tensorflow/python:extra_py_tests_deps",
|
||||||
"//tensorflow/python:framework",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import scipy.special as sps
|
||||||
|
|
||||||
from tensorflow.compiler.tests import xla_test
|
from tensorflow.compiler.tests import xla_test
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
@ -31,7 +32,7 @@ from tensorflow.python.platform import googletest
|
|||||||
|
|
||||||
class TernaryOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
class TernaryOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||||
|
|
||||||
def _testTernary(self, op, a, b, c, expected):
|
def _testTernary(self, op, a, b, c, expected, rtol=1e-3, atol=1e-6):
|
||||||
with self.session() as session:
|
with self.session() as session:
|
||||||
with self.test_scope():
|
with self.test_scope():
|
||||||
pa = array_ops.placeholder(dtypes.as_dtype(a.dtype), a.shape, name="a")
|
pa = array_ops.placeholder(dtypes.as_dtype(a.dtype), a.shape, name="a")
|
||||||
@ -39,7 +40,7 @@ class TernaryOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
|||||||
pc = array_ops.placeholder(dtypes.as_dtype(c.dtype), c.shape, name="c")
|
pc = array_ops.placeholder(dtypes.as_dtype(c.dtype), c.shape, name="c")
|
||||||
output = op(pa, pb, pc)
|
output = op(pa, pb, pc)
|
||||||
result = session.run(output, {pa: a, pb: b, pc: c})
|
result = session.run(output, {pa: a, pb: b, pc: c})
|
||||||
self.assertAllClose(result, expected, rtol=1e-3)
|
self.assertAllClose(result, expected, rtol=rtol, atol=atol)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@parameterized.parameters(
|
@parameterized.parameters(
|
||||||
@ -210,6 +211,55 @@ class TernaryOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
|||||||
upper,
|
upper,
|
||||||
expected=np.minimum(np.maximum(x, lower), upper))
|
expected=np.minimum(np.maximum(x, lower), upper))
|
||||||
|
|
||||||
|
def testBetaincSanity(self):
|
||||||
|
# This operation is only supported for float32 and float64.
|
||||||
|
for dtype in self.numeric_types & {np.float32, np.float64}:
|
||||||
|
# Sanity check a few identities:
|
||||||
|
# - betainc(a, b, 0) == 0
|
||||||
|
# - betainc(a, b, 1) == 1
|
||||||
|
# - betainc(a, 1, x) == x ** a
|
||||||
|
# Compare against the implementation in SciPy.
|
||||||
|
a = np.array([.3, .4, .2, .2], dtype=dtype)
|
||||||
|
b = np.array([1., 1., .4, .4], dtype=dtype)
|
||||||
|
x = np.array([.3, .4, .0, .1], dtype=dtype)
|
||||||
|
expected = sps.betainc(a, b, x)
|
||||||
|
self._testTernary(
|
||||||
|
math_ops.betainc, a, b, x, expected, rtol=5e-6, atol=6e-6)
|
||||||
|
|
||||||
|
@parameterized.parameters(
|
||||||
|
{
|
||||||
|
'sigma': 1e15,
|
||||||
|
'rtol': 1e-6,
|
||||||
|
'atol': 1e-6
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'sigma': 30,
|
||||||
|
'rtol': 1e-6,
|
||||||
|
'atol': 2e-3
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'sigma': 1e-8,
|
||||||
|
'rtol': 5e-4,
|
||||||
|
'atol': 3e-6
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'sigma': 1e-16,
|
||||||
|
'rtol': 1e-6,
|
||||||
|
'atol': 2e-4
|
||||||
|
},
|
||||||
|
)
|
||||||
|
def testBetainc(self, sigma, rtol, atol):
|
||||||
|
# This operation is only supported for float32 and float64.
|
||||||
|
for dtype in self.numeric_types & {np.float32, np.float64}:
|
||||||
|
# Randomly generate a, b, x in the numerical domain of betainc.
|
||||||
|
# Compare against the implementation in SciPy.
|
||||||
|
a = np.abs(np.random.randn(10, 10) * sigma).astype(dtype) # in (0, infty)
|
||||||
|
b = np.abs(np.random.randn(10, 10) * sigma).astype(dtype) # in (0, infty)
|
||||||
|
x = np.random.rand(10, 10).astype(dtype) # in (0, 1)
|
||||||
|
expected = sps.betainc(a, b, x, dtype=dtype)
|
||||||
|
self._testTernary(
|
||||||
|
math_ops.betainc, a, b, x, expected, rtol=rtol, atol=atol)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
googletest.main()
|
googletest.main()
|
||||||
|
@ -14,6 +14,7 @@ tf_kernel_library(
|
|||||||
"batch_norm_op.cc",
|
"batch_norm_op.cc",
|
||||||
"batchtospace_op.cc",
|
"batchtospace_op.cc",
|
||||||
"bcast_ops.cc",
|
"bcast_ops.cc",
|
||||||
|
"beta_op.cc",
|
||||||
"bias_ops.cc",
|
"bias_ops.cc",
|
||||||
"binary_ops.cc",
|
"binary_ops.cc",
|
||||||
"broadcast_to_op.cc",
|
"broadcast_to_op.cc",
|
||||||
|
81
tensorflow/compiler/tf2xla/kernels/beta_op.cc
Normal file
81
tensorflow/compiler/tf2xla/kernels/beta_op.cc
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <limits>
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/tf2xla/lib/broadcast.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/lib/loops.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/lib/math.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||||
|
#include "tensorflow/compiler/xla/status_macros.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
class BetaincOp : public XlaOpKernel {
|
||||||
|
public:
|
||||||
|
explicit BetaincOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||||
|
|
||||||
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
|
const TensorShape& a_shape = ctx->InputShape(0);
|
||||||
|
const TensorShape& b_shape = ctx->InputShape(1);
|
||||||
|
const TensorShape& x_shape = ctx->InputShape(2);
|
||||||
|
if (a_shape.dims() > 0 && b_shape.dims() > 0) {
|
||||||
|
OP_REQUIRES(ctx, a_shape == b_shape,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Shapes of a and b are inconsistent: ",
|
||||||
|
a_shape.DebugString(), " vs. ", b_shape.DebugString()));
|
||||||
|
}
|
||||||
|
if (a_shape.dims() > 0 && x_shape.dims() > 0) {
|
||||||
|
OP_REQUIRES(ctx, a_shape == x_shape,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Shapes of a and x are inconsistent: ",
|
||||||
|
a_shape.DebugString(), " vs. ", x_shape.DebugString()));
|
||||||
|
}
|
||||||
|
if (b_shape.dims() > 0 && x_shape.dims() > 0) {
|
||||||
|
OP_REQUIRES(ctx, b_shape == x_shape,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Shapes of b and x are inconsistent: ",
|
||||||
|
b_shape.DebugString(), " vs. ", x_shape.DebugString()));
|
||||||
|
}
|
||||||
|
|
||||||
|
TensorShape merged_shape(a_shape);
|
||||||
|
if (b_shape.dims() > 0) merged_shape = b_shape;
|
||||||
|
if (x_shape.dims() > 0) merged_shape = x_shape;
|
||||||
|
|
||||||
|
auto builder = ctx->builder();
|
||||||
|
auto result =
|
||||||
|
builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
auto a, BroadcastTo(ctx->Input(0), merged_shape.dim_sizes()));
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
auto b, BroadcastTo(ctx->Input(1), merged_shape.dim_sizes()));
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
auto x, BroadcastTo(ctx->Input(2), merged_shape.dim_sizes()));
|
||||||
|
return xla::RegularizedIncompleteBeta(a, b, x);
|
||||||
|
});
|
||||||
|
ctx->SetOutput(0, result);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_XLA_OP(Name("Betainc"), BetaincOp);
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace tensorflow
|
@ -145,6 +145,7 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":arithmetic",
|
":arithmetic",
|
||||||
":constants",
|
":constants",
|
||||||
|
":loops",
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
"//tensorflow/compiler/xla:status_macros",
|
"//tensorflow/compiler/xla:status_macros",
|
||||||
"//tensorflow/compiler/xla/client:xla_builder",
|
"//tensorflow/compiler/xla/client:xla_builder",
|
||||||
|
@ -21,6 +21,8 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
|
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
|
||||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/lib/loops.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/xla/status_macros.h"
|
#include "tensorflow/compiler/xla/status_macros.h"
|
||||||
@ -497,6 +499,30 @@ XlaOp Lgamma(XlaOp input) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Computes an approximation of the lbeta function which is equivalent to
|
||||||
|
// log(abs(Beta(a, b))) but avoids overflow by computing it with lgamma.
|
||||||
|
static XlaOp Lbeta(XlaOp a, XlaOp b) {
|
||||||
|
// Beta(a, b) can be computed using Gamma as per
|
||||||
|
// http://dlmf.nist.gov/5.12.E1 as follows:
|
||||||
|
// Beta(a, b) = (Gamma(a) * Gamma(b)) / Gamma(a + b)
|
||||||
|
//
|
||||||
|
// To avoid overflow, we compute in the log domain.
|
||||||
|
//
|
||||||
|
// As per http://dlmf.nist.gov/4.8.E2 we can transform:
|
||||||
|
// Log(a * b)
|
||||||
|
// into:
|
||||||
|
// Log(a) + Log(b)
|
||||||
|
//
|
||||||
|
// Likewise, per https://dlmf.nist.gov/4.8.E4, we can turn:
|
||||||
|
// Log(a - b)
|
||||||
|
// into:
|
||||||
|
// Log(a) - Log(b)
|
||||||
|
//
|
||||||
|
// This means that we can compute Log(Beta(a, b)) by:
|
||||||
|
// Log(Gamma(a)) + Log(Gamma(b)) - Log(Gamma(a + b))
|
||||||
|
return Lgamma(a) + Lgamma(b) - Lgamma(a + b);
|
||||||
|
}
|
||||||
|
|
||||||
// Compute the Digamma function using Lanczos' approximation from "A Precision
|
// Compute the Digamma function using Lanczos' approximation from "A Precision
|
||||||
// Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
|
// Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
|
||||||
// series B. Vol. 1:
|
// series B. Vol. 1:
|
||||||
@ -1037,4 +1063,212 @@ XlaOp BesselI1e(XlaOp x) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// I J Thompson and A R Barnett. 1986. Coulomb and Bessel functions of complex
|
||||||
|
// arguments and order. J. Comput. Phys. 64, 2 (June 1986), 490-509.
|
||||||
|
// DOI=http://dx.doi.org/10.1016/0021-9991(86)90046-X
|
||||||
|
static XlaOp LentzThompsonBarnettAlgorithm(
|
||||||
|
int64 num_iterations, double small, double threshold,
|
||||||
|
const ForEachIndexBodyFunction& nth_partial_numerator,
|
||||||
|
const ForEachIndexBodyFunction& nth_partial_denominator,
|
||||||
|
absl::Span<const XlaOp> inputs, absl::string_view name) {
|
||||||
|
auto& b = *inputs.front().builder();
|
||||||
|
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||||
|
TF_RET_CHECK(num_iterations < INT32_MAX);
|
||||||
|
|
||||||
|
enum {
|
||||||
|
// Position in the evaluation.
|
||||||
|
kIterationIdx,
|
||||||
|
// Whether or not we have reached the desired tolerance.
|
||||||
|
kValuesUnconvergedIdx,
|
||||||
|
// Ratio between nth canonical numerator and the nth-1 canonical
|
||||||
|
// numerator.
|
||||||
|
kCIdx,
|
||||||
|
// Ratio between nth-1 canonical denominator and the nth canonical
|
||||||
|
// denominator.
|
||||||
|
kDIdx,
|
||||||
|
// Computed approximant in the evaluation.
|
||||||
|
kHIdx,
|
||||||
|
// Inputs follow all of the other state.
|
||||||
|
kFirstInputIdx,
|
||||||
|
};
|
||||||
|
auto while_cond_fn = [num_iterations](
|
||||||
|
absl::Span<const XlaOp> values,
|
||||||
|
XlaBuilder* cond_builder) -> StatusOr<XlaOp> {
|
||||||
|
auto iteration = values[kIterationIdx];
|
||||||
|
auto iterations_remain_cond =
|
||||||
|
Lt(iteration, ScalarLike(iteration, num_iterations));
|
||||||
|
auto values_unconverged_cond = values[kValuesUnconvergedIdx];
|
||||||
|
return And(iterations_remain_cond, values_unconverged_cond);
|
||||||
|
};
|
||||||
|
|
||||||
|
auto while_body_fn =
|
||||||
|
[small, threshold, &nth_partial_numerator, &nth_partial_denominator](
|
||||||
|
absl::Span<const XlaOp> values,
|
||||||
|
XlaBuilder* body_builder) -> StatusOr<std::vector<XlaOp>> {
|
||||||
|
XlaOp iteration = values[kIterationIdx];
|
||||||
|
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
std::vector<XlaOp> partial_numerator,
|
||||||
|
nth_partial_numerator(iteration, values.subspan(kFirstInputIdx),
|
||||||
|
body_builder));
|
||||||
|
TF_RET_CHECK(partial_numerator.size() == 1);
|
||||||
|
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
std::vector<XlaOp> partial_denominator,
|
||||||
|
nth_partial_denominator(iteration, values.subspan(kFirstInputIdx),
|
||||||
|
body_builder));
|
||||||
|
TF_RET_CHECK(partial_denominator.size() == 1);
|
||||||
|
|
||||||
|
auto c = partial_denominator[0] + partial_numerator[0] / values[kCIdx];
|
||||||
|
auto small_constant = FullLike(c, small);
|
||||||
|
c = Select(Lt(Abs(c), small_constant), small_constant, c);
|
||||||
|
|
||||||
|
auto d = partial_denominator[0] + partial_numerator[0] * values[kDIdx];
|
||||||
|
d = Select(Lt(Abs(d), small_constant), small_constant, d);
|
||||||
|
|
||||||
|
d = Reciprocal(d);
|
||||||
|
|
||||||
|
auto delta = c * d;
|
||||||
|
auto h = values[kHIdx] * delta;
|
||||||
|
|
||||||
|
std::vector<XlaOp> updated_values(values.size());
|
||||||
|
updated_values[kIterationIdx] = Add(iteration, ScalarLike(iteration, 1));
|
||||||
|
updated_values[kCIdx] = c;
|
||||||
|
updated_values[kDIdx] = d;
|
||||||
|
updated_values[kHIdx] = h;
|
||||||
|
std::copy(values.begin() + kFirstInputIdx, values.end(),
|
||||||
|
updated_values.begin() + kFirstInputIdx);
|
||||||
|
|
||||||
|
// If any values are greater than the tolerance, we have not converged.
|
||||||
|
auto tolerance_comparison =
|
||||||
|
Ge(Abs(Sub(delta, FullLike(delta, 1.0))), FullLike(delta, threshold));
|
||||||
|
updated_values[kValuesUnconvergedIdx] =
|
||||||
|
ReduceAll(tolerance_comparison, ConstantR0<bool>(body_builder, false),
|
||||||
|
CreateScalarOrComputation(PRED, body_builder));
|
||||||
|
return updated_values;
|
||||||
|
};
|
||||||
|
|
||||||
|
TF_ASSIGN_OR_RETURN(std::vector<XlaOp> partial_denominator,
|
||||||
|
nth_partial_denominator(Zero(&b, U32), inputs, &b));
|
||||||
|
TF_RET_CHECK(partial_denominator.size() == 1);
|
||||||
|
auto h = partial_denominator[0];
|
||||||
|
auto small_constant = FullLike(h, small);
|
||||||
|
h = Select(Lt(Abs(h), small_constant), small_constant, h);
|
||||||
|
|
||||||
|
std::vector<XlaOp> values(kFirstInputIdx + inputs.size());
|
||||||
|
values[kIterationIdx] = One(&b, U32);
|
||||||
|
values[kValuesUnconvergedIdx] = ConstantR0<bool>(&b, true);
|
||||||
|
values[kCIdx] = h;
|
||||||
|
values[kDIdx] = FullLike(h, 0.0);
|
||||||
|
values[kHIdx] = h;
|
||||||
|
std::copy(inputs.begin(), inputs.end(), values.begin() + kFirstInputIdx);
|
||||||
|
TF_ASSIGN_OR_RETURN(values, WhileLoopHelper(while_cond_fn, while_body_fn,
|
||||||
|
values, name, &b));
|
||||||
|
return values[kHIdx];
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
XlaOp RegularizedIncompleteBeta(XlaOp a, XlaOp b, XlaOp x) {
|
||||||
|
auto& builder = *x.builder();
|
||||||
|
return builder.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||||
|
TF_ASSIGN_OR_RETURN(Shape shape, builder.GetShape(a));
|
||||||
|
|
||||||
|
// The partial numerator for the incomplete beta function is given
|
||||||
|
// here: http://dlmf.nist.gov/8.17.E23 Note that there is a special
|
||||||
|
// case: the partial numerator for the first iteration is one.
|
||||||
|
auto NthPartialBetaincNumerator =
|
||||||
|
[&shape](XlaOp iteration, absl::Span<const XlaOp> inputs,
|
||||||
|
XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
|
||||||
|
auto a = inputs[0];
|
||||||
|
auto b = inputs[1];
|
||||||
|
auto x = inputs[2];
|
||||||
|
auto iteration_bcast = Broadcast(iteration, shape.dimensions());
|
||||||
|
auto iteration_is_even =
|
||||||
|
Eq(iteration_bcast % FullLike(iteration_bcast, 2),
|
||||||
|
FullLike(iteration_bcast, 0));
|
||||||
|
auto iteration_is_one = Eq(iteration_bcast, FullLike(iteration_bcast, 1));
|
||||||
|
auto iteration_minus_one = iteration_bcast - FullLike(iteration_bcast, 1);
|
||||||
|
auto m = iteration_minus_one / FullLike(iteration_minus_one, 2);
|
||||||
|
m = ConvertElementType(m, shape.element_type());
|
||||||
|
auto one = FullLike(a, 1.0);
|
||||||
|
auto two = FullLike(a, 2.0);
|
||||||
|
// Partial numerator terms.
|
||||||
|
auto even_numerator =
|
||||||
|
-(a + m) * (a + b + m) * x / ((a + two * m) * (a + two * m + one));
|
||||||
|
auto odd_numerator =
|
||||||
|
m * (b - m) * x / ((a + two * m - one) * (a + two * m));
|
||||||
|
auto one_numerator = ScalarLike(x, 1.0);
|
||||||
|
auto numerator = Select(iteration_is_even, even_numerator, odd_numerator);
|
||||||
|
return std::vector<XlaOp>{
|
||||||
|
Select(iteration_is_one, one_numerator, numerator)};
|
||||||
|
};
|
||||||
|
|
||||||
|
auto NthPartialBetaincDenominator =
|
||||||
|
[&shape](XlaOp iteration, absl::Span<const XlaOp> inputs,
|
||||||
|
XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
|
||||||
|
auto x = inputs[2];
|
||||||
|
auto iteration_bcast = Broadcast(iteration, shape.dimensions());
|
||||||
|
return std::vector<XlaOp>{
|
||||||
|
Select(Eq(iteration_bcast, ScalarLike(iteration_bcast, 0)),
|
||||||
|
ScalarLike(x, 0.0), ScalarLike(x, 1.0))};
|
||||||
|
};
|
||||||
|
|
||||||
|
// Determine if the inputs are out of range.
|
||||||
|
auto result_is_nan =
|
||||||
|
Or(Or(Or(Le(a, ScalarLike(a, 0.0)), Le(b, ScalarLike(b, 0.0))),
|
||||||
|
Lt(x, ScalarLike(x, 0.0))),
|
||||||
|
Gt(x, ScalarLike(x, 1.0)));
|
||||||
|
|
||||||
|
// The continued fraction will converge rapidly when x < (a+1)/(a+b+2)
|
||||||
|
// as per: http://dlmf.nist.gov/8.17.E23
|
||||||
|
//
|
||||||
|
// Otherwise, we can rewrite using the symmetry relation as per:
|
||||||
|
// http://dlmf.nist.gov/8.17.E4
|
||||||
|
auto converges_rapidly =
|
||||||
|
Lt(x, (a + FullLike(a, 1.0)) / (a + b + FullLike(b, 2.0)));
|
||||||
|
auto a_orig = a;
|
||||||
|
a = Select(converges_rapidly, a, b);
|
||||||
|
b = Select(converges_rapidly, b, a_orig);
|
||||||
|
x = Select(converges_rapidly, x, Sub(FullLike(x, 1.0), x));
|
||||||
|
|
||||||
|
XlaOp continued_fraction;
|
||||||
|
|
||||||
|
// Thresholds and iteration counts taken from Cephes.
|
||||||
|
if (shape.element_type() == F32) {
|
||||||
|
continued_fraction = LentzThompsonBarnettAlgorithm(
|
||||||
|
/*num_iterations=*/200,
|
||||||
|
/*small=*/std::numeric_limits<float>::epsilon() / 2.0f,
|
||||||
|
/*threshold=*/std::numeric_limits<float>::epsilon() / 2.0f,
|
||||||
|
/*nth_partial_numerator=*/NthPartialBetaincNumerator,
|
||||||
|
/*nth_partial_denominator=*/NthPartialBetaincDenominator, {a, b, x},
|
||||||
|
"Betainc");
|
||||||
|
} else {
|
||||||
|
TF_RET_CHECK(shape.element_type() == F64);
|
||||||
|
continued_fraction = LentzThompsonBarnettAlgorithm(
|
||||||
|
/*num_iterations=*/600,
|
||||||
|
/*small=*/std::numeric_limits<double>::epsilon() / 2.0f,
|
||||||
|
/*threshold=*/std::numeric_limits<double>::epsilon() / 2.0f,
|
||||||
|
/*nth_partial_numerator=*/NthPartialBetaincNumerator,
|
||||||
|
/*nth_partial_denominator=*/NthPartialBetaincDenominator, {a, b, x},
|
||||||
|
"Betainc");
|
||||||
|
}
|
||||||
|
|
||||||
|
// We want to compute the regularized complete beta function so we need to
|
||||||
|
// combine the continued fraction with a few more terms as well as dividing
|
||||||
|
// it by Beta(a, b). To avoid overflow, we compute in the log domain.
|
||||||
|
// See http://dlmf.nist.gov/8.17.E22 for an easier to read version of this
|
||||||
|
// formula.
|
||||||
|
auto lbeta = Lbeta(a, b);
|
||||||
|
auto result =
|
||||||
|
continued_fraction * Exp(Log(x) * a + Log1p(-x) * b - lbeta) / a;
|
||||||
|
result =
|
||||||
|
Select(result_is_nan, NanValue(&builder, shape.element_type()), result);
|
||||||
|
|
||||||
|
// We have an additional fixup to do if we are taking advantage of the
|
||||||
|
// symmetry relation.
|
||||||
|
return Select(converges_rapidly, result,
|
||||||
|
Sub(FullLike(result, 1.0), result));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -108,6 +108,9 @@ XlaOp BesselI0e(XlaOp x);
|
|||||||
// at x.
|
// at x.
|
||||||
XlaOp BesselI1e(XlaOp x);
|
XlaOp BesselI1e(XlaOp x);
|
||||||
|
|
||||||
|
// Computes the Regularized Incomplete Beta function.
|
||||||
|
XlaOp RegularizedIncompleteBeta(XlaOp a, XlaOp b, XlaOp x);
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATH_H_
|
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATH_H_
|
||||||
|
@ -50,49 +50,51 @@ class BetaincTest(test.TestCase):
|
|||||||
tf_out_t = math_ops.betainc(tf_a_s, tf_b_s, tf_x_s)
|
tf_out_t = math_ops.betainc(tf_a_s, tf_b_s, tf_x_s)
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
tf_out = self.evaluate(tf_out_t)
|
tf_out = self.evaluate(tf_out_t)
|
||||||
scipy_out = special.betainc(a_s, b_s, x_s).astype(np_dt)
|
scipy_out = special.betainc(a_s, b_s, x_s, dtype=np_dt)
|
||||||
|
|
||||||
# the scipy version of betainc uses a double-only implementation.
|
# the scipy version of betainc uses a double-only implementation.
|
||||||
# TODO(ebrevdo): identify reasons for (sometime) precision loss
|
# TODO(ebrevdo): identify reasons for (sometime) precision loss
|
||||||
# with doubles
|
# with doubles
|
||||||
tol = 1e-4 if dtype == dtypes.float32 else 5e-5
|
rtol = 1e-4 if dtype == dtypes.float32 else 5e-5
|
||||||
self.assertAllCloseAccordingToType(scipy_out, tf_out, rtol=tol, atol=0)
|
atol = 9e-6 if dtype == dtypes.float32 else 3e-6
|
||||||
|
self.assertAllCloseAccordingToType(
|
||||||
|
scipy_out, tf_out, rtol=rtol, atol=atol)
|
||||||
|
|
||||||
# Test out-of-range values (most should return nan output)
|
# Test out-of-range values (most should return nan output)
|
||||||
combinations = list(itertools.product([-1, 0, 0.5, 1.0, 1.5], repeat=3))
|
combinations = list(itertools.product([-1, 0, 0.5, 1.0, 1.5], repeat=3))
|
||||||
a_comb, b_comb, x_comb = np.asarray(list(zip(*combinations)), dtype=np_dt)
|
a_comb, b_comb, x_comb = np.asarray(list(zip(*combinations)), dtype=np_dt)
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
tf_comb = math_ops.betainc(a_comb, b_comb, x_comb).eval()
|
tf_comb = math_ops.betainc(a_comb, b_comb, x_comb).eval()
|
||||||
scipy_comb = special.betainc(a_comb, b_comb, x_comb).astype(np_dt)
|
scipy_comb = special.betainc(a_comb, b_comb, x_comb, dtype=np_dt)
|
||||||
self.assertAllCloseAccordingToType(scipy_comb, tf_comb)
|
self.assertAllCloseAccordingToType(scipy_comb, tf_comb)
|
||||||
|
|
||||||
# Test broadcasting between scalars and other shapes
|
# Test broadcasting between scalars and other shapes
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
self.assertAllCloseAccordingToType(
|
self.assertAllCloseAccordingToType(
|
||||||
special.betainc(0.1, b_s, x_s).astype(np_dt),
|
special.betainc(0.1, b_s, x_s, dtype=np_dt),
|
||||||
math_ops.betainc(0.1, b_s, x_s).eval(),
|
math_ops.betainc(0.1, b_s, x_s).eval(),
|
||||||
rtol=tol,
|
rtol=rtol,
|
||||||
atol=0)
|
atol=atol)
|
||||||
self.assertAllCloseAccordingToType(
|
self.assertAllCloseAccordingToType(
|
||||||
special.betainc(a_s, 0.1, x_s).astype(np_dt),
|
special.betainc(a_s, 0.1, x_s, dtype=np_dt),
|
||||||
math_ops.betainc(a_s, 0.1, x_s).eval(),
|
math_ops.betainc(a_s, 0.1, x_s).eval(),
|
||||||
rtol=tol,
|
rtol=rtol,
|
||||||
atol=0)
|
atol=atol)
|
||||||
self.assertAllCloseAccordingToType(
|
self.assertAllCloseAccordingToType(
|
||||||
special.betainc(a_s, b_s, 0.1).astype(np_dt),
|
special.betainc(a_s, b_s, 0.1, dtype=np_dt),
|
||||||
math_ops.betainc(a_s, b_s, 0.1).eval(),
|
math_ops.betainc(a_s, b_s, 0.1).eval(),
|
||||||
rtol=tol,
|
rtol=rtol,
|
||||||
atol=0)
|
atol=atol)
|
||||||
self.assertAllCloseAccordingToType(
|
self.assertAllCloseAccordingToType(
|
||||||
special.betainc(0.1, b_s, 0.1).astype(np_dt),
|
special.betainc(0.1, b_s, 0.1, dtype=np_dt),
|
||||||
math_ops.betainc(0.1, b_s, 0.1).eval(),
|
math_ops.betainc(0.1, b_s, 0.1).eval(),
|
||||||
rtol=tol,
|
rtol=rtol,
|
||||||
atol=0)
|
atol=atol)
|
||||||
self.assertAllCloseAccordingToType(
|
self.assertAllCloseAccordingToType(
|
||||||
special.betainc(0.1, 0.1, 0.1).astype(np_dt),
|
special.betainc(0.1, 0.1, 0.1, dtype=np_dt),
|
||||||
math_ops.betainc(0.1, 0.1, 0.1).eval(),
|
math_ops.betainc(0.1, 0.1, 0.1).eval(),
|
||||||
rtol=tol,
|
rtol=rtol,
|
||||||
atol=0)
|
atol=atol)
|
||||||
|
|
||||||
with self.assertRaisesRegexp(ValueError, "must be equal"):
|
with self.assertRaisesRegexp(ValueError, "must be equal"):
|
||||||
math_ops.betainc(0.5, [0.5], [[0.5]])
|
math_ops.betainc(0.5, [0.5], [[0.5]])
|
||||||
|
@ -322,7 +322,7 @@ class BetaTest(test.TestCase):
|
|||||||
self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x)
|
self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x)
|
||||||
if not stats:
|
if not stats:
|
||||||
return
|
return
|
||||||
self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=1e-4, atol=0)
|
self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=9e-3, atol=5e-6)
|
||||||
|
|
||||||
def testBetaLogCdf(self):
|
def testBetaLogCdf(self):
|
||||||
shape = (30, 40, 50)
|
shape = (30, 40, 50)
|
||||||
@ -335,7 +335,7 @@ class BetaTest(test.TestCase):
|
|||||||
self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x)
|
self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x)
|
||||||
if not stats:
|
if not stats:
|
||||||
return
|
return
|
||||||
self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=1e-4, atol=0)
|
self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=3e-3, atol=2e-5)
|
||||||
|
|
||||||
def testBetaWithSoftplusConcentration(self):
|
def testBetaWithSoftplusConcentration(self):
|
||||||
a, b = -4.2, -9.1
|
a, b = -4.2, -9.1
|
||||||
|
Loading…
x
Reference in New Issue
Block a user