[XLA:Python] Don't depend on npy_... math functions in bfloat16 code.

PiperOrigin-RevId: 277522429
Change-Id: Ib1e0b4cf0acdf992891f926fb099435f6563420d
This commit is contained in:
Peter Hawkins 2019-10-30 09:37:19 -07:00 committed by TensorFlower Gardener
parent df1485e231
commit 7acd3bb9d7
2 changed files with 84 additions and 24 deletions

View File

@ -824,18 +824,44 @@ struct Multiply {
struct TrueDivide {
bfloat16 operator()(bfloat16 a, bfloat16 b) { return a / b; }
};
std::pair<float, float> divmod(float a, float b) {
if (b == 0.0f) {
float nan = std::numeric_limits<float>::quiet_NaN();
return {nan, nan};
}
float mod = std::fmod(a, b);
float div = (a - mod) / b;
if (mod != 0.0f) {
if ((b < 0.0f) != (mod < 0.0f)) {
mod += b;
div -= 1.0f;
}
} else {
mod = std::copysign(0.0f, b);
}
float floordiv;
if (div != 0.0f) {
floordiv = std::floor(div);
if (div - floordiv > 0.5f) {
floordiv += 1.0f;
}
} else {
floordiv = std::copysign(0.0f, a / b);
}
return {floordiv, mod};
}
struct FloorDivide {
bfloat16 operator()(bfloat16 a, bfloat16 b) {
float mod;
return bfloat16(
npy_divmodf(static_cast<float>(a), static_cast<float>(b), &mod));
return bfloat16(divmod(static_cast<float>(a), static_cast<float>(b)).first);
}
};
struct Remainder {
bfloat16 operator()(bfloat16 a, bfloat16 b) {
float mod;
npy_divmodf(static_cast<float>(a), static_cast<float>(b), &mod);
return bfloat16(mod);
return bfloat16(
divmod(static_cast<float>(a), static_cast<float>(b)).second);
}
};
struct DivmodUFunc {
@ -851,9 +877,10 @@ struct DivmodUFunc {
for (npy_intp k = 0; k < *dimensions; k++) {
bfloat16 x = *reinterpret_cast<const bfloat16*>(i0);
bfloat16 y = *reinterpret_cast<const bfloat16*>(i1);
float mod;
*reinterpret_cast<bfloat16*>(o0) = bfloat16(
npy_divmodf(static_cast<float>(x), static_cast<float>(y), &mod));
float floordiv, mod;
std::tie(floordiv, mod) =
divmod(static_cast<float>(x), static_cast<float>(y));
*reinterpret_cast<bfloat16*>(o0) = bfloat16(floordiv);
*reinterpret_cast<bfloat16*>(o1) = bfloat16(mod);
i0 += steps[0];
i1 += steps[1];
@ -927,9 +954,18 @@ struct Frexp {
}
};
struct Heaviside {
bfloat16 operator()(bfloat16 a, bfloat16 b) {
return bfloat16(
npy_heavisidef(static_cast<float>(a), static_cast<float>(b)));
bfloat16 operator()(bfloat16 bx, bfloat16 h0) {
float x = static_cast<float>(bx);
if (std::isnan(x)) {
return bx;
}
if (x < 0) {
return bfloat16(0.0f);
}
if (x > 0) {
return bfloat16(1.0f);
}
return h0; // x == 0
}
};
struct Conjugate {
@ -970,15 +1006,37 @@ struct Log1p {
}
};
struct LogAddExp {
bfloat16 operator()(bfloat16 a, bfloat16 b) {
return bfloat16(
npy_logaddexpf(static_cast<float>(a), static_cast<float>(b)));
bfloat16 operator()(bfloat16 bx, bfloat16 by) {
float x = static_cast<float>(bx);
float y = static_cast<float>(by);
if (x == y) {
// Handles infinities of the same sign.
return bfloat16(x + std::log(2.0f));
}
float out = std::numeric_limits<float>::quiet_NaN();
if (x > y) {
out = x + std::log1p(std::exp(y - x));
} else if (x < y) {
out = y + std::log1p(std::exp(x - y));
}
return bfloat16(out);
}
};
struct LogAddExp2 {
bfloat16 operator()(bfloat16 a, bfloat16 b) {
return bfloat16(
npy_logaddexp2f(static_cast<float>(a), static_cast<float>(b)));
bfloat16 operator()(bfloat16 bx, bfloat16 by) {
float x = static_cast<float>(bx);
float y = static_cast<float>(by);
if (x == y) {
// Handles infinities of the same sign.
return bfloat16(x + 1.0f);
}
float out = std::numeric_limits<float>::quiet_NaN();
if (x > y) {
out = x + std::log1p(std::exp2(y - x)) / std::log(2.0f);
} else if (x < y) {
out = y + std::log1p(std::exp2(x - y)) / std::log(2.0f);
}
return bfloat16(out);
}
};
struct Modf {
@ -1104,12 +1162,14 @@ struct Arctanh {
};
struct Deg2rad {
bfloat16 operator()(bfloat16 a) {
return bfloat16(npy_deg2radf(static_cast<float>(a)));
static constexpr float radians_per_degree = M_PI / 180.0f;
return bfloat16(static_cast<float>(a) * radians_per_degree);
}
};
struct Rad2deg {
bfloat16 operator()(bfloat16 a) {
return bfloat16(npy_rad2degf(static_cast<float>(a)));
static constexpr float degrees_per_radian = 180.0f / M_PI;
return bfloat16(static_cast<float>(a) * degrees_per_radian);
}
};

View File

@ -317,7 +317,7 @@ class Bfloat16NumPyTest(parameterized.TestCase):
} for op in UNARY_UFUNCS))
def testUnaryUfunc(self, op):
rng = np.random.RandomState(seed=42)
x = rng.randn(3, 7).astype(bfloat16)
x = rng.randn(3, 7, 10).astype(bfloat16)
numpy_assert_allclose(
op(x).astype(np.float32), op(x.astype(np.float32)), rtol=1e-2)
@ -327,8 +327,8 @@ class Bfloat16NumPyTest(parameterized.TestCase):
} for op in BINARY_UFUNCS))
def testBinaryUfunc(self, op):
rng = np.random.RandomState(seed=42)
x = rng.randn(3, 7).astype(bfloat16)
y = rng.randn(4, 1, 7).astype(bfloat16)
x = rng.randn(3, 7, 10).astype(bfloat16)
y = rng.randn(4, 1, 7, 10).astype(bfloat16)
numpy_assert_allclose(
op(x, y).astype(np.float32),
op(x.astype(np.float32), y.astype(np.float32)),
@ -351,7 +351,7 @@ class Bfloat16NumPyTest(parameterized.TestCase):
} for op in [np.isfinite, np.isinf, np.isnan, np.signbit, np.logical_not]))
def testPredicateUfunc(self, op):
rng = np.random.RandomState(seed=42)
shape = (3, 7)
shape = (3, 7, 10)
posinf_flips = rng.rand(*shape) < 0.1
neginf_flips = rng.rand(*shape) < 0.1
nan_flips = rng.rand(*shape) < 0.1