From 4bddb55693ea2d96b000369f84ab693304320541 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 2 Jun 2020 13:08:32 -0700 Subject: [PATCH] [XLA] Improve numerical stability of Logistic. PiperOrigin-RevId: 314389569 Change-Id: Ia5b3f01f6c3f56f4c876b406900147dfe32ce0df --- tensorflow/compiler/tests/unary_ops_test.py | 10 ---------- tensorflow/compiler/xla/client/lib/math.cc | 8 ++------ tensorflow/python/kernel_tests/rnn_cell_test.py | 7 +++++-- 3 files changed, 7 insertions(+), 18 deletions(-) diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index f2ec6be43cb..85bf89c4f9e 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -510,16 +510,6 @@ class UnaryOpsTest(xla_test.XLATestCase): ], dtype=dtype)) - @test_util.disable_mlir_bridge( - "TODO(b/155501444): Handle _UnaryOpsComposition ops from Grappler") - def testFloatOpsDisabledOnMlirBridge(self): - for dtype in self.float_types: - if dtype != np.float16: - self._assertOpOutputMatchesExpected( - lambda x: math_ops.sigmoid(x) / math_ops.log1p(math_ops.exp(x)), - np.array([-40, 40], dtype=dtype), - expected=np.array([1.0, 0.025], dtype=dtype)) - @test_util.disable_mlir_bridge( "TODO(b/153812660): Handle tf.QuantizeAndDequantize compilation") def testQuantizeAndDequantize(self): diff --git a/tensorflow/compiler/xla/client/lib/math.cc b/tensorflow/compiler/xla/client/lib/math.cc index edc86e546cd..f2ee94a0159 100644 --- a/tensorflow/compiler/xla/client/lib/math.cc +++ b/tensorflow/compiler/xla/client/lib/math.cc @@ -1394,12 +1394,8 @@ XlaOp NextAfter(XlaOp from, XlaOp to) { } XlaOp Logistic(XlaOp x) { - if (x.builder()->GetShape(x).ValueOrDie().element_type() == F16) { - auto half = xla::ScalarLike(x, 0.5); - return half + half * xla::Tanh(half * x); - } - auto one = xla::ScalarLike(x, 1); - return xla::Div(one, (one + xla::Exp(xla::Neg(x)))); + auto half = xla::ScalarLike(x, 0.5); + return half + half * xla::Tanh(half * x); } // Computes an approximation to the modified Bessel function of the first kind, diff --git a/tensorflow/python/kernel_tests/rnn_cell_test.py b/tensorflow/python/kernel_tests/rnn_cell_test.py index 9de14006de2..d29c533badf 100644 --- a/tensorflow/python/kernel_tests/rnn_cell_test.py +++ b/tensorflow/python/kernel_tests/rnn_cell_test.py @@ -35,6 +35,7 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import control_flow_v2_toggles from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops @@ -1015,7 +1016,8 @@ class LSTMTest(test.TestCase): }) comparison_fn = self.assertAllEqual - if test_util.is_xla_enabled(): + if (test_util.is_xla_enabled() and + control_flow_v2_toggles.control_flow_v2_enabled()): comparison_fn = self.assertAllClose if in_graph_mode: comparison_fn(outputs_static, outputs_dynamic) @@ -1105,7 +1107,8 @@ class LSTMTest(test.TestCase): }) comparison_fn = self.assertAllEqual - if test_util.is_xla_enabled(): + if (test_util.is_xla_enabled() and + control_flow_v2_toggles.control_flow_v2_enabled()): comparison_fn = self.assertAllClose if in_graph_mode: comparison_fn(outputs_static, outputs_dynamic)