From e60c1ba960e598be9c0e0cdd331cdc10e8919dbb Mon Sep 17 00:00:00 2001 From: Thomas Joerg Date: Mon, 8 Jun 2020 11:48:32 -0700 Subject: [PATCH] [XLA] Improve numerical stability of Logistic. PiperOrigin-RevId: 315320526 Change-Id: Iedfd22d0fb657cb31dda537786ce001f1dab168b --- tensorflow/compiler/tests/unary_ops_test.py | 10 ++++++++++ tensorflow/compiler/xla/client/lib/math.cc | 4 ++-- tensorflow/python/kernel_tests/rnn_cell_test.py | 7 ++----- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index 85bf89c4f9e..f2ec6be43cb 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -510,6 +510,16 @@ class UnaryOpsTest(xla_test.XLATestCase): ], dtype=dtype)) + @test_util.disable_mlir_bridge( + "TODO(b/155501444): Handle _UnaryOpsComposition ops from Grappler") + def testFloatOpsDisabledOnMlirBridge(self): + for dtype in self.float_types: + if dtype != np.float16: + self._assertOpOutputMatchesExpected( + lambda x: math_ops.sigmoid(x) / math_ops.log1p(math_ops.exp(x)), + np.array([-40, 40], dtype=dtype), + expected=np.array([1.0, 0.025], dtype=dtype)) + @test_util.disable_mlir_bridge( "TODO(b/153812660): Handle tf.QuantizeAndDequantize compilation") def testQuantizeAndDequantize(self): diff --git a/tensorflow/compiler/xla/client/lib/math.cc b/tensorflow/compiler/xla/client/lib/math.cc index f2ee94a0159..6cbaa043055 100644 --- a/tensorflow/compiler/xla/client/lib/math.cc +++ b/tensorflow/compiler/xla/client/lib/math.cc @@ -1394,8 +1394,8 @@ XlaOp NextAfter(XlaOp from, XlaOp to) { } XlaOp Logistic(XlaOp x) { - auto half = xla::ScalarLike(x, 0.5); - return half + half * xla::Tanh(half * x); + auto one = xla::ScalarLike(x, 1); + return xla::Div(one, (one + xla::Exp(xla::Neg(x)))); } // Computes an approximation to the modified Bessel function of the first kind, diff --git a/tensorflow/python/kernel_tests/rnn_cell_test.py b/tensorflow/python/kernel_tests/rnn_cell_test.py index d29c533badf..9de14006de2 100644 --- a/tensorflow/python/kernel_tests/rnn_cell_test.py +++ b/tensorflow/python/kernel_tests/rnn_cell_test.py @@ -35,7 +35,6 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import control_flow_v2_toggles from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops @@ -1016,8 +1015,7 @@ class LSTMTest(test.TestCase): }) comparison_fn = self.assertAllEqual - if (test_util.is_xla_enabled() and - control_flow_v2_toggles.control_flow_v2_enabled()): + if test_util.is_xla_enabled(): comparison_fn = self.assertAllClose if in_graph_mode: comparison_fn(outputs_static, outputs_dynamic) @@ -1107,8 +1105,7 @@ class LSTMTest(test.TestCase): }) comparison_fn = self.assertAllEqual - if (test_util.is_xla_enabled() and - control_flow_v2_toggles.control_flow_v2_enabled()): + if test_util.is_xla_enabled(): comparison_fn = self.assertAllClose if in_graph_mode: comparison_fn(outputs_static, outputs_dynamic)