[XLA] Add xla::IsNan, IsPosInf, IsNegInf, IsPosOrNegInf.
Useful numeric helper functions. PiperOrigin-RevId: 231503933
This commit is contained in:
parent
91ebeecc92
commit
cfb819c9cc
@ -65,11 +65,8 @@ XLAJIT_MAKE_UNARY(Exp, xla::Exp(x));
|
|||||||
XLAJIT_MAKE_UNARY(Expm1, xla::Expm1(x));
|
XLAJIT_MAKE_UNARY(Expm1, xla::Expm1(x));
|
||||||
XLAJIT_MAKE_UNARY(Floor, xla::Floor(x));
|
XLAJIT_MAKE_UNARY(Floor, xla::Floor(x));
|
||||||
XLAJIT_MAKE_UNARY(IsFinite, xla::IsFinite(x));
|
XLAJIT_MAKE_UNARY(IsFinite, xla::IsFinite(x));
|
||||||
XLAJIT_MAKE_UNARY(
|
XLAJIT_MAKE_UNARY(IsInf, xla::IsInf(x));
|
||||||
IsInf,
|
XLAJIT_MAKE_UNARY(IsNan, xla::IsNan(x));
|
||||||
xla::Eq(xla::Abs(x),
|
|
||||||
xla::ScalarLike(x, std::numeric_limits<double>::infinity())));
|
|
||||||
XLAJIT_MAKE_UNARY(IsNan, xla::Ne(x, x));
|
|
||||||
// Return 1/x
|
// Return 1/x
|
||||||
XLAJIT_MAKE_UNARY(Inv, xla::ScalarLike(x, 1.0) / x);
|
XLAJIT_MAKE_UNARY(Inv, xla::ScalarLike(x, 1.0) / x);
|
||||||
XLAJIT_MAKE_UNARY(Reciprocal, xla::ScalarLike(x, 1.0) / x);
|
XLAJIT_MAKE_UNARY(Reciprocal, xla::ScalarLike(x, 1.0) / x);
|
||||||
|
@ -184,6 +184,7 @@ cc_library(
|
|||||||
srcs = ["math.cc"],
|
srcs = ["math.cc"],
|
||||||
hdrs = ["math.h"],
|
hdrs = ["math.h"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":arithmetic",
|
||||||
":constants",
|
":constants",
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
"//tensorflow/compiler/xla:status_macros",
|
"//tensorflow/compiler/xla:status_macros",
|
||||||
@ -197,6 +198,7 @@ xla_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":math",
|
":math",
|
||||||
"//tensorflow/compiler/xla:literal_util",
|
"//tensorflow/compiler/xla:literal_util",
|
||||||
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
"//tensorflow/compiler/xla:test",
|
"//tensorflow/compiler/xla:test",
|
||||||
"//tensorflow/compiler/xla:types",
|
"//tensorflow/compiler/xla:types",
|
||||||
"//tensorflow/compiler/xla:xla_data_proto",
|
"//tensorflow/compiler/xla:xla_data_proto",
|
||||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/compiler/xla/client/lib/math.h"
|
#include "tensorflow/compiler/xla/client/lib/math.h"
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
|
||||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
@ -26,6 +27,58 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
|
// TODO(jlebar): Use this function in more places in this file to restrict the
|
||||||
|
// domain of other functions.
|
||||||
|
static Status EnsureOperandIsRealFp(absl::string_view op_name, XlaOp operand) {
|
||||||
|
auto& b = *operand.builder();
|
||||||
|
TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(operand));
|
||||||
|
auto elem_ty = shape.element_type();
|
||||||
|
if (!primitive_util::IsFloatingPointType(elem_ty)) {
|
||||||
|
return InvalidArgument(
|
||||||
|
"Operands to %s must be real-valued floating-point, but got %s",
|
||||||
|
op_name, PrimitiveType_Name(elem_ty));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
XlaOp IsPosInf(XlaOp operand) {
|
||||||
|
auto& b = *operand.builder();
|
||||||
|
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||||
|
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsPosInf", operand));
|
||||||
|
TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(operand));
|
||||||
|
// Note that this is only correct for floating-point types. If we wanted it
|
||||||
|
// to be correct for all types, we'd need to Gt(MaxFiniteValue).
|
||||||
|
return Eq(operand, MaxValue(&b, shape.element_type()));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
XlaOp IsNegInf(XlaOp operand) {
|
||||||
|
auto& b = *operand.builder();
|
||||||
|
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||||
|
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsNegInf", operand));
|
||||||
|
TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(operand));
|
||||||
|
// Note that this is only correct for floating-point types. If we wanted it
|
||||||
|
// to be correct for all types, we'd need to Lt(MinFiniteValue).
|
||||||
|
return Eq(operand, MinValue(&b, shape.element_type()));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
XlaOp IsInf(XlaOp operand) {
|
||||||
|
auto& b = *operand.builder();
|
||||||
|
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||||
|
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsInf", operand));
|
||||||
|
return IsPosInf(Abs(operand));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
XlaOp IsNan(XlaOp operand) {
|
||||||
|
auto& b = *operand.builder();
|
||||||
|
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||||
|
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsNan", operand));
|
||||||
|
return Ne(operand, operand);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
XlaOp Sqrt(XlaOp operand) { return Pow(operand, ScalarLike(operand, 0.5)); }
|
XlaOp Sqrt(XlaOp operand) { return Pow(operand, ScalarLike(operand, 0.5)); }
|
||||||
|
|
||||||
XlaOp Rsqrt(XlaOp operand) { return Pow(operand, ScalarLike(operand, -0.5)); }
|
XlaOp Rsqrt(XlaOp operand) { return Pow(operand, ScalarLike(operand, -0.5)); }
|
||||||
@ -101,14 +154,8 @@ XlaOp EvaluatePolynomial(XlaOp x, absl::Span<const float> coefficients) {
|
|||||||
XlaOp Erfc(XlaOp x) {
|
XlaOp Erfc(XlaOp x) {
|
||||||
auto& b = *x.builder();
|
auto& b = *x.builder();
|
||||||
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||||
// Reject non-real non-fp inputs. (We could extend erfc to accept complex
|
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Erfc", x));
|
||||||
// types, but it doesn't seem necessary at this point.)
|
|
||||||
TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(x));
|
|
||||||
if (!ShapeUtil::ElementIsFloating(shape)) {
|
|
||||||
return InvalidArgument(
|
|
||||||
"erfc only accepts real floating-point arrays or scalars, but got %s",
|
|
||||||
shape.ToString());
|
|
||||||
}
|
|
||||||
XlaOp abs_x = Abs(x);
|
XlaOp abs_x = Abs(x);
|
||||||
XlaOp z = Exp(-x * x);
|
XlaOp z = Exp(-x * x);
|
||||||
|
|
||||||
@ -223,15 +270,7 @@ static constexpr std::array<double, 8> kLanczosCoefficients = {
|
|||||||
XlaOp Lgamma(XlaOp input) {
|
XlaOp Lgamma(XlaOp input) {
|
||||||
auto& b = *input.builder();
|
auto& b = *input.builder();
|
||||||
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||||
// Reject non-real non-fp inputs. (We could extend lgamma to accept complex
|
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Lgamma", input));
|
||||||
// types, but it doesn't seem necessary at this point.)
|
|
||||||
TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(input));
|
|
||||||
if (!ShapeUtil::ElementIsFloating(shape)) {
|
|
||||||
return InvalidArgument(
|
|
||||||
"lgamma only accepts real floating-point arrays or scalars, but got "
|
|
||||||
"%s",
|
|
||||||
shape.ToString());
|
|
||||||
}
|
|
||||||
|
|
||||||
XlaOp one_half = ScalarLike(input, 0.5);
|
XlaOp one_half = ScalarLike(input, 0.5);
|
||||||
XlaOp one = ScalarLike(input, 1);
|
XlaOp one = ScalarLike(input, 1);
|
||||||
@ -321,15 +360,7 @@ XlaOp Lgamma(XlaOp input) {
|
|||||||
XlaOp Digamma(XlaOp input) {
|
XlaOp Digamma(XlaOp input) {
|
||||||
auto& b = *input.builder();
|
auto& b = *input.builder();
|
||||||
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||||
// Reject non-real non-fp inputs. (We could extend digamma to accept
|
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Digamma", input));
|
||||||
// complex types, but it doesn't seem necessary at this point.)
|
|
||||||
TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(input));
|
|
||||||
if (!ShapeUtil::ElementIsFloating(shape)) {
|
|
||||||
return InvalidArgument(
|
|
||||||
"digamma only accepts real floating-point arrays or scalars, but got "
|
|
||||||
"%s",
|
|
||||||
shape.ToString());
|
|
||||||
}
|
|
||||||
|
|
||||||
XlaOp zero = ScalarLike(input, 0);
|
XlaOp zero = ScalarLike(input, 0);
|
||||||
XlaOp one_half = ScalarLike(input, 0.5);
|
XlaOp one_half = ScalarLike(input, 0.5);
|
||||||
@ -381,12 +412,8 @@ XlaOp RoundToEven(XlaOp x) {
|
|||||||
// Reject non-real non-fp inputs (What does it even mean to round a complex
|
// Reject non-real non-fp inputs (What does it even mean to round a complex
|
||||||
// number? Do you round each component equally? In that case, you should
|
// number? Do you round each component equally? In that case, you should
|
||||||
// just ask for that explicitly.)
|
// just ask for that explicitly.)
|
||||||
TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(x));
|
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("RoundToEven", x));
|
||||||
if (ShapeUtil::ElementIsComplex(shape)) {
|
|
||||||
return InvalidArgument(
|
|
||||||
"RoundToEven doesn't accept complex inputs, but got %s",
|
|
||||||
shape.ToString());
|
|
||||||
}
|
|
||||||
auto half = ScalarLike(x, 0.5);
|
auto half = ScalarLike(x, 0.5);
|
||||||
auto one = ScalarLike(x, 1.0);
|
auto one = ScalarLike(x, 1.0);
|
||||||
auto two = ScalarLike(x, 2.0);
|
auto two = ScalarLike(x, 2.0);
|
||||||
|
@ -20,6 +20,18 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
|
// Determines whether operand is +/-inf or nan.
|
||||||
|
//
|
||||||
|
// Raises an error if called on integral or complex values.
|
||||||
|
XlaOp IsPosInf(XlaOp operand);
|
||||||
|
XlaOp IsNegInf(XlaOp operand);
|
||||||
|
XlaOp IsInf(XlaOp operand);
|
||||||
|
XlaOp IsNan(XlaOp operand);
|
||||||
|
|
||||||
|
// Returns the next number after 'from' in the direction of 'to' the same way
|
||||||
|
// std::nextafter(from, to) would.
|
||||||
|
XlaOp NextAfter(XlaOp from, XlaOp to);
|
||||||
|
|
||||||
// Computes the square root of 'operand'.
|
// Computes the square root of 'operand'.
|
||||||
XlaOp Sqrt(XlaOp operand);
|
XlaOp Sqrt(XlaOp operand);
|
||||||
|
|
||||||
@ -90,10 +102,6 @@ XlaOp Sinh(XlaOp x);
|
|||||||
// is true, otherwise returns its argument.
|
// is true, otherwise returns its argument.
|
||||||
xla::XlaOp MaybeConjugate(xla::XlaOp x, bool conjugate);
|
xla::XlaOp MaybeConjugate(xla::XlaOp x, bool conjugate);
|
||||||
|
|
||||||
// Returns the next number after 'from' in the direction of 'to' the same way
|
|
||||||
// std::nextafter(from, to) would.
|
|
||||||
XlaOp NextAfter(XlaOp from, XlaOp to);
|
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATH_H_
|
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATH_H_
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/client/lib/math.h"
|
#include "tensorflow/compiler/xla/client/lib/math.h"
|
||||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||||
#include "tensorflow/compiler/xla/literal_util.h"
|
#include "tensorflow/compiler/xla/literal_util.h"
|
||||||
|
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||||
#include "tensorflow/compiler/xla/test.h"
|
#include "tensorflow/compiler/xla/test.h"
|
||||||
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
|
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
|
||||||
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||||
@ -55,6 +56,38 @@ class MathTypedTest : public MathTest {
|
|||||||
&b, {T{0.0}, T{-0.0}, -std::numeric_limits<T>::infinity()}, {},
|
&b, {T{0.0}, T{-0.0}, -std::numeric_limits<T>::infinity()}, {},
|
||||||
error_spec_);
|
error_spec_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void TestIsInfOrNan() {
|
||||||
|
SetFastMathDisabled(true);
|
||||||
|
|
||||||
|
XlaBuilder b(TestName());
|
||||||
|
auto x =
|
||||||
|
ConstantR1<T>(&b, {
|
||||||
|
T{0},
|
||||||
|
T{100},
|
||||||
|
T{-1000},
|
||||||
|
T{std::numeric_limits<T>::max()},
|
||||||
|
T{std::numeric_limits<T>::lowest()},
|
||||||
|
T{std::numeric_limits<float>::infinity()},
|
||||||
|
T{-std::numeric_limits<float>::infinity()},
|
||||||
|
T{std::numeric_limits<float>::quiet_NaN()},
|
||||||
|
T{std::numeric_limits<float>::signaling_NaN()},
|
||||||
|
});
|
||||||
|
Tuple(&b, {IsFinite(x), IsInf(x), IsPosInf(x), IsNegInf(x), IsNan(x)});
|
||||||
|
|
||||||
|
auto expected = LiteralUtil::MakeTupleOwned(
|
||||||
|
LiteralUtil::CreateR1<bool>(
|
||||||
|
{true, true, true, true, true, false, false, false, false}),
|
||||||
|
LiteralUtil::CreateR1<bool>(
|
||||||
|
{false, false, false, false, false, true, true, false, false}),
|
||||||
|
LiteralUtil::CreateR1<bool>(
|
||||||
|
{false, false, false, false, false, true, false, false, false}),
|
||||||
|
LiteralUtil::CreateR1<bool>(
|
||||||
|
{false, false, false, false, false, false, true, false, false}),
|
||||||
|
LiteralUtil::CreateR1<bool>(
|
||||||
|
{false, false, false, false, false, false, false, true, true}));
|
||||||
|
ComputeAndCompareLiteral(&b, expected, {});
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// TODO(b/123355973): Add bfloat16 to TestTypes once it's working.
|
// TODO(b/123355973): Add bfloat16 to TestTypes once it's working.
|
||||||
@ -68,6 +101,50 @@ TYPED_TEST_CASE(MathTypedTest, TestTypes);
|
|||||||
|
|
||||||
XLA_TYPED_TEST(MathTypedTest, LogEdgeCases) { this->TestLogEdgeCases(); }
|
XLA_TYPED_TEST(MathTypedTest, LogEdgeCases) { this->TestLogEdgeCases(); }
|
||||||
XLA_TYPED_TEST(MathTypedTest, Log1pEdgeCases) { this->TestLog1pEdgeCases(); }
|
XLA_TYPED_TEST(MathTypedTest, Log1pEdgeCases) { this->TestLog1pEdgeCases(); }
|
||||||
|
XLA_TYPED_TEST(MathTypedTest, IsInfOrNan) { this->TestIsInfOrNan(); }
|
||||||
|
|
||||||
|
// Check that certain ops only support real, floating-point inputs.
|
||||||
|
//
|
||||||
|
// TODO(jlebar): Expand this test to cover more ops.
|
||||||
|
XLA_TEST_F(MathTest, RealFpOnlyOps) {
|
||||||
|
for (int64 i = PrimitiveType_MIN; i <= PrimitiveType_MAX; ++i) {
|
||||||
|
auto ty = static_cast<PrimitiveType>(i);
|
||||||
|
SCOPED_TRACE(PrimitiveType_Name(ty));
|
||||||
|
Shape shape;
|
||||||
|
if (primitive_util::IsArrayType(ty)) {
|
||||||
|
shape = ShapeUtil::MakeShape(ty, {42});
|
||||||
|
} else if (ty == PrimitiveType::TUPLE) {
|
||||||
|
shape = ShapeUtil::MakeTupleShape({});
|
||||||
|
} else if (ty == PrimitiveType::OPAQUE) {
|
||||||
|
shape = ShapeUtil::MakeOpaqueShape();
|
||||||
|
} else if (ty == PrimitiveType::TOKEN) {
|
||||||
|
shape = ShapeUtil::MakeTokenShape();
|
||||||
|
} else {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const auto& test :
|
||||||
|
std::vector<std::pair<std::function<XlaOp(XlaOp)>, string>>({
|
||||||
|
{IsFinite, "is_finite"},
|
||||||
|
{IsInf, "is_inf"},
|
||||||
|
{IsPosInf, "is_pos_inf"},
|
||||||
|
{IsNegInf, "is_neg_inf"},
|
||||||
|
{IsNan, "is_nan"},
|
||||||
|
{Erf, "erf"},
|
||||||
|
{Erfc, "erfc"},
|
||||||
|
{Lgamma, "lgamma"},
|
||||||
|
{Digamma, "digamma"},
|
||||||
|
{RoundToEven, "round_to_even"},
|
||||||
|
})) {
|
||||||
|
SCOPED_TRACE(test.second);
|
||||||
|
XlaBuilder b(TestName());
|
||||||
|
XlaOp p = Parameter(&b, 0, shape, "p0");
|
||||||
|
test.first(p);
|
||||||
|
|
||||||
|
EXPECT_EQ(b.first_error().ok(), primitive_util::IsFloatingPointType(ty));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
XLA_TEST_F(MathTest, SqrtF32) {
|
XLA_TEST_F(MathTest, SqrtF32) {
|
||||||
XlaBuilder builder(TestName());
|
XlaBuilder builder(TestName());
|
||||||
@ -145,8 +222,7 @@ XLA_TEST_F(MathTest, Lgamma) {
|
|||||||
ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
|
ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(jlebar): Fails on interpreter due to unimplemented operation.
|
XLA_TEST_F(MathTest, LgammaF16) {
|
||||||
XLA_TEST_F(MathTest, DISABLED_ON_INTERPRETER(LgammaF16)) {
|
|
||||||
SetFastMathDisabled(true);
|
SetFastMathDisabled(true);
|
||||||
|
|
||||||
XlaBuilder b(TestName());
|
XlaBuilder b(TestName());
|
||||||
|
@ -1776,10 +1776,14 @@ XlaOp Imag(const XlaOp& operand);
|
|||||||
XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
|
XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
|
||||||
absl::Span<const int64> broadcast_dimensions = {});
|
absl::Span<const int64> broadcast_dimensions = {});
|
||||||
|
|
||||||
// Enqueues an operator that tests if the operand's values are finite, i.e.,
|
// Enqueues an operator that tests if the operand's values are finite, i.e., not
|
||||||
// not Inf or NaN. Defined only for floating-point types. Returns an array of
|
// +/-Inf or NaN. Returns an array of booleans with the same shape where
|
||||||
// booleans with the same shape where entries are true iff the corresponding
|
// entries are true iff the corresponding entry was not infinite or NaN.
|
||||||
// entry was NaN.
|
//
|
||||||
|
// Defined only for real-valued (i.e. not complex) floating-point types; raises
|
||||||
|
// an error for other types.
|
||||||
|
//
|
||||||
|
// See also IsInf, IsPosInf, IsNegInf, and IsNan in lib/math.h.
|
||||||
XlaOp IsFinite(const XlaOp& operand);
|
XlaOp IsFinite(const XlaOp& operand);
|
||||||
|
|
||||||
// Enqueues an iota operation onto the computation.
|
// Enqueues an iota operation onto the computation.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user