[XLA] Avoid UB in absl::bitcast from Eigen::half.
PiperOrigin-RevId: 219331320
This commit is contained in:
parent
507c566376
commit
6b8e08a932
@ -34,16 +34,22 @@ namespace xla {
|
||||
namespace literal_comparison {
|
||||
namespace {
|
||||
|
||||
// Since Eigen::half doesn't satisfy the absl::bit_cast contract, we need to be
|
||||
// able to transparently access the raw 16-bit value contained within.
|
||||
template <typename T>
|
||||
T GetRawValue(T val) {
|
||||
return val;
|
||||
}
|
||||
uint16 GetRawValue(Eigen::half val) { return val.x; }
|
||||
|
||||
// Helper function for comparing a floating point type, FloatT, bitwise equal
|
||||
// between the left-hand-side and right-hand-side, by bit-casting to UnsignedT
|
||||
// -- on miscompare, a nice error message is given in the AssertionFailure.
|
||||
template <typename FloatT, typename UnsignedT>
|
||||
Status CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs,
|
||||
absl::Span<const int64> multi_index) {
|
||||
// TODO(b/118627822): These are unsafe bit_casts because Eigen::Half is not
|
||||
// trivially copyable.
|
||||
auto ulhs = absl::bit_cast<UnsignedT>(lhs);
|
||||
auto urhs = absl::bit_cast<UnsignedT>(rhs);
|
||||
auto ulhs = absl::bit_cast<UnsignedT>(GetRawValue(lhs));
|
||||
auto urhs = absl::bit_cast<UnsignedT>(GetRawValue(rhs));
|
||||
auto lhs_double = static_cast<double>(lhs);
|
||||
auto rhs_double = static_cast<double>(rhs);
|
||||
if (ulhs != urhs) {
|
||||
|
Loading…
Reference in New Issue
Block a user