diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 265f879252e..ff1196b2be6 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -1951,6 +1951,7 @@ absl::flat_hash_set GetKnownXLAAllowlistOp() { "ParallelDynamicStitch", "ParameterizedTruncatedNormal", "PartitionedCall", + "Polygamma", "PopulationCount", "Qr", "QuantizeAndDequantizeV2", diff --git a/tensorflow/compiler/tests/special_math_test.py b/tensorflow/compiler/tests/special_math_test.py index 73ff3cb57da..5e7f8763743 100644 --- a/tensorflow/compiler/tests/special_math_test.py +++ b/tensorflow/compiler/tests/special_math_test.py @@ -55,6 +55,11 @@ def _igammac(a, x): return math_ops.igammac(a, x) +@def_function.function(experimental_compile=True) +def _polygamma(n, x): + return math_ops.polygamma(n, x) + + @def_function.function(experimental_compile=True) def _zeta(a, q): return math_ops.zeta(a, q) @@ -256,6 +261,94 @@ class ZetaTest(xla_test.XLATestCase, parameterized.TestCase): self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol) +class PolygammaTest(xla_test.XLATestCase, parameterized.TestCase): + + def setUp(self): + if flags.FLAGS.vary_seed: + entropy = os.urandom(64) + if six.PY2: + answer = int(entropy.encode('hex'), 16) + else: + answer = int.from_bytes(entropy, 'big') + np.random.seed(answer % (2**32 - 1)) + super(PolygammaTest, self).setUp() + + def adjust_tolerance_for_tpu(self, dtype, rtol, atol): + if self.device not in ['TPU']: + return rtol, atol + + if dtype == np.float32: + return 2e-2, 1e-7 + return 2e-4, 1e-20 + + @test_util.disable_mlir_bridge('TODO(b/165736950): Add support in MLIR') + def testBadValues(self): + x = np.random.uniform(low=0.3, high=20., size=[10]) + with self.session() as sess: + with self.test_scope(): + y = _polygamma(np.float64(-1.), x) + actual = sess.run(y) + # Not defined for negative numbers. + self.assertTrue(np.all(np.isnan(actual))) + + with self.session() as sess: + with self.test_scope(): + y = _polygamma(np.float64(0.1), x) + actual = sess.run(y) + # Not defined for non-integers. + self.assertTrue(np.all(np.isnan(actual))) + + @parameterized.parameters((np.float32, 1e-2, 1e-11), + (np.float64, 1e-4, 1e-30)) + @test_util.disable_mlir_bridge('TODO(b/165736950): Add support in MLIR') + def testRecoverDigamma(self, dtype, rtol, atol): + rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol) + if self.device not in ['XLA_GPU', 'XLA_CPU'] and dtype == np.float64: + self.skipTest( + 'Skipping test because some F64 operations are ' + 'numerically unstable on TPU.' + ) + + x = np.random.uniform(low=0.1, high=50., size=[NUM_SAMPLES]).astype(dtype) + expected_values = sps.digamma(x) + with self.session() as sess: + with self.test_scope(): + y = _polygamma(dtype(0.), x) + actual = sess.run(y) + + self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol) + + @parameterized.parameters((np.float32, 1e-2, 1e-11), + (np.float64, 1e-4, 1e-30)) + @test_util.disable_mlir_bridge('TODO(b/165736950): Add support in MLIR') + def testSmallN(self, dtype, rtol, atol): + rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol) + # Test values near zero. + n = np.random.randint(low=1, high=5, size=[NUM_SAMPLES]).astype(dtype) + x = np.random.uniform( + low=np.finfo(dtype).tiny, high=1., size=[NUM_SAMPLES]).astype(dtype) + + expected_values = sps.polygamma(n, x) + with self.session() as sess: + with self.test_scope(): + actual = sess.run(_polygamma(n, x)) + self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol) + + @parameterized.parameters((np.float32, 1e-2, 1e-11), + (np.float64, 1e-4, 1e-30)) + @test_util.disable_mlir_bridge('TODO(b/165736950): Add support in MLIR') + def testMediumLargeN(self, dtype, rtol, atol): + rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol) + n = np.random.randint(low=5, high=10, size=[NUM_SAMPLES]).astype(dtype) + x = np.random.uniform(low=1., high=1e1, size=[NUM_SAMPLES]).astype(dtype) + + expected_values = sps.polygamma(n, x) + with self.session() as sess: + with self.test_scope(): + actual = sess.run(_polygamma(n, x)) + self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol) + + class IgammaTest(xla_test.XLATestCase, parameterized.TestCase): def setUp(self): diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc index c7e421a2f1e..39f4beed0f4 100644 --- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -290,6 +290,14 @@ xla::XlaOp IgammacImpl(xla::XlaOp x, xla::XlaOp y, XLA_MAKE_BINARY(Igammac, IgammacImpl(lhs, rhs, broadcast_helper)); +xla::XlaOp PolygammaImpl(xla::XlaOp n, xla::XlaOp x, + const BCast& broadcast_helper) { + std::tie(n, x) = XlaBinaryOp::Broadcast(n, x, broadcast_helper); + return xla::Polygamma(n, x); +} + +XLA_MAKE_BINARY(Polygamma, PolygammaImpl(lhs, rhs, broadcast_helper)); + xla::XlaOp ZetaImpl(xla::XlaOp x, xla::XlaOp q, const BCast& broadcast_helper) { std::tie(x, q) = XlaBinaryOp::Broadcast(x, q, broadcast_helper); return xla::Zeta(x, q); diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py index 3cbd80e26ec..6278ea1a3a9 100644 --- a/tensorflow/compiler/tf2xla/python/xla.py +++ b/tensorflow/compiler/tf2xla/python/xla.py @@ -206,6 +206,7 @@ igamma = _broadcasting_binary_op(math_ops.igamma) igamma_grad_a = _broadcasting_binary_op(gen_math_ops.igamma_grad_a) random_gamma_grad = _broadcasting_binary_op(gen_random_ops.random_gamma_grad) igammac = _broadcasting_binary_op(math_ops.igammac) +polygamma = _broadcasting_binary_op(math_ops.polygamma) zeta = _broadcasting_binary_op(math_ops.zeta) diff --git a/tensorflow/compiler/xla/client/lib/math.cc b/tensorflow/compiler/xla/client/lib/math.cc index 0d7512a900b..410c86732d6 100644 --- a/tensorflow/compiler/xla/client/lib/math.cc +++ b/tensorflow/compiler/xla/client/lib/math.cc @@ -1832,6 +1832,47 @@ XlaOp RegularizedIncompleteBeta(XlaOp a, XlaOp b, XlaOp x) { }); } +XlaOp Polygamma(XlaOp n, XlaOp x) { + auto& builder = *x.builder(); + auto doit = [](XlaOp n, XlaOp x, PrimitiveType type) -> XlaOp { + XlaOp n_plus_one = n + ScalarLike(n, 1.); + XlaOp sign = + (ScalarLike(n, 2.) * Rem(n, ScalarLike(n, 2.)) - ScalarLike(n, 1.)); + + const double nan = std::numeric_limits::quiet_NaN(); + + XlaOp output = Select(Eq(n, ScalarLike(n, 0.)), Digamma(x), + sign * Exp(Lgamma(n_plus_one)) * Zeta(n_plus_one, x)); + // Check that n is a natural number. + output = Select(Or(Ne(n, Floor(n)), Lt(n, ScalarLike(n, 0.))), + ScalarLike(n, nan), output); + return output; + }; + return builder.ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(auto n_shape, builder.GetShape(n)); + TF_ASSIGN_OR_RETURN(auto x_shape, builder.GetShape(x)); + if (n_shape != x_shape) { + return InvalidArgument( + "Arguments to Polygamma must have equal shapes and types; " + "got %s and %s", + n_shape.ToString(), x_shape.ToString()); + } + TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Zeta", x)); + bool needs_upcast = + n_shape.element_type() == F16 || x_shape.element_type() == BF16; + + if (needs_upcast) { + n = ConvertElementType(n, F32); + x = ConvertElementType(x, F32); + } + XlaOp result = doit(n, x, n_shape.element_type()); + if (needs_upcast) { + result = ConvertElementType(result, n_shape.element_type()); + } + return result; + }); +} + XlaOp Zeta(XlaOp x, XlaOp q) { auto& builder = *x.builder(); auto doit = [&builder](XlaOp x, XlaOp q, PrimitiveType type) -> XlaOp { diff --git a/tensorflow/compiler/xla/client/lib/math.h b/tensorflow/compiler/xla/client/lib/math.h index 1f4b2143320..e6b5ac992cc 100644 --- a/tensorflow/compiler/xla/client/lib/math.h +++ b/tensorflow/compiler/xla/client/lib/math.h @@ -72,6 +72,9 @@ XlaOp RandomGammaGrad(XlaOp a, XlaOp x); // Computes an approximation of the complementary incomplete gamma function. XlaOp Igammac(XlaOp a, XlaOp x); +// Computes the Polygamma of two arguments. +XlaOp Polygamma(XlaOp n, XlaOp x); + // Computes the Riemann zeta function of two arguments. XlaOp Zeta(XlaOp x, XlaOp q);