Add XLA registration for Polygamma, to be used for Digamma gradients.

PiperOrigin-RevId: 332557647
Change-Id: I97fba661240412a49716544c5cf106a700ab89a4
This commit is contained in:
Srinivas Vasudevan 2020-09-18 17:24:02 -07:00 committed by TensorFlower Gardener
parent 27b417360c
commit b946521465
6 changed files with 147 additions and 0 deletions

View File

@ -1951,6 +1951,7 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
"ParallelDynamicStitch",
"ParameterizedTruncatedNormal",
"PartitionedCall",
"Polygamma",
"PopulationCount",
"Qr",
"QuantizeAndDequantizeV2",

View File

@ -55,6 +55,11 @@ def _igammac(a, x):
return math_ops.igammac(a, x)
@def_function.function(experimental_compile=True)
def _polygamma(n, x):
return math_ops.polygamma(n, x)
@def_function.function(experimental_compile=True)
def _zeta(a, q):
return math_ops.zeta(a, q)
@ -256,6 +261,94 @@ class ZetaTest(xla_test.XLATestCase, parameterized.TestCase):
self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol)
class PolygammaTest(xla_test.XLATestCase, parameterized.TestCase):
def setUp(self):
if flags.FLAGS.vary_seed:
entropy = os.urandom(64)
if six.PY2:
answer = int(entropy.encode('hex'), 16)
else:
answer = int.from_bytes(entropy, 'big')
np.random.seed(answer % (2**32 - 1))
super(PolygammaTest, self).setUp()
def adjust_tolerance_for_tpu(self, dtype, rtol, atol):
if self.device not in ['TPU']:
return rtol, atol
if dtype == np.float32:
return 2e-2, 1e-7
return 2e-4, 1e-20
@test_util.disable_mlir_bridge('TODO(b/165736950): Add support in MLIR')
def testBadValues(self):
x = np.random.uniform(low=0.3, high=20., size=[10])
with self.session() as sess:
with self.test_scope():
y = _polygamma(np.float64(-1.), x)
actual = sess.run(y)
# Not defined for negative numbers.
self.assertTrue(np.all(np.isnan(actual)))
with self.session() as sess:
with self.test_scope():
y = _polygamma(np.float64(0.1), x)
actual = sess.run(y)
# Not defined for non-integers.
self.assertTrue(np.all(np.isnan(actual)))
@parameterized.parameters((np.float32, 1e-2, 1e-11),
(np.float64, 1e-4, 1e-30))
@test_util.disable_mlir_bridge('TODO(b/165736950): Add support in MLIR')
def testRecoverDigamma(self, dtype, rtol, atol):
rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol)
if self.device not in ['XLA_GPU', 'XLA_CPU'] and dtype == np.float64:
self.skipTest(
'Skipping test because some F64 operations are '
'numerically unstable on TPU.'
)
x = np.random.uniform(low=0.1, high=50., size=[NUM_SAMPLES]).astype(dtype)
expected_values = sps.digamma(x)
with self.session() as sess:
with self.test_scope():
y = _polygamma(dtype(0.), x)
actual = sess.run(y)
self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol)
@parameterized.parameters((np.float32, 1e-2, 1e-11),
(np.float64, 1e-4, 1e-30))
@test_util.disable_mlir_bridge('TODO(b/165736950): Add support in MLIR')
def testSmallN(self, dtype, rtol, atol):
rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol)
# Test values near zero.
n = np.random.randint(low=1, high=5, size=[NUM_SAMPLES]).astype(dtype)
x = np.random.uniform(
low=np.finfo(dtype).tiny, high=1., size=[NUM_SAMPLES]).astype(dtype)
expected_values = sps.polygamma(n, x)
with self.session() as sess:
with self.test_scope():
actual = sess.run(_polygamma(n, x))
self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol)
@parameterized.parameters((np.float32, 1e-2, 1e-11),
(np.float64, 1e-4, 1e-30))
@test_util.disable_mlir_bridge('TODO(b/165736950): Add support in MLIR')
def testMediumLargeN(self, dtype, rtol, atol):
rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol)
n = np.random.randint(low=5, high=10, size=[NUM_SAMPLES]).astype(dtype)
x = np.random.uniform(low=1., high=1e1, size=[NUM_SAMPLES]).astype(dtype)
expected_values = sps.polygamma(n, x)
with self.session() as sess:
with self.test_scope():
actual = sess.run(_polygamma(n, x))
self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol)
class IgammaTest(xla_test.XLATestCase, parameterized.TestCase):
def setUp(self):

View File

@ -290,6 +290,14 @@ xla::XlaOp IgammacImpl(xla::XlaOp x, xla::XlaOp y,
XLA_MAKE_BINARY(Igammac, IgammacImpl(lhs, rhs, broadcast_helper));
xla::XlaOp PolygammaImpl(xla::XlaOp n, xla::XlaOp x,
const BCast& broadcast_helper) {
std::tie(n, x) = XlaBinaryOp::Broadcast(n, x, broadcast_helper);
return xla::Polygamma(n, x);
}
XLA_MAKE_BINARY(Polygamma, PolygammaImpl(lhs, rhs, broadcast_helper));
xla::XlaOp ZetaImpl(xla::XlaOp x, xla::XlaOp q, const BCast& broadcast_helper) {
std::tie(x, q) = XlaBinaryOp::Broadcast(x, q, broadcast_helper);
return xla::Zeta(x, q);

View File

@ -206,6 +206,7 @@ igamma = _broadcasting_binary_op(math_ops.igamma)
igamma_grad_a = _broadcasting_binary_op(gen_math_ops.igamma_grad_a)
random_gamma_grad = _broadcasting_binary_op(gen_random_ops.random_gamma_grad)
igammac = _broadcasting_binary_op(math_ops.igammac)
polygamma = _broadcasting_binary_op(math_ops.polygamma)
zeta = _broadcasting_binary_op(math_ops.zeta)

View File

@ -1832,6 +1832,47 @@ XlaOp RegularizedIncompleteBeta(XlaOp a, XlaOp b, XlaOp x) {
});
}
XlaOp Polygamma(XlaOp n, XlaOp x) {
auto& builder = *x.builder();
auto doit = [](XlaOp n, XlaOp x, PrimitiveType type) -> XlaOp {
XlaOp n_plus_one = n + ScalarLike(n, 1.);
XlaOp sign =
(ScalarLike(n, 2.) * Rem(n, ScalarLike(n, 2.)) - ScalarLike(n, 1.));
const double nan = std::numeric_limits<double>::quiet_NaN();
XlaOp output = Select(Eq(n, ScalarLike(n, 0.)), Digamma(x),
sign * Exp(Lgamma(n_plus_one)) * Zeta(n_plus_one, x));
// Check that n is a natural number.
output = Select(Or(Ne(n, Floor(n)), Lt(n, ScalarLike(n, 0.))),
ScalarLike(n, nan), output);
return output;
};
return builder.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(auto n_shape, builder.GetShape(n));
TF_ASSIGN_OR_RETURN(auto x_shape, builder.GetShape(x));
if (n_shape != x_shape) {
return InvalidArgument(
"Arguments to Polygamma must have equal shapes and types; "
"got %s and %s",
n_shape.ToString(), x_shape.ToString());
}
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Zeta", x));
bool needs_upcast =
n_shape.element_type() == F16 || x_shape.element_type() == BF16;
if (needs_upcast) {
n = ConvertElementType(n, F32);
x = ConvertElementType(x, F32);
}
XlaOp result = doit(n, x, n_shape.element_type());
if (needs_upcast) {
result = ConvertElementType(result, n_shape.element_type());
}
return result;
});
}
XlaOp Zeta(XlaOp x, XlaOp q) {
auto& builder = *x.builder();
auto doit = [&builder](XlaOp x, XlaOp q, PrimitiveType type) -> XlaOp {

View File

@ -72,6 +72,9 @@ XlaOp RandomGammaGrad(XlaOp a, XlaOp x);
// Computes an approximation of the complementary incomplete gamma function.
XlaOp Igammac(XlaOp a, XlaOp x);
// Computes the Polygamma of two arguments.
XlaOp Polygamma(XlaOp n, XlaOp x);
// Computes the Riemann zeta function of two arguments.
XlaOp Zeta(XlaOp x, XlaOp q);