[XLA] Make kSign return NaN for NaN inputs
Also correct tf.math.sign for NaN inputs, we returned +1.0 instead of 0.0. PiperOrigin-RevId: 232523293
This commit is contained in:
parent
dfcd531df0
commit
87826a32bb
@ -391,6 +391,11 @@ class UnaryOpsTest(xla_test.XLATestCase):
|
||||
expected=np.array(
|
||||
[[-0.66666669, -0.5, 0, 0.5, 0.66666669]], dtype=dtype))
|
||||
|
||||
self._assertOpOutputMatchesExpected(
|
||||
math_ops.sign,
|
||||
np.array([[-2.0, -1.0, -0.0, +0.0, 1.0, 2.0]], dtype=dtype),
|
||||
expected=np.array([[-1.0, -1.0, -0.0, +0.0, 1.0, 1.0]], dtype=dtype))
|
||||
|
||||
self._assertOpOutputMatchesExpected(
|
||||
math_ops.is_finite,
|
||||
np.array(
|
||||
@ -743,6 +748,10 @@ class UnaryOpsTest(xla_test.XLATestCase):
|
||||
np.array(
|
||||
[[np.NINF, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]], dtype=dtype),
|
||||
expected=np.array([[0, 0, 0, 0, 0, 0, 0, 0, 1]], dtype=np.bool))
|
||||
self._assertOpOutputMatchesExpected(
|
||||
math_ops.sign,
|
||||
np.array([[np.nan]], dtype=dtype),
|
||||
expected=np.array([[0.0]], dtype=dtype))
|
||||
|
||||
def testLogicalOps(self):
|
||||
self._assertOpOutputMatchesExpected(
|
||||
|
@ -89,8 +89,9 @@ xla::XlaOp Sigmoid(xla::XlaOp x) {
|
||||
}
|
||||
XLAJIT_MAKE_UNARY(Sigmoid, Sigmoid(x));
|
||||
|
||||
// Returns 0 if x is 0, -1 if x < 0 and 1 if x > 0.
|
||||
XLAJIT_MAKE_UNARY(Sign, xla::Sign(x));
|
||||
// Returns 0 if x is NaN, 0 if x is 0, -1 if x < 0 and 1 if x > 0.
|
||||
XLAJIT_MAKE_UNARY(Sign,
|
||||
xla::Select(xla::Ne(x, x), xla::ZerosLike(x), xla::Sign(x)));
|
||||
XLAJIT_MAKE_UNARY(Sinh, xla::Sinh(x));
|
||||
|
||||
// softplus(x) = log(1 + exp(x))
|
||||
|
@ -1186,7 +1186,7 @@ if and only if the corresponding input element is finite.
|
||||
|
||||
<b>`Sign(operand)`</b> Element-wise sign operation `x -> sgn(x)` where
|
||||
|
||||
$$\text{sgn}(x) = \begin{cases} -1 & x < 0\\ 0 & x = 0\\ 1 & x > 0 \end{cases}$$
|
||||
$$\text{sgn}(x) = \begin{cases} -1 & x < 0\\ -0 & x = -0\\ NaN & x = NaN\\ +0 & x = +0\\ 1 & x > 0 \end{cases}$$
|
||||
|
||||
using the comparison operator of the element type of `operand`.
|
||||
|
||||
|
@ -440,14 +440,16 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
|
||||
{operand_value},
|
||||
{operand_value->getType()}, b_);
|
||||
case HloOpcode::kSign: {
|
||||
// TODO(b/32151903): Ensure consistent sign behavior for -0.0.
|
||||
auto type = operand_value->getType();
|
||||
auto zero = llvm::ConstantFP::get(type, 0.0);
|
||||
auto oeq = FCmpOEQ(operand_value, zero);
|
||||
auto olt = FCmpOLT(operand_value, zero);
|
||||
return Select(oeq, zero,
|
||||
Select(olt, llvm::ConstantFP::get(type, -1.0),
|
||||
llvm::ConstantFP::get(type, 1.0)));
|
||||
auto ne0_i1 = FCmpONE(operand_value, zero);
|
||||
auto ne0_float = UIToFP(ne0_i1, type);
|
||||
llvm::Value* result = llvm_ir::EmitCallToIntrinsic(
|
||||
llvm::Intrinsic::copysign, {ne0_float, operand_value},
|
||||
{operand_value->getType()}, b_);
|
||||
auto is_nan = FCmpUNO(operand_value, operand_value);
|
||||
result = Select(is_nan, operand_value, result);
|
||||
return result;
|
||||
}
|
||||
case HloOpcode::kIsFinite: {
|
||||
// abs(x) o!= inf, this works because the comparison returns false if
|
||||
|
@ -462,9 +462,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
return HandleNegate<ReturnT>(negate);
|
||||
}
|
||||
|
||||
template <
|
||||
typename NativeT,
|
||||
typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
|
||||
template <typename NativeT,
|
||||
typename std::enable_if<std::is_integral<NativeT>::value>::type* =
|
||||
nullptr>
|
||||
Status HandleSign(HloInstruction* sign) {
|
||||
TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign],
|
||||
ElementWiseUnaryOp(sign, [](ElementwiseT elem_operand) {
|
||||
@ -474,6 +474,23 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename NativeT,
|
||||
typename std::enable_if<
|
||||
std::is_same<NativeT, bfloat16>::value ||
|
||||
std::is_same<NativeT, Eigen::half>::value ||
|
||||
std::is_floating_point<NativeT>::value>::type* = nullptr>
|
||||
Status HandleSign(HloInstruction* sign) {
|
||||
TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign],
|
||||
ElementWiseUnaryOp(sign, [](ElementwiseT elem_operand) {
|
||||
return std::isnan(elem_operand)
|
||||
? elem_operand
|
||||
: std::copysign(
|
||||
elem_operand != ElementwiseT(0),
|
||||
elem_operand);
|
||||
}));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <
|
||||
typename NativeT,
|
||||
typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
|
||||
|
@ -269,6 +269,11 @@ class IrBuilderMixin {
|
||||
return mixin_builder()->CreateFCmpUNE(std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
template <class... Args>
|
||||
llvm::Value* FCmpUNO(Args&&... args) {
|
||||
return mixin_builder()->CreateFCmpUNO(std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
template <class... Args>
|
||||
llvm::Value* FDiv(Args&&... args) {
|
||||
return mixin_builder()->CreateFDiv(std::forward<Args>(args)...);
|
||||
|
@ -64,7 +64,9 @@ class UnaryOpTest : public ClientLibraryTestBase {
|
||||
&builder, {-2, 25, 0, static_cast<T>(-0.0), -123, inf<T>(), -inf<T>()});
|
||||
Sign(arg);
|
||||
|
||||
ComputeAndCompareR1<T>(&builder, {-1, 1, 0, 0, -1, 1, -1}, {});
|
||||
ComputeAndCompareR1<T>(
|
||||
&builder,
|
||||
{-1, 1, static_cast<T>(+0.0), static_cast<T>(-0.0), -1, 1, -1}, {});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
Loading…
Reference in New Issue
Block a user