[XLA] Implement the regularized incomplete beta function
An implementation involving continued fractions is used. The continued fraction for this function can be seen at http://dlmf.nist.gov/8.17.v PiperOrigin-RevId: 277327138 Change-Id: Icb8527af344b629806fd7e4880072d05e2530cd2
This commit is contained in:
parent
2061ec8104
commit
e5983061d3
@ -1022,6 +1022,7 @@ tf_xla_py_test(
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:extra_py_tests_deps",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
import scipy.special as sps
|
||||
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -31,7 +32,7 @@ from tensorflow.python.platform import googletest
|
||||
|
||||
class TernaryOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
|
||||
def _testTernary(self, op, a, b, c, expected):
|
||||
def _testTernary(self, op, a, b, c, expected, rtol=1e-3, atol=1e-6):
|
||||
with self.session() as session:
|
||||
with self.test_scope():
|
||||
pa = array_ops.placeholder(dtypes.as_dtype(a.dtype), a.shape, name="a")
|
||||
@ -39,7 +40,7 @@ class TernaryOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
pc = array_ops.placeholder(dtypes.as_dtype(c.dtype), c.shape, name="c")
|
||||
output = op(pa, pb, pc)
|
||||
result = session.run(output, {pa: a, pb: b, pc: c})
|
||||
self.assertAllClose(result, expected, rtol=1e-3)
|
||||
self.assertAllClose(result, expected, rtol=rtol, atol=atol)
|
||||
return result
|
||||
|
||||
@parameterized.parameters(
|
||||
@ -210,6 +211,55 @@ class TernaryOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
upper,
|
||||
expected=np.minimum(np.maximum(x, lower), upper))
|
||||
|
||||
def testBetaincSanity(self):
|
||||
# This operation is only supported for float32 and float64.
|
||||
for dtype in self.numeric_types & {np.float32, np.float64}:
|
||||
# Sanity check a few identities:
|
||||
# - betainc(a, b, 0) == 0
|
||||
# - betainc(a, b, 1) == 1
|
||||
# - betainc(a, 1, x) == x ** a
|
||||
# Compare against the implementation in SciPy.
|
||||
a = np.array([.3, .4, .2, .2], dtype=dtype)
|
||||
b = np.array([1., 1., .4, .4], dtype=dtype)
|
||||
x = np.array([.3, .4, .0, .1], dtype=dtype)
|
||||
expected = sps.betainc(a, b, x)
|
||||
self._testTernary(
|
||||
math_ops.betainc, a, b, x, expected, rtol=5e-6, atol=6e-6)
|
||||
|
||||
@parameterized.parameters(
|
||||
{
|
||||
'sigma': 1e15,
|
||||
'rtol': 1e-6,
|
||||
'atol': 1e-6
|
||||
},
|
||||
{
|
||||
'sigma': 30,
|
||||
'rtol': 1e-6,
|
||||
'atol': 2e-3
|
||||
},
|
||||
{
|
||||
'sigma': 1e-8,
|
||||
'rtol': 5e-4,
|
||||
'atol': 3e-6
|
||||
},
|
||||
{
|
||||
'sigma': 1e-16,
|
||||
'rtol': 1e-6,
|
||||
'atol': 2e-4
|
||||
},
|
||||
)
|
||||
def testBetainc(self, sigma, rtol, atol):
|
||||
# This operation is only supported for float32 and float64.
|
||||
for dtype in self.numeric_types & {np.float32, np.float64}:
|
||||
# Randomly generate a, b, x in the numerical domain of betainc.
|
||||
# Compare against the implementation in SciPy.
|
||||
a = np.abs(np.random.randn(10, 10) * sigma).astype(dtype) # in (0, infty)
|
||||
b = np.abs(np.random.randn(10, 10) * sigma).astype(dtype) # in (0, infty)
|
||||
x = np.random.rand(10, 10).astype(dtype) # in (0, 1)
|
||||
expected = sps.betainc(a, b, x, dtype=dtype)
|
||||
self._testTernary(
|
||||
math_ops.betainc, a, b, x, expected, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
||||
|
@ -14,6 +14,7 @@ tf_kernel_library(
|
||||
"batch_norm_op.cc",
|
||||
"batchtospace_op.cc",
|
||||
"bcast_ops.cc",
|
||||
"beta_op.cc",
|
||||
"bias_ops.cc",
|
||||
"binary_ops.cc",
|
||||
"broadcast_to_op.cc",
|
||||
|
81
tensorflow/compiler/tf2xla/kernels/beta_op.cc
Normal file
81
tensorflow/compiler/tf2xla/kernels/beta_op.cc
Normal file
@ -0,0 +1,81 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <limits>
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/lib/broadcast.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/loops.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/math.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
class BetaincOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit BetaincOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
const TensorShape& a_shape = ctx->InputShape(0);
|
||||
const TensorShape& b_shape = ctx->InputShape(1);
|
||||
const TensorShape& x_shape = ctx->InputShape(2);
|
||||
if (a_shape.dims() > 0 && b_shape.dims() > 0) {
|
||||
OP_REQUIRES(ctx, a_shape == b_shape,
|
||||
errors::InvalidArgument(
|
||||
"Shapes of a and b are inconsistent: ",
|
||||
a_shape.DebugString(), " vs. ", b_shape.DebugString()));
|
||||
}
|
||||
if (a_shape.dims() > 0 && x_shape.dims() > 0) {
|
||||
OP_REQUIRES(ctx, a_shape == x_shape,
|
||||
errors::InvalidArgument(
|
||||
"Shapes of a and x are inconsistent: ",
|
||||
a_shape.DebugString(), " vs. ", x_shape.DebugString()));
|
||||
}
|
||||
if (b_shape.dims() > 0 && x_shape.dims() > 0) {
|
||||
OP_REQUIRES(ctx, b_shape == x_shape,
|
||||
errors::InvalidArgument(
|
||||
"Shapes of b and x are inconsistent: ",
|
||||
b_shape.DebugString(), " vs. ", x_shape.DebugString()));
|
||||
}
|
||||
|
||||
TensorShape merged_shape(a_shape);
|
||||
if (b_shape.dims() > 0) merged_shape = b_shape;
|
||||
if (x_shape.dims() > 0) merged_shape = x_shape;
|
||||
|
||||
auto builder = ctx->builder();
|
||||
auto result =
|
||||
builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto a, BroadcastTo(ctx->Input(0), merged_shape.dim_sizes()));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto b, BroadcastTo(ctx->Input(1), merged_shape.dim_sizes()));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto x, BroadcastTo(ctx->Input(2), merged_shape.dim_sizes()));
|
||||
return xla::RegularizedIncompleteBeta(a, b, x);
|
||||
});
|
||||
ctx->SetOutput(0, result);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("Betainc"), BetaincOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -145,6 +145,7 @@ cc_library(
|
||||
deps = [
|
||||
":arithmetic",
|
||||
":constants",
|
||||
":loops",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
|
@ -21,6 +21,8 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/loops.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
@ -497,6 +499,30 @@ XlaOp Lgamma(XlaOp input) {
|
||||
});
|
||||
}
|
||||
|
||||
// Computes an approximation of the lbeta function which is equivalent to
|
||||
// log(abs(Beta(a, b))) but avoids overflow by computing it with lgamma.
|
||||
static XlaOp Lbeta(XlaOp a, XlaOp b) {
|
||||
// Beta(a, b) can be computed using Gamma as per
|
||||
// http://dlmf.nist.gov/5.12.E1 as follows:
|
||||
// Beta(a, b) = (Gamma(a) * Gamma(b)) / Gamma(a + b)
|
||||
//
|
||||
// To avoid overflow, we compute in the log domain.
|
||||
//
|
||||
// As per http://dlmf.nist.gov/4.8.E2 we can transform:
|
||||
// Log(a * b)
|
||||
// into:
|
||||
// Log(a) + Log(b)
|
||||
//
|
||||
// Likewise, per https://dlmf.nist.gov/4.8.E4, we can turn:
|
||||
// Log(a - b)
|
||||
// into:
|
||||
// Log(a) - Log(b)
|
||||
//
|
||||
// This means that we can compute Log(Beta(a, b)) by:
|
||||
// Log(Gamma(a)) + Log(Gamma(b)) - Log(Gamma(a + b))
|
||||
return Lgamma(a) + Lgamma(b) - Lgamma(a + b);
|
||||
}
|
||||
|
||||
// Compute the Digamma function using Lanczos' approximation from "A Precision
|
||||
// Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
|
||||
// series B. Vol. 1:
|
||||
@ -1037,4 +1063,212 @@ XlaOp BesselI1e(XlaOp x) {
|
||||
});
|
||||
}
|
||||
|
||||
// I J Thompson and A R Barnett. 1986. Coulomb and Bessel functions of complex
|
||||
// arguments and order. J. Comput. Phys. 64, 2 (June 1986), 490-509.
|
||||
// DOI=http://dx.doi.org/10.1016/0021-9991(86)90046-X
|
||||
static XlaOp LentzThompsonBarnettAlgorithm(
|
||||
int64 num_iterations, double small, double threshold,
|
||||
const ForEachIndexBodyFunction& nth_partial_numerator,
|
||||
const ForEachIndexBodyFunction& nth_partial_denominator,
|
||||
absl::Span<const XlaOp> inputs, absl::string_view name) {
|
||||
auto& b = *inputs.front().builder();
|
||||
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_RET_CHECK(num_iterations < INT32_MAX);
|
||||
|
||||
enum {
|
||||
// Position in the evaluation.
|
||||
kIterationIdx,
|
||||
// Whether or not we have reached the desired tolerance.
|
||||
kValuesUnconvergedIdx,
|
||||
// Ratio between nth canonical numerator and the nth-1 canonical
|
||||
// numerator.
|
||||
kCIdx,
|
||||
// Ratio between nth-1 canonical denominator and the nth canonical
|
||||
// denominator.
|
||||
kDIdx,
|
||||
// Computed approximant in the evaluation.
|
||||
kHIdx,
|
||||
// Inputs follow all of the other state.
|
||||
kFirstInputIdx,
|
||||
};
|
||||
auto while_cond_fn = [num_iterations](
|
||||
absl::Span<const XlaOp> values,
|
||||
XlaBuilder* cond_builder) -> StatusOr<XlaOp> {
|
||||
auto iteration = values[kIterationIdx];
|
||||
auto iterations_remain_cond =
|
||||
Lt(iteration, ScalarLike(iteration, num_iterations));
|
||||
auto values_unconverged_cond = values[kValuesUnconvergedIdx];
|
||||
return And(iterations_remain_cond, values_unconverged_cond);
|
||||
};
|
||||
|
||||
auto while_body_fn =
|
||||
[small, threshold, &nth_partial_numerator, &nth_partial_denominator](
|
||||
absl::Span<const XlaOp> values,
|
||||
XlaBuilder* body_builder) -> StatusOr<std::vector<XlaOp>> {
|
||||
XlaOp iteration = values[kIterationIdx];
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::vector<XlaOp> partial_numerator,
|
||||
nth_partial_numerator(iteration, values.subspan(kFirstInputIdx),
|
||||
body_builder));
|
||||
TF_RET_CHECK(partial_numerator.size() == 1);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::vector<XlaOp> partial_denominator,
|
||||
nth_partial_denominator(iteration, values.subspan(kFirstInputIdx),
|
||||
body_builder));
|
||||
TF_RET_CHECK(partial_denominator.size() == 1);
|
||||
|
||||
auto c = partial_denominator[0] + partial_numerator[0] / values[kCIdx];
|
||||
auto small_constant = FullLike(c, small);
|
||||
c = Select(Lt(Abs(c), small_constant), small_constant, c);
|
||||
|
||||
auto d = partial_denominator[0] + partial_numerator[0] * values[kDIdx];
|
||||
d = Select(Lt(Abs(d), small_constant), small_constant, d);
|
||||
|
||||
d = Reciprocal(d);
|
||||
|
||||
auto delta = c * d;
|
||||
auto h = values[kHIdx] * delta;
|
||||
|
||||
std::vector<XlaOp> updated_values(values.size());
|
||||
updated_values[kIterationIdx] = Add(iteration, ScalarLike(iteration, 1));
|
||||
updated_values[kCIdx] = c;
|
||||
updated_values[kDIdx] = d;
|
||||
updated_values[kHIdx] = h;
|
||||
std::copy(values.begin() + kFirstInputIdx, values.end(),
|
||||
updated_values.begin() + kFirstInputIdx);
|
||||
|
||||
// If any values are greater than the tolerance, we have not converged.
|
||||
auto tolerance_comparison =
|
||||
Ge(Abs(Sub(delta, FullLike(delta, 1.0))), FullLike(delta, threshold));
|
||||
updated_values[kValuesUnconvergedIdx] =
|
||||
ReduceAll(tolerance_comparison, ConstantR0<bool>(body_builder, false),
|
||||
CreateScalarOrComputation(PRED, body_builder));
|
||||
return updated_values;
|
||||
};
|
||||
|
||||
TF_ASSIGN_OR_RETURN(std::vector<XlaOp> partial_denominator,
|
||||
nth_partial_denominator(Zero(&b, U32), inputs, &b));
|
||||
TF_RET_CHECK(partial_denominator.size() == 1);
|
||||
auto h = partial_denominator[0];
|
||||
auto small_constant = FullLike(h, small);
|
||||
h = Select(Lt(Abs(h), small_constant), small_constant, h);
|
||||
|
||||
std::vector<XlaOp> values(kFirstInputIdx + inputs.size());
|
||||
values[kIterationIdx] = One(&b, U32);
|
||||
values[kValuesUnconvergedIdx] = ConstantR0<bool>(&b, true);
|
||||
values[kCIdx] = h;
|
||||
values[kDIdx] = FullLike(h, 0.0);
|
||||
values[kHIdx] = h;
|
||||
std::copy(inputs.begin(), inputs.end(), values.begin() + kFirstInputIdx);
|
||||
TF_ASSIGN_OR_RETURN(values, WhileLoopHelper(while_cond_fn, while_body_fn,
|
||||
values, name, &b));
|
||||
return values[kHIdx];
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp RegularizedIncompleteBeta(XlaOp a, XlaOp b, XlaOp x) {
|
||||
auto& builder = *x.builder();
|
||||
return builder.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(Shape shape, builder.GetShape(a));
|
||||
|
||||
// The partial numerator for the incomplete beta function is given
|
||||
// here: http://dlmf.nist.gov/8.17.E23 Note that there is a special
|
||||
// case: the partial numerator for the first iteration is one.
|
||||
auto NthPartialBetaincNumerator =
|
||||
[&shape](XlaOp iteration, absl::Span<const XlaOp> inputs,
|
||||
XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
|
||||
auto a = inputs[0];
|
||||
auto b = inputs[1];
|
||||
auto x = inputs[2];
|
||||
auto iteration_bcast = Broadcast(iteration, shape.dimensions());
|
||||
auto iteration_is_even =
|
||||
Eq(iteration_bcast % FullLike(iteration_bcast, 2),
|
||||
FullLike(iteration_bcast, 0));
|
||||
auto iteration_is_one = Eq(iteration_bcast, FullLike(iteration_bcast, 1));
|
||||
auto iteration_minus_one = iteration_bcast - FullLike(iteration_bcast, 1);
|
||||
auto m = iteration_minus_one / FullLike(iteration_minus_one, 2);
|
||||
m = ConvertElementType(m, shape.element_type());
|
||||
auto one = FullLike(a, 1.0);
|
||||
auto two = FullLike(a, 2.0);
|
||||
// Partial numerator terms.
|
||||
auto even_numerator =
|
||||
-(a + m) * (a + b + m) * x / ((a + two * m) * (a + two * m + one));
|
||||
auto odd_numerator =
|
||||
m * (b - m) * x / ((a + two * m - one) * (a + two * m));
|
||||
auto one_numerator = ScalarLike(x, 1.0);
|
||||
auto numerator = Select(iteration_is_even, even_numerator, odd_numerator);
|
||||
return std::vector<XlaOp>{
|
||||
Select(iteration_is_one, one_numerator, numerator)};
|
||||
};
|
||||
|
||||
auto NthPartialBetaincDenominator =
|
||||
[&shape](XlaOp iteration, absl::Span<const XlaOp> inputs,
|
||||
XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
|
||||
auto x = inputs[2];
|
||||
auto iteration_bcast = Broadcast(iteration, shape.dimensions());
|
||||
return std::vector<XlaOp>{
|
||||
Select(Eq(iteration_bcast, ScalarLike(iteration_bcast, 0)),
|
||||
ScalarLike(x, 0.0), ScalarLike(x, 1.0))};
|
||||
};
|
||||
|
||||
// Determine if the inputs are out of range.
|
||||
auto result_is_nan =
|
||||
Or(Or(Or(Le(a, ScalarLike(a, 0.0)), Le(b, ScalarLike(b, 0.0))),
|
||||
Lt(x, ScalarLike(x, 0.0))),
|
||||
Gt(x, ScalarLike(x, 1.0)));
|
||||
|
||||
// The continued fraction will converge rapidly when x < (a+1)/(a+b+2)
|
||||
// as per: http://dlmf.nist.gov/8.17.E23
|
||||
//
|
||||
// Otherwise, we can rewrite using the symmetry relation as per:
|
||||
// http://dlmf.nist.gov/8.17.E4
|
||||
auto converges_rapidly =
|
||||
Lt(x, (a + FullLike(a, 1.0)) / (a + b + FullLike(b, 2.0)));
|
||||
auto a_orig = a;
|
||||
a = Select(converges_rapidly, a, b);
|
||||
b = Select(converges_rapidly, b, a_orig);
|
||||
x = Select(converges_rapidly, x, Sub(FullLike(x, 1.0), x));
|
||||
|
||||
XlaOp continued_fraction;
|
||||
|
||||
// Thresholds and iteration counts taken from Cephes.
|
||||
if (shape.element_type() == F32) {
|
||||
continued_fraction = LentzThompsonBarnettAlgorithm(
|
||||
/*num_iterations=*/200,
|
||||
/*small=*/std::numeric_limits<float>::epsilon() / 2.0f,
|
||||
/*threshold=*/std::numeric_limits<float>::epsilon() / 2.0f,
|
||||
/*nth_partial_numerator=*/NthPartialBetaincNumerator,
|
||||
/*nth_partial_denominator=*/NthPartialBetaincDenominator, {a, b, x},
|
||||
"Betainc");
|
||||
} else {
|
||||
TF_RET_CHECK(shape.element_type() == F64);
|
||||
continued_fraction = LentzThompsonBarnettAlgorithm(
|
||||
/*num_iterations=*/600,
|
||||
/*small=*/std::numeric_limits<double>::epsilon() / 2.0f,
|
||||
/*threshold=*/std::numeric_limits<double>::epsilon() / 2.0f,
|
||||
/*nth_partial_numerator=*/NthPartialBetaincNumerator,
|
||||
/*nth_partial_denominator=*/NthPartialBetaincDenominator, {a, b, x},
|
||||
"Betainc");
|
||||
}
|
||||
|
||||
// We want to compute the regularized complete beta function so we need to
|
||||
// combine the continued fraction with a few more terms as well as dividing
|
||||
// it by Beta(a, b). To avoid overflow, we compute in the log domain.
|
||||
// See http://dlmf.nist.gov/8.17.E22 for an easier to read version of this
|
||||
// formula.
|
||||
auto lbeta = Lbeta(a, b);
|
||||
auto result =
|
||||
continued_fraction * Exp(Log(x) * a + Log1p(-x) * b - lbeta) / a;
|
||||
result =
|
||||
Select(result_is_nan, NanValue(&builder, shape.element_type()), result);
|
||||
|
||||
// We have an additional fixup to do if we are taking advantage of the
|
||||
// symmetry relation.
|
||||
return Select(converges_rapidly, result,
|
||||
Sub(FullLike(result, 1.0), result));
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -108,6 +108,9 @@ XlaOp BesselI0e(XlaOp x);
|
||||
// at x.
|
||||
XlaOp BesselI1e(XlaOp x);
|
||||
|
||||
// Computes the Regularized Incomplete Beta function.
|
||||
XlaOp RegularizedIncompleteBeta(XlaOp a, XlaOp b, XlaOp x);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATH_H_
|
||||
|
@ -50,49 +50,51 @@ class BetaincTest(test.TestCase):
|
||||
tf_out_t = math_ops.betainc(tf_a_s, tf_b_s, tf_x_s)
|
||||
with self.cached_session():
|
||||
tf_out = self.evaluate(tf_out_t)
|
||||
scipy_out = special.betainc(a_s, b_s, x_s).astype(np_dt)
|
||||
scipy_out = special.betainc(a_s, b_s, x_s, dtype=np_dt)
|
||||
|
||||
# the scipy version of betainc uses a double-only implementation.
|
||||
# TODO(ebrevdo): identify reasons for (sometime) precision loss
|
||||
# with doubles
|
||||
tol = 1e-4 if dtype == dtypes.float32 else 5e-5
|
||||
self.assertAllCloseAccordingToType(scipy_out, tf_out, rtol=tol, atol=0)
|
||||
rtol = 1e-4 if dtype == dtypes.float32 else 5e-5
|
||||
atol = 9e-6 if dtype == dtypes.float32 else 3e-6
|
||||
self.assertAllCloseAccordingToType(
|
||||
scipy_out, tf_out, rtol=rtol, atol=atol)
|
||||
|
||||
# Test out-of-range values (most should return nan output)
|
||||
combinations = list(itertools.product([-1, 0, 0.5, 1.0, 1.5], repeat=3))
|
||||
a_comb, b_comb, x_comb = np.asarray(list(zip(*combinations)), dtype=np_dt)
|
||||
with self.cached_session():
|
||||
tf_comb = math_ops.betainc(a_comb, b_comb, x_comb).eval()
|
||||
scipy_comb = special.betainc(a_comb, b_comb, x_comb).astype(np_dt)
|
||||
scipy_comb = special.betainc(a_comb, b_comb, x_comb, dtype=np_dt)
|
||||
self.assertAllCloseAccordingToType(scipy_comb, tf_comb)
|
||||
|
||||
# Test broadcasting between scalars and other shapes
|
||||
with self.cached_session():
|
||||
self.assertAllCloseAccordingToType(
|
||||
special.betainc(0.1, b_s, x_s).astype(np_dt),
|
||||
special.betainc(0.1, b_s, x_s, dtype=np_dt),
|
||||
math_ops.betainc(0.1, b_s, x_s).eval(),
|
||||
rtol=tol,
|
||||
atol=0)
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
self.assertAllCloseAccordingToType(
|
||||
special.betainc(a_s, 0.1, x_s).astype(np_dt),
|
||||
special.betainc(a_s, 0.1, x_s, dtype=np_dt),
|
||||
math_ops.betainc(a_s, 0.1, x_s).eval(),
|
||||
rtol=tol,
|
||||
atol=0)
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
self.assertAllCloseAccordingToType(
|
||||
special.betainc(a_s, b_s, 0.1).astype(np_dt),
|
||||
special.betainc(a_s, b_s, 0.1, dtype=np_dt),
|
||||
math_ops.betainc(a_s, b_s, 0.1).eval(),
|
||||
rtol=tol,
|
||||
atol=0)
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
self.assertAllCloseAccordingToType(
|
||||
special.betainc(0.1, b_s, 0.1).astype(np_dt),
|
||||
special.betainc(0.1, b_s, 0.1, dtype=np_dt),
|
||||
math_ops.betainc(0.1, b_s, 0.1).eval(),
|
||||
rtol=tol,
|
||||
atol=0)
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
self.assertAllCloseAccordingToType(
|
||||
special.betainc(0.1, 0.1, 0.1).astype(np_dt),
|
||||
special.betainc(0.1, 0.1, 0.1, dtype=np_dt),
|
||||
math_ops.betainc(0.1, 0.1, 0.1).eval(),
|
||||
rtol=tol,
|
||||
atol=0)
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, "must be equal"):
|
||||
math_ops.betainc(0.5, [0.5], [[0.5]])
|
||||
|
@ -322,7 +322,7 @@ class BetaTest(test.TestCase):
|
||||
self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x)
|
||||
if not stats:
|
||||
return
|
||||
self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=1e-4, atol=0)
|
||||
self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=9e-3, atol=5e-6)
|
||||
|
||||
def testBetaLogCdf(self):
|
||||
shape = (30, 40, 50)
|
||||
@ -335,7 +335,7 @@ class BetaTest(test.TestCase):
|
||||
self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x)
|
||||
if not stats:
|
||||
return
|
||||
self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=1e-4, atol=0)
|
||||
self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=3e-3, atol=2e-5)
|
||||
|
||||
def testBetaWithSoftplusConcentration(self):
|
||||
a, b = -4.2, -9.1
|
||||
|
Loading…
x
Reference in New Issue
Block a user