diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc index dc755b1cb26..94f9a9358dc 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc @@ -90,6 +90,8 @@ Value MaterializePolynomialApproximation( Value MaterializeErfApproximationF32(ConversionPatternRewriter &rewriter, Location loc, Value operand) { + assert(operand.getType().cast<RankedTensorType>().getElementType().isF32() && + "expect f32 element type"); const std::vector<float> kAlpha{ -2.72614225801306e-10f, 2.77068142495902e-08f, -2.10102402082508e-06f, -5.69250639462346e-05f, -7.34990630326855e-04f, -2.95459980854025e-03f, @@ -121,14 +123,28 @@ struct ConvertErfOp : public OpConversionPattern<ErfOp> { LogicalResult matchAndRewrite( ErfOp op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const override { - Type ty = getElementTypeOrSelf(op.getType()); - - // For now, we support only f32. - if (!ty.isF32()) return failure(); - + Location loc = op.getLoc(); ErfOp::Adaptor transformed(operands); - rewriter.replaceOp(op, MaterializeErfApproximationF32( - rewriter, op.getLoc(), transformed.operand())); + Value x = transformed.operand(); + Type ty = x.getType().cast<RankedTensorType>().getElementType(); + + // For now, we support only f32 and f16. + if (!ty.isF32() && !ty.isF16()) return failure(); + + // Cast argument to f32 tensor if needed. + assert((ty.isF16() || ty.isF32()) && "expect f16 or f32 at this point"); + if (ty.isF16()) { + x = rewriter.create<mhlo::ConvertOp>(loc, x, rewriter.getF32Type()); + } + + Value result = MaterializeErfApproximationF32(rewriter, loc, x); + + // Cast back if needed. + if (ty.isF16()) { + result = rewriter.create<mhlo::ConvertOp>(loc, result, ty); + } + + rewriter.replaceOp(op, result); return success(); } }; diff --git a/tensorflow/compiler/mlir/hlo/tests/chlo_legalize_to_mhlo.mlir b/tensorflow/compiler/mlir/hlo/tests/chlo_legalize_to_mhlo.mlir index fbba7617aae..fcc8bf6f6fc 100644 --- a/tensorflow/compiler/mlir/hlo/tests/chlo_legalize_to_mhlo.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/chlo_legalize_to_mhlo.mlir @@ -86,3 +86,13 @@ func @erf_f32(%arg : tensor<f32>) -> tensor<f32> { %1 = "chlo.erf"(%arg) : (tensor<f32>) -> tensor<f32> return %1 : tensor<f32> } + +// CHECK-LABEL: @erf_f16 +// CHECK-SAME: %[[ARG:.*]]: tensor<f16> +func @erf_f16(%arg : tensor<f16>) -> tensor<f16> { + // CHECK: "mhlo.convert"(%[[ARG]]) : (tensor<f16>) -> tensor<f32> + // CHECK: %[[RESULT:.*]] = "mhlo.convert"(%{{.*}}) : (tensor<f32>) -> tensor<f16> + // CHECK: return %[[RESULT]] + %1 = "chlo.erf"(%arg) : (tensor<f16>) -> tensor<f16> + return %1 : tensor<f16> +} diff --git a/tensorflow/core/kernels/cwise_op_erf.cc b/tensorflow/core/kernels/cwise_op_erf.cc index c1cb53a4757..bb3dc35647b 100644 --- a/tensorflow/core/kernels/cwise_op_erf.cc +++ b/tensorflow/core/kernels/cwise_op_erf.cc @@ -21,11 +21,11 @@ REGISTER3(UnaryOp, CPU, "Erf", functor::erf, float, Eigen::half, double); #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -REGISTER2(UnaryOp, GPU, "Erf", functor::erf, Eigen::half, double); +REGISTER(UnaryOp, GPU, "Erf", functor::erf, double); #if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \ !defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED) -REGISTER(UnaryOp, GPU, "Erf", functor::erf, float); +REGISTER2(UnaryOp, GPU, "Erf", functor::erf, Eigen::half, float); #endif #endif diff --git a/tensorflow/core/kernels/mlir_generated/BUILD b/tensorflow/core/kernels/mlir_generated/BUILD index e78e00cc305..a6eea696995 100644 --- a/tensorflow/core/kernels/mlir_generated/BUILD +++ b/tensorflow/core/kernels/mlir_generated/BUILD @@ -385,6 +385,7 @@ gen_kernel_library( name = "erf", tile_size = "256", types = [ + "f16", "f32", ], unroll_factors = "4", diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_erf.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_erf.cc index b3c577b16bb..b1982f5ce90 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_erf.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_erf.cc @@ -18,6 +18,7 @@ limitations under the License. namespace tensorflow { +GENERATE_AND_REGISTER_UNARY_KERNEL(Erf, f16, DT_HALF, Eigen::half); GENERATE_AND_REGISTER_UNARY_KERNEL(Erf, f32, DT_FLOAT, float); } // namespace tensorflow