[XLA] Make kSign return NaN for NaN inputs

Also correct tf.math.sign for NaN inputs, we returned +1.0 instead of 0.0.

PiperOrigin-RevId: 232523293
This commit is contained in:
David Majnemer 2019-02-05 11:21:38 -08:00 committed by TensorFlower Gardener
parent dfcd531df0
commit 87826a32bb
7 changed files with 49 additions and 13 deletions

View File

@ -391,6 +391,11 @@ class UnaryOpsTest(xla_test.XLATestCase):
expected=np.array(
[[-0.66666669, -0.5, 0, 0.5, 0.66666669]], dtype=dtype))
self._assertOpOutputMatchesExpected(
math_ops.sign,
np.array([[-2.0, -1.0, -0.0, +0.0, 1.0, 2.0]], dtype=dtype),
expected=np.array([[-1.0, -1.0, -0.0, +0.0, 1.0, 1.0]], dtype=dtype))
self._assertOpOutputMatchesExpected(
math_ops.is_finite,
np.array(
@ -743,6 +748,10 @@ class UnaryOpsTest(xla_test.XLATestCase):
np.array(
[[np.NINF, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]], dtype=dtype),
expected=np.array([[0, 0, 0, 0, 0, 0, 0, 0, 1]], dtype=np.bool))
self._assertOpOutputMatchesExpected(
math_ops.sign,
np.array([[np.nan]], dtype=dtype),
expected=np.array([[0.0]], dtype=dtype))
def testLogicalOps(self):
self._assertOpOutputMatchesExpected(

View File

@ -89,8 +89,9 @@ xla::XlaOp Sigmoid(xla::XlaOp x) {
}
XLAJIT_MAKE_UNARY(Sigmoid, Sigmoid(x));
// Returns 0 if x is 0, -1 if x < 0 and 1 if x > 0.
XLAJIT_MAKE_UNARY(Sign, xla::Sign(x));
// Returns 0 if x is NaN, 0 if x is 0, -1 if x < 0 and 1 if x > 0.
XLAJIT_MAKE_UNARY(Sign,
xla::Select(xla::Ne(x, x), xla::ZerosLike(x), xla::Sign(x)));
XLAJIT_MAKE_UNARY(Sinh, xla::Sinh(x));
// softplus(x) = log(1 + exp(x))

View File

@ -1186,7 +1186,7 @@ if and only if the corresponding input element is finite.
<b>`Sign(operand)`</b> Element-wise sign operation `x -> sgn(x)` where
$$\text{sgn}(x) = \begin{cases} -1 & x < 0\\ 0 & x = 0\\ 1 & x > 0 \end{cases}$$
$$\text{sgn}(x) = \begin{cases} -1 & x < 0\\ -0 & x = -0\\ NaN & x = NaN\\ +0 & x = +0\\ 1 & x > 0 \end{cases}$$
using the comparison operator of the element type of `operand`.

View File

@ -440,14 +440,16 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
{operand_value},
{operand_value->getType()}, b_);
case HloOpcode::kSign: {
// TODO(b/32151903): Ensure consistent sign behavior for -0.0.
auto type = operand_value->getType();
auto zero = llvm::ConstantFP::get(type, 0.0);
auto oeq = FCmpOEQ(operand_value, zero);
auto olt = FCmpOLT(operand_value, zero);
return Select(oeq, zero,
Select(olt, llvm::ConstantFP::get(type, -1.0),
llvm::ConstantFP::get(type, 1.0)));
auto ne0_i1 = FCmpONE(operand_value, zero);
auto ne0_float = UIToFP(ne0_i1, type);
llvm::Value* result = llvm_ir::EmitCallToIntrinsic(
llvm::Intrinsic::copysign, {ne0_float, operand_value},
{operand_value->getType()}, b_);
auto is_nan = FCmpUNO(operand_value, operand_value);
result = Select(is_nan, operand_value, result);
return result;
}
case HloOpcode::kIsFinite: {
// abs(x) o!= inf, this works because the comparison returns false if

View File

@ -462,9 +462,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return HandleNegate<ReturnT>(negate);
}
template <
typename NativeT,
typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
template <typename NativeT,
typename std::enable_if<std::is_integral<NativeT>::value>::type* =
nullptr>
Status HandleSign(HloInstruction* sign) {
TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign],
ElementWiseUnaryOp(sign, [](ElementwiseT elem_operand) {
@ -474,6 +474,23 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return Status::OK();
}
template <typename NativeT,
typename std::enable_if<
std::is_same<NativeT, bfloat16>::value ||
std::is_same<NativeT, Eigen::half>::value ||
std::is_floating_point<NativeT>::value>::type* = nullptr>
Status HandleSign(HloInstruction* sign) {
TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign],
ElementWiseUnaryOp(sign, [](ElementwiseT elem_operand) {
return std::isnan(elem_operand)
? elem_operand
: std::copysign(
elem_operand != ElementwiseT(0),
elem_operand);
}));
return Status::OK();
}
template <
typename NativeT,
typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>

View File

@ -269,6 +269,11 @@ class IrBuilderMixin {
return mixin_builder()->CreateFCmpUNE(std::forward<Args>(args)...);
}
template <class... Args>
llvm::Value* FCmpUNO(Args&&... args) {
return mixin_builder()->CreateFCmpUNO(std::forward<Args>(args)...);
}
template <class... Args>
llvm::Value* FDiv(Args&&... args) {
return mixin_builder()->CreateFDiv(std::forward<Args>(args)...);

View File

@ -64,7 +64,9 @@ class UnaryOpTest : public ClientLibraryTestBase {
&builder, {-2, 25, 0, static_cast<T>(-0.0), -123, inf<T>(), -inf<T>()});
Sign(arg);
ComputeAndCompareR1<T>(&builder, {-1, 1, 0, 0, -1, 1, -1}, {});
ComputeAndCompareR1<T>(
&builder,
{-1, 1, static_cast<T>(+0.0), static_cast<T>(-0.0), -1, 1, -1}, {});
}
template <typename T>