[XLA] Add xla::IsNan, IsPosInf, IsNegInf, IsPosOrNegInf.

Useful numeric helper functions.

PiperOrigin-RevId: 231503933
This commit is contained in:
Justin Lebar 2019-01-29 17:13:33 -08:00 committed by TensorFlower Gardener
parent 91ebeecc92
commit cfb819c9cc
6 changed files with 161 additions and 47 deletions

View File

@ -65,11 +65,8 @@ XLAJIT_MAKE_UNARY(Exp, xla::Exp(x));
XLAJIT_MAKE_UNARY(Expm1, xla::Expm1(x));
XLAJIT_MAKE_UNARY(Floor, xla::Floor(x));
XLAJIT_MAKE_UNARY(IsFinite, xla::IsFinite(x));
XLAJIT_MAKE_UNARY(
IsInf,
xla::Eq(xla::Abs(x),
xla::ScalarLike(x, std::numeric_limits<double>::infinity())));
XLAJIT_MAKE_UNARY(IsNan, xla::Ne(x, x));
XLAJIT_MAKE_UNARY(IsInf, xla::IsInf(x));
XLAJIT_MAKE_UNARY(IsNan, xla::IsNan(x));
// Return 1/x
XLAJIT_MAKE_UNARY(Inv, xla::ScalarLike(x, 1.0) / x);
XLAJIT_MAKE_UNARY(Reciprocal, xla::ScalarLike(x, 1.0) / x);

View File

@ -184,6 +184,7 @@ cc_library(
srcs = ["math.cc"],
hdrs = ["math.h"],
deps = [
":arithmetic",
":constants",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
@ -197,6 +198,7 @@ xla_test(
deps = [
":math",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
@ -26,6 +27,58 @@ limitations under the License.
namespace xla {
// TODO(jlebar): Use this function in more places in this file to restrict the
// domain of other functions.
static Status EnsureOperandIsRealFp(absl::string_view op_name, XlaOp operand) {
auto& b = *operand.builder();
TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(operand));
auto elem_ty = shape.element_type();
if (!primitive_util::IsFloatingPointType(elem_ty)) {
return InvalidArgument(
"Operands to %s must be real-valued floating-point, but got %s",
op_name, PrimitiveType_Name(elem_ty));
}
return Status::OK();
}
XlaOp IsPosInf(XlaOp operand) {
auto& b = *operand.builder();
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsPosInf", operand));
TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(operand));
// Note that this is only correct for floating-point types. If we wanted it
// to be correct for all types, we'd need to Gt(MaxFiniteValue).
return Eq(operand, MaxValue(&b, shape.element_type()));
});
}
XlaOp IsNegInf(XlaOp operand) {
auto& b = *operand.builder();
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsNegInf", operand));
TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(operand));
// Note that this is only correct for floating-point types. If we wanted it
// to be correct for all types, we'd need to Lt(MinFiniteValue).
return Eq(operand, MinValue(&b, shape.element_type()));
});
}
XlaOp IsInf(XlaOp operand) {
auto& b = *operand.builder();
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsInf", operand));
return IsPosInf(Abs(operand));
});
}
XlaOp IsNan(XlaOp operand) {
auto& b = *operand.builder();
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsNan", operand));
return Ne(operand, operand);
});
}
XlaOp Sqrt(XlaOp operand) { return Pow(operand, ScalarLike(operand, 0.5)); }
XlaOp Rsqrt(XlaOp operand) { return Pow(operand, ScalarLike(operand, -0.5)); }
@ -101,14 +154,8 @@ XlaOp EvaluatePolynomial(XlaOp x, absl::Span<const float> coefficients) {
XlaOp Erfc(XlaOp x) {
auto& b = *x.builder();
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
// Reject non-real non-fp inputs. (We could extend erfc to accept complex
// types, but it doesn't seem necessary at this point.)
TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(x));
if (!ShapeUtil::ElementIsFloating(shape)) {
return InvalidArgument(
"erfc only accepts real floating-point arrays or scalars, but got %s",
shape.ToString());
}
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Erfc", x));
XlaOp abs_x = Abs(x);
XlaOp z = Exp(-x * x);
@ -223,15 +270,7 @@ static constexpr std::array<double, 8> kLanczosCoefficients = {
XlaOp Lgamma(XlaOp input) {
auto& b = *input.builder();
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
// Reject non-real non-fp inputs. (We could extend lgamma to accept complex
// types, but it doesn't seem necessary at this point.)
TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(input));
if (!ShapeUtil::ElementIsFloating(shape)) {
return InvalidArgument(
"lgamma only accepts real floating-point arrays or scalars, but got "
"%s",
shape.ToString());
}
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Lgamma", input));
XlaOp one_half = ScalarLike(input, 0.5);
XlaOp one = ScalarLike(input, 1);
@ -321,15 +360,7 @@ XlaOp Lgamma(XlaOp input) {
XlaOp Digamma(XlaOp input) {
auto& b = *input.builder();
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
// Reject non-real non-fp inputs. (We could extend digamma to accept
// complex types, but it doesn't seem necessary at this point.)
TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(input));
if (!ShapeUtil::ElementIsFloating(shape)) {
return InvalidArgument(
"digamma only accepts real floating-point arrays or scalars, but got "
"%s",
shape.ToString());
}
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Digamma", input));
XlaOp zero = ScalarLike(input, 0);
XlaOp one_half = ScalarLike(input, 0.5);
@ -381,12 +412,8 @@ XlaOp RoundToEven(XlaOp x) {
// Reject non-real non-fp inputs (What does it even mean to round a complex
// number? Do you round each component equally? In that case, you should
// just ask for that explicitly.)
TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(x));
if (ShapeUtil::ElementIsComplex(shape)) {
return InvalidArgument(
"RoundToEven doesn't accept complex inputs, but got %s",
shape.ToString());
}
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("RoundToEven", x));
auto half = ScalarLike(x, 0.5);
auto one = ScalarLike(x, 1.0);
auto two = ScalarLike(x, 2.0);

View File

@ -20,6 +20,18 @@ limitations under the License.
namespace xla {
// Determines whether operand is +/-inf or nan.
//
// Raises an error if called on integral or complex values.
XlaOp IsPosInf(XlaOp operand);
XlaOp IsNegInf(XlaOp operand);
XlaOp IsInf(XlaOp operand);
XlaOp IsNan(XlaOp operand);
// Returns the next number after 'from' in the direction of 'to' the same way
// std::nextafter(from, to) would.
XlaOp NextAfter(XlaOp from, XlaOp to);
// Computes the square root of 'operand'.
XlaOp Sqrt(XlaOp operand);
@ -90,10 +102,6 @@ XlaOp Sinh(XlaOp x);
// is true, otherwise returns its argument.
xla::XlaOp MaybeConjugate(xla::XlaOp x, bool conjugate);
// Returns the next number after 'from' in the direction of 'to' the same way
// std::nextafter(from, to) would.
XlaOp NextAfter(XlaOp from, XlaOp to);
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATH_H_

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
@ -55,6 +56,38 @@ class MathTypedTest : public MathTest {
&b, {T{0.0}, T{-0.0}, -std::numeric_limits<T>::infinity()}, {},
error_spec_);
}
void TestIsInfOrNan() {
SetFastMathDisabled(true);
XlaBuilder b(TestName());
auto x =
ConstantR1<T>(&b, {
T{0},
T{100},
T{-1000},
T{std::numeric_limits<T>::max()},
T{std::numeric_limits<T>::lowest()},
T{std::numeric_limits<float>::infinity()},
T{-std::numeric_limits<float>::infinity()},
T{std::numeric_limits<float>::quiet_NaN()},
T{std::numeric_limits<float>::signaling_NaN()},
});
Tuple(&b, {IsFinite(x), IsInf(x), IsPosInf(x), IsNegInf(x), IsNan(x)});
auto expected = LiteralUtil::MakeTupleOwned(
LiteralUtil::CreateR1<bool>(
{true, true, true, true, true, false, false, false, false}),
LiteralUtil::CreateR1<bool>(
{false, false, false, false, false, true, true, false, false}),
LiteralUtil::CreateR1<bool>(
{false, false, false, false, false, true, false, false, false}),
LiteralUtil::CreateR1<bool>(
{false, false, false, false, false, false, true, false, false}),
LiteralUtil::CreateR1<bool>(
{false, false, false, false, false, false, false, true, true}));
ComputeAndCompareLiteral(&b, expected, {});
}
};
// TODO(b/123355973): Add bfloat16 to TestTypes once it's working.
@ -68,6 +101,50 @@ TYPED_TEST_CASE(MathTypedTest, TestTypes);
XLA_TYPED_TEST(MathTypedTest, LogEdgeCases) { this->TestLogEdgeCases(); }
XLA_TYPED_TEST(MathTypedTest, Log1pEdgeCases) { this->TestLog1pEdgeCases(); }
XLA_TYPED_TEST(MathTypedTest, IsInfOrNan) { this->TestIsInfOrNan(); }
// Check that certain ops only support real, floating-point inputs.
//
// TODO(jlebar): Expand this test to cover more ops.
XLA_TEST_F(MathTest, RealFpOnlyOps) {
for (int64 i = PrimitiveType_MIN; i <= PrimitiveType_MAX; ++i) {
auto ty = static_cast<PrimitiveType>(i);
SCOPED_TRACE(PrimitiveType_Name(ty));
Shape shape;
if (primitive_util::IsArrayType(ty)) {
shape = ShapeUtil::MakeShape(ty, {42});
} else if (ty == PrimitiveType::TUPLE) {
shape = ShapeUtil::MakeTupleShape({});
} else if (ty == PrimitiveType::OPAQUE) {
shape = ShapeUtil::MakeOpaqueShape();
} else if (ty == PrimitiveType::TOKEN) {
shape = ShapeUtil::MakeTokenShape();
} else {
continue;
}
for (const auto& test :
std::vector<std::pair<std::function<XlaOp(XlaOp)>, string>>({
{IsFinite, "is_finite"},
{IsInf, "is_inf"},
{IsPosInf, "is_pos_inf"},
{IsNegInf, "is_neg_inf"},
{IsNan, "is_nan"},
{Erf, "erf"},
{Erfc, "erfc"},
{Lgamma, "lgamma"},
{Digamma, "digamma"},
{RoundToEven, "round_to_even"},
})) {
SCOPED_TRACE(test.second);
XlaBuilder b(TestName());
XlaOp p = Parameter(&b, 0, shape, "p0");
test.first(p);
EXPECT_EQ(b.first_error().ok(), primitive_util::IsFloatingPointType(ty));
}
}
}
XLA_TEST_F(MathTest, SqrtF32) {
XlaBuilder builder(TestName());
@ -145,8 +222,7 @@ XLA_TEST_F(MathTest, Lgamma) {
ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
}
// TODO(jlebar): Fails on interpreter due to unimplemented operation.
XLA_TEST_F(MathTest, DISABLED_ON_INTERPRETER(LgammaF16)) {
XLA_TEST_F(MathTest, LgammaF16) {
SetFastMathDisabled(true);
XlaBuilder b(TestName());

View File

@ -1776,10 +1776,14 @@ XlaOp Imag(const XlaOp& operand);
XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> broadcast_dimensions = {});
// Enqueues an operator that tests if the operand's values are finite, i.e.,
// not Inf or NaN. Defined only for floating-point types. Returns an array of
// booleans with the same shape where entries are true iff the corresponding
// entry was NaN.
// Enqueues an operator that tests if the operand's values are finite, i.e., not
// +/-Inf or NaN. Returns an array of booleans with the same shape where
// entries are true iff the corresponding entry was not infinite or NaN.
//
// Defined only for real-valued (i.e. not complex) floating-point types; raises
// an error for other types.
//
// See also IsInf, IsPosInf, IsNegInf, and IsNan in lib/math.h.
XlaOp IsFinite(const XlaOp& operand);
// Enqueues an iota operation onto the computation.