From 87826a32bb0167d7e7437d37f9c4fe94a295b08e Mon Sep 17 00:00:00 2001 From: David Majnemer Date: Tue, 5 Feb 2019 11:21:38 -0800 Subject: [PATCH] [XLA] Make kSign return NaN for NaN inputs Also correct tf.math.sign for NaN inputs, we returned +1.0 instead of 0.0. PiperOrigin-RevId: 232523293 --- tensorflow/compiler/tests/unary_ops_test.py | 9 ++++++++ .../compiler/tf2xla/kernels/unary_ops.cc | 5 ++-- .../compiler/xla/g3doc/operation_semantics.md | 2 +- .../xla/service/elemental_ir_emitter.cc | 14 ++++++----- .../xla/service/hlo_evaluator_typed_visitor.h | 23 ++++++++++++++++--- .../xla/service/llvm_ir/ir_builder_mixin.h | 5 ++++ .../compiler/xla/tests/unary_op_test.cc | 4 +++- 7 files changed, 49 insertions(+), 13 deletions(-) diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index 083e2e58ae0..978ed667ee6 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -391,6 +391,11 @@ class UnaryOpsTest(xla_test.XLATestCase): expected=np.array( [[-0.66666669, -0.5, 0, 0.5, 0.66666669]], dtype=dtype)) + self._assertOpOutputMatchesExpected( + math_ops.sign, + np.array([[-2.0, -1.0, -0.0, +0.0, 1.0, 2.0]], dtype=dtype), + expected=np.array([[-1.0, -1.0, -0.0, +0.0, 1.0, 1.0]], dtype=dtype)) + self._assertOpOutputMatchesExpected( math_ops.is_finite, np.array( @@ -743,6 +748,10 @@ class UnaryOpsTest(xla_test.XLATestCase): np.array( [[np.NINF, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]], dtype=dtype), expected=np.array([[0, 0, 0, 0, 0, 0, 0, 0, 1]], dtype=np.bool)) + self._assertOpOutputMatchesExpected( + math_ops.sign, + np.array([[np.nan]], dtype=dtype), + expected=np.array([[0.0]], dtype=dtype)) def testLogicalOps(self): self._assertOpOutputMatchesExpected( diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc index 4544e034914..62b5cd32da5 100644 --- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc @@ -89,8 +89,9 @@ xla::XlaOp Sigmoid(xla::XlaOp x) { } XLAJIT_MAKE_UNARY(Sigmoid, Sigmoid(x)); -// Returns 0 if x is 0, -1 if x < 0 and 1 if x > 0. -XLAJIT_MAKE_UNARY(Sign, xla::Sign(x)); +// Returns 0 if x is NaN, 0 if x is 0, -1 if x < 0 and 1 if x > 0. +XLAJIT_MAKE_UNARY(Sign, + xla::Select(xla::Ne(x, x), xla::ZerosLike(x), xla::Sign(x))); XLAJIT_MAKE_UNARY(Sinh, xla::Sinh(x)); // softplus(x) = log(1 + exp(x)) diff --git a/tensorflow/compiler/xla/g3doc/operation_semantics.md b/tensorflow/compiler/xla/g3doc/operation_semantics.md index 363fd17b69b..db90d184b52 100644 --- a/tensorflow/compiler/xla/g3doc/operation_semantics.md +++ b/tensorflow/compiler/xla/g3doc/operation_semantics.md @@ -1186,7 +1186,7 @@ if and only if the corresponding input element is finite. `Sign(operand)` Element-wise sign operation `x -> sgn(x)` where -$$\text{sgn}(x) = \begin{cases} -1 & x < 0\\ 0 & x = 0\\ 1 & x > 0 \end{cases}$$ +$$\text{sgn}(x) = \begin{cases} -1 & x < 0\\ -0 & x = -0\\ NaN & x = NaN\\ +0 & x = +0\\ 1 & x > 0 \end{cases}$$ using the comparison operator of the element type of `operand`. diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index ccdc12a30af..fef84ac80a0 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -440,14 +440,16 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( {operand_value}, {operand_value->getType()}, b_); case HloOpcode::kSign: { - // TODO(b/32151903): Ensure consistent sign behavior for -0.0. auto type = operand_value->getType(); auto zero = llvm::ConstantFP::get(type, 0.0); - auto oeq = FCmpOEQ(operand_value, zero); - auto olt = FCmpOLT(operand_value, zero); - return Select(oeq, zero, - Select(olt, llvm::ConstantFP::get(type, -1.0), - llvm::ConstantFP::get(type, 1.0))); + auto ne0_i1 = FCmpONE(operand_value, zero); + auto ne0_float = UIToFP(ne0_i1, type); + llvm::Value* result = llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::copysign, {ne0_float, operand_value}, + {operand_value->getType()}, b_); + auto is_nan = FCmpUNO(operand_value, operand_value); + result = Select(is_nan, operand_value, result); + return result; } case HloOpcode::kIsFinite: { // abs(x) o!= inf, this works because the comparison returns false if diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index 742a389ed04..652042e85f0 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -462,9 +462,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return HandleNegate(negate); } - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> + template ::value>::type* = + nullptr> Status HandleSign(HloInstruction* sign) { TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign], ElementWiseUnaryOp(sign, [](ElementwiseT elem_operand) { @@ -474,6 +474,23 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } + template ::value || + std::is_same::value || + std::is_floating_point::value>::type* = nullptr> + Status HandleSign(HloInstruction* sign) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign], + ElementWiseUnaryOp(sign, [](ElementwiseT elem_operand) { + return std::isnan(elem_operand) + ? elem_operand + : std::copysign( + elem_operand != ElementwiseT(0), + elem_operand); + })); + return Status::OK(); + } + template < typename NativeT, typename std::enable_if::value>::type* = nullptr> diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h b/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h index cf5083e8c13..02c719502ee 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h @@ -269,6 +269,11 @@ class IrBuilderMixin { return mixin_builder()->CreateFCmpUNE(std::forward(args)...); } + template + llvm::Value* FCmpUNO(Args&&... args) { + return mixin_builder()->CreateFCmpUNO(std::forward(args)...); + } + template llvm::Value* FDiv(Args&&... args) { return mixin_builder()->CreateFDiv(std::forward(args)...); diff --git a/tensorflow/compiler/xla/tests/unary_op_test.cc b/tensorflow/compiler/xla/tests/unary_op_test.cc index 4fbd7f2fb17..c51f30f3b5d 100644 --- a/tensorflow/compiler/xla/tests/unary_op_test.cc +++ b/tensorflow/compiler/xla/tests/unary_op_test.cc @@ -64,7 +64,9 @@ class UnaryOpTest : public ClientLibraryTestBase { &builder, {-2, 25, 0, static_cast(-0.0), -123, inf(), -inf()}); Sign(arg); - ComputeAndCompareR1(&builder, {-1, 1, 0, 0, -1, 1, -1}, {}); + ComputeAndCompareR1( + &builder, + {-1, 1, static_cast(+0.0), static_cast(-0.0), -1, 1, -1}, {}); } template