[XLA:Python] Don't depend on npy_... math functions in bfloat16 code.
PiperOrigin-RevId: 277522429 Change-Id: Ib1e0b4cf0acdf992891f926fb099435f6563420d
This commit is contained in:
parent
df1485e231
commit
7acd3bb9d7
@ -824,18 +824,44 @@ struct Multiply {
|
|||||||
struct TrueDivide {
|
struct TrueDivide {
|
||||||
bfloat16 operator()(bfloat16 a, bfloat16 b) { return a / b; }
|
bfloat16 operator()(bfloat16 a, bfloat16 b) { return a / b; }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
std::pair<float, float> divmod(float a, float b) {
|
||||||
|
if (b == 0.0f) {
|
||||||
|
float nan = std::numeric_limits<float>::quiet_NaN();
|
||||||
|
return {nan, nan};
|
||||||
|
}
|
||||||
|
float mod = std::fmod(a, b);
|
||||||
|
float div = (a - mod) / b;
|
||||||
|
if (mod != 0.0f) {
|
||||||
|
if ((b < 0.0f) != (mod < 0.0f)) {
|
||||||
|
mod += b;
|
||||||
|
div -= 1.0f;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
mod = std::copysign(0.0f, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
float floordiv;
|
||||||
|
if (div != 0.0f) {
|
||||||
|
floordiv = std::floor(div);
|
||||||
|
if (div - floordiv > 0.5f) {
|
||||||
|
floordiv += 1.0f;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
floordiv = std::copysign(0.0f, a / b);
|
||||||
|
}
|
||||||
|
return {floordiv, mod};
|
||||||
|
}
|
||||||
|
|
||||||
struct FloorDivide {
|
struct FloorDivide {
|
||||||
bfloat16 operator()(bfloat16 a, bfloat16 b) {
|
bfloat16 operator()(bfloat16 a, bfloat16 b) {
|
||||||
float mod;
|
return bfloat16(divmod(static_cast<float>(a), static_cast<float>(b)).first);
|
||||||
return bfloat16(
|
|
||||||
npy_divmodf(static_cast<float>(a), static_cast<float>(b), &mod));
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
struct Remainder {
|
struct Remainder {
|
||||||
bfloat16 operator()(bfloat16 a, bfloat16 b) {
|
bfloat16 operator()(bfloat16 a, bfloat16 b) {
|
||||||
float mod;
|
return bfloat16(
|
||||||
npy_divmodf(static_cast<float>(a), static_cast<float>(b), &mod);
|
divmod(static_cast<float>(a), static_cast<float>(b)).second);
|
||||||
return bfloat16(mod);
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
struct DivmodUFunc {
|
struct DivmodUFunc {
|
||||||
@ -851,9 +877,10 @@ struct DivmodUFunc {
|
|||||||
for (npy_intp k = 0; k < *dimensions; k++) {
|
for (npy_intp k = 0; k < *dimensions; k++) {
|
||||||
bfloat16 x = *reinterpret_cast<const bfloat16*>(i0);
|
bfloat16 x = *reinterpret_cast<const bfloat16*>(i0);
|
||||||
bfloat16 y = *reinterpret_cast<const bfloat16*>(i1);
|
bfloat16 y = *reinterpret_cast<const bfloat16*>(i1);
|
||||||
float mod;
|
float floordiv, mod;
|
||||||
*reinterpret_cast<bfloat16*>(o0) = bfloat16(
|
std::tie(floordiv, mod) =
|
||||||
npy_divmodf(static_cast<float>(x), static_cast<float>(y), &mod));
|
divmod(static_cast<float>(x), static_cast<float>(y));
|
||||||
|
*reinterpret_cast<bfloat16*>(o0) = bfloat16(floordiv);
|
||||||
*reinterpret_cast<bfloat16*>(o1) = bfloat16(mod);
|
*reinterpret_cast<bfloat16*>(o1) = bfloat16(mod);
|
||||||
i0 += steps[0];
|
i0 += steps[0];
|
||||||
i1 += steps[1];
|
i1 += steps[1];
|
||||||
@ -927,9 +954,18 @@ struct Frexp {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
struct Heaviside {
|
struct Heaviside {
|
||||||
bfloat16 operator()(bfloat16 a, bfloat16 b) {
|
bfloat16 operator()(bfloat16 bx, bfloat16 h0) {
|
||||||
return bfloat16(
|
float x = static_cast<float>(bx);
|
||||||
npy_heavisidef(static_cast<float>(a), static_cast<float>(b)));
|
if (std::isnan(x)) {
|
||||||
|
return bx;
|
||||||
|
}
|
||||||
|
if (x < 0) {
|
||||||
|
return bfloat16(0.0f);
|
||||||
|
}
|
||||||
|
if (x > 0) {
|
||||||
|
return bfloat16(1.0f);
|
||||||
|
}
|
||||||
|
return h0; // x == 0
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
struct Conjugate {
|
struct Conjugate {
|
||||||
@ -970,15 +1006,37 @@ struct Log1p {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
struct LogAddExp {
|
struct LogAddExp {
|
||||||
bfloat16 operator()(bfloat16 a, bfloat16 b) {
|
bfloat16 operator()(bfloat16 bx, bfloat16 by) {
|
||||||
return bfloat16(
|
float x = static_cast<float>(bx);
|
||||||
npy_logaddexpf(static_cast<float>(a), static_cast<float>(b)));
|
float y = static_cast<float>(by);
|
||||||
|
if (x == y) {
|
||||||
|
// Handles infinities of the same sign.
|
||||||
|
return bfloat16(x + std::log(2.0f));
|
||||||
|
}
|
||||||
|
float out = std::numeric_limits<float>::quiet_NaN();
|
||||||
|
if (x > y) {
|
||||||
|
out = x + std::log1p(std::exp(y - x));
|
||||||
|
} else if (x < y) {
|
||||||
|
out = y + std::log1p(std::exp(x - y));
|
||||||
|
}
|
||||||
|
return bfloat16(out);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
struct LogAddExp2 {
|
struct LogAddExp2 {
|
||||||
bfloat16 operator()(bfloat16 a, bfloat16 b) {
|
bfloat16 operator()(bfloat16 bx, bfloat16 by) {
|
||||||
return bfloat16(
|
float x = static_cast<float>(bx);
|
||||||
npy_logaddexp2f(static_cast<float>(a), static_cast<float>(b)));
|
float y = static_cast<float>(by);
|
||||||
|
if (x == y) {
|
||||||
|
// Handles infinities of the same sign.
|
||||||
|
return bfloat16(x + 1.0f);
|
||||||
|
}
|
||||||
|
float out = std::numeric_limits<float>::quiet_NaN();
|
||||||
|
if (x > y) {
|
||||||
|
out = x + std::log1p(std::exp2(y - x)) / std::log(2.0f);
|
||||||
|
} else if (x < y) {
|
||||||
|
out = y + std::log1p(std::exp2(x - y)) / std::log(2.0f);
|
||||||
|
}
|
||||||
|
return bfloat16(out);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
struct Modf {
|
struct Modf {
|
||||||
@ -1104,12 +1162,14 @@ struct Arctanh {
|
|||||||
};
|
};
|
||||||
struct Deg2rad {
|
struct Deg2rad {
|
||||||
bfloat16 operator()(bfloat16 a) {
|
bfloat16 operator()(bfloat16 a) {
|
||||||
return bfloat16(npy_deg2radf(static_cast<float>(a)));
|
static constexpr float radians_per_degree = M_PI / 180.0f;
|
||||||
|
return bfloat16(static_cast<float>(a) * radians_per_degree);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
struct Rad2deg {
|
struct Rad2deg {
|
||||||
bfloat16 operator()(bfloat16 a) {
|
bfloat16 operator()(bfloat16 a) {
|
||||||
return bfloat16(npy_rad2degf(static_cast<float>(a)));
|
static constexpr float degrees_per_radian = 180.0f / M_PI;
|
||||||
|
return bfloat16(static_cast<float>(a) * degrees_per_radian);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -317,7 +317,7 @@ class Bfloat16NumPyTest(parameterized.TestCase):
|
|||||||
} for op in UNARY_UFUNCS))
|
} for op in UNARY_UFUNCS))
|
||||||
def testUnaryUfunc(self, op):
|
def testUnaryUfunc(self, op):
|
||||||
rng = np.random.RandomState(seed=42)
|
rng = np.random.RandomState(seed=42)
|
||||||
x = rng.randn(3, 7).astype(bfloat16)
|
x = rng.randn(3, 7, 10).astype(bfloat16)
|
||||||
numpy_assert_allclose(
|
numpy_assert_allclose(
|
||||||
op(x).astype(np.float32), op(x.astype(np.float32)), rtol=1e-2)
|
op(x).astype(np.float32), op(x.astype(np.float32)), rtol=1e-2)
|
||||||
|
|
||||||
@ -327,8 +327,8 @@ class Bfloat16NumPyTest(parameterized.TestCase):
|
|||||||
} for op in BINARY_UFUNCS))
|
} for op in BINARY_UFUNCS))
|
||||||
def testBinaryUfunc(self, op):
|
def testBinaryUfunc(self, op):
|
||||||
rng = np.random.RandomState(seed=42)
|
rng = np.random.RandomState(seed=42)
|
||||||
x = rng.randn(3, 7).astype(bfloat16)
|
x = rng.randn(3, 7, 10).astype(bfloat16)
|
||||||
y = rng.randn(4, 1, 7).astype(bfloat16)
|
y = rng.randn(4, 1, 7, 10).astype(bfloat16)
|
||||||
numpy_assert_allclose(
|
numpy_assert_allclose(
|
||||||
op(x, y).astype(np.float32),
|
op(x, y).astype(np.float32),
|
||||||
op(x.astype(np.float32), y.astype(np.float32)),
|
op(x.astype(np.float32), y.astype(np.float32)),
|
||||||
@ -351,7 +351,7 @@ class Bfloat16NumPyTest(parameterized.TestCase):
|
|||||||
} for op in [np.isfinite, np.isinf, np.isnan, np.signbit, np.logical_not]))
|
} for op in [np.isfinite, np.isinf, np.isnan, np.signbit, np.logical_not]))
|
||||||
def testPredicateUfunc(self, op):
|
def testPredicateUfunc(self, op):
|
||||||
rng = np.random.RandomState(seed=42)
|
rng = np.random.RandomState(seed=42)
|
||||||
shape = (3, 7)
|
shape = (3, 7, 10)
|
||||||
posinf_flips = rng.rand(*shape) < 0.1
|
posinf_flips = rng.rand(*shape) < 0.1
|
||||||
neginf_flips = rng.rand(*shape) < 0.1
|
neginf_flips = rng.rand(*shape) < 0.1
|
||||||
nan_flips = rng.rand(*shape) < 0.1
|
nan_flips = rng.rand(*shape) < 0.1
|
||||||
|
Loading…
Reference in New Issue
Block a user