[XLA] Improve numerical stability of Logistic.
PiperOrigin-RevId: 315320526 Change-Id: Iedfd22d0fb657cb31dda537786ce001f1dab168b
This commit is contained in:
parent
699af178b3
commit
e60c1ba960
@ -510,6 +510,16 @@ class UnaryOpsTest(xla_test.XLATestCase):
|
|||||||
],
|
],
|
||||||
dtype=dtype))
|
dtype=dtype))
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge(
|
||||||
|
"TODO(b/155501444): Handle _UnaryOpsComposition ops from Grappler")
|
||||||
|
def testFloatOpsDisabledOnMlirBridge(self):
|
||||||
|
for dtype in self.float_types:
|
||||||
|
if dtype != np.float16:
|
||||||
|
self._assertOpOutputMatchesExpected(
|
||||||
|
lambda x: math_ops.sigmoid(x) / math_ops.log1p(math_ops.exp(x)),
|
||||||
|
np.array([-40, 40], dtype=dtype),
|
||||||
|
expected=np.array([1.0, 0.025], dtype=dtype))
|
||||||
|
|
||||||
@test_util.disable_mlir_bridge(
|
@test_util.disable_mlir_bridge(
|
||||||
"TODO(b/153812660): Handle tf.QuantizeAndDequantize compilation")
|
"TODO(b/153812660): Handle tf.QuantizeAndDequantize compilation")
|
||||||
def testQuantizeAndDequantize(self):
|
def testQuantizeAndDequantize(self):
|
||||||
|
@ -1394,8 +1394,8 @@ XlaOp NextAfter(XlaOp from, XlaOp to) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
XlaOp Logistic(XlaOp x) {
|
XlaOp Logistic(XlaOp x) {
|
||||||
auto half = xla::ScalarLike(x, 0.5);
|
auto one = xla::ScalarLike(x, 1);
|
||||||
return half + half * xla::Tanh(half * x);
|
return xla::Div(one, (one + xla::Exp(xla::Neg(x))));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Computes an approximation to the modified Bessel function of the first kind,
|
// Computes an approximation to the modified Bessel function of the first kind,
|
||||||
|
@ -35,7 +35,6 @@ from tensorflow.python.framework import tensor_shape
|
|||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import control_flow_v2_toggles
|
|
||||||
from tensorflow.python.ops import gradients_impl
|
from tensorflow.python.ops import gradients_impl
|
||||||
from tensorflow.python.ops import init_ops
|
from tensorflow.python.ops import init_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
@ -1016,8 +1015,7 @@ class LSTMTest(test.TestCase):
|
|||||||
})
|
})
|
||||||
|
|
||||||
comparison_fn = self.assertAllEqual
|
comparison_fn = self.assertAllEqual
|
||||||
if (test_util.is_xla_enabled() and
|
if test_util.is_xla_enabled():
|
||||||
control_flow_v2_toggles.control_flow_v2_enabled()):
|
|
||||||
comparison_fn = self.assertAllClose
|
comparison_fn = self.assertAllClose
|
||||||
if in_graph_mode:
|
if in_graph_mode:
|
||||||
comparison_fn(outputs_static, outputs_dynamic)
|
comparison_fn(outputs_static, outputs_dynamic)
|
||||||
@ -1107,8 +1105,7 @@ class LSTMTest(test.TestCase):
|
|||||||
})
|
})
|
||||||
|
|
||||||
comparison_fn = self.assertAllEqual
|
comparison_fn = self.assertAllEqual
|
||||||
if (test_util.is_xla_enabled() and
|
if test_util.is_xla_enabled():
|
||||||
control_flow_v2_toggles.control_flow_v2_enabled()):
|
|
||||||
comparison_fn = self.assertAllClose
|
comparison_fn = self.assertAllClose
|
||||||
if in_graph_mode:
|
if in_graph_mode:
|
||||||
comparison_fn(outputs_static, outputs_dynamic)
|
comparison_fn(outputs_static, outputs_dynamic)
|
||||||
|
Loading…
Reference in New Issue
Block a user