Fix SignOp lowering for floating point values.
It didn't return 0 for 0.0 and -0.0. Currently we emit -0.0 for -0.0 which is correct according to the HLO dialect. For the TF_SignOp we should emit 0.0 in that case, we will leave that as a TODO. Enable the tests which work now, and add another one for Int64. Also improve the registration code, we should not register the Int32 kernel. PiperOrigin-RevId: 347602378 Change-Id: Id6d3f545d197a1fc9cd7d3d5a2afe7a774e149b5
This commit is contained in:
parent
53bf046ac3
commit
282c83de94
@ -539,22 +539,14 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::SignOp>(Location loc,
|
||||
Type element_type = getElementTypeOrSelf(args.front().getType());
|
||||
if (auto float_type = element_type.dyn_cast<FloatType>()) {
|
||||
bool ignored;
|
||||
APFloat zero_apfloat(0.0f);
|
||||
zero_apfloat.convert(float_type.getFloatSemantics(),
|
||||
APFloat::rmNearestTiesToEven, &ignored);
|
||||
Value zero =
|
||||
b->create<mlir::ConstantFloatOp>(loc, zero_apfloat, float_type);
|
||||
APFloat one_apfloat(1.0f);
|
||||
one_apfloat.convert(float_type.getFloatSemantics(),
|
||||
APFloat::rmNearestTiesToEven, &ignored);
|
||||
Value one = b->create<mlir::ConstantFloatOp>(loc, one_apfloat, float_type);
|
||||
if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) {
|
||||
zero = b->create<::mlir::SplatOp>(loc, vec_type, zero);
|
||||
one = b->create<::mlir::SplatOp>(loc, vec_type, one);
|
||||
}
|
||||
Value ne0_i1 =
|
||||
b->create<::mlir::CmpFOp>(loc, CmpFPredicate::ONE, args[0], zero);
|
||||
Value ne0_float = b->create<::mlir::UIToFPOp>(loc, ne0_i1, float_type);
|
||||
Value copy_sign =
|
||||
b->create<::mlir::CopySignOp>(loc, result_types, ne0_float, args[0]);
|
||||
auto is_nan =
|
||||
b->create<::mlir::CmpFOp>(loc, CmpFPredicate::UNO, args[0], args[0]);
|
||||
return b->create<::mlir::SelectOp>(loc, is_nan, args[0], copy_sign);
|
||||
return b->create<::mlir::CopySignOp>(loc, result_types, one, args[0]);
|
||||
} else if (auto integer_type = element_type.dyn_cast<IntegerType>()) {
|
||||
// sign(x) = x == 0 ? 0 : ((x s>> 31) | 1)
|
||||
Value zero =
|
||||
|
@ -594,12 +594,8 @@ func @sign(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]):
|
||||
// CHECK-NEXT: %[[CST_0:.*]] = constant 0.000000e+00 : f32
|
||||
// CHECK-NEXT: %[[NE_0:.*]] = cmpf "one", %[[OPERAND_IN]], %[[CST_0]] : f32
|
||||
// CHECK-NEXT: %[[NE_0_FLOAT:.*]] = uitofp %[[NE_0]] : i1 to f32
|
||||
// CHECK-NEXT: %[[SIGN:.*]] = copysign %[[NE_0_FLOAT]], %[[OPERAND_IN]] : f32
|
||||
// CHECK-NEXT: %[[CMP:.*]] = cmpf "uno", %[[OPERAND_IN]], %[[OPERAND_IN]] : f32
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[OPERAND_IN]], %[[SIGN]] : f32
|
||||
// CHECK-NEXT: %[[CST:.*]] = constant 1.000000e+00 : f32
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = copysign %[[CST]], %[[OPERAND_IN]] : f32
|
||||
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
|
||||
|
||||
// -----
|
||||
@ -611,12 +607,8 @@ func @sign_bf16(%input: memref<2x2xbf16>, %result: memref<2x2xbf16>) {
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: bf16, %[[RESULT_OUT:.*]]):
|
||||
// CHECK-NEXT: %[[CST_0:.*]] = constant 0.000000e+00 : bf16
|
||||
// CHECK-NEXT: %[[NE_0:.*]] = cmpf "one", %[[OPERAND_IN]], %[[CST_0]] : bf16
|
||||
// CHECK-NEXT: %[[NE_0_FLOAT:.*]] = uitofp %[[NE_0]] : i1 to bf16
|
||||
// CHECK-NEXT: %[[SIGN:.*]] = copysign %[[NE_0_FLOAT]], %[[OPERAND_IN]] : bf16
|
||||
// CHECK-NEXT: %[[CMP:.*]] = cmpf "uno", %[[OPERAND_IN]], %[[OPERAND_IN]] : bf16
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[OPERAND_IN]], %[[SIGN]] : bf16
|
||||
// CHECK-NEXT: %[[CST:.*]] = constant 1.000000e+00 : bf16
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = copysign %[[CST]], %[[OPERAND_IN]] : bf16
|
||||
// CHECK-NEXT: linalg.yield %[[RESULT]] : bf16
|
||||
|
||||
// -----
|
||||
|
@ -23,9 +23,6 @@ REGISTER8(UnaryOp, CPU, "Sign", functor::sign, float, double, int32, int64,
|
||||
!defined(MLIR_GENERATED_UNRANKED_GPU_KERNELS_ENABLED)
|
||||
REGISTER6(UnaryOp, GPU, "Sign", functor::sign, float, Eigen::half, double,
|
||||
int64, complex64, complex128);
|
||||
#else
|
||||
REGISTER2(UnaryOp, GPU, "Sign", functor::sign, complex64, complex128);
|
||||
#endif
|
||||
|
||||
// A special GPU kernel for int32.
|
||||
// TODO(b/25387198): Also enable int32 in device memory. This kernel
|
||||
@ -37,5 +34,6 @@ REGISTER_KERNEL_BUILDER(Name("Sign")
|
||||
.TypeConstraint<int32>("T"),
|
||||
UnaryOp<CPUDevice, functor::sign<int32>>);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -515,34 +515,27 @@ T expected_sign(T x) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
TEST_F(GpuUnaryOpTest, SignFloat) {
|
||||
// TODO(b/162577610): Enable these tests when our generated kernels handle 0.0
|
||||
// and -0.0 correctly.
|
||||
TEST_F(GpuUnaryOpTest, DISABLED_SignFloat) {
|
||||
Run<float>(DefaultInputShape(), DefaultInput<float>(),
|
||||
/*op_name=*/"Sign",
|
||||
/*expected_callback=*/expected_sign,
|
||||
/*expect_equal=*/true);
|
||||
}
|
||||
|
||||
TEST_F(GpuUnaryOpTest, SignDouble) {
|
||||
TEST_F(GpuUnaryOpTest, DISABLED_SignDouble) {
|
||||
Run<double>(DefaultInputShape(), DefaultInput<double>(),
|
||||
/*op_name=*/"Sign",
|
||||
/*expected_callback=*/expected_sign,
|
||||
/*expect_equal=*/true);
|
||||
}
|
||||
|
||||
TEST_F(GpuUnaryOpTest, SignHalf) {
|
||||
TEST_F(GpuUnaryOpTest, DISABLED_SignHalf) {
|
||||
Run<Eigen::half, float>(DefaultInputShape(), DefaultInput<Eigen::half>(),
|
||||
/*op_name=*/"Sign",
|
||||
/*expected_callback=*/expected_sign,
|
||||
// TODO(b/162577610): We should actually use true
|
||||
// here. This requires returning 0.0 for input -0.0.
|
||||
/*expect_equal=*/false);
|
||||
}
|
||||
|
||||
TEST_F(GpuUnaryOpTest, SignInt64) {
|
||||
Run<int64>(DefaultInputShape(), DefaultInput<int64>(),
|
||||
/*op_name=*/"Sign",
|
||||
/*expected_callback=*/expected_sign,
|
||||
/*expect_equal=*/true);
|
||||
/*expect_equal=*/true);
|
||||
}
|
||||
|
||||
/// Test `tf.Sin`.
|
||||
|
@ -21,7 +21,7 @@ namespace tensorflow {
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Sign, f16, DT_HALF, Eigen::half);
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Sign, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Sign, f64, DT_DOUBLE, double);
|
||||
// TODO(b/25387198): We cannot use a regular GPU kernel for int32.
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Sign, i32, DT_INT32, int32);
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Sign, i64, DT_INT64, int64);
|
||||
// TODO(b/162577610): Register the kernel for complex types and bfloat.
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user