From b0199879f09e4bd0068e627e1a43384dcff34983 Mon Sep 17 00:00:00 2001
From: Stephan Herhut <herhut@google.com>
Date: Wed, 17 Jun 2020 01:26:51 -0700
Subject: [PATCH] Add lowering for TanhOp that uses an approximation instead of
 lowering to intrinsics.

The same approximation is used by the XLA compiler.

PiperOrigin-RevId: 316844625
Change-Id: I1a909bd063509491a6a58ae0acc3bfb919cb34d5
---
 tensorflow/compiler/mlir/xla/BUILD            |  19 ++
 .../tests/legalize_tanh_to_approximation.mlir | 134 ++++++++++++++
 .../legalize_tanh_to_approximation.cc         | 167 ++++++++++++++++++
 .../compiler/mlir/xla/transforms/passes.h     |   7 +
 .../compiler/mlir/xla/transforms/rewriters.h  |   8 +
 5 files changed, 335 insertions(+)
 create mode 100644 tensorflow/compiler/mlir/xla/tests/legalize_tanh_to_approximation.mlir
 create mode 100644 tensorflow/compiler/mlir/xla/transforms/legalize_tanh_to_approximation.cc

diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD
index 43458aab2d3..d089f80d571 100644
--- a/tensorflow/compiler/mlir/xla/BUILD
+++ b/tensorflow/compiler/mlir/xla/BUILD
@@ -515,6 +515,24 @@ cc_library(
     alwayslink = 1,
 )
 
+cc_library(
+    name = "xla_legalize_tanh_to_approximation",
+    srcs = ["transforms/legalize_tanh_to_approximation.cc"],
+    hdrs = [
+        "transforms/passes.h",
+        "transforms/rewriters.h",
+    ],
+    deps = [
+        "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:StandardOps",
+        "@llvm-project//mlir:Support",
+        "@llvm-project//mlir:Transforms",
+    ],
+    alwayslink = 1,
+)
+
 gentbl(
     name = "xla_lower_complex_inc_gen",
     tbl_outs = [
@@ -946,6 +964,7 @@ cc_library(
         ":xla_hlo_fusion",
         ":xla_hlo_to_lhlo_with_xla",
         ":xla_legalize_control_flow",
+        ":xla_legalize_tanh_to_approximation",
         ":xla_legalize_tf",
         ":xla_legalize_tf_with_tf2xla",
         ":xla_legalize_to_linalg",
diff --git a/tensorflow/compiler/mlir/xla/tests/legalize_tanh_to_approximation.mlir b/tensorflow/compiler/mlir/xla/tests/legalize_tanh_to_approximation.mlir
new file mode 100644
index 00000000000..a8286c9b5a9
--- /dev/null
+++ b/tensorflow/compiler/mlir/xla/tests/legalize_tanh_to_approximation.mlir
@@ -0,0 +1,134 @@
+// RUN: xla-opt -xla-legalize-tanh-to-approximation -split-input-file %s | FileCheck %s
+
+func @tanh_f64(%arg0 : f64) -> f64 {
+  %res = tanh %arg0 : f64
+  return %res : f64
+}
+
+// CHECK-LABEL: @tanh_f64
+// CHECK: tanh
+
+// -----
+
+func @tanh_f32(%arg0 : f32) -> f32 {
+  %res = tanh %arg0 : f32
+  return %res : f32
+}
+
+// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
+// CHECK:       module {
+
+// CHECK-LABEL:   func @tanh_f32(
+// CHECK-SAME:                   %[[VAL_0:.*]]: f32) -> f32 {
+// CHECK:           %[[VAL_1:.*]] = constant 2.000000e+01 : f32
+// CHECK:           %[[VAL_2:.*]] = constant 1.000000e+00 : f32
+// CHECK:           %[[VAL_3:.*]] = constant 4.000000e-04 : f32
+// CHECK:           %[[VAL_4:.*]] = constant 9.000000e+00 : f32
+// CHECK:           %[[VAL_5:.*]] = constant -2.76076837E-16 : f32
+// CHECK:           %[[VAL_6:.*]] = constant 2.00018794E-13 : f32
+// CHECK:           %[[VAL_7:.*]] = constant -8.60467184E-11 : f32
+// CHECK:           %[[VAL_8:.*]] = constant 5.12229725E-8 : f32
+// CHECK:           %[[VAL_9:.*]] = constant 1.48572235E-5 : f32
+// CHECK:           %[[VAL_10:.*]] = constant 6.37261954E-4 : f32
+// CHECK:           %[[VAL_11:.*]] = constant 0.00489352457 : f32
+// CHECK:           %[[VAL_12:.*]] = constant 1.19825836E-6 : f32
+// CHECK:           %[[VAL_13:.*]] = constant 1.18534706E-4 : f32
+// CHECK:           %[[VAL_14:.*]] = constant 0.00226843474 : f32
+// CHECK:           %[[VAL_15:.*]] = constant 0.00489352504 : f32
+// CHECK:           %[[VAL_16:.*]] = absf %[[VAL_0]] : f32
+// CHECK:           %[[VAL_17:.*]] = copysign %[[VAL_2]], %[[VAL_0]] : f32
+// CHECK:           %[[VAL_18:.*]] = cmpf "ult", %[[VAL_16]], %[[VAL_1]] : f32
+// CHECK:           %[[VAL_19:.*]] = cmpf "olt", %[[VAL_16]], %[[VAL_3]] : f32
+// CHECK:           %[[VAL_20:.*]] = cmpf "ule", %[[VAL_16]], %[[VAL_4]] : f32
+// CHECK:           %[[VAL_21:.*]] = copysign %[[VAL_4]], %[[VAL_0]] : f32
+// CHECK:           %[[VAL_22:.*]] = select %[[VAL_20]], %[[VAL_0]], %[[VAL_21]] : f32
+// CHECK:           %[[VAL_23:.*]] = mulf %[[VAL_22]], %[[VAL_22]] : f32
+// CHECK:           %[[VAL_24:.*]] = mulf %[[VAL_23]], %[[VAL_5]] : f32
+// CHECK:           %[[VAL_25:.*]] = addf %[[VAL_24]], %[[VAL_6]] : f32
+// CHECK:           %[[VAL_26:.*]] = mulf %[[VAL_23]], %[[VAL_25]] : f32
+// CHECK:           %[[VAL_27:.*]] = addf %[[VAL_26]], %[[VAL_7]] : f32
+// CHECK:           %[[VAL_28:.*]] = mulf %[[VAL_23]], %[[VAL_27]] : f32
+// CHECK:           %[[VAL_29:.*]] = addf %[[VAL_28]], %[[VAL_8]] : f32
+// CHECK:           %[[VAL_30:.*]] = mulf %[[VAL_23]], %[[VAL_29]] : f32
+// CHECK:           %[[VAL_31:.*]] = addf %[[VAL_30]], %[[VAL_9]] : f32
+// CHECK:           %[[VAL_32:.*]] = mulf %[[VAL_23]], %[[VAL_31]] : f32
+// CHECK:           %[[VAL_33:.*]] = addf %[[VAL_32]], %[[VAL_10]] : f32
+// CHECK:           %[[VAL_34:.*]] = mulf %[[VAL_23]], %[[VAL_33]] : f32
+// CHECK:           %[[VAL_35:.*]] = addf %[[VAL_34]], %[[VAL_11]] : f32
+// CHECK:           %[[VAL_36:.*]] = mulf %[[VAL_22]], %[[VAL_35]] : f32
+// CHECK:           %[[VAL_37:.*]] = mulf %[[VAL_23]], %[[VAL_12]] : f32
+// CHECK:           %[[VAL_38:.*]] = addf %[[VAL_37]], %[[VAL_13]] : f32
+// CHECK:           %[[VAL_39:.*]] = mulf %[[VAL_23]], %[[VAL_38]] : f32
+// CHECK:           %[[VAL_40:.*]] = addf %[[VAL_39]], %[[VAL_14]] : f32
+// CHECK:           %[[VAL_41:.*]] = mulf %[[VAL_23]], %[[VAL_40]] : f32
+// CHECK:           %[[VAL_42:.*]] = addf %[[VAL_41]], %[[VAL_15]] : f32
+// CHECK:           %[[VAL_43:.*]] = divf %[[VAL_36]], %[[VAL_42]] : f32
+// CHECK:           %[[VAL_44:.*]] = select %[[VAL_19]], %[[VAL_0]], %[[VAL_43]] : f32
+// CHECK:           %[[VAL_45:.*]] = select %[[VAL_18]], %[[VAL_44]], %[[VAL_17]] : f32
+// CHECK:           return %[[VAL_45]] : f32
+// CHECK:         }
+// CHECK:       }
+
+// -----
+
+func @tanh_f16(%arg0 : f16) -> f16 {
+  %res = tanh %arg0 : f16
+  return %res : f16
+}
+
+// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
+// CHECK:       module {
+
+// CHECK-LABEL:   func @tanh_f16(
+// CHECK-SAME:                   %[[VAL_0:.*]]: f16) -> f16 {
+// CHECK:           %[[VAL_1:.*]] = constant 2.000000e+01 : f32
+// CHECK:           %[[VAL_2:.*]] = constant 1.000000e+00 : f32
+// CHECK:           %[[VAL_3:.*]] = constant 4.000000e-04 : f32
+// CHECK:           %[[VAL_4:.*]] = constant 9.000000e+00 : f32
+// CHECK:           %[[VAL_5:.*]] = constant -2.76076837E-16 : f32
+// CHECK:           %[[VAL_6:.*]] = constant 2.00018794E-13 : f32
+// CHECK:           %[[VAL_7:.*]] = constant -8.60467184E-11 : f32
+// CHECK:           %[[VAL_8:.*]] = constant 5.12229725E-8 : f32
+// CHECK:           %[[VAL_9:.*]] = constant 1.48572235E-5 : f32
+// CHECK:           %[[VAL_10:.*]] = constant 6.37261954E-4 : f32
+// CHECK:           %[[VAL_11:.*]] = constant 0.00489352457 : f32
+// CHECK:           %[[VAL_12:.*]] = constant 1.19825836E-6 : f32
+// CHECK:           %[[VAL_13:.*]] = constant 1.18534706E-4 : f32
+// CHECK:           %[[VAL_14:.*]] = constant 0.00226843474 : f32
+// CHECK:           %[[VAL_15:.*]] = constant 0.00489352504 : f32
+// CHECK:           %[[VAL_16:.*]] = fpext %[[VAL_0]] : f16 to f32
+// CHECK:           %[[VAL_17:.*]] = absf %[[VAL_16]] : f32
+// CHECK:           %[[VAL_18:.*]] = copysign %[[VAL_2]], %[[VAL_16]] : f32
+// CHECK:           %[[VAL_19:.*]] = cmpf "ult", %[[VAL_17]], %[[VAL_1]] : f32
+// CHECK:           %[[VAL_20:.*]] = cmpf "olt", %[[VAL_17]], %[[VAL_3]] : f32
+// CHECK:           %[[VAL_21:.*]] = cmpf "ule", %[[VAL_17]], %[[VAL_4]] : f32
+// CHECK:           %[[VAL_22:.*]] = copysign %[[VAL_4]], %[[VAL_16]] : f32
+// CHECK:           %[[VAL_23:.*]] = select %[[VAL_21]], %[[VAL_16]], %[[VAL_22]] : f32
+// CHECK:           %[[VAL_24:.*]] = mulf %[[VAL_23]], %[[VAL_23]] : f32
+// CHECK:           %[[VAL_25:.*]] = mulf %[[VAL_24]], %[[VAL_5]] : f32
+// CHECK:           %[[VAL_26:.*]] = addf %[[VAL_25]], %[[VAL_6]] : f32
+// CHECK:           %[[VAL_27:.*]] = mulf %[[VAL_24]], %[[VAL_26]] : f32
+// CHECK:           %[[VAL_28:.*]] = addf %[[VAL_27]], %[[VAL_7]] : f32
+// CHECK:           %[[VAL_29:.*]] = mulf %[[VAL_24]], %[[VAL_28]] : f32
+// CHECK:           %[[VAL_30:.*]] = addf %[[VAL_29]], %[[VAL_8]] : f32
+// CHECK:           %[[VAL_31:.*]] = mulf %[[VAL_24]], %[[VAL_30]] : f32
+// CHECK:           %[[VAL_32:.*]] = addf %[[VAL_31]], %[[VAL_9]] : f32
+// CHECK:           %[[VAL_33:.*]] = mulf %[[VAL_24]], %[[VAL_32]] : f32
+// CHECK:           %[[VAL_34:.*]] = addf %[[VAL_33]], %[[VAL_10]] : f32
+// CHECK:           %[[VAL_35:.*]] = mulf %[[VAL_24]], %[[VAL_34]] : f32
+// CHECK:           %[[VAL_36:.*]] = addf %[[VAL_35]], %[[VAL_11]] : f32
+// CHECK:           %[[VAL_37:.*]] = mulf %[[VAL_23]], %[[VAL_36]] : f32
+// CHECK:           %[[VAL_38:.*]] = mulf %[[VAL_24]], %[[VAL_12]] : f32
+// CHECK:           %[[VAL_39:.*]] = addf %[[VAL_38]], %[[VAL_13]] : f32
+// CHECK:           %[[VAL_40:.*]] = mulf %[[VAL_24]], %[[VAL_39]] : f32
+// CHECK:           %[[VAL_41:.*]] = addf %[[VAL_40]], %[[VAL_14]] : f32
+// CHECK:           %[[VAL_42:.*]] = mulf %[[VAL_24]], %[[VAL_41]] : f32
+// CHECK:           %[[VAL_43:.*]] = addf %[[VAL_42]], %[[VAL_15]] : f32
+// CHECK:           %[[VAL_44:.*]] = divf %[[VAL_37]], %[[VAL_43]] : f32
+// CHECK:           %[[VAL_45:.*]] = select %[[VAL_20]], %[[VAL_16]], %[[VAL_44]] : f32
+// CHECK:           %[[VAL_46:.*]] = select %[[VAL_19]], %[[VAL_45]], %[[VAL_18]] : f32
+// CHECK:           %[[VAL_47:.*]] = fptrunc %[[VAL_46]] : f32 to f16
+// CHECK:           return %[[VAL_47]] : f16
+// CHECK:         }
+// CHECK:       }
+
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tanh_to_approximation.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tanh_to_approximation.cc
new file mode 100644
index 00000000000..9696db377da
--- /dev/null
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tanh_to_approximation.cc
@@ -0,0 +1,167 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This file implements logic for lowering the tanh standard ops to an
+// approximation.
+
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
+#include "mlir/IR/Function.h"  // from @llvm-project
+#include "mlir/IR/PatternMatch.h"  // from @llvm-project
+#include "mlir/Pass/Pass.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
+#include "tensorflow/compiler/mlir/xla/transforms/rewriters.h"
+
+namespace mlir {
+namespace xla {
+namespace {
+
+/// Emits the fast tanh approximation that is also used by XLA.
+static Value EmitTanhApproximation(Value input, Value abs_value, Location loc,
+                                   PatternRewriter &rewriter) {
+  // For small values of x, we can approximate tanh(x)=x. For extremely small
+  // values of x (|x| < 1e-37), the other approximation would evaluate
+  // tanh(x) = 0.
+  constexpr float kCanUseApprox = 0.0004;
+  Value can_use_approx =
+      rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(kCanUseApprox));
+  Value return_input = rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT,
+                                               abs_value, can_use_approx);
+
+  // Clamp the input to [-9, 9].
+  Value plus_nine =
+      rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(9.0));
+  Value smaller_than_nine =
+      rewriter.create<CmpFOp>(loc, CmpFPredicate::ULE, abs_value, plus_nine);
+  Value input_clamped = rewriter.create<SelectOp>(
+      loc, smaller_than_nine, input,
+      rewriter.create<CopySignOp>(loc, plus_nine, input));
+
+  static constexpr std::array<float, 7> numerator_coeffs{
+      -2.76076847742355e-16f, 2.00018790482477e-13f, -8.60467152213735e-11f,
+      5.12229709037114e-08f,  1.48572235717979e-05f, 6.37261928875436e-04f,
+      4.89352455891786e-03f};
+
+  static constexpr std::array<float, 4> denominator_coeffs{
+      1.19825839466702e-06f, 1.18534705686654e-04f, 2.26843463243900e-03f,
+      4.89352518554385e-03f};
+
+  Value input_squared =
+      rewriter.create<MulFOp>(loc, input_clamped, input_clamped);
+  Value numerator = rewriter.create<ConstantOp>(
+      loc, rewriter.getF32FloatAttr(numerator_coeffs[0]));
+  for (int i = 1; i < numerator_coeffs.size(); i++) {
+    numerator = rewriter.create<AddFOp>(
+        loc, rewriter.create<MulFOp>(loc, input_squared, numerator),
+        rewriter.create<ConstantOp>(
+            loc, rewriter.getF32FloatAttr(numerator_coeffs[i])));
+  }
+
+  numerator = rewriter.create<MulFOp>(loc, input_clamped, numerator);
+
+  Value denominator = rewriter.create<ConstantOp>(
+      loc, rewriter.getF32FloatAttr(denominator_coeffs[0]));
+  for (int i = 1; i < denominator_coeffs.size(); i++) {
+    denominator = rewriter.create<AddFOp>(
+        loc, rewriter.create<MulFOp>(loc, input_squared, denominator),
+        rewriter.create<ConstantOp>(
+            loc, rewriter.getF32FloatAttr(denominator_coeffs[i])));
+  }
+
+  Value approx = rewriter.create<DivFOp>(loc, numerator, denominator);
+
+  return rewriter.create<SelectOp>(loc, return_input, input, approx);
+}
+
+class ApproximateTanhLowering : public OpRewritePattern<TanhOp> {
+ public:
+  explicit ApproximateTanhLowering(MLIRContext *ctx)
+      : OpRewritePattern<TanhOp>(ctx, 100) {}
+
+  LogicalResult matchAndRewrite(TanhOp tanhOp,
+                                PatternRewriter &rewriter) const override {
+    Type operand_type = tanhOp.getType();
+
+    if (operand_type.isF64()) {
+      // Similar to XLA, do not rewrite f64 as precision might matter.
+      return failure();
+    }
+
+    Location loc = tanhOp.getLoc();
+    Value input = tanhOp.operand();
+    if (operand_type.isF16()) {
+      input = rewriter.create<FPExtOp>(loc, input, rewriter.getF32Type());
+    }
+
+    // If we still do not have f32, fail.
+    if (!input.getType().isF32()) {
+      return failure();
+    }
+
+    // For |operand| > 20.0, we just return -1/1.
+    constexpr double kMaxValue = 20.0;
+    Value max_value =
+        rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(kMaxValue));
+    Value abs_value = rewriter.create<AbsFOp>(loc, input);
+
+    Value one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0));
+    Value one_with_sign = rewriter.create<CopySignOp>(loc, one, input);
+
+    Value smaller_than_twenty =
+        rewriter.create<CmpFOp>(loc, CmpFPredicate::ULT, abs_value, max_value);
+
+    // Otherwise, we use the approximation.
+    Value approx = EmitTanhApproximation(input, abs_value, loc, rewriter);
+
+    Value result = rewriter.create<SelectOp>(loc, smaller_than_twenty, approx,
+                                             one_with_sign);
+
+    // Truncate back if needed.
+    if (operand_type.isF16()) {
+      result = rewriter.create<FPTruncOp>(loc, result, rewriter.getF16Type());
+    }
+
+    rewriter.replaceOp(tanhOp, {result});
+    return success();
+  }
+};
+
+struct LegalizeTanhToApproximation
+    : public PassWrapper<LegalizeTanhToApproximation, FunctionPass> {
+  /// Perform the lowering of standard dialect operations to approximations.
+  void runOnFunction() override {
+    OwningRewritePatternList patterns;
+    PopulateTanhToApproximationPatterns(&getContext(), &patterns);
+    applyPatternsAndFoldGreedily(getFunction(), patterns);
+  }
+};
+
+}  // anonymous namespace
+
+std::unique_ptr<mlir::OperationPass<mlir::FuncOp>>
+createLegalizeTanhToApproximationPass() {
+  return std::make_unique<LegalizeTanhToApproximation>();
+}
+
+void PopulateTanhToApproximationPatterns(mlir::MLIRContext *context,
+                                         OwningRewritePatternList *patterns) {
+  patterns->insert<ApproximateTanhLowering>(context);
+}
+
+static PassRegistration<LegalizeTanhToApproximation> legalize_pass(
+    "xla-legalize-tanh-to-approximation",
+    "Legalize tanh from standard dialect to an approximation");
+
+}  // namespace xla
+}  // namespace mlir
diff --git a/tensorflow/compiler/mlir/xla/transforms/passes.h b/tensorflow/compiler/mlir/xla/transforms/passes.h
index a2af8124786..3db0bc3b474 100644
--- a/tensorflow/compiler/mlir/xla/transforms/passes.h
+++ b/tensorflow/compiler/mlir/xla/transforms/passes.h
@@ -115,6 +115,13 @@ std::unique_ptr<Pass> createLhloCopyRemovalPass();
 std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToParallelLoopsPass();
 
 }  // namespace xla_lhlo
+
+namespace xla {
+
+/// Lowers the standard TanhOp to an approximation that does not use intrinsics.
+std::unique_ptr<OperationPass<FuncOp>> createLegalizeTanhToApproximationPass();
+
+}  // namespace xla
 }  // namespace mlir
 
 #endif  // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_PASSES_H_
diff --git a/tensorflow/compiler/mlir/xla/transforms/rewriters.h b/tensorflow/compiler/mlir/xla/transforms/rewriters.h
index 59347198fe4..7303b87be75 100644
--- a/tensorflow/compiler/mlir/xla/transforms/rewriters.h
+++ b/tensorflow/compiler/mlir/xla/transforms/rewriters.h
@@ -91,6 +91,14 @@ void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
 
 }  // namespace xla_chlo
 
+namespace xla {
+
+// Populates a pattern that translates the standard TanhOp to an approximation
+// that does not use intrinsics.
+void PopulateTanhToApproximationPatterns(MLIRContext *context,
+                                         OwningRewritePatternList *patterns);
+
+}  // namespace xla
 }  // namespace mlir
 
 #endif  // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_REWRITERS_H_