Add XLA registration for Zeta/Zetac
PiperOrigin-RevId: 332509521 Change-Id: If93869e3e4732cb09e85ec9cc61b2fe085fdb1ce
This commit is contained in:
parent
7813d49417
commit
c8ee679b8e
tensorflow
compiler
jit
tests
tf2xla
xla
python/ops/parallel_for
@ -2094,6 +2094,7 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
|
||||
"XlaSpmdShardToFullShape",
|
||||
"XlaSvd",
|
||||
"XlaWhile",
|
||||
"Zeta",
|
||||
"_Arg",
|
||||
"_ArrayToList",
|
||||
"_ListToArray",
|
||||
|
@ -31,6 +31,7 @@ import six
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import gen_math_ops
|
||||
from tensorflow.python.ops import gen_random_ops
|
||||
from tensorflow.python.ops import gradient_checker_v2
|
||||
@ -54,6 +55,11 @@ def _igammac(a, x):
|
||||
return math_ops.igammac(a, x)
|
||||
|
||||
|
||||
@def_function.function(experimental_compile=True)
|
||||
def _zeta(a, q):
|
||||
return math_ops.zeta(a, q)
|
||||
|
||||
|
||||
# This is df/da / df/dx, where f = igamma.
|
||||
def implicit_reparameterization_grad(a, x):
|
||||
log_prob = math_ops.xlogy(a - 1., x) - math_ops.lgamma(a) - x
|
||||
@ -136,6 +142,120 @@ class Log1pTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
self._test_range(0., 3., dtype, rtol, atol, is_negative=False)
|
||||
|
||||
|
||||
class ZetaTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
if flags.FLAGS.vary_seed:
|
||||
entropy = os.urandom(64)
|
||||
if six.PY2:
|
||||
answer = int(entropy.encode('hex'), 16)
|
||||
else:
|
||||
answer = int.from_bytes(entropy, 'big')
|
||||
np.random.seed(answer % (2**32 - 1))
|
||||
super(ZetaTest, self).setUp()
|
||||
|
||||
def adjust_tolerance_for_tpu(self, dtype, rtol, atol):
|
||||
if self.device not in ['TPU']:
|
||||
return rtol, atol
|
||||
|
||||
if dtype == np.float32:
|
||||
return 2e-2, 1e-7
|
||||
return 2e-4, 1e-20
|
||||
|
||||
@test_util.disable_mlir_bridge('TODO(b/165736950): Add support in MLIR')
|
||||
def testBadValues(self):
|
||||
q = np.random.uniform(low=0.3, high=20., size=[10])
|
||||
with self.session() as sess:
|
||||
with self.test_scope():
|
||||
y = _zeta(np.float64(1.), q)
|
||||
actual = sess.run(y)
|
||||
# When x == 1, this is the Harmonic series.
|
||||
self.assertTrue(np.all(np.isinf(actual)))
|
||||
|
||||
with self.session() as sess:
|
||||
with self.test_scope():
|
||||
y = _zeta(np.float64(0.1), q)
|
||||
actual = sess.run(y)
|
||||
# When x < 1, this is undefined.
|
||||
self.assertTrue(np.all(np.isnan(actual)))
|
||||
|
||||
with self.session() as sess:
|
||||
with self.test_scope():
|
||||
y = _zeta([1., 1.1], [-1.1, -1.])
|
||||
actual = sess.run(y)
|
||||
|
||||
# When q is negative, zeta is not defined
|
||||
# if q is an integer or x is not an integer.
|
||||
self.assertTrue(np.all(np.isinf(actual)))
|
||||
|
||||
@parameterized.parameters((np.float32, 1e-2, 1e-11),
|
||||
(np.float64, 1e-4, 1e-30))
|
||||
@test_util.disable_mlir_bridge('TODO(b/165736950): Add support in MLIR')
|
||||
def testLargeXSmallQ(self, dtype, rtol, atol):
|
||||
rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol)
|
||||
if self.device not in ['XLA_GPU', 'XLA_CPU'] and dtype == np.float64:
|
||||
# TODO(b/165739664): Figure out why on TPU F64 Zeta sometimes returns
|
||||
# infs.
|
||||
self.skipTest(
|
||||
'Skipping test because some F64 operations are numerically '
|
||||
'unstable on TPU.')
|
||||
|
||||
x = np.random.uniform(low=100., high=200., size=[NUM_SAMPLES]).astype(dtype)
|
||||
q = np.random.uniform(low=0.3, high=1., size=[NUM_SAMPLES]).astype(dtype)
|
||||
|
||||
expected_values = sps.zeta(x, q)
|
||||
with self.session() as sess:
|
||||
with self.test_scope():
|
||||
y = _zeta(x, q)
|
||||
actual = sess.run(y)
|
||||
|
||||
self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol)
|
||||
|
||||
@parameterized.parameters((np.float32, 1e-2, 1e-11),
|
||||
(np.float64, 1e-4, 1e-30))
|
||||
@test_util.disable_mlir_bridge('TODO(b/165736950): Add support in MLIR')
|
||||
def testSmallValues(self, dtype, rtol, atol):
|
||||
rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol)
|
||||
# Test values near zero.
|
||||
x = np.random.uniform(low=1.1, high=10., size=[NUM_SAMPLES]).astype(dtype)
|
||||
q = np.random.uniform(
|
||||
low=np.finfo(dtype).tiny, high=1., size=[NUM_SAMPLES]).astype(dtype)
|
||||
|
||||
expected_values = sps.zeta(x, q)
|
||||
with self.session() as sess:
|
||||
with self.test_scope():
|
||||
actual = sess.run(_zeta(x, q))
|
||||
self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol)
|
||||
|
||||
@parameterized.parameters((np.float32, 1e-2, 1e-11),
|
||||
(np.float64, 1e-4, 1e-30))
|
||||
@test_util.disable_mlir_bridge('TODO(b/165736950): Add support in MLIR')
|
||||
def testMediumValues(self, dtype, rtol, atol):
|
||||
rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol)
|
||||
x = np.random.uniform(low=1.1, high=100., size=[NUM_SAMPLES]).astype(dtype)
|
||||
q = np.random.uniform(low=1., high=1e1, size=[NUM_SAMPLES]).astype(dtype)
|
||||
|
||||
expected_values = sps.zeta(x, q)
|
||||
with self.session() as sess:
|
||||
with self.test_scope():
|
||||
actual = sess.run(_zeta(x, q))
|
||||
self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol)
|
||||
|
||||
@parameterized.parameters((np.float32, 2e-2, 1e-5), (np.float64, 1e-4, 1e-30))
|
||||
@test_util.disable_mlir_bridge('TODO(b/165736950): Add support in MLIR')
|
||||
def testLargeValues(self, dtype, rtol, atol):
|
||||
x = np.random.uniform(
|
||||
low=100., high=int(1e3), size=[NUM_SAMPLES]).astype(dtype)
|
||||
q = np.random.uniform(
|
||||
low=1., high=int(1e1), size=[NUM_SAMPLES]).astype(dtype)
|
||||
|
||||
expected_values = sps.zeta(x, q)
|
||||
with self.session() as sess:
|
||||
with self.test_scope():
|
||||
actual = sess.run(_zeta(x, q))
|
||||
self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol)
|
||||
|
||||
|
||||
class IgammaTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
@ -290,6 +290,13 @@ xla::XlaOp IgammacImpl(xla::XlaOp x, xla::XlaOp y,
|
||||
|
||||
XLA_MAKE_BINARY(Igammac, IgammacImpl(lhs, rhs, broadcast_helper));
|
||||
|
||||
xla::XlaOp ZetaImpl(xla::XlaOp x, xla::XlaOp q, const BCast& broadcast_helper) {
|
||||
std::tie(x, q) = XlaBinaryOp::Broadcast(x, q, broadcast_helper);
|
||||
return xla::Zeta(x, q);
|
||||
}
|
||||
|
||||
XLA_MAKE_BINARY(Zeta, ZetaImpl(lhs, rhs, broadcast_helper));
|
||||
|
||||
#undef XLA_MAKE_BINARY
|
||||
|
||||
class ApproximateEqualOp : public XlaOpKernel {
|
||||
|
@ -206,6 +206,7 @@ igamma = _broadcasting_binary_op(math_ops.igamma)
|
||||
igamma_grad_a = _broadcasting_binary_op(gen_math_ops.igamma_grad_a)
|
||||
random_gamma_grad = _broadcasting_binary_op(gen_random_ops.random_gamma_grad)
|
||||
igammac = _broadcasting_binary_op(math_ops.igammac)
|
||||
zeta = _broadcasting_binary_op(math_ops.zeta)
|
||||
|
||||
|
||||
def _binary_op(fn):
|
||||
|
@ -1832,4 +1832,98 @@ XlaOp RegularizedIncompleteBeta(XlaOp a, XlaOp b, XlaOp x) {
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp Zeta(XlaOp x, XlaOp q) {
|
||||
auto& builder = *x.builder();
|
||||
auto doit = [&builder](XlaOp x, XlaOp q, PrimitiveType type) -> XlaOp {
|
||||
// (2k) ! / B_{2k}, where B_{2k} are the Bernoulli numbers.
|
||||
// These are ordered in reverse.
|
||||
static const std::array<double, 12> kZetaCoeffs{
|
||||
-7.1661652561756670113e18,
|
||||
1.8152105401943546773e17,
|
||||
-4.5979787224074726105e15,
|
||||
1.1646782814350067249e14,
|
||||
-2.950130727918164224e12,
|
||||
7.47242496e10,
|
||||
-1.8924375803183791606e9,
|
||||
47900160.0,
|
||||
-1209600.0,
|
||||
30240.0,
|
||||
-720.0,
|
||||
12.0,
|
||||
};
|
||||
|
||||
// For speed we'll always use 9 iterations for the initial series estimate,
|
||||
// and a 12 term expansion for the Euler-Maclaurin formula.
|
||||
|
||||
XlaOp a = q;
|
||||
XlaOp neg_power = ScalarLike(a, 0.);
|
||||
XlaOp initial_sum = Pow(q, Neg(x));
|
||||
for (int i = 0; i < 9; ++i) {
|
||||
a = a + ScalarLike(a, 1.);
|
||||
neg_power = Pow(a, Neg(x));
|
||||
initial_sum = initial_sum + neg_power;
|
||||
}
|
||||
a = a + ScalarLike(a, 1.);
|
||||
neg_power = Pow(a, Neg(x));
|
||||
XlaOp s = initial_sum + neg_power * a / (x - ScalarLike(a, 1.));
|
||||
XlaOp a_inverse_square = Reciprocal(Square(a));
|
||||
XlaOp horner_sum = ScalarLike(a, 0.);
|
||||
XlaOp factor = ScalarLike(a, 1.);
|
||||
// Use Horner's rule for this.
|
||||
// Note this differs from Cephes which does a 'naive' polynomial evaluation.
|
||||
// Using Horner's rule allows to avoid some NaN's and Infs from happening,
|
||||
// resulting in more numerically stable code.
|
||||
for (int i = 0; i < 11; ++i) {
|
||||
factor =
|
||||
(x - ScalarLike(x, 22 - 2 * i)) * (x - ScalarLike(x, 21 - 2 * i));
|
||||
horner_sum = factor * a_inverse_square *
|
||||
(horner_sum + ScalarLike(a, 1. / kZetaCoeffs[i]));
|
||||
}
|
||||
s = s + neg_power *
|
||||
(ScalarLike(neg_power, 0.5) +
|
||||
x / a * (ScalarLike(a, 1. / kZetaCoeffs[11]) + horner_sum));
|
||||
|
||||
const double nan = std::numeric_limits<double>::quiet_NaN();
|
||||
const double inf = std::numeric_limits<double>::infinity();
|
||||
// Use the initial zeta sum without the correction term coming
|
||||
// from Euler-Maclaurin if it is accurate enough.
|
||||
XlaOp output =
|
||||
Select(Lt(Abs(neg_power), Abs(initial_sum) * Epsilon(&builder, type)),
|
||||
initial_sum, s);
|
||||
// This is the harmonic series.
|
||||
output = Select(Eq(x, ScalarLike(x, 1.)), ScalarLike(x, inf), output);
|
||||
// Function is not defined for x < 1.
|
||||
output = Select(Lt(x, ScalarLike(x, 1.)), ScalarLike(x, nan), output);
|
||||
// If q <= 0, then when q is an integer or x is not an integer, this is
|
||||
// NaN.
|
||||
XlaOp domain_error = And(Le(q, ScalarLike(x, 0.)), Ne(x, Floor(x)));
|
||||
XlaOp negative_integer_q = And(Le(q, ScalarLike(x, 0.)), Eq(q, Floor(q)));
|
||||
output = Select(negative_integer_q, ScalarLike(x, inf), output);
|
||||
output = Select(domain_error, ScalarLike(x, nan), output);
|
||||
return output;
|
||||
};
|
||||
return builder.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(auto x_shape, builder.GetShape(x));
|
||||
TF_ASSIGN_OR_RETURN(auto q_shape, builder.GetShape(q));
|
||||
if (x_shape != q_shape) {
|
||||
return InvalidArgument(
|
||||
"Arguments to Zeta must have equal shapes and types; got %s and %s",
|
||||
x_shape.ToString(), q_shape.ToString());
|
||||
}
|
||||
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Zeta", x));
|
||||
bool needs_upcast =
|
||||
x_shape.element_type() == F16 || x_shape.element_type() == BF16;
|
||||
|
||||
if (needs_upcast) {
|
||||
x = ConvertElementType(x, F32);
|
||||
q = ConvertElementType(q, F32);
|
||||
}
|
||||
XlaOp result = doit(x, q, x_shape.element_type());
|
||||
if (needs_upcast) {
|
||||
result = ConvertElementType(result, x_shape.element_type());
|
||||
}
|
||||
return result;
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -72,6 +72,9 @@ XlaOp RandomGammaGrad(XlaOp a, XlaOp x);
|
||||
// Computes an approximation of the complementary incomplete gamma function.
|
||||
XlaOp Igammac(XlaOp a, XlaOp x);
|
||||
|
||||
// Computes the Riemann zeta function of two arguments.
|
||||
XlaOp Zeta(XlaOp x, XlaOp q);
|
||||
|
||||
// Rounds the given number to even when the number is equidistant between two
|
||||
// integers.
|
||||
XlaOp RoundToEven(XlaOp x);
|
||||
|
@ -291,6 +291,7 @@ void BuildOpsSubmodule(py::module* m) {
|
||||
ops.def("RandomGammaGrad", &RandomGammaGrad, py::arg("a"), py::arg("x"));
|
||||
ops.def("RegularizedIncompleteBeta", &RegularizedIncompleteBeta, py::arg("a"),
|
||||
py::arg("b"), py::arg("x"));
|
||||
ops.def("Zeta", &Zeta, py::arg("x"), py::arg("q"));
|
||||
|
||||
#define BINARY_OP(op) \
|
||||
ops.def( \
|
||||
|
@ -196,12 +196,15 @@ class MathTest(PForTestCase, parameterized.TestCase):
|
||||
math_ops.subtract,
|
||||
math_ops.truncate_mod,
|
||||
safe_polygamma,
|
||||
safe_zeta,
|
||||
]
|
||||
# FloorDiv fails on XLA due floor's discontinuities exacerbating small
|
||||
# division differences.
|
||||
if not test_util.is_xla_enabled():
|
||||
float_ops += [math_ops.floor_div]
|
||||
# TODO(b/168912036): Re-enable once GPU + XLA issues for Zeta are
|
||||
# resolved.
|
||||
if not test_util.is_gpu_available():
|
||||
float_ops += [safe_zeta]
|
||||
for op in logical_ops + float_ops:
|
||||
x = random_ops.random_uniform([7, 3, 5])
|
||||
y = random_ops.random_uniform([3, 5])
|
||||
|
Loading…
Reference in New Issue
Block a user