Add XLA registration for Zeta/Zetac

PiperOrigin-RevId: 332509521
Change-Id: If93869e3e4732cb09e85ec9cc61b2fe085fdb1ce
This commit is contained in:
Srinivas Vasudevan 2020-09-18 12:57:56 -07:00 committed by TensorFlower Gardener
parent 7813d49417
commit c8ee679b8e
8 changed files with 231 additions and 1 deletions
tensorflow
compiler
python/ops/parallel_for

View File

@ -2094,6 +2094,7 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
"XlaSpmdShardToFullShape",
"XlaSvd",
"XlaWhile",
"Zeta",
"_Arg",
"_ArrayToList",
"_ListToArray",

View File

@ -31,6 +31,7 @@ import six
from tensorflow.compiler.tests import xla_test
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import test_util
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import gen_random_ops
from tensorflow.python.ops import gradient_checker_v2
@ -54,6 +55,11 @@ def _igammac(a, x):
return math_ops.igammac(a, x)
@def_function.function(experimental_compile=True)
def _zeta(a, q):
return math_ops.zeta(a, q)
# This is df/da / df/dx, where f = igamma.
def implicit_reparameterization_grad(a, x):
log_prob = math_ops.xlogy(a - 1., x) - math_ops.lgamma(a) - x
@ -136,6 +142,120 @@ class Log1pTest(xla_test.XLATestCase, parameterized.TestCase):
self._test_range(0., 3., dtype, rtol, atol, is_negative=False)
class ZetaTest(xla_test.XLATestCase, parameterized.TestCase):
def setUp(self):
if flags.FLAGS.vary_seed:
entropy = os.urandom(64)
if six.PY2:
answer = int(entropy.encode('hex'), 16)
else:
answer = int.from_bytes(entropy, 'big')
np.random.seed(answer % (2**32 - 1))
super(ZetaTest, self).setUp()
def adjust_tolerance_for_tpu(self, dtype, rtol, atol):
if self.device not in ['TPU']:
return rtol, atol
if dtype == np.float32:
return 2e-2, 1e-7
return 2e-4, 1e-20
@test_util.disable_mlir_bridge('TODO(b/165736950): Add support in MLIR')
def testBadValues(self):
q = np.random.uniform(low=0.3, high=20., size=[10])
with self.session() as sess:
with self.test_scope():
y = _zeta(np.float64(1.), q)
actual = sess.run(y)
# When x == 1, this is the Harmonic series.
self.assertTrue(np.all(np.isinf(actual)))
with self.session() as sess:
with self.test_scope():
y = _zeta(np.float64(0.1), q)
actual = sess.run(y)
# When x < 1, this is undefined.
self.assertTrue(np.all(np.isnan(actual)))
with self.session() as sess:
with self.test_scope():
y = _zeta([1., 1.1], [-1.1, -1.])
actual = sess.run(y)
# When q is negative, zeta is not defined
# if q is an integer or x is not an integer.
self.assertTrue(np.all(np.isinf(actual)))
@parameterized.parameters((np.float32, 1e-2, 1e-11),
(np.float64, 1e-4, 1e-30))
@test_util.disable_mlir_bridge('TODO(b/165736950): Add support in MLIR')
def testLargeXSmallQ(self, dtype, rtol, atol):
rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol)
if self.device not in ['XLA_GPU', 'XLA_CPU'] and dtype == np.float64:
# TODO(b/165739664): Figure out why on TPU F64 Zeta sometimes returns
# infs.
self.skipTest(
'Skipping test because some F64 operations are numerically '
'unstable on TPU.')
x = np.random.uniform(low=100., high=200., size=[NUM_SAMPLES]).astype(dtype)
q = np.random.uniform(low=0.3, high=1., size=[NUM_SAMPLES]).astype(dtype)
expected_values = sps.zeta(x, q)
with self.session() as sess:
with self.test_scope():
y = _zeta(x, q)
actual = sess.run(y)
self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol)
@parameterized.parameters((np.float32, 1e-2, 1e-11),
(np.float64, 1e-4, 1e-30))
@test_util.disable_mlir_bridge('TODO(b/165736950): Add support in MLIR')
def testSmallValues(self, dtype, rtol, atol):
rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol)
# Test values near zero.
x = np.random.uniform(low=1.1, high=10., size=[NUM_SAMPLES]).astype(dtype)
q = np.random.uniform(
low=np.finfo(dtype).tiny, high=1., size=[NUM_SAMPLES]).astype(dtype)
expected_values = sps.zeta(x, q)
with self.session() as sess:
with self.test_scope():
actual = sess.run(_zeta(x, q))
self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol)
@parameterized.parameters((np.float32, 1e-2, 1e-11),
(np.float64, 1e-4, 1e-30))
@test_util.disable_mlir_bridge('TODO(b/165736950): Add support in MLIR')
def testMediumValues(self, dtype, rtol, atol):
rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol)
x = np.random.uniform(low=1.1, high=100., size=[NUM_SAMPLES]).astype(dtype)
q = np.random.uniform(low=1., high=1e1, size=[NUM_SAMPLES]).astype(dtype)
expected_values = sps.zeta(x, q)
with self.session() as sess:
with self.test_scope():
actual = sess.run(_zeta(x, q))
self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol)
@parameterized.parameters((np.float32, 2e-2, 1e-5), (np.float64, 1e-4, 1e-30))
@test_util.disable_mlir_bridge('TODO(b/165736950): Add support in MLIR')
def testLargeValues(self, dtype, rtol, atol):
x = np.random.uniform(
low=100., high=int(1e3), size=[NUM_SAMPLES]).astype(dtype)
q = np.random.uniform(
low=1., high=int(1e1), size=[NUM_SAMPLES]).astype(dtype)
expected_values = sps.zeta(x, q)
with self.session() as sess:
with self.test_scope():
actual = sess.run(_zeta(x, q))
self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol)
class IgammaTest(xla_test.XLATestCase, parameterized.TestCase):
def setUp(self):

View File

@ -290,6 +290,13 @@ xla::XlaOp IgammacImpl(xla::XlaOp x, xla::XlaOp y,
XLA_MAKE_BINARY(Igammac, IgammacImpl(lhs, rhs, broadcast_helper));
xla::XlaOp ZetaImpl(xla::XlaOp x, xla::XlaOp q, const BCast& broadcast_helper) {
std::tie(x, q) = XlaBinaryOp::Broadcast(x, q, broadcast_helper);
return xla::Zeta(x, q);
}
XLA_MAKE_BINARY(Zeta, ZetaImpl(lhs, rhs, broadcast_helper));
#undef XLA_MAKE_BINARY
class ApproximateEqualOp : public XlaOpKernel {

View File

@ -206,6 +206,7 @@ igamma = _broadcasting_binary_op(math_ops.igamma)
igamma_grad_a = _broadcasting_binary_op(gen_math_ops.igamma_grad_a)
random_gamma_grad = _broadcasting_binary_op(gen_random_ops.random_gamma_grad)
igammac = _broadcasting_binary_op(math_ops.igammac)
zeta = _broadcasting_binary_op(math_ops.zeta)
def _binary_op(fn):

View File

@ -1832,4 +1832,98 @@ XlaOp RegularizedIncompleteBeta(XlaOp a, XlaOp b, XlaOp x) {
});
}
XlaOp Zeta(XlaOp x, XlaOp q) {
auto& builder = *x.builder();
auto doit = [&builder](XlaOp x, XlaOp q, PrimitiveType type) -> XlaOp {
// (2k) ! / B_{2k}, where B_{2k} are the Bernoulli numbers.
// These are ordered in reverse.
static const std::array<double, 12> kZetaCoeffs{
-7.1661652561756670113e18,
1.8152105401943546773e17,
-4.5979787224074726105e15,
1.1646782814350067249e14,
-2.950130727918164224e12,
7.47242496e10,
-1.8924375803183791606e9,
47900160.0,
-1209600.0,
30240.0,
-720.0,
12.0,
};
// For speed we'll always use 9 iterations for the initial series estimate,
// and a 12 term expansion for the Euler-Maclaurin formula.
XlaOp a = q;
XlaOp neg_power = ScalarLike(a, 0.);
XlaOp initial_sum = Pow(q, Neg(x));
for (int i = 0; i < 9; ++i) {
a = a + ScalarLike(a, 1.);
neg_power = Pow(a, Neg(x));
initial_sum = initial_sum + neg_power;
}
a = a + ScalarLike(a, 1.);
neg_power = Pow(a, Neg(x));
XlaOp s = initial_sum + neg_power * a / (x - ScalarLike(a, 1.));
XlaOp a_inverse_square = Reciprocal(Square(a));
XlaOp horner_sum = ScalarLike(a, 0.);
XlaOp factor = ScalarLike(a, 1.);
// Use Horner's rule for this.
// Note this differs from Cephes which does a 'naive' polynomial evaluation.
// Using Horner's rule allows to avoid some NaN's and Infs from happening,
// resulting in more numerically stable code.
for (int i = 0; i < 11; ++i) {
factor =
(x - ScalarLike(x, 22 - 2 * i)) * (x - ScalarLike(x, 21 - 2 * i));
horner_sum = factor * a_inverse_square *
(horner_sum + ScalarLike(a, 1. / kZetaCoeffs[i]));
}
s = s + neg_power *
(ScalarLike(neg_power, 0.5) +
x / a * (ScalarLike(a, 1. / kZetaCoeffs[11]) + horner_sum));
const double nan = std::numeric_limits<double>::quiet_NaN();
const double inf = std::numeric_limits<double>::infinity();
// Use the initial zeta sum without the correction term coming
// from Euler-Maclaurin if it is accurate enough.
XlaOp output =
Select(Lt(Abs(neg_power), Abs(initial_sum) * Epsilon(&builder, type)),
initial_sum, s);
// This is the harmonic series.
output = Select(Eq(x, ScalarLike(x, 1.)), ScalarLike(x, inf), output);
// Function is not defined for x < 1.
output = Select(Lt(x, ScalarLike(x, 1.)), ScalarLike(x, nan), output);
// If q <= 0, then when q is an integer or x is not an integer, this is
// NaN.
XlaOp domain_error = And(Le(q, ScalarLike(x, 0.)), Ne(x, Floor(x)));
XlaOp negative_integer_q = And(Le(q, ScalarLike(x, 0.)), Eq(q, Floor(q)));
output = Select(negative_integer_q, ScalarLike(x, inf), output);
output = Select(domain_error, ScalarLike(x, nan), output);
return output;
};
return builder.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(auto x_shape, builder.GetShape(x));
TF_ASSIGN_OR_RETURN(auto q_shape, builder.GetShape(q));
if (x_shape != q_shape) {
return InvalidArgument(
"Arguments to Zeta must have equal shapes and types; got %s and %s",
x_shape.ToString(), q_shape.ToString());
}
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Zeta", x));
bool needs_upcast =
x_shape.element_type() == F16 || x_shape.element_type() == BF16;
if (needs_upcast) {
x = ConvertElementType(x, F32);
q = ConvertElementType(q, F32);
}
XlaOp result = doit(x, q, x_shape.element_type());
if (needs_upcast) {
result = ConvertElementType(result, x_shape.element_type());
}
return result;
});
}
} // namespace xla

View File

@ -72,6 +72,9 @@ XlaOp RandomGammaGrad(XlaOp a, XlaOp x);
// Computes an approximation of the complementary incomplete gamma function.
XlaOp Igammac(XlaOp a, XlaOp x);
// Computes the Riemann zeta function of two arguments.
XlaOp Zeta(XlaOp x, XlaOp q);
// Rounds the given number to even when the number is equidistant between two
// integers.
XlaOp RoundToEven(XlaOp x);

View File

@ -291,6 +291,7 @@ void BuildOpsSubmodule(py::module* m) {
ops.def("RandomGammaGrad", &RandomGammaGrad, py::arg("a"), py::arg("x"));
ops.def("RegularizedIncompleteBeta", &RegularizedIncompleteBeta, py::arg("a"),
py::arg("b"), py::arg("x"));
ops.def("Zeta", &Zeta, py::arg("x"), py::arg("q"));
#define BINARY_OP(op) \
ops.def( \

View File

@ -196,12 +196,15 @@ class MathTest(PForTestCase, parameterized.TestCase):
math_ops.subtract,
math_ops.truncate_mod,
safe_polygamma,
safe_zeta,
]
# FloorDiv fails on XLA due floor's discontinuities exacerbating small
# division differences.
if not test_util.is_xla_enabled():
float_ops += [math_ops.floor_div]
# TODO(b/168912036): Re-enable once GPU + XLA issues for Zeta are
# resolved.
if not test_util.is_gpu_available():
float_ops += [safe_zeta]
for op in logical_ops + float_ops:
x = random_ops.random_uniform([7, 3, 5])
y = random_ops.random_uniform([3, 5])