[XLA] Softplus should be monontic
While we are here, add NumPy support for sorting bfloat16 values to make it easier to write the test. PiperOrigin-RevId: 315363425 Change-Id: I3830835549ca02754da8c657e3722f9f0462a12a
This commit is contained in:
parent
7152155517
commit
27d684112b
@ -85,6 +85,11 @@ class UnaryOpsTest(xla_test.XLATestCase):
|
||||
for i in xrange(len(result)):
|
||||
self.assertAllClose(result[i], expected[i], rtol, atol)
|
||||
|
||||
def AssertCloseAndSorted(self, result, expected, rtol, atol):
|
||||
"""Tests that result and expeted are both close and sorted."""
|
||||
self.assertAllClose(result, expected, rtol, atol)
|
||||
self.assertAllEqual(np.sort(result), result)
|
||||
|
||||
@test_util.disable_mlir_bridge(
|
||||
"MlirHloBuilder::Iota missing required for xla::Diag")
|
||||
def testAllTypeOps(self):
|
||||
@ -1122,17 +1127,27 @@ class UnaryOpsTest(xla_test.XLATestCase):
|
||||
[[[12, 13, 14, 15, 28, 29, 30, 31]]]]],
|
||||
dtype=dtype))
|
||||
|
||||
def _assertSoftplusMatchesExpected(self, features, dtype):
|
||||
def _assertSoftplusMatchesExpected(self,
|
||||
features,
|
||||
dtype,
|
||||
equality_test=None,
|
||||
rtol=1e-6,
|
||||
atol=9.1e-6):
|
||||
features = np.array(features, dtype=dtype)
|
||||
zero = np.asarray(0).astype(dtype)
|
||||
expected = np.logaddexp(zero, features).astype(dtype)
|
||||
self._assertOpOutputMatchesExpected(
|
||||
nn_ops.softplus, features, expected=expected, rtol=1e-6, atol=9.1e-6)
|
||||
nn_ops.softplus,
|
||||
features,
|
||||
expected=expected,
|
||||
equality_test=equality_test,
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
|
||||
@test_util.disable_mlir_bridge(
|
||||
"bf16 type not supported in CreateDenseElementsAttrFromLiteral")
|
||||
def testSoftplus(self):
|
||||
for dtype in self.float_types:
|
||||
for dtype in self.float_types & {dtypes.float32, dtypes.float64}:
|
||||
self._assertSoftplusMatchesExpected([[-2, 0, 8]], dtype)
|
||||
self._assertSoftplusMatchesExpected(
|
||||
[[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]], dtype)
|
||||
@ -1148,6 +1163,13 @@ class UnaryOpsTest(xla_test.XLATestCase):
|
||||
-log_eps + ten
|
||||
], dtype)
|
||||
|
||||
self._assertSoftplusMatchesExpected(
|
||||
[0.69302183, 0.69324386],
|
||||
dtype,
|
||||
equality_test=self.AssertCloseAndSorted,
|
||||
rtol=9e-5,
|
||||
atol=9e-5)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
||||
|
@ -89,16 +89,25 @@ XLAJIT_MAKE_UNARY(Sign,
|
||||
xla::Select(xla::Ne(x, x), xla::ZerosLike(x), xla::Sign(x)));
|
||||
XLAJIT_MAKE_UNARY(Sinh, xla::Sinh(x));
|
||||
|
||||
// softplus(x) = log(1 + exp(x))
|
||||
//
|
||||
// This is not numerically stable when x is large, it can easily overflow.
|
||||
// However, we can compute it as LogSumExp(x, 0):
|
||||
// max(x, 0) + log(exp(x - max(x, 0)) + exp(0 - max(x, 0)))
|
||||
//
|
||||
// This is equivalent to:
|
||||
// max(x, 0) + log1p(exp(-abs(x)))
|
||||
XLAJIT_MAKE_UNARY(Softplus, xla::Max(x, xla::ScalarLike(x, 0.0)) +
|
||||
xla::Log1p(xla::Exp(-xla::Abs(x))));
|
||||
static xla::XlaOp Softplus(xla::XlaBuilder* b, xla::XlaOp features) {
|
||||
return b->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(auto shape, b->GetShape(features));
|
||||
xla::XlaOp threshold =
|
||||
Log(xla::Epsilon(b, shape.element_type())) + ScalarLike(features, 2.0);
|
||||
// Value above which exp(x) may overflow, but softplus(x) == x
|
||||
// is within machine epsilon.
|
||||
xla::XlaOp too_large = Gt(features, -threshold);
|
||||
// Value below which exp(x) may underflow, but softplus(x) == exp(x)
|
||||
// is within machine epsilon.
|
||||
xla::XlaOp too_small = Lt(features, threshold);
|
||||
xla::XlaOp features_exp = Exp(features);
|
||||
xla::XlaOp output =
|
||||
Select(too_large, features,
|
||||
Select(too_small, features_exp, Log1p(features_exp)));
|
||||
return output;
|
||||
});
|
||||
}
|
||||
XLAJIT_MAKE_UNARY(Softplus, Softplus(b, x));
|
||||
|
||||
// softsign(x) = x / (abs(x) + 1)
|
||||
XLAJIT_MAKE_UNARY(Softsign, x / (xla::Abs(x) + xla::ScalarLike(x, 1.0)));
|
||||
|
@ -441,6 +441,29 @@ void ByteSwap16(void* value) {
|
||||
std::swap(p[0], p[1]);
|
||||
}
|
||||
|
||||
int NPyBfloat16_Compare(const void* a, const void* b, void* arr) {
|
||||
bfloat16 x;
|
||||
memcpy(&x, a, sizeof(bfloat16));
|
||||
|
||||
bfloat16 y;
|
||||
memcpy(&y, b, sizeof(bfloat16));
|
||||
|
||||
if (x < y) {
|
||||
return -1;
|
||||
}
|
||||
if (y < x) {
|
||||
return 1;
|
||||
}
|
||||
// NaNs sort to the end.
|
||||
if (!std::isnan(x) && std::isnan(y)) {
|
||||
return -1;
|
||||
}
|
||||
if (std::isnan(x) && !std::isnan(y)) {
|
||||
return 1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
void NPyBfloat16_CopySwapN(void* dstv, npy_intp dstride, void* srcv,
|
||||
npy_intp sstride, npy_intp n, int swap, void* arr) {
|
||||
char* dst = reinterpret_cast<char*>(dstv);
|
||||
@ -1280,6 +1303,7 @@ bool Initialize() {
|
||||
PyArray_InitArrFuncs(&NPyBfloat16_ArrFuncs);
|
||||
NPyBfloat16_ArrFuncs.getitem = NPyBfloat16_GetItem;
|
||||
NPyBfloat16_ArrFuncs.setitem = NPyBfloat16_SetItem;
|
||||
NPyBfloat16_ArrFuncs.compare = NPyBfloat16_Compare;
|
||||
NPyBfloat16_ArrFuncs.copyswapn = NPyBfloat16_CopySwapN;
|
||||
NPyBfloat16_ArrFuncs.copyswap = NPyBfloat16_CopySwap;
|
||||
NPyBfloat16_ArrFuncs.nonzero = NPyBfloat16_NonZero;
|
||||
|
@ -219,6 +219,12 @@ class Bfloat16Test(parameterized.TestCase):
|
||||
numpy_assert_allclose(
|
||||
a, b, rtol=0.1, atol=0.1, equal_nan=True, err_msg="", verbose=True)
|
||||
|
||||
def testSort(self):
|
||||
values_to_sort = np.float32(FLOAT_VALUES)
|
||||
sorted_f32 = np.sort(values_to_sort)
|
||||
sorted_bf16 = np.sort(values_to_sort.astype(bfloat16))
|
||||
np.testing.assert_equal(sorted_f32, np.float32(sorted_bf16))
|
||||
|
||||
|
||||
BinaryOp = collections.namedtuple("BinaryOp", ["op"])
|
||||
|
||||
|
@ -412,6 +412,29 @@ void ByteSwap16(void* value) {
|
||||
std::swap(p[0], p[1]);
|
||||
}
|
||||
|
||||
int NPyBfloat16_Compare(const void* a, const void* b, void* arr) {
|
||||
bfloat16 x;
|
||||
memcpy(&x, a, sizeof(bfloat16));
|
||||
|
||||
bfloat16 y;
|
||||
memcpy(&y, b, sizeof(bfloat16));
|
||||
|
||||
if (x < y) {
|
||||
return -1;
|
||||
}
|
||||
if (y < x) {
|
||||
return 1;
|
||||
}
|
||||
// NaNs sort to the end.
|
||||
if (!std::isnan(x) && std::isnan(y)) {
|
||||
return -1;
|
||||
}
|
||||
if (std::isnan(x) && !std::isnan(y)) {
|
||||
return 1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
void NPyBfloat16_CopySwapN(void* dstv, npy_intp dstride, void* srcv,
|
||||
npy_intp sstride, npy_intp n, int swap, void* arr) {
|
||||
char* dst = reinterpret_cast<char*>(dstv);
|
||||
@ -561,6 +584,7 @@ bool Initialize() {
|
||||
PyArray_InitArrFuncs(&NPyBfloat16_ArrFuncs);
|
||||
NPyBfloat16_ArrFuncs.getitem = NPyBfloat16_GetItem;
|
||||
NPyBfloat16_ArrFuncs.setitem = NPyBfloat16_SetItem;
|
||||
NPyBfloat16_ArrFuncs.compare = NPyBfloat16_Compare;
|
||||
NPyBfloat16_ArrFuncs.copyswapn = NPyBfloat16_CopySwapN;
|
||||
NPyBfloat16_ArrFuncs.copyswap = NPyBfloat16_CopySwap;
|
||||
NPyBfloat16_ArrFuncs.nonzero = NPyBfloat16_NonZero;
|
||||
|
@ -32,15 +32,19 @@ from tensorflow.python.platform import test
|
||||
bfloat16 = _pywrap_bfloat16.TF_bfloat16_type()
|
||||
|
||||
|
||||
class Bfloat16Test(test.TestCase):
|
||||
def float_values():
|
||||
"""Returns values that should round trip exactly to float and back."""
|
||||
epsilon = float.fromhex("1.0p-7")
|
||||
return [
|
||||
0.0, 1.0, -1, 0.5, -0.5, epsilon, 1.0 + epsilon, 1.0 - epsilon,
|
||||
-1.0 - epsilon, -1.0 + epsilon, 3.5, 42.0, 255.0, 256.0,
|
||||
float("inf"),
|
||||
float("-inf"),
|
||||
float("nan")
|
||||
]
|
||||
|
||||
def float_values(self):
|
||||
"""Returns values that should round trip exactly to float and back."""
|
||||
epsilon = float.fromhex("1.0p-7")
|
||||
return [
|
||||
0.0, 1.0, -1, 0.5, -0.5, epsilon, 1.0 + epsilon, 1.0 - epsilon,
|
||||
-1.0 - epsilon, -1.0 + epsilon, 3.5, 42.0, 255.0, 256.0,
|
||||
float("inf"), float("-inf"), float("nan")]
|
||||
|
||||
class Bfloat16Test(test.TestCase):
|
||||
|
||||
def _assertFloatIdentical(self, v, w):
|
||||
if math.isnan(v):
|
||||
@ -49,7 +53,7 @@ class Bfloat16Test(test.TestCase):
|
||||
self.assertEqual(v, w)
|
||||
|
||||
def testRoundTripToFloat(self):
|
||||
for v in self.float_values():
|
||||
for v in float_values():
|
||||
self._assertFloatIdentical(v, float(bfloat16(v)))
|
||||
|
||||
def testRoundTripToInt(self):
|
||||
@ -82,7 +86,7 @@ class Bfloat16Test(test.TestCase):
|
||||
|
||||
# Tests for Python operations
|
||||
def testNegate(self):
|
||||
for v in self.float_values():
|
||||
for v in float_values():
|
||||
self._assertFloatIdentical(-v, float(-bfloat16(v)))
|
||||
|
||||
def testAdd(self):
|
||||
@ -132,33 +136,33 @@ class Bfloat16Test(test.TestCase):
|
||||
self.assertTrue(math.isnan(float(bfloat16(3.5) / bfloat16(float("nan")))))
|
||||
|
||||
def testLess(self):
|
||||
for v in self.float_values():
|
||||
for w in self.float_values():
|
||||
for v in float_values():
|
||||
for w in float_values():
|
||||
self.assertEqual(v < w, bfloat16(v) < bfloat16(w))
|
||||
|
||||
def testLessEqual(self):
|
||||
for v in self.float_values():
|
||||
for w in self.float_values():
|
||||
for v in float_values():
|
||||
for w in float_values():
|
||||
self.assertEqual(v <= w, bfloat16(v) <= bfloat16(w))
|
||||
|
||||
def testGreater(self):
|
||||
for v in self.float_values():
|
||||
for w in self.float_values():
|
||||
for v in float_values():
|
||||
for w in float_values():
|
||||
self.assertEqual(v > w, bfloat16(v) > bfloat16(w))
|
||||
|
||||
def testGreaterEqual(self):
|
||||
for v in self.float_values():
|
||||
for w in self.float_values():
|
||||
for v in float_values():
|
||||
for w in float_values():
|
||||
self.assertEqual(v >= w, bfloat16(v) >= bfloat16(w))
|
||||
|
||||
def testEqual(self):
|
||||
for v in self.float_values():
|
||||
for w in self.float_values():
|
||||
for v in float_values():
|
||||
for w in float_values():
|
||||
self.assertEqual(v == w, bfloat16(v) == bfloat16(w))
|
||||
|
||||
def testNotEqual(self):
|
||||
for v in self.float_values():
|
||||
for w in self.float_values():
|
||||
for v in float_values():
|
||||
for w in float_values():
|
||||
self.assertEqual(v != w, bfloat16(v) != bfloat16(w))
|
||||
|
||||
def testNan(self):
|
||||
@ -259,6 +263,12 @@ class Bfloat16NumPyTest(test.TestCase):
|
||||
np.arange(-16384., 16384., 64., dtype=np.float32).astype(bfloat16),
|
||||
np.arange(-16384., 16384., 64., dtype=bfloat16))
|
||||
|
||||
def testSort(self):
|
||||
values_to_sort = np.float32(float_values())
|
||||
sorted_f32 = np.sort(values_to_sort)
|
||||
sorted_bf16 = np.sort(values_to_sort.astype(bfloat16))
|
||||
self.assertAllEqual(sorted_f32, np.float32(sorted_bf16))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user