diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index dd9cf615e4d..ed31bd0b683 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -1022,6 +1022,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", + "//tensorflow/python:extra_py_tests_deps", "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", diff --git a/tensorflow/compiler/tests/ternary_ops_test.py b/tensorflow/compiler/tests/ternary_ops_test.py index d81a06a99ca..465f368db82 100644 --- a/tensorflow/compiler/tests/ternary_ops_test.py +++ b/tensorflow/compiler/tests/ternary_ops_test.py @@ -20,6 +20,7 @@ from __future__ import print_function from absl.testing import parameterized import numpy as np +import scipy.special as sps from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes @@ -31,7 +32,7 @@ from tensorflow.python.platform import googletest class TernaryOpsTest(xla_test.XLATestCase, parameterized.TestCase): - def _testTernary(self, op, a, b, c, expected): + def _testTernary(self, op, a, b, c, expected, rtol=1e-3, atol=1e-6): with self.session() as session: with self.test_scope(): pa = array_ops.placeholder(dtypes.as_dtype(a.dtype), a.shape, name="a") @@ -39,7 +40,7 @@ class TernaryOpsTest(xla_test.XLATestCase, parameterized.TestCase): pc = array_ops.placeholder(dtypes.as_dtype(c.dtype), c.shape, name="c") output = op(pa, pb, pc) result = session.run(output, {pa: a, pb: b, pc: c}) - self.assertAllClose(result, expected, rtol=1e-3) + self.assertAllClose(result, expected, rtol=rtol, atol=atol) return result @parameterized.parameters( @@ -210,6 +211,55 @@ class TernaryOpsTest(xla_test.XLATestCase, parameterized.TestCase): upper, expected=np.minimum(np.maximum(x, lower), upper)) + def testBetaincSanity(self): + # This operation is only supported for float32 and float64. + for dtype in self.numeric_types & {np.float32, np.float64}: + # Sanity check a few identities: + # - betainc(a, b, 0) == 0 + # - betainc(a, b, 1) == 1 + # - betainc(a, 1, x) == x ** a + # Compare against the implementation in SciPy. + a = np.array([.3, .4, .2, .2], dtype=dtype) + b = np.array([1., 1., .4, .4], dtype=dtype) + x = np.array([.3, .4, .0, .1], dtype=dtype) + expected = sps.betainc(a, b, x) + self._testTernary( + math_ops.betainc, a, b, x, expected, rtol=5e-6, atol=6e-6) + + @parameterized.parameters( + { + 'sigma': 1e15, + 'rtol': 1e-6, + 'atol': 1e-6 + }, + { + 'sigma': 30, + 'rtol': 1e-6, + 'atol': 2e-3 + }, + { + 'sigma': 1e-8, + 'rtol': 5e-4, + 'atol': 3e-6 + }, + { + 'sigma': 1e-16, + 'rtol': 1e-6, + 'atol': 2e-4 + }, + ) + def testBetainc(self, sigma, rtol, atol): + # This operation is only supported for float32 and float64. + for dtype in self.numeric_types & {np.float32, np.float64}: + # Randomly generate a, b, x in the numerical domain of betainc. + # Compare against the implementation in SciPy. + a = np.abs(np.random.randn(10, 10) * sigma).astype(dtype) # in (0, infty) + b = np.abs(np.random.randn(10, 10) * sigma).astype(dtype) # in (0, infty) + x = np.random.rand(10, 10).astype(dtype) # in (0, 1) + expected = sps.betainc(a, b, x, dtype=dtype) + self._testTernary( + math_ops.betainc, a, b, x, expected, rtol=rtol, atol=atol) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index b6e89c6540b..762c9b80bbe 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -14,6 +14,7 @@ tf_kernel_library( "batch_norm_op.cc", "batchtospace_op.cc", "bcast_ops.cc", + "beta_op.cc", "bias_ops.cc", "binary_ops.cc", "broadcast_to_op.cc", diff --git a/tensorflow/compiler/tf2xla/kernels/beta_op.cc b/tensorflow/compiler/tf2xla/kernels/beta_op.cc new file mode 100644 index 00000000000..aa4a8cae118 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/beta_op.cc @@ -0,0 +1,81 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/tf2xla/lib/broadcast.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/loops.h" +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/status_macros.h" + +namespace tensorflow { +namespace { + +class BetaincOp : public XlaOpKernel { + public: + explicit BetaincOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape& a_shape = ctx->InputShape(0); + const TensorShape& b_shape = ctx->InputShape(1); + const TensorShape& x_shape = ctx->InputShape(2); + if (a_shape.dims() > 0 && b_shape.dims() > 0) { + OP_REQUIRES(ctx, a_shape == b_shape, + errors::InvalidArgument( + "Shapes of a and b are inconsistent: ", + a_shape.DebugString(), " vs. ", b_shape.DebugString())); + } + if (a_shape.dims() > 0 && x_shape.dims() > 0) { + OP_REQUIRES(ctx, a_shape == x_shape, + errors::InvalidArgument( + "Shapes of a and x are inconsistent: ", + a_shape.DebugString(), " vs. ", x_shape.DebugString())); + } + if (b_shape.dims() > 0 && x_shape.dims() > 0) { + OP_REQUIRES(ctx, b_shape == x_shape, + errors::InvalidArgument( + "Shapes of b and x are inconsistent: ", + b_shape.DebugString(), " vs. ", x_shape.DebugString())); + } + + TensorShape merged_shape(a_shape); + if (b_shape.dims() > 0) merged_shape = b_shape; + if (x_shape.dims() > 0) merged_shape = x_shape; + + auto builder = ctx->builder(); + auto result = + builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + TF_ASSIGN_OR_RETURN( + auto a, BroadcastTo(ctx->Input(0), merged_shape.dim_sizes())); + TF_ASSIGN_OR_RETURN( + auto b, BroadcastTo(ctx->Input(1), merged_shape.dim_sizes())); + TF_ASSIGN_OR_RETURN( + auto x, BroadcastTo(ctx->Input(2), merged_shape.dim_sizes())); + return xla::RegularizedIncompleteBeta(a, b, x); + }); + ctx->SetOutput(0, result); + } +}; + +REGISTER_XLA_OP(Name("Betainc"), BetaincOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index 99bccdb3bb8..ead07825412 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -145,6 +145,7 @@ cc_library( deps = [ ":arithmetic", ":constants", + ":loops", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla/client:xla_builder", diff --git a/tensorflow/compiler/xla/client/lib/math.cc b/tensorflow/compiler/xla/client/lib/math.cc index 5dedd1de77c..989968b5cbc 100644 --- a/tensorflow/compiler/xla/client/lib/math.cc +++ b/tensorflow/compiler/xla/client/lib/math.cc @@ -21,6 +21,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/loops.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -497,6 +499,30 @@ XlaOp Lgamma(XlaOp input) { }); } +// Computes an approximation of the lbeta function which is equivalent to +// log(abs(Beta(a, b))) but avoids overflow by computing it with lgamma. +static XlaOp Lbeta(XlaOp a, XlaOp b) { + // Beta(a, b) can be computed using Gamma as per + // http://dlmf.nist.gov/5.12.E1 as follows: + // Beta(a, b) = (Gamma(a) * Gamma(b)) / Gamma(a + b) + // + // To avoid overflow, we compute in the log domain. + // + // As per http://dlmf.nist.gov/4.8.E2 we can transform: + // Log(a * b) + // into: + // Log(a) + Log(b) + // + // Likewise, per https://dlmf.nist.gov/4.8.E4, we can turn: + // Log(a - b) + // into: + // Log(a) - Log(b) + // + // This means that we can compute Log(Beta(a, b)) by: + // Log(Gamma(a)) + Log(Gamma(b)) - Log(Gamma(a + b)) + return Lgamma(a) + Lgamma(b) - Lgamma(a + b); +} + // Compute the Digamma function using Lanczos' approximation from "A Precision // Approximation of the Gamma Function". SIAM Journal on Numerical Analysis // series B. Vol. 1: @@ -1037,4 +1063,212 @@ XlaOp BesselI1e(XlaOp x) { }); } +// I J Thompson and A R Barnett. 1986. Coulomb and Bessel functions of complex +// arguments and order. J. Comput. Phys. 64, 2 (June 1986), 490-509. +// DOI=http://dx.doi.org/10.1016/0021-9991(86)90046-X +static XlaOp LentzThompsonBarnettAlgorithm( + int64 num_iterations, double small, double threshold, + const ForEachIndexBodyFunction& nth_partial_numerator, + const ForEachIndexBodyFunction& nth_partial_denominator, + absl::Span inputs, absl::string_view name) { + auto& b = *inputs.front().builder(); + return b.ReportErrorOrReturn([&]() -> StatusOr { + TF_RET_CHECK(num_iterations < INT32_MAX); + + enum { + // Position in the evaluation. + kIterationIdx, + // Whether or not we have reached the desired tolerance. + kValuesUnconvergedIdx, + // Ratio between nth canonical numerator and the nth-1 canonical + // numerator. + kCIdx, + // Ratio between nth-1 canonical denominator and the nth canonical + // denominator. + kDIdx, + // Computed approximant in the evaluation. + kHIdx, + // Inputs follow all of the other state. + kFirstInputIdx, + }; + auto while_cond_fn = [num_iterations]( + absl::Span values, + XlaBuilder* cond_builder) -> StatusOr { + auto iteration = values[kIterationIdx]; + auto iterations_remain_cond = + Lt(iteration, ScalarLike(iteration, num_iterations)); + auto values_unconverged_cond = values[kValuesUnconvergedIdx]; + return And(iterations_remain_cond, values_unconverged_cond); + }; + + auto while_body_fn = + [small, threshold, &nth_partial_numerator, &nth_partial_denominator]( + absl::Span values, + XlaBuilder* body_builder) -> StatusOr> { + XlaOp iteration = values[kIterationIdx]; + + TF_ASSIGN_OR_RETURN( + std::vector partial_numerator, + nth_partial_numerator(iteration, values.subspan(kFirstInputIdx), + body_builder)); + TF_RET_CHECK(partial_numerator.size() == 1); + + TF_ASSIGN_OR_RETURN( + std::vector partial_denominator, + nth_partial_denominator(iteration, values.subspan(kFirstInputIdx), + body_builder)); + TF_RET_CHECK(partial_denominator.size() == 1); + + auto c = partial_denominator[0] + partial_numerator[0] / values[kCIdx]; + auto small_constant = FullLike(c, small); + c = Select(Lt(Abs(c), small_constant), small_constant, c); + + auto d = partial_denominator[0] + partial_numerator[0] * values[kDIdx]; + d = Select(Lt(Abs(d), small_constant), small_constant, d); + + d = Reciprocal(d); + + auto delta = c * d; + auto h = values[kHIdx] * delta; + + std::vector updated_values(values.size()); + updated_values[kIterationIdx] = Add(iteration, ScalarLike(iteration, 1)); + updated_values[kCIdx] = c; + updated_values[kDIdx] = d; + updated_values[kHIdx] = h; + std::copy(values.begin() + kFirstInputIdx, values.end(), + updated_values.begin() + kFirstInputIdx); + + // If any values are greater than the tolerance, we have not converged. + auto tolerance_comparison = + Ge(Abs(Sub(delta, FullLike(delta, 1.0))), FullLike(delta, threshold)); + updated_values[kValuesUnconvergedIdx] = + ReduceAll(tolerance_comparison, ConstantR0(body_builder, false), + CreateScalarOrComputation(PRED, body_builder)); + return updated_values; + }; + + TF_ASSIGN_OR_RETURN(std::vector partial_denominator, + nth_partial_denominator(Zero(&b, U32), inputs, &b)); + TF_RET_CHECK(partial_denominator.size() == 1); + auto h = partial_denominator[0]; + auto small_constant = FullLike(h, small); + h = Select(Lt(Abs(h), small_constant), small_constant, h); + + std::vector values(kFirstInputIdx + inputs.size()); + values[kIterationIdx] = One(&b, U32); + values[kValuesUnconvergedIdx] = ConstantR0(&b, true); + values[kCIdx] = h; + values[kDIdx] = FullLike(h, 0.0); + values[kHIdx] = h; + std::copy(inputs.begin(), inputs.end(), values.begin() + kFirstInputIdx); + TF_ASSIGN_OR_RETURN(values, WhileLoopHelper(while_cond_fn, while_body_fn, + values, name, &b)); + return values[kHIdx]; + }); +} + +XlaOp RegularizedIncompleteBeta(XlaOp a, XlaOp b, XlaOp x) { + auto& builder = *x.builder(); + return builder.ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder.GetShape(a)); + + // The partial numerator for the incomplete beta function is given + // here: http://dlmf.nist.gov/8.17.E23 Note that there is a special + // case: the partial numerator for the first iteration is one. + auto NthPartialBetaincNumerator = + [&shape](XlaOp iteration, absl::Span inputs, + XlaBuilder* builder) -> StatusOr> { + auto a = inputs[0]; + auto b = inputs[1]; + auto x = inputs[2]; + auto iteration_bcast = Broadcast(iteration, shape.dimensions()); + auto iteration_is_even = + Eq(iteration_bcast % FullLike(iteration_bcast, 2), + FullLike(iteration_bcast, 0)); + auto iteration_is_one = Eq(iteration_bcast, FullLike(iteration_bcast, 1)); + auto iteration_minus_one = iteration_bcast - FullLike(iteration_bcast, 1); + auto m = iteration_minus_one / FullLike(iteration_minus_one, 2); + m = ConvertElementType(m, shape.element_type()); + auto one = FullLike(a, 1.0); + auto two = FullLike(a, 2.0); + // Partial numerator terms. + auto even_numerator = + -(a + m) * (a + b + m) * x / ((a + two * m) * (a + two * m + one)); + auto odd_numerator = + m * (b - m) * x / ((a + two * m - one) * (a + two * m)); + auto one_numerator = ScalarLike(x, 1.0); + auto numerator = Select(iteration_is_even, even_numerator, odd_numerator); + return std::vector{ + Select(iteration_is_one, one_numerator, numerator)}; + }; + + auto NthPartialBetaincDenominator = + [&shape](XlaOp iteration, absl::Span inputs, + XlaBuilder* builder) -> StatusOr> { + auto x = inputs[2]; + auto iteration_bcast = Broadcast(iteration, shape.dimensions()); + return std::vector{ + Select(Eq(iteration_bcast, ScalarLike(iteration_bcast, 0)), + ScalarLike(x, 0.0), ScalarLike(x, 1.0))}; + }; + + // Determine if the inputs are out of range. + auto result_is_nan = + Or(Or(Or(Le(a, ScalarLike(a, 0.0)), Le(b, ScalarLike(b, 0.0))), + Lt(x, ScalarLike(x, 0.0))), + Gt(x, ScalarLike(x, 1.0))); + + // The continued fraction will converge rapidly when x < (a+1)/(a+b+2) + // as per: http://dlmf.nist.gov/8.17.E23 + // + // Otherwise, we can rewrite using the symmetry relation as per: + // http://dlmf.nist.gov/8.17.E4 + auto converges_rapidly = + Lt(x, (a + FullLike(a, 1.0)) / (a + b + FullLike(b, 2.0))); + auto a_orig = a; + a = Select(converges_rapidly, a, b); + b = Select(converges_rapidly, b, a_orig); + x = Select(converges_rapidly, x, Sub(FullLike(x, 1.0), x)); + + XlaOp continued_fraction; + + // Thresholds and iteration counts taken from Cephes. + if (shape.element_type() == F32) { + continued_fraction = LentzThompsonBarnettAlgorithm( + /*num_iterations=*/200, + /*small=*/std::numeric_limits::epsilon() / 2.0f, + /*threshold=*/std::numeric_limits::epsilon() / 2.0f, + /*nth_partial_numerator=*/NthPartialBetaincNumerator, + /*nth_partial_denominator=*/NthPartialBetaincDenominator, {a, b, x}, + "Betainc"); + } else { + TF_RET_CHECK(shape.element_type() == F64); + continued_fraction = LentzThompsonBarnettAlgorithm( + /*num_iterations=*/600, + /*small=*/std::numeric_limits::epsilon() / 2.0f, + /*threshold=*/std::numeric_limits::epsilon() / 2.0f, + /*nth_partial_numerator=*/NthPartialBetaincNumerator, + /*nth_partial_denominator=*/NthPartialBetaincDenominator, {a, b, x}, + "Betainc"); + } + + // We want to compute the regularized complete beta function so we need to + // combine the continued fraction with a few more terms as well as dividing + // it by Beta(a, b). To avoid overflow, we compute in the log domain. + // See http://dlmf.nist.gov/8.17.E22 for an easier to read version of this + // formula. + auto lbeta = Lbeta(a, b); + auto result = + continued_fraction * Exp(Log(x) * a + Log1p(-x) * b - lbeta) / a; + result = + Select(result_is_nan, NanValue(&builder, shape.element_type()), result); + + // We have an additional fixup to do if we are taking advantage of the + // symmetry relation. + return Select(converges_rapidly, result, + Sub(FullLike(result, 1.0), result)); + }); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/math.h b/tensorflow/compiler/xla/client/lib/math.h index 89435c309a3..3a0b870f8d8 100644 --- a/tensorflow/compiler/xla/client/lib/math.h +++ b/tensorflow/compiler/xla/client/lib/math.h @@ -108,6 +108,9 @@ XlaOp BesselI0e(XlaOp x); // at x. XlaOp BesselI1e(XlaOp x); +// Computes the Regularized Incomplete Beta function. +XlaOp RegularizedIncompleteBeta(XlaOp a, XlaOp b, XlaOp x); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATH_H_ diff --git a/tensorflow/python/kernel_tests/betainc_op_test.py b/tensorflow/python/kernel_tests/betainc_op_test.py index 9dc34a60628..c4f70b5bc29 100644 --- a/tensorflow/python/kernel_tests/betainc_op_test.py +++ b/tensorflow/python/kernel_tests/betainc_op_test.py @@ -50,49 +50,51 @@ class BetaincTest(test.TestCase): tf_out_t = math_ops.betainc(tf_a_s, tf_b_s, tf_x_s) with self.cached_session(): tf_out = self.evaluate(tf_out_t) - scipy_out = special.betainc(a_s, b_s, x_s).astype(np_dt) + scipy_out = special.betainc(a_s, b_s, x_s, dtype=np_dt) # the scipy version of betainc uses a double-only implementation. # TODO(ebrevdo): identify reasons for (sometime) precision loss # with doubles - tol = 1e-4 if dtype == dtypes.float32 else 5e-5 - self.assertAllCloseAccordingToType(scipy_out, tf_out, rtol=tol, atol=0) + rtol = 1e-4 if dtype == dtypes.float32 else 5e-5 + atol = 9e-6 if dtype == dtypes.float32 else 3e-6 + self.assertAllCloseAccordingToType( + scipy_out, tf_out, rtol=rtol, atol=atol) # Test out-of-range values (most should return nan output) combinations = list(itertools.product([-1, 0, 0.5, 1.0, 1.5], repeat=3)) a_comb, b_comb, x_comb = np.asarray(list(zip(*combinations)), dtype=np_dt) with self.cached_session(): tf_comb = math_ops.betainc(a_comb, b_comb, x_comb).eval() - scipy_comb = special.betainc(a_comb, b_comb, x_comb).astype(np_dt) + scipy_comb = special.betainc(a_comb, b_comb, x_comb, dtype=np_dt) self.assertAllCloseAccordingToType(scipy_comb, tf_comb) # Test broadcasting between scalars and other shapes with self.cached_session(): self.assertAllCloseAccordingToType( - special.betainc(0.1, b_s, x_s).astype(np_dt), + special.betainc(0.1, b_s, x_s, dtype=np_dt), math_ops.betainc(0.1, b_s, x_s).eval(), - rtol=tol, - atol=0) + rtol=rtol, + atol=atol) self.assertAllCloseAccordingToType( - special.betainc(a_s, 0.1, x_s).astype(np_dt), + special.betainc(a_s, 0.1, x_s, dtype=np_dt), math_ops.betainc(a_s, 0.1, x_s).eval(), - rtol=tol, - atol=0) + rtol=rtol, + atol=atol) self.assertAllCloseAccordingToType( - special.betainc(a_s, b_s, 0.1).astype(np_dt), + special.betainc(a_s, b_s, 0.1, dtype=np_dt), math_ops.betainc(a_s, b_s, 0.1).eval(), - rtol=tol, - atol=0) + rtol=rtol, + atol=atol) self.assertAllCloseAccordingToType( - special.betainc(0.1, b_s, 0.1).astype(np_dt), + special.betainc(0.1, b_s, 0.1, dtype=np_dt), math_ops.betainc(0.1, b_s, 0.1).eval(), - rtol=tol, - atol=0) + rtol=rtol, + atol=atol) self.assertAllCloseAccordingToType( - special.betainc(0.1, 0.1, 0.1).astype(np_dt), + special.betainc(0.1, 0.1, 0.1, dtype=np_dt), math_ops.betainc(0.1, 0.1, 0.1).eval(), - rtol=tol, - atol=0) + rtol=rtol, + atol=atol) with self.assertRaisesRegexp(ValueError, "must be equal"): math_ops.betainc(0.5, [0.5], [[0.5]]) diff --git a/tensorflow/python/kernel_tests/distributions/beta_test.py b/tensorflow/python/kernel_tests/distributions/beta_test.py index 42e81bd6589..fa4a4a0822e 100644 --- a/tensorflow/python/kernel_tests/distributions/beta_test.py +++ b/tensorflow/python/kernel_tests/distributions/beta_test.py @@ -322,7 +322,7 @@ class BetaTest(test.TestCase): self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x) if not stats: return - self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=1e-4, atol=0) + self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=9e-3, atol=5e-6) def testBetaLogCdf(self): shape = (30, 40, 50) @@ -335,7 +335,7 @@ class BetaTest(test.TestCase): self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x) if not stats: return - self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=1e-4, atol=0) + self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=3e-3, atol=2e-5) def testBetaWithSoftplusConcentration(self): a, b = -4.2, -9.1