[XLA] Add xla::IsNan, IsPosInf, IsNegInf, IsPosOrNegInf.
Useful numeric helper functions. PiperOrigin-RevId: 231503933
This commit is contained in:
parent
91ebeecc92
commit
cfb819c9cc
@ -65,11 +65,8 @@ XLAJIT_MAKE_UNARY(Exp, xla::Exp(x));
|
||||
XLAJIT_MAKE_UNARY(Expm1, xla::Expm1(x));
|
||||
XLAJIT_MAKE_UNARY(Floor, xla::Floor(x));
|
||||
XLAJIT_MAKE_UNARY(IsFinite, xla::IsFinite(x));
|
||||
XLAJIT_MAKE_UNARY(
|
||||
IsInf,
|
||||
xla::Eq(xla::Abs(x),
|
||||
xla::ScalarLike(x, std::numeric_limits<double>::infinity())));
|
||||
XLAJIT_MAKE_UNARY(IsNan, xla::Ne(x, x));
|
||||
XLAJIT_MAKE_UNARY(IsInf, xla::IsInf(x));
|
||||
XLAJIT_MAKE_UNARY(IsNan, xla::IsNan(x));
|
||||
// Return 1/x
|
||||
XLAJIT_MAKE_UNARY(Inv, xla::ScalarLike(x, 1.0) / x);
|
||||
XLAJIT_MAKE_UNARY(Reciprocal, xla::ScalarLike(x, 1.0) / x);
|
||||
|
@ -184,6 +184,7 @@ cc_library(
|
||||
srcs = ["math.cc"],
|
||||
hdrs = ["math.h"],
|
||||
deps = [
|
||||
":arithmetic",
|
||||
":constants",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
@ -197,6 +198,7 @@ xla_test(
|
||||
deps = [
|
||||
":math",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/client/lib/math.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
@ -26,6 +27,58 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
// TODO(jlebar): Use this function in more places in this file to restrict the
|
||||
// domain of other functions.
|
||||
static Status EnsureOperandIsRealFp(absl::string_view op_name, XlaOp operand) {
|
||||
auto& b = *operand.builder();
|
||||
TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(operand));
|
||||
auto elem_ty = shape.element_type();
|
||||
if (!primitive_util::IsFloatingPointType(elem_ty)) {
|
||||
return InvalidArgument(
|
||||
"Operands to %s must be real-valued floating-point, but got %s",
|
||||
op_name, PrimitiveType_Name(elem_ty));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
XlaOp IsPosInf(XlaOp operand) {
|
||||
auto& b = *operand.builder();
|
||||
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsPosInf", operand));
|
||||
TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(operand));
|
||||
// Note that this is only correct for floating-point types. If we wanted it
|
||||
// to be correct for all types, we'd need to Gt(MaxFiniteValue).
|
||||
return Eq(operand, MaxValue(&b, shape.element_type()));
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp IsNegInf(XlaOp operand) {
|
||||
auto& b = *operand.builder();
|
||||
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsNegInf", operand));
|
||||
TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(operand));
|
||||
// Note that this is only correct for floating-point types. If we wanted it
|
||||
// to be correct for all types, we'd need to Lt(MinFiniteValue).
|
||||
return Eq(operand, MinValue(&b, shape.element_type()));
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp IsInf(XlaOp operand) {
|
||||
auto& b = *operand.builder();
|
||||
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsInf", operand));
|
||||
return IsPosInf(Abs(operand));
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp IsNan(XlaOp operand) {
|
||||
auto& b = *operand.builder();
|
||||
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsNan", operand));
|
||||
return Ne(operand, operand);
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp Sqrt(XlaOp operand) { return Pow(operand, ScalarLike(operand, 0.5)); }
|
||||
|
||||
XlaOp Rsqrt(XlaOp operand) { return Pow(operand, ScalarLike(operand, -0.5)); }
|
||||
@ -101,14 +154,8 @@ XlaOp EvaluatePolynomial(XlaOp x, absl::Span<const float> coefficients) {
|
||||
XlaOp Erfc(XlaOp x) {
|
||||
auto& b = *x.builder();
|
||||
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
// Reject non-real non-fp inputs. (We could extend erfc to accept complex
|
||||
// types, but it doesn't seem necessary at this point.)
|
||||
TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(x));
|
||||
if (!ShapeUtil::ElementIsFloating(shape)) {
|
||||
return InvalidArgument(
|
||||
"erfc only accepts real floating-point arrays or scalars, but got %s",
|
||||
shape.ToString());
|
||||
}
|
||||
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Erfc", x));
|
||||
|
||||
XlaOp abs_x = Abs(x);
|
||||
XlaOp z = Exp(-x * x);
|
||||
|
||||
@ -223,15 +270,7 @@ static constexpr std::array<double, 8> kLanczosCoefficients = {
|
||||
XlaOp Lgamma(XlaOp input) {
|
||||
auto& b = *input.builder();
|
||||
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
// Reject non-real non-fp inputs. (We could extend lgamma to accept complex
|
||||
// types, but it doesn't seem necessary at this point.)
|
||||
TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(input));
|
||||
if (!ShapeUtil::ElementIsFloating(shape)) {
|
||||
return InvalidArgument(
|
||||
"lgamma only accepts real floating-point arrays or scalars, but got "
|
||||
"%s",
|
||||
shape.ToString());
|
||||
}
|
||||
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Lgamma", input));
|
||||
|
||||
XlaOp one_half = ScalarLike(input, 0.5);
|
||||
XlaOp one = ScalarLike(input, 1);
|
||||
@ -321,15 +360,7 @@ XlaOp Lgamma(XlaOp input) {
|
||||
XlaOp Digamma(XlaOp input) {
|
||||
auto& b = *input.builder();
|
||||
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
// Reject non-real non-fp inputs. (We could extend digamma to accept
|
||||
// complex types, but it doesn't seem necessary at this point.)
|
||||
TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(input));
|
||||
if (!ShapeUtil::ElementIsFloating(shape)) {
|
||||
return InvalidArgument(
|
||||
"digamma only accepts real floating-point arrays or scalars, but got "
|
||||
"%s",
|
||||
shape.ToString());
|
||||
}
|
||||
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Digamma", input));
|
||||
|
||||
XlaOp zero = ScalarLike(input, 0);
|
||||
XlaOp one_half = ScalarLike(input, 0.5);
|
||||
@ -381,12 +412,8 @@ XlaOp RoundToEven(XlaOp x) {
|
||||
// Reject non-real non-fp inputs (What does it even mean to round a complex
|
||||
// number? Do you round each component equally? In that case, you should
|
||||
// just ask for that explicitly.)
|
||||
TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(x));
|
||||
if (ShapeUtil::ElementIsComplex(shape)) {
|
||||
return InvalidArgument(
|
||||
"RoundToEven doesn't accept complex inputs, but got %s",
|
||||
shape.ToString());
|
||||
}
|
||||
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("RoundToEven", x));
|
||||
|
||||
auto half = ScalarLike(x, 0.5);
|
||||
auto one = ScalarLike(x, 1.0);
|
||||
auto two = ScalarLike(x, 2.0);
|
||||
|
@ -20,6 +20,18 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Determines whether operand is +/-inf or nan.
|
||||
//
|
||||
// Raises an error if called on integral or complex values.
|
||||
XlaOp IsPosInf(XlaOp operand);
|
||||
XlaOp IsNegInf(XlaOp operand);
|
||||
XlaOp IsInf(XlaOp operand);
|
||||
XlaOp IsNan(XlaOp operand);
|
||||
|
||||
// Returns the next number after 'from' in the direction of 'to' the same way
|
||||
// std::nextafter(from, to) would.
|
||||
XlaOp NextAfter(XlaOp from, XlaOp to);
|
||||
|
||||
// Computes the square root of 'operand'.
|
||||
XlaOp Sqrt(XlaOp operand);
|
||||
|
||||
@ -90,10 +102,6 @@ XlaOp Sinh(XlaOp x);
|
||||
// is true, otherwise returns its argument.
|
||||
xla::XlaOp MaybeConjugate(xla::XlaOp x, bool conjugate);
|
||||
|
||||
// Returns the next number after 'from' in the direction of 'to' the same way
|
||||
// std::nextafter(from, to) would.
|
||||
XlaOp NextAfter(XlaOp from, XlaOp to);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATH_H_
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/client/lib/math.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||
@ -55,6 +56,38 @@ class MathTypedTest : public MathTest {
|
||||
&b, {T{0.0}, T{-0.0}, -std::numeric_limits<T>::infinity()}, {},
|
||||
error_spec_);
|
||||
}
|
||||
|
||||
void TestIsInfOrNan() {
|
||||
SetFastMathDisabled(true);
|
||||
|
||||
XlaBuilder b(TestName());
|
||||
auto x =
|
||||
ConstantR1<T>(&b, {
|
||||
T{0},
|
||||
T{100},
|
||||
T{-1000},
|
||||
T{std::numeric_limits<T>::max()},
|
||||
T{std::numeric_limits<T>::lowest()},
|
||||
T{std::numeric_limits<float>::infinity()},
|
||||
T{-std::numeric_limits<float>::infinity()},
|
||||
T{std::numeric_limits<float>::quiet_NaN()},
|
||||
T{std::numeric_limits<float>::signaling_NaN()},
|
||||
});
|
||||
Tuple(&b, {IsFinite(x), IsInf(x), IsPosInf(x), IsNegInf(x), IsNan(x)});
|
||||
|
||||
auto expected = LiteralUtil::MakeTupleOwned(
|
||||
LiteralUtil::CreateR1<bool>(
|
||||
{true, true, true, true, true, false, false, false, false}),
|
||||
LiteralUtil::CreateR1<bool>(
|
||||
{false, false, false, false, false, true, true, false, false}),
|
||||
LiteralUtil::CreateR1<bool>(
|
||||
{false, false, false, false, false, true, false, false, false}),
|
||||
LiteralUtil::CreateR1<bool>(
|
||||
{false, false, false, false, false, false, true, false, false}),
|
||||
LiteralUtil::CreateR1<bool>(
|
||||
{false, false, false, false, false, false, false, true, true}));
|
||||
ComputeAndCompareLiteral(&b, expected, {});
|
||||
}
|
||||
};
|
||||
|
||||
// TODO(b/123355973): Add bfloat16 to TestTypes once it's working.
|
||||
@ -68,6 +101,50 @@ TYPED_TEST_CASE(MathTypedTest, TestTypes);
|
||||
|
||||
XLA_TYPED_TEST(MathTypedTest, LogEdgeCases) { this->TestLogEdgeCases(); }
|
||||
XLA_TYPED_TEST(MathTypedTest, Log1pEdgeCases) { this->TestLog1pEdgeCases(); }
|
||||
XLA_TYPED_TEST(MathTypedTest, IsInfOrNan) { this->TestIsInfOrNan(); }
|
||||
|
||||
// Check that certain ops only support real, floating-point inputs.
|
||||
//
|
||||
// TODO(jlebar): Expand this test to cover more ops.
|
||||
XLA_TEST_F(MathTest, RealFpOnlyOps) {
|
||||
for (int64 i = PrimitiveType_MIN; i <= PrimitiveType_MAX; ++i) {
|
||||
auto ty = static_cast<PrimitiveType>(i);
|
||||
SCOPED_TRACE(PrimitiveType_Name(ty));
|
||||
Shape shape;
|
||||
if (primitive_util::IsArrayType(ty)) {
|
||||
shape = ShapeUtil::MakeShape(ty, {42});
|
||||
} else if (ty == PrimitiveType::TUPLE) {
|
||||
shape = ShapeUtil::MakeTupleShape({});
|
||||
} else if (ty == PrimitiveType::OPAQUE) {
|
||||
shape = ShapeUtil::MakeOpaqueShape();
|
||||
} else if (ty == PrimitiveType::TOKEN) {
|
||||
shape = ShapeUtil::MakeTokenShape();
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (const auto& test :
|
||||
std::vector<std::pair<std::function<XlaOp(XlaOp)>, string>>({
|
||||
{IsFinite, "is_finite"},
|
||||
{IsInf, "is_inf"},
|
||||
{IsPosInf, "is_pos_inf"},
|
||||
{IsNegInf, "is_neg_inf"},
|
||||
{IsNan, "is_nan"},
|
||||
{Erf, "erf"},
|
||||
{Erfc, "erfc"},
|
||||
{Lgamma, "lgamma"},
|
||||
{Digamma, "digamma"},
|
||||
{RoundToEven, "round_to_even"},
|
||||
})) {
|
||||
SCOPED_TRACE(test.second);
|
||||
XlaBuilder b(TestName());
|
||||
XlaOp p = Parameter(&b, 0, shape, "p0");
|
||||
test.first(p);
|
||||
|
||||
EXPECT_EQ(b.first_error().ok(), primitive_util::IsFloatingPointType(ty));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
XLA_TEST_F(MathTest, SqrtF32) {
|
||||
XlaBuilder builder(TestName());
|
||||
@ -145,8 +222,7 @@ XLA_TEST_F(MathTest, Lgamma) {
|
||||
ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
|
||||
}
|
||||
|
||||
// TODO(jlebar): Fails on interpreter due to unimplemented operation.
|
||||
XLA_TEST_F(MathTest, DISABLED_ON_INTERPRETER(LgammaF16)) {
|
||||
XLA_TEST_F(MathTest, LgammaF16) {
|
||||
SetFastMathDisabled(true);
|
||||
|
||||
XlaBuilder b(TestName());
|
||||
|
@ -1776,10 +1776,14 @@ XlaOp Imag(const XlaOp& operand);
|
||||
XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
|
||||
absl::Span<const int64> broadcast_dimensions = {});
|
||||
|
||||
// Enqueues an operator that tests if the operand's values are finite, i.e.,
|
||||
// not Inf or NaN. Defined only for floating-point types. Returns an array of
|
||||
// booleans with the same shape where entries are true iff the corresponding
|
||||
// entry was NaN.
|
||||
// Enqueues an operator that tests if the operand's values are finite, i.e., not
|
||||
// +/-Inf or NaN. Returns an array of booleans with the same shape where
|
||||
// entries are true iff the corresponding entry was not infinite or NaN.
|
||||
//
|
||||
// Defined only for real-valued (i.e. not complex) floating-point types; raises
|
||||
// an error for other types.
|
||||
//
|
||||
// See also IsInf, IsPosInf, IsNegInf, and IsNan in lib/math.h.
|
||||
XlaOp IsFinite(const XlaOp& operand);
|
||||
|
||||
// Enqueues an iota operation onto the computation.
|
||||
|
Loading…
x
Reference in New Issue
Block a user