Add lowering for TanhOp that uses an approximation instead of lowering to intrinsics.

The same approximation is used by the XLA compiler.

PiperOrigin-RevId: 316844625
Change-Id: I1a909bd063509491a6a58ae0acc3bfb919cb34d5
This commit is contained in:
Stephan Herhut 2020-06-17 01:26:51 -07:00 committed by TensorFlower Gardener
parent de907d8746
commit b0199879f0
5 changed files with 335 additions and 0 deletions

View File

@ -515,6 +515,24 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "xla_legalize_tanh_to_approximation",
srcs = ["transforms/legalize_tanh_to_approximation.cc"],
hdrs = [
"transforms/passes.h",
"transforms/rewriters.h",
],
deps = [
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
],
alwayslink = 1,
)
gentbl(
name = "xla_lower_complex_inc_gen",
tbl_outs = [
@ -946,6 +964,7 @@ cc_library(
":xla_hlo_fusion",
":xla_hlo_to_lhlo_with_xla",
":xla_legalize_control_flow",
":xla_legalize_tanh_to_approximation",
":xla_legalize_tf",
":xla_legalize_tf_with_tf2xla",
":xla_legalize_to_linalg",

View File

@ -0,0 +1,134 @@
// RUN: xla-opt -xla-legalize-tanh-to-approximation -split-input-file %s | FileCheck %s
func @tanh_f64(%arg0 : f64) -> f64 {
%res = tanh %arg0 : f64
return %res : f64
}
// CHECK-LABEL: @tanh_f64
// CHECK: tanh
// -----
func @tanh_f32(%arg0 : f32) -> f32 {
%res = tanh %arg0 : f32
return %res : f32
}
// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
// CHECK: module {
// CHECK-LABEL: func @tanh_f32(
// CHECK-SAME: %[[VAL_0:.*]]: f32) -> f32 {
// CHECK: %[[VAL_1:.*]] = constant 2.000000e+01 : f32
// CHECK: %[[VAL_2:.*]] = constant 1.000000e+00 : f32
// CHECK: %[[VAL_3:.*]] = constant 4.000000e-04 : f32
// CHECK: %[[VAL_4:.*]] = constant 9.000000e+00 : f32
// CHECK: %[[VAL_5:.*]] = constant -2.76076837E-16 : f32
// CHECK: %[[VAL_6:.*]] = constant 2.00018794E-13 : f32
// CHECK: %[[VAL_7:.*]] = constant -8.60467184E-11 : f32
// CHECK: %[[VAL_8:.*]] = constant 5.12229725E-8 : f32
// CHECK: %[[VAL_9:.*]] = constant 1.48572235E-5 : f32
// CHECK: %[[VAL_10:.*]] = constant 6.37261954E-4 : f32
// CHECK: %[[VAL_11:.*]] = constant 0.00489352457 : f32
// CHECK: %[[VAL_12:.*]] = constant 1.19825836E-6 : f32
// CHECK: %[[VAL_13:.*]] = constant 1.18534706E-4 : f32
// CHECK: %[[VAL_14:.*]] = constant 0.00226843474 : f32
// CHECK: %[[VAL_15:.*]] = constant 0.00489352504 : f32
// CHECK: %[[VAL_16:.*]] = absf %[[VAL_0]] : f32
// CHECK: %[[VAL_17:.*]] = copysign %[[VAL_2]], %[[VAL_0]] : f32
// CHECK: %[[VAL_18:.*]] = cmpf "ult", %[[VAL_16]], %[[VAL_1]] : f32
// CHECK: %[[VAL_19:.*]] = cmpf "olt", %[[VAL_16]], %[[VAL_3]] : f32
// CHECK: %[[VAL_20:.*]] = cmpf "ule", %[[VAL_16]], %[[VAL_4]] : f32
// CHECK: %[[VAL_21:.*]] = copysign %[[VAL_4]], %[[VAL_0]] : f32
// CHECK: %[[VAL_22:.*]] = select %[[VAL_20]], %[[VAL_0]], %[[VAL_21]] : f32
// CHECK: %[[VAL_23:.*]] = mulf %[[VAL_22]], %[[VAL_22]] : f32
// CHECK: %[[VAL_24:.*]] = mulf %[[VAL_23]], %[[VAL_5]] : f32
// CHECK: %[[VAL_25:.*]] = addf %[[VAL_24]], %[[VAL_6]] : f32
// CHECK: %[[VAL_26:.*]] = mulf %[[VAL_23]], %[[VAL_25]] : f32
// CHECK: %[[VAL_27:.*]] = addf %[[VAL_26]], %[[VAL_7]] : f32
// CHECK: %[[VAL_28:.*]] = mulf %[[VAL_23]], %[[VAL_27]] : f32
// CHECK: %[[VAL_29:.*]] = addf %[[VAL_28]], %[[VAL_8]] : f32
// CHECK: %[[VAL_30:.*]] = mulf %[[VAL_23]], %[[VAL_29]] : f32
// CHECK: %[[VAL_31:.*]] = addf %[[VAL_30]], %[[VAL_9]] : f32
// CHECK: %[[VAL_32:.*]] = mulf %[[VAL_23]], %[[VAL_31]] : f32
// CHECK: %[[VAL_33:.*]] = addf %[[VAL_32]], %[[VAL_10]] : f32
// CHECK: %[[VAL_34:.*]] = mulf %[[VAL_23]], %[[VAL_33]] : f32
// CHECK: %[[VAL_35:.*]] = addf %[[VAL_34]], %[[VAL_11]] : f32
// CHECK: %[[VAL_36:.*]] = mulf %[[VAL_22]], %[[VAL_35]] : f32
// CHECK: %[[VAL_37:.*]] = mulf %[[VAL_23]], %[[VAL_12]] : f32
// CHECK: %[[VAL_38:.*]] = addf %[[VAL_37]], %[[VAL_13]] : f32
// CHECK: %[[VAL_39:.*]] = mulf %[[VAL_23]], %[[VAL_38]] : f32
// CHECK: %[[VAL_40:.*]] = addf %[[VAL_39]], %[[VAL_14]] : f32
// CHECK: %[[VAL_41:.*]] = mulf %[[VAL_23]], %[[VAL_40]] : f32
// CHECK: %[[VAL_42:.*]] = addf %[[VAL_41]], %[[VAL_15]] : f32
// CHECK: %[[VAL_43:.*]] = divf %[[VAL_36]], %[[VAL_42]] : f32
// CHECK: %[[VAL_44:.*]] = select %[[VAL_19]], %[[VAL_0]], %[[VAL_43]] : f32
// CHECK: %[[VAL_45:.*]] = select %[[VAL_18]], %[[VAL_44]], %[[VAL_17]] : f32
// CHECK: return %[[VAL_45]] : f32
// CHECK: }
// CHECK: }
// -----
func @tanh_f16(%arg0 : f16) -> f16 {
%res = tanh %arg0 : f16
return %res : f16
}
// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
// CHECK: module {
// CHECK-LABEL: func @tanh_f16(
// CHECK-SAME: %[[VAL_0:.*]]: f16) -> f16 {
// CHECK: %[[VAL_1:.*]] = constant 2.000000e+01 : f32
// CHECK: %[[VAL_2:.*]] = constant 1.000000e+00 : f32
// CHECK: %[[VAL_3:.*]] = constant 4.000000e-04 : f32
// CHECK: %[[VAL_4:.*]] = constant 9.000000e+00 : f32
// CHECK: %[[VAL_5:.*]] = constant -2.76076837E-16 : f32
// CHECK: %[[VAL_6:.*]] = constant 2.00018794E-13 : f32
// CHECK: %[[VAL_7:.*]] = constant -8.60467184E-11 : f32
// CHECK: %[[VAL_8:.*]] = constant 5.12229725E-8 : f32
// CHECK: %[[VAL_9:.*]] = constant 1.48572235E-5 : f32
// CHECK: %[[VAL_10:.*]] = constant 6.37261954E-4 : f32
// CHECK: %[[VAL_11:.*]] = constant 0.00489352457 : f32
// CHECK: %[[VAL_12:.*]] = constant 1.19825836E-6 : f32
// CHECK: %[[VAL_13:.*]] = constant 1.18534706E-4 : f32
// CHECK: %[[VAL_14:.*]] = constant 0.00226843474 : f32
// CHECK: %[[VAL_15:.*]] = constant 0.00489352504 : f32
// CHECK: %[[VAL_16:.*]] = fpext %[[VAL_0]] : f16 to f32
// CHECK: %[[VAL_17:.*]] = absf %[[VAL_16]] : f32
// CHECK: %[[VAL_18:.*]] = copysign %[[VAL_2]], %[[VAL_16]] : f32
// CHECK: %[[VAL_19:.*]] = cmpf "ult", %[[VAL_17]], %[[VAL_1]] : f32
// CHECK: %[[VAL_20:.*]] = cmpf "olt", %[[VAL_17]], %[[VAL_3]] : f32
// CHECK: %[[VAL_21:.*]] = cmpf "ule", %[[VAL_17]], %[[VAL_4]] : f32
// CHECK: %[[VAL_22:.*]] = copysign %[[VAL_4]], %[[VAL_16]] : f32
// CHECK: %[[VAL_23:.*]] = select %[[VAL_21]], %[[VAL_16]], %[[VAL_22]] : f32
// CHECK: %[[VAL_24:.*]] = mulf %[[VAL_23]], %[[VAL_23]] : f32
// CHECK: %[[VAL_25:.*]] = mulf %[[VAL_24]], %[[VAL_5]] : f32
// CHECK: %[[VAL_26:.*]] = addf %[[VAL_25]], %[[VAL_6]] : f32
// CHECK: %[[VAL_27:.*]] = mulf %[[VAL_24]], %[[VAL_26]] : f32
// CHECK: %[[VAL_28:.*]] = addf %[[VAL_27]], %[[VAL_7]] : f32
// CHECK: %[[VAL_29:.*]] = mulf %[[VAL_24]], %[[VAL_28]] : f32
// CHECK: %[[VAL_30:.*]] = addf %[[VAL_29]], %[[VAL_8]] : f32
// CHECK: %[[VAL_31:.*]] = mulf %[[VAL_24]], %[[VAL_30]] : f32
// CHECK: %[[VAL_32:.*]] = addf %[[VAL_31]], %[[VAL_9]] : f32
// CHECK: %[[VAL_33:.*]] = mulf %[[VAL_24]], %[[VAL_32]] : f32
// CHECK: %[[VAL_34:.*]] = addf %[[VAL_33]], %[[VAL_10]] : f32
// CHECK: %[[VAL_35:.*]] = mulf %[[VAL_24]], %[[VAL_34]] : f32
// CHECK: %[[VAL_36:.*]] = addf %[[VAL_35]], %[[VAL_11]] : f32
// CHECK: %[[VAL_37:.*]] = mulf %[[VAL_23]], %[[VAL_36]] : f32
// CHECK: %[[VAL_38:.*]] = mulf %[[VAL_24]], %[[VAL_12]] : f32
// CHECK: %[[VAL_39:.*]] = addf %[[VAL_38]], %[[VAL_13]] : f32
// CHECK: %[[VAL_40:.*]] = mulf %[[VAL_24]], %[[VAL_39]] : f32
// CHECK: %[[VAL_41:.*]] = addf %[[VAL_40]], %[[VAL_14]] : f32
// CHECK: %[[VAL_42:.*]] = mulf %[[VAL_24]], %[[VAL_41]] : f32
// CHECK: %[[VAL_43:.*]] = addf %[[VAL_42]], %[[VAL_15]] : f32
// CHECK: %[[VAL_44:.*]] = divf %[[VAL_37]], %[[VAL_43]] : f32
// CHECK: %[[VAL_45:.*]] = select %[[VAL_20]], %[[VAL_16]], %[[VAL_44]] : f32
// CHECK: %[[VAL_46:.*]] = select %[[VAL_19]], %[[VAL_45]], %[[VAL_18]] : f32
// CHECK: %[[VAL_47:.*]] = fptrunc %[[VAL_46]] : f32 to f16
// CHECK: return %[[VAL_47]] : f16
// CHECK: }
// CHECK: }

View File

@ -0,0 +1,167 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file implements logic for lowering the tanh standard ops to an
// approximation.
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
#include "tensorflow/compiler/mlir/xla/transforms/rewriters.h"
namespace mlir {
namespace xla {
namespace {
/// Emits the fast tanh approximation that is also used by XLA.
static Value EmitTanhApproximation(Value input, Value abs_value, Location loc,
PatternRewriter &rewriter) {
// For small values of x, we can approximate tanh(x)=x. For extremely small
// values of x (|x| < 1e-37), the other approximation would evaluate
// tanh(x) = 0.
constexpr float kCanUseApprox = 0.0004;
Value can_use_approx =
rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(kCanUseApprox));
Value return_input = rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT,
abs_value, can_use_approx);
// Clamp the input to [-9, 9].
Value plus_nine =
rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(9.0));
Value smaller_than_nine =
rewriter.create<CmpFOp>(loc, CmpFPredicate::ULE, abs_value, plus_nine);
Value input_clamped = rewriter.create<SelectOp>(
loc, smaller_than_nine, input,
rewriter.create<CopySignOp>(loc, plus_nine, input));
static constexpr std::array<float, 7> numerator_coeffs{
-2.76076847742355e-16f, 2.00018790482477e-13f, -8.60467152213735e-11f,
5.12229709037114e-08f, 1.48572235717979e-05f, 6.37261928875436e-04f,
4.89352455891786e-03f};
static constexpr std::array<float, 4> denominator_coeffs{
1.19825839466702e-06f, 1.18534705686654e-04f, 2.26843463243900e-03f,
4.89352518554385e-03f};
Value input_squared =
rewriter.create<MulFOp>(loc, input_clamped, input_clamped);
Value numerator = rewriter.create<ConstantOp>(
loc, rewriter.getF32FloatAttr(numerator_coeffs[0]));
for (int i = 1; i < numerator_coeffs.size(); i++) {
numerator = rewriter.create<AddFOp>(
loc, rewriter.create<MulFOp>(loc, input_squared, numerator),
rewriter.create<ConstantOp>(
loc, rewriter.getF32FloatAttr(numerator_coeffs[i])));
}
numerator = rewriter.create<MulFOp>(loc, input_clamped, numerator);
Value denominator = rewriter.create<ConstantOp>(
loc, rewriter.getF32FloatAttr(denominator_coeffs[0]));
for (int i = 1; i < denominator_coeffs.size(); i++) {
denominator = rewriter.create<AddFOp>(
loc, rewriter.create<MulFOp>(loc, input_squared, denominator),
rewriter.create<ConstantOp>(
loc, rewriter.getF32FloatAttr(denominator_coeffs[i])));
}
Value approx = rewriter.create<DivFOp>(loc, numerator, denominator);
return rewriter.create<SelectOp>(loc, return_input, input, approx);
}
class ApproximateTanhLowering : public OpRewritePattern<TanhOp> {
public:
explicit ApproximateTanhLowering(MLIRContext *ctx)
: OpRewritePattern<TanhOp>(ctx, 100) {}
LogicalResult matchAndRewrite(TanhOp tanhOp,
PatternRewriter &rewriter) const override {
Type operand_type = tanhOp.getType();
if (operand_type.isF64()) {
// Similar to XLA, do not rewrite f64 as precision might matter.
return failure();
}
Location loc = tanhOp.getLoc();
Value input = tanhOp.operand();
if (operand_type.isF16()) {
input = rewriter.create<FPExtOp>(loc, input, rewriter.getF32Type());
}
// If we still do not have f32, fail.
if (!input.getType().isF32()) {
return failure();
}
// For |operand| > 20.0, we just return -1/1.
constexpr double kMaxValue = 20.0;
Value max_value =
rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(kMaxValue));
Value abs_value = rewriter.create<AbsFOp>(loc, input);
Value one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0));
Value one_with_sign = rewriter.create<CopySignOp>(loc, one, input);
Value smaller_than_twenty =
rewriter.create<CmpFOp>(loc, CmpFPredicate::ULT, abs_value, max_value);
// Otherwise, we use the approximation.
Value approx = EmitTanhApproximation(input, abs_value, loc, rewriter);
Value result = rewriter.create<SelectOp>(loc, smaller_than_twenty, approx,
one_with_sign);
// Truncate back if needed.
if (operand_type.isF16()) {
result = rewriter.create<FPTruncOp>(loc, result, rewriter.getF16Type());
}
rewriter.replaceOp(tanhOp, {result});
return success();
}
};
struct LegalizeTanhToApproximation
: public PassWrapper<LegalizeTanhToApproximation, FunctionPass> {
/// Perform the lowering of standard dialect operations to approximations.
void runOnFunction() override {
OwningRewritePatternList patterns;
PopulateTanhToApproximationPatterns(&getContext(), &patterns);
applyPatternsAndFoldGreedily(getFunction(), patterns);
}
};
} // anonymous namespace
std::unique_ptr<mlir::OperationPass<mlir::FuncOp>>
createLegalizeTanhToApproximationPass() {
return std::make_unique<LegalizeTanhToApproximation>();
}
void PopulateTanhToApproximationPatterns(mlir::MLIRContext *context,
OwningRewritePatternList *patterns) {
patterns->insert<ApproximateTanhLowering>(context);
}
static PassRegistration<LegalizeTanhToApproximation> legalize_pass(
"xla-legalize-tanh-to-approximation",
"Legalize tanh from standard dialect to an approximation");
} // namespace xla
} // namespace mlir

View File

@ -115,6 +115,13 @@ std::unique_ptr<Pass> createLhloCopyRemovalPass();
std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToParallelLoopsPass();
} // namespace xla_lhlo
namespace xla {
/// Lowers the standard TanhOp to an approximation that does not use intrinsics.
std::unique_ptr<OperationPass<FuncOp>> createLegalizeTanhToApproximationPass();
} // namespace xla
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_PASSES_H_

View File

@ -91,6 +91,14 @@ void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
} // namespace xla_chlo
namespace xla {
// Populates a pattern that translates the standard TanhOp to an approximation
// that does not use intrinsics.
void PopulateTanhToApproximationPatterns(MLIRContext *context,
OwningRewritePatternList *patterns);
} // namespace xla
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_REWRITERS_H_