[MLIR][KernelGen] Add erf kernel and missing lowering for f16 type

PiperOrigin-RevId: 352416184
Change-Id: I6c49a6e0da11d74380aacc8612e2588308672bb9
This commit is contained in:
A. Unique TensorFlower 2021-01-18 08:19:56 -08:00 committed by TensorFlower Gardener
parent 64ef722547
commit 9cb96ea8fd
5 changed files with 37 additions and 9 deletions

View File

@ -90,6 +90,8 @@ Value MaterializePolynomialApproximation(
Value MaterializeErfApproximationF32(ConversionPatternRewriter &rewriter,
Location loc, Value operand) {
assert(operand.getType().cast<RankedTensorType>().getElementType().isF32() &&
"expect f32 element type");
const std::vector<float> kAlpha{
-2.72614225801306e-10f, 2.77068142495902e-08f, -2.10102402082508e-06f,
-5.69250639462346e-05f, -7.34990630326855e-04f, -2.95459980854025e-03f,
@ -121,14 +123,28 @@ struct ConvertErfOp : public OpConversionPattern<ErfOp> {
LogicalResult matchAndRewrite(
ErfOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Type ty = getElementTypeOrSelf(op.getType());
// For now, we support only f32.
if (!ty.isF32()) return failure();
Location loc = op.getLoc();
ErfOp::Adaptor transformed(operands);
rewriter.replaceOp(op, MaterializeErfApproximationF32(
rewriter, op.getLoc(), transformed.operand()));
Value x = transformed.operand();
Type ty = x.getType().cast<RankedTensorType>().getElementType();
// For now, we support only f32 and f16.
if (!ty.isF32() && !ty.isF16()) return failure();
// Cast argument to f32 tensor if needed.
assert((ty.isF16() || ty.isF32()) && "expect f16 or f32 at this point");
if (ty.isF16()) {
x = rewriter.create<mhlo::ConvertOp>(loc, x, rewriter.getF32Type());
}
Value result = MaterializeErfApproximationF32(rewriter, loc, x);
// Cast back if needed.
if (ty.isF16()) {
result = rewriter.create<mhlo::ConvertOp>(loc, result, ty);
}
rewriter.replaceOp(op, result);
return success();
}
};

View File

@ -86,3 +86,13 @@ func @erf_f32(%arg : tensor<f32>) -> tensor<f32> {
%1 = "chlo.erf"(%arg) : (tensor<f32>) -> tensor<f32>
return %1 : tensor<f32>
}
// CHECK-LABEL: @erf_f16
// CHECK-SAME: %[[ARG:.*]]: tensor<f16>
func @erf_f16(%arg : tensor<f16>) -> tensor<f16> {
// CHECK: "mhlo.convert"(%[[ARG]]) : (tensor<f16>) -> tensor<f32>
// CHECK: %[[RESULT:.*]] = "mhlo.convert"(%{{.*}}) : (tensor<f32>) -> tensor<f16>
// CHECK: return %[[RESULT]]
%1 = "chlo.erf"(%arg) : (tensor<f16>) -> tensor<f16>
return %1 : tensor<f16>
}

View File

@ -21,11 +21,11 @@ REGISTER3(UnaryOp, CPU, "Erf", functor::erf, float, Eigen::half, double);
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
REGISTER2(UnaryOp, GPU, "Erf", functor::erf, Eigen::half, double);
REGISTER(UnaryOp, GPU, "Erf", functor::erf, double);
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
!defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
REGISTER(UnaryOp, GPU, "Erf", functor::erf, float);
REGISTER2(UnaryOp, GPU, "Erf", functor::erf, Eigen::half, float);
#endif
#endif

View File

@ -385,6 +385,7 @@ gen_kernel_library(
name = "erf",
tile_size = "256",
types = [
"f16",
"f32",
],
unroll_factors = "4",

View File

@ -18,6 +18,7 @@ limitations under the License.
namespace tensorflow {
GENERATE_AND_REGISTER_UNARY_KERNEL(Erf, f16, DT_HALF, Eigen::half);
GENERATE_AND_REGISTER_UNARY_KERNEL(Erf, f32, DT_FLOAT, float);
} // namespace tensorflow