[XLA] Softplus should be monontic

While we are here, add NumPy support for sorting bfloat16 values to make it
easier to write the test.

PiperOrigin-RevId: 315363425
Change-Id: I3830835549ca02754da8c657e3722f9f0462a12a
This commit is contained in:
David Majnemer 2020-06-08 15:12:18 -07:00 committed by TensorFlower Gardener
parent 7152155517
commit 27d684112b
6 changed files with 130 additions and 35 deletions

View File

@ -85,6 +85,11 @@ class UnaryOpsTest(xla_test.XLATestCase):
for i in xrange(len(result)):
self.assertAllClose(result[i], expected[i], rtol, atol)
def AssertCloseAndSorted(self, result, expected, rtol, atol):
"""Tests that result and expeted are both close and sorted."""
self.assertAllClose(result, expected, rtol, atol)
self.assertAllEqual(np.sort(result), result)
@test_util.disable_mlir_bridge(
"MlirHloBuilder::Iota missing required for xla::Diag")
def testAllTypeOps(self):
@ -1122,17 +1127,27 @@ class UnaryOpsTest(xla_test.XLATestCase):
[[[12, 13, 14, 15, 28, 29, 30, 31]]]]],
dtype=dtype))
def _assertSoftplusMatchesExpected(self, features, dtype):
def _assertSoftplusMatchesExpected(self,
features,
dtype,
equality_test=None,
rtol=1e-6,
atol=9.1e-6):
features = np.array(features, dtype=dtype)
zero = np.asarray(0).astype(dtype)
expected = np.logaddexp(zero, features).astype(dtype)
self._assertOpOutputMatchesExpected(
nn_ops.softplus, features, expected=expected, rtol=1e-6, atol=9.1e-6)
nn_ops.softplus,
features,
expected=expected,
equality_test=equality_test,
rtol=rtol,
atol=atol)
@test_util.disable_mlir_bridge(
"bf16 type not supported in CreateDenseElementsAttrFromLiteral")
def testSoftplus(self):
for dtype in self.float_types:
for dtype in self.float_types & {dtypes.float32, dtypes.float64}:
self._assertSoftplusMatchesExpected([[-2, 0, 8]], dtype)
self._assertSoftplusMatchesExpected(
[[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]], dtype)
@ -1148,6 +1163,13 @@ class UnaryOpsTest(xla_test.XLATestCase):
-log_eps + ten
], dtype)
self._assertSoftplusMatchesExpected(
[0.69302183, 0.69324386],
dtype,
equality_test=self.AssertCloseAndSorted,
rtol=9e-5,
atol=9e-5)
if __name__ == "__main__":
googletest.main()

View File

@ -89,16 +89,25 @@ XLAJIT_MAKE_UNARY(Sign,
xla::Select(xla::Ne(x, x), xla::ZerosLike(x), xla::Sign(x)));
XLAJIT_MAKE_UNARY(Sinh, xla::Sinh(x));
// softplus(x) = log(1 + exp(x))
//
// This is not numerically stable when x is large, it can easily overflow.
// However, we can compute it as LogSumExp(x, 0):
// max(x, 0) + log(exp(x - max(x, 0)) + exp(0 - max(x, 0)))
//
// This is equivalent to:
// max(x, 0) + log1p(exp(-abs(x)))
XLAJIT_MAKE_UNARY(Softplus, xla::Max(x, xla::ScalarLike(x, 0.0)) +
xla::Log1p(xla::Exp(-xla::Abs(x))));
static xla::XlaOp Softplus(xla::XlaBuilder* b, xla::XlaOp features) {
return b->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
TF_ASSIGN_OR_RETURN(auto shape, b->GetShape(features));
xla::XlaOp threshold =
Log(xla::Epsilon(b, shape.element_type())) + ScalarLike(features, 2.0);
// Value above which exp(x) may overflow, but softplus(x) == x
// is within machine epsilon.
xla::XlaOp too_large = Gt(features, -threshold);
// Value below which exp(x) may underflow, but softplus(x) == exp(x)
// is within machine epsilon.
xla::XlaOp too_small = Lt(features, threshold);
xla::XlaOp features_exp = Exp(features);
xla::XlaOp output =
Select(too_large, features,
Select(too_small, features_exp, Log1p(features_exp)));
return output;
});
}
XLAJIT_MAKE_UNARY(Softplus, Softplus(b, x));
// softsign(x) = x / (abs(x) + 1)
XLAJIT_MAKE_UNARY(Softsign, x / (xla::Abs(x) + xla::ScalarLike(x, 1.0)));

View File

@ -441,6 +441,29 @@ void ByteSwap16(void* value) {
std::swap(p[0], p[1]);
}
int NPyBfloat16_Compare(const void* a, const void* b, void* arr) {
bfloat16 x;
memcpy(&x, a, sizeof(bfloat16));
bfloat16 y;
memcpy(&y, b, sizeof(bfloat16));
if (x < y) {
return -1;
}
if (y < x) {
return 1;
}
// NaNs sort to the end.
if (!std::isnan(x) && std::isnan(y)) {
return -1;
}
if (std::isnan(x) && !std::isnan(y)) {
return 1;
}
return 0;
}
void NPyBfloat16_CopySwapN(void* dstv, npy_intp dstride, void* srcv,
npy_intp sstride, npy_intp n, int swap, void* arr) {
char* dst = reinterpret_cast<char*>(dstv);
@ -1280,6 +1303,7 @@ bool Initialize() {
PyArray_InitArrFuncs(&NPyBfloat16_ArrFuncs);
NPyBfloat16_ArrFuncs.getitem = NPyBfloat16_GetItem;
NPyBfloat16_ArrFuncs.setitem = NPyBfloat16_SetItem;
NPyBfloat16_ArrFuncs.compare = NPyBfloat16_Compare;
NPyBfloat16_ArrFuncs.copyswapn = NPyBfloat16_CopySwapN;
NPyBfloat16_ArrFuncs.copyswap = NPyBfloat16_CopySwap;
NPyBfloat16_ArrFuncs.nonzero = NPyBfloat16_NonZero;

View File

@ -219,6 +219,12 @@ class Bfloat16Test(parameterized.TestCase):
numpy_assert_allclose(
a, b, rtol=0.1, atol=0.1, equal_nan=True, err_msg="", verbose=True)
def testSort(self):
values_to_sort = np.float32(FLOAT_VALUES)
sorted_f32 = np.sort(values_to_sort)
sorted_bf16 = np.sort(values_to_sort.astype(bfloat16))
np.testing.assert_equal(sorted_f32, np.float32(sorted_bf16))
BinaryOp = collections.namedtuple("BinaryOp", ["op"])

View File

@ -412,6 +412,29 @@ void ByteSwap16(void* value) {
std::swap(p[0], p[1]);
}
int NPyBfloat16_Compare(const void* a, const void* b, void* arr) {
bfloat16 x;
memcpy(&x, a, sizeof(bfloat16));
bfloat16 y;
memcpy(&y, b, sizeof(bfloat16));
if (x < y) {
return -1;
}
if (y < x) {
return 1;
}
// NaNs sort to the end.
if (!std::isnan(x) && std::isnan(y)) {
return -1;
}
if (std::isnan(x) && !std::isnan(y)) {
return 1;
}
return 0;
}
void NPyBfloat16_CopySwapN(void* dstv, npy_intp dstride, void* srcv,
npy_intp sstride, npy_intp n, int swap, void* arr) {
char* dst = reinterpret_cast<char*>(dstv);
@ -561,6 +584,7 @@ bool Initialize() {
PyArray_InitArrFuncs(&NPyBfloat16_ArrFuncs);
NPyBfloat16_ArrFuncs.getitem = NPyBfloat16_GetItem;
NPyBfloat16_ArrFuncs.setitem = NPyBfloat16_SetItem;
NPyBfloat16_ArrFuncs.compare = NPyBfloat16_Compare;
NPyBfloat16_ArrFuncs.copyswapn = NPyBfloat16_CopySwapN;
NPyBfloat16_ArrFuncs.copyswap = NPyBfloat16_CopySwap;
NPyBfloat16_ArrFuncs.nonzero = NPyBfloat16_NonZero;

View File

@ -32,15 +32,19 @@ from tensorflow.python.platform import test
bfloat16 = _pywrap_bfloat16.TF_bfloat16_type()
class Bfloat16Test(test.TestCase):
def float_values():
"""Returns values that should round trip exactly to float and back."""
epsilon = float.fromhex("1.0p-7")
return [
0.0, 1.0, -1, 0.5, -0.5, epsilon, 1.0 + epsilon, 1.0 - epsilon,
-1.0 - epsilon, -1.0 + epsilon, 3.5, 42.0, 255.0, 256.0,
float("inf"),
float("-inf"),
float("nan")
]
def float_values(self):
"""Returns values that should round trip exactly to float and back."""
epsilon = float.fromhex("1.0p-7")
return [
0.0, 1.0, -1, 0.5, -0.5, epsilon, 1.0 + epsilon, 1.0 - epsilon,
-1.0 - epsilon, -1.0 + epsilon, 3.5, 42.0, 255.0, 256.0,
float("inf"), float("-inf"), float("nan")]
class Bfloat16Test(test.TestCase):
def _assertFloatIdentical(self, v, w):
if math.isnan(v):
@ -49,7 +53,7 @@ class Bfloat16Test(test.TestCase):
self.assertEqual(v, w)
def testRoundTripToFloat(self):
for v in self.float_values():
for v in float_values():
self._assertFloatIdentical(v, float(bfloat16(v)))
def testRoundTripToInt(self):
@ -82,7 +86,7 @@ class Bfloat16Test(test.TestCase):
# Tests for Python operations
def testNegate(self):
for v in self.float_values():
for v in float_values():
self._assertFloatIdentical(-v, float(-bfloat16(v)))
def testAdd(self):
@ -132,33 +136,33 @@ class Bfloat16Test(test.TestCase):
self.assertTrue(math.isnan(float(bfloat16(3.5) / bfloat16(float("nan")))))
def testLess(self):
for v in self.float_values():
for w in self.float_values():
for v in float_values():
for w in float_values():
self.assertEqual(v < w, bfloat16(v) < bfloat16(w))
def testLessEqual(self):
for v in self.float_values():
for w in self.float_values():
for v in float_values():
for w in float_values():
self.assertEqual(v <= w, bfloat16(v) <= bfloat16(w))
def testGreater(self):
for v in self.float_values():
for w in self.float_values():
for v in float_values():
for w in float_values():
self.assertEqual(v > w, bfloat16(v) > bfloat16(w))
def testGreaterEqual(self):
for v in self.float_values():
for w in self.float_values():
for v in float_values():
for w in float_values():
self.assertEqual(v >= w, bfloat16(v) >= bfloat16(w))
def testEqual(self):
for v in self.float_values():
for w in self.float_values():
for v in float_values():
for w in float_values():
self.assertEqual(v == w, bfloat16(v) == bfloat16(w))
def testNotEqual(self):
for v in self.float_values():
for w in self.float_values():
for v in float_values():
for w in float_values():
self.assertEqual(v != w, bfloat16(v) != bfloat16(w))
def testNan(self):
@ -259,6 +263,12 @@ class Bfloat16NumPyTest(test.TestCase):
np.arange(-16384., 16384., 64., dtype=np.float32).astype(bfloat16),
np.arange(-16384., 16384., 64., dtype=bfloat16))
def testSort(self):
values_to_sort = np.float32(float_values())
sorted_f32 = np.sort(values_to_sort)
sorted_bf16 = np.sort(values_to_sort.astype(bfloat16))
self.assertAllEqual(sorted_f32, np.float32(sorted_bf16))
if __name__ == "__main__":
test.main()