[XLA] Improve numerical stability of Logistic.

PiperOrigin-RevId: 314371904
Change-Id: Ica3e5b3146cbce739e76bc667a3df0571f671fb4
This commit is contained in:
Thomas Joerg 2020-06-02 11:37:42 -07:00 committed by TensorFlower Gardener
parent 0c44453d02
commit df6b2f04a1
3 changed files with 18 additions and 7 deletions

View File

@ -510,6 +510,16 @@ class UnaryOpsTest(xla_test.XLATestCase):
], ],
dtype=dtype)) dtype=dtype))
@test_util.disable_mlir_bridge(
"TODO(b/155501444): Handle _UnaryOpsComposition ops from Grappler")
def testFloatOpsDisabledOnMlirBridge(self):
for dtype in self.float_types:
if dtype != np.float16:
self._assertOpOutputMatchesExpected(
lambda x: math_ops.sigmoid(x) / math_ops.log1p(math_ops.exp(x)),
np.array([-40, 40], dtype=dtype),
expected=np.array([1.0, 0.025], dtype=dtype))
@test_util.disable_mlir_bridge( @test_util.disable_mlir_bridge(
"TODO(b/153812660): Handle tf.QuantizeAndDequantize compilation") "TODO(b/153812660): Handle tf.QuantizeAndDequantize compilation")
def testQuantizeAndDequantize(self): def testQuantizeAndDequantize(self):

View File

@ -1394,8 +1394,12 @@ XlaOp NextAfter(XlaOp from, XlaOp to) {
} }
XlaOp Logistic(XlaOp x) { XlaOp Logistic(XlaOp x) {
if (x.builder()->GetShape(x).ValueOrDie().element_type() == F16) {
auto half = xla::ScalarLike(x, 0.5); auto half = xla::ScalarLike(x, 0.5);
return half + half * xla::Tanh(half * x); return half + half * xla::Tanh(half * x);
}
auto one = xla::ScalarLike(x, 1);
return xla::Div(one, (one + xla::Exp(xla::Neg(x))));
} }
// Computes an approximation to the modified Bessel function of the first kind, // Computes an approximation to the modified Bessel function of the first kind,

View File

@ -35,7 +35,6 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_v2_toggles
from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
@ -1016,8 +1015,7 @@ class LSTMTest(test.TestCase):
}) })
comparison_fn = self.assertAllEqual comparison_fn = self.assertAllEqual
if (test_util.is_xla_enabled() and if test_util.is_xla_enabled():
control_flow_v2_toggles.control_flow_v2_enabled()):
comparison_fn = self.assertAllClose comparison_fn = self.assertAllClose
if in_graph_mode: if in_graph_mode:
comparison_fn(outputs_static, outputs_dynamic) comparison_fn(outputs_static, outputs_dynamic)
@ -1107,8 +1105,7 @@ class LSTMTest(test.TestCase):
}) })
comparison_fn = self.assertAllEqual comparison_fn = self.assertAllEqual
if (test_util.is_xla_enabled() and if test_util.is_xla_enabled():
control_flow_v2_toggles.control_flow_v2_enabled()):
comparison_fn = self.assertAllClose comparison_fn = self.assertAllClose
if in_graph_mode: if in_graph_mode:
comparison_fn(outputs_static, outputs_dynamic) comparison_fn(outputs_static, outputs_dynamic)