Add lowering for TanhOp that uses an approximation instead of lowering to intrinsics.
The same approximation is used by the XLA compiler. PiperOrigin-RevId: 316844625 Change-Id: I1a909bd063509491a6a58ae0acc3bfb919cb34d5
This commit is contained in:
parent
de907d8746
commit
b0199879f0
@ -515,6 +515,24 @@ cc_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_legalize_tanh_to_approximation",
|
||||
srcs = ["transforms/legalize_tanh_to_approximation.cc"],
|
||||
hdrs = [
|
||||
"transforms/passes.h",
|
||||
"transforms/rewriters.h",
|
||||
],
|
||||
deps = [
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
gentbl(
|
||||
name = "xla_lower_complex_inc_gen",
|
||||
tbl_outs = [
|
||||
@ -946,6 +964,7 @@ cc_library(
|
||||
":xla_hlo_fusion",
|
||||
":xla_hlo_to_lhlo_with_xla",
|
||||
":xla_legalize_control_flow",
|
||||
":xla_legalize_tanh_to_approximation",
|
||||
":xla_legalize_tf",
|
||||
":xla_legalize_tf_with_tf2xla",
|
||||
":xla_legalize_to_linalg",
|
||||
|
@ -0,0 +1,134 @@
|
||||
// RUN: xla-opt -xla-legalize-tanh-to-approximation -split-input-file %s | FileCheck %s
|
||||
|
||||
func @tanh_f64(%arg0 : f64) -> f64 {
|
||||
%res = tanh %arg0 : f64
|
||||
return %res : f64
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @tanh_f64
|
||||
// CHECK: tanh
|
||||
|
||||
// -----
|
||||
|
||||
func @tanh_f32(%arg0 : f32) -> f32 {
|
||||
%res = tanh %arg0 : f32
|
||||
return %res : f32
|
||||
}
|
||||
|
||||
// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
|
||||
// CHECK: module {
|
||||
|
||||
// CHECK-LABEL: func @tanh_f32(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: f32) -> f32 {
|
||||
// CHECK: %[[VAL_1:.*]] = constant 2.000000e+01 : f32
|
||||
// CHECK: %[[VAL_2:.*]] = constant 1.000000e+00 : f32
|
||||
// CHECK: %[[VAL_3:.*]] = constant 4.000000e-04 : f32
|
||||
// CHECK: %[[VAL_4:.*]] = constant 9.000000e+00 : f32
|
||||
// CHECK: %[[VAL_5:.*]] = constant -2.76076837E-16 : f32
|
||||
// CHECK: %[[VAL_6:.*]] = constant 2.00018794E-13 : f32
|
||||
// CHECK: %[[VAL_7:.*]] = constant -8.60467184E-11 : f32
|
||||
// CHECK: %[[VAL_8:.*]] = constant 5.12229725E-8 : f32
|
||||
// CHECK: %[[VAL_9:.*]] = constant 1.48572235E-5 : f32
|
||||
// CHECK: %[[VAL_10:.*]] = constant 6.37261954E-4 : f32
|
||||
// CHECK: %[[VAL_11:.*]] = constant 0.00489352457 : f32
|
||||
// CHECK: %[[VAL_12:.*]] = constant 1.19825836E-6 : f32
|
||||
// CHECK: %[[VAL_13:.*]] = constant 1.18534706E-4 : f32
|
||||
// CHECK: %[[VAL_14:.*]] = constant 0.00226843474 : f32
|
||||
// CHECK: %[[VAL_15:.*]] = constant 0.00489352504 : f32
|
||||
// CHECK: %[[VAL_16:.*]] = absf %[[VAL_0]] : f32
|
||||
// CHECK: %[[VAL_17:.*]] = copysign %[[VAL_2]], %[[VAL_0]] : f32
|
||||
// CHECK: %[[VAL_18:.*]] = cmpf "ult", %[[VAL_16]], %[[VAL_1]] : f32
|
||||
// CHECK: %[[VAL_19:.*]] = cmpf "olt", %[[VAL_16]], %[[VAL_3]] : f32
|
||||
// CHECK: %[[VAL_20:.*]] = cmpf "ule", %[[VAL_16]], %[[VAL_4]] : f32
|
||||
// CHECK: %[[VAL_21:.*]] = copysign %[[VAL_4]], %[[VAL_0]] : f32
|
||||
// CHECK: %[[VAL_22:.*]] = select %[[VAL_20]], %[[VAL_0]], %[[VAL_21]] : f32
|
||||
// CHECK: %[[VAL_23:.*]] = mulf %[[VAL_22]], %[[VAL_22]] : f32
|
||||
// CHECK: %[[VAL_24:.*]] = mulf %[[VAL_23]], %[[VAL_5]] : f32
|
||||
// CHECK: %[[VAL_25:.*]] = addf %[[VAL_24]], %[[VAL_6]] : f32
|
||||
// CHECK: %[[VAL_26:.*]] = mulf %[[VAL_23]], %[[VAL_25]] : f32
|
||||
// CHECK: %[[VAL_27:.*]] = addf %[[VAL_26]], %[[VAL_7]] : f32
|
||||
// CHECK: %[[VAL_28:.*]] = mulf %[[VAL_23]], %[[VAL_27]] : f32
|
||||
// CHECK: %[[VAL_29:.*]] = addf %[[VAL_28]], %[[VAL_8]] : f32
|
||||
// CHECK: %[[VAL_30:.*]] = mulf %[[VAL_23]], %[[VAL_29]] : f32
|
||||
// CHECK: %[[VAL_31:.*]] = addf %[[VAL_30]], %[[VAL_9]] : f32
|
||||
// CHECK: %[[VAL_32:.*]] = mulf %[[VAL_23]], %[[VAL_31]] : f32
|
||||
// CHECK: %[[VAL_33:.*]] = addf %[[VAL_32]], %[[VAL_10]] : f32
|
||||
// CHECK: %[[VAL_34:.*]] = mulf %[[VAL_23]], %[[VAL_33]] : f32
|
||||
// CHECK: %[[VAL_35:.*]] = addf %[[VAL_34]], %[[VAL_11]] : f32
|
||||
// CHECK: %[[VAL_36:.*]] = mulf %[[VAL_22]], %[[VAL_35]] : f32
|
||||
// CHECK: %[[VAL_37:.*]] = mulf %[[VAL_23]], %[[VAL_12]] : f32
|
||||
// CHECK: %[[VAL_38:.*]] = addf %[[VAL_37]], %[[VAL_13]] : f32
|
||||
// CHECK: %[[VAL_39:.*]] = mulf %[[VAL_23]], %[[VAL_38]] : f32
|
||||
// CHECK: %[[VAL_40:.*]] = addf %[[VAL_39]], %[[VAL_14]] : f32
|
||||
// CHECK: %[[VAL_41:.*]] = mulf %[[VAL_23]], %[[VAL_40]] : f32
|
||||
// CHECK: %[[VAL_42:.*]] = addf %[[VAL_41]], %[[VAL_15]] : f32
|
||||
// CHECK: %[[VAL_43:.*]] = divf %[[VAL_36]], %[[VAL_42]] : f32
|
||||
// CHECK: %[[VAL_44:.*]] = select %[[VAL_19]], %[[VAL_0]], %[[VAL_43]] : f32
|
||||
// CHECK: %[[VAL_45:.*]] = select %[[VAL_18]], %[[VAL_44]], %[[VAL_17]] : f32
|
||||
// CHECK: return %[[VAL_45]] : f32
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
|
||||
// -----
|
||||
|
||||
func @tanh_f16(%arg0 : f16) -> f16 {
|
||||
%res = tanh %arg0 : f16
|
||||
return %res : f16
|
||||
}
|
||||
|
||||
// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
|
||||
// CHECK: module {
|
||||
|
||||
// CHECK-LABEL: func @tanh_f16(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: f16) -> f16 {
|
||||
// CHECK: %[[VAL_1:.*]] = constant 2.000000e+01 : f32
|
||||
// CHECK: %[[VAL_2:.*]] = constant 1.000000e+00 : f32
|
||||
// CHECK: %[[VAL_3:.*]] = constant 4.000000e-04 : f32
|
||||
// CHECK: %[[VAL_4:.*]] = constant 9.000000e+00 : f32
|
||||
// CHECK: %[[VAL_5:.*]] = constant -2.76076837E-16 : f32
|
||||
// CHECK: %[[VAL_6:.*]] = constant 2.00018794E-13 : f32
|
||||
// CHECK: %[[VAL_7:.*]] = constant -8.60467184E-11 : f32
|
||||
// CHECK: %[[VAL_8:.*]] = constant 5.12229725E-8 : f32
|
||||
// CHECK: %[[VAL_9:.*]] = constant 1.48572235E-5 : f32
|
||||
// CHECK: %[[VAL_10:.*]] = constant 6.37261954E-4 : f32
|
||||
// CHECK: %[[VAL_11:.*]] = constant 0.00489352457 : f32
|
||||
// CHECK: %[[VAL_12:.*]] = constant 1.19825836E-6 : f32
|
||||
// CHECK: %[[VAL_13:.*]] = constant 1.18534706E-4 : f32
|
||||
// CHECK: %[[VAL_14:.*]] = constant 0.00226843474 : f32
|
||||
// CHECK: %[[VAL_15:.*]] = constant 0.00489352504 : f32
|
||||
// CHECK: %[[VAL_16:.*]] = fpext %[[VAL_0]] : f16 to f32
|
||||
// CHECK: %[[VAL_17:.*]] = absf %[[VAL_16]] : f32
|
||||
// CHECK: %[[VAL_18:.*]] = copysign %[[VAL_2]], %[[VAL_16]] : f32
|
||||
// CHECK: %[[VAL_19:.*]] = cmpf "ult", %[[VAL_17]], %[[VAL_1]] : f32
|
||||
// CHECK: %[[VAL_20:.*]] = cmpf "olt", %[[VAL_17]], %[[VAL_3]] : f32
|
||||
// CHECK: %[[VAL_21:.*]] = cmpf "ule", %[[VAL_17]], %[[VAL_4]] : f32
|
||||
// CHECK: %[[VAL_22:.*]] = copysign %[[VAL_4]], %[[VAL_16]] : f32
|
||||
// CHECK: %[[VAL_23:.*]] = select %[[VAL_21]], %[[VAL_16]], %[[VAL_22]] : f32
|
||||
// CHECK: %[[VAL_24:.*]] = mulf %[[VAL_23]], %[[VAL_23]] : f32
|
||||
// CHECK: %[[VAL_25:.*]] = mulf %[[VAL_24]], %[[VAL_5]] : f32
|
||||
// CHECK: %[[VAL_26:.*]] = addf %[[VAL_25]], %[[VAL_6]] : f32
|
||||
// CHECK: %[[VAL_27:.*]] = mulf %[[VAL_24]], %[[VAL_26]] : f32
|
||||
// CHECK: %[[VAL_28:.*]] = addf %[[VAL_27]], %[[VAL_7]] : f32
|
||||
// CHECK: %[[VAL_29:.*]] = mulf %[[VAL_24]], %[[VAL_28]] : f32
|
||||
// CHECK: %[[VAL_30:.*]] = addf %[[VAL_29]], %[[VAL_8]] : f32
|
||||
// CHECK: %[[VAL_31:.*]] = mulf %[[VAL_24]], %[[VAL_30]] : f32
|
||||
// CHECK: %[[VAL_32:.*]] = addf %[[VAL_31]], %[[VAL_9]] : f32
|
||||
// CHECK: %[[VAL_33:.*]] = mulf %[[VAL_24]], %[[VAL_32]] : f32
|
||||
// CHECK: %[[VAL_34:.*]] = addf %[[VAL_33]], %[[VAL_10]] : f32
|
||||
// CHECK: %[[VAL_35:.*]] = mulf %[[VAL_24]], %[[VAL_34]] : f32
|
||||
// CHECK: %[[VAL_36:.*]] = addf %[[VAL_35]], %[[VAL_11]] : f32
|
||||
// CHECK: %[[VAL_37:.*]] = mulf %[[VAL_23]], %[[VAL_36]] : f32
|
||||
// CHECK: %[[VAL_38:.*]] = mulf %[[VAL_24]], %[[VAL_12]] : f32
|
||||
// CHECK: %[[VAL_39:.*]] = addf %[[VAL_38]], %[[VAL_13]] : f32
|
||||
// CHECK: %[[VAL_40:.*]] = mulf %[[VAL_24]], %[[VAL_39]] : f32
|
||||
// CHECK: %[[VAL_41:.*]] = addf %[[VAL_40]], %[[VAL_14]] : f32
|
||||
// CHECK: %[[VAL_42:.*]] = mulf %[[VAL_24]], %[[VAL_41]] : f32
|
||||
// CHECK: %[[VAL_43:.*]] = addf %[[VAL_42]], %[[VAL_15]] : f32
|
||||
// CHECK: %[[VAL_44:.*]] = divf %[[VAL_37]], %[[VAL_43]] : f32
|
||||
// CHECK: %[[VAL_45:.*]] = select %[[VAL_20]], %[[VAL_16]], %[[VAL_44]] : f32
|
||||
// CHECK: %[[VAL_46:.*]] = select %[[VAL_19]], %[[VAL_45]], %[[VAL_18]] : f32
|
||||
// CHECK: %[[VAL_47:.*]] = fptrunc %[[VAL_46]] : f32 to f16
|
||||
// CHECK: return %[[VAL_47]] : f16
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
|
@ -0,0 +1,167 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This file implements logic for lowering the tanh standard ops to an
|
||||
// approximation.
|
||||
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/xla/transforms/rewriters.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
/// Emits the fast tanh approximation that is also used by XLA.
|
||||
static Value EmitTanhApproximation(Value input, Value abs_value, Location loc,
|
||||
PatternRewriter &rewriter) {
|
||||
// For small values of x, we can approximate tanh(x)=x. For extremely small
|
||||
// values of x (|x| < 1e-37), the other approximation would evaluate
|
||||
// tanh(x) = 0.
|
||||
constexpr float kCanUseApprox = 0.0004;
|
||||
Value can_use_approx =
|
||||
rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(kCanUseApprox));
|
||||
Value return_input = rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT,
|
||||
abs_value, can_use_approx);
|
||||
|
||||
// Clamp the input to [-9, 9].
|
||||
Value plus_nine =
|
||||
rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(9.0));
|
||||
Value smaller_than_nine =
|
||||
rewriter.create<CmpFOp>(loc, CmpFPredicate::ULE, abs_value, plus_nine);
|
||||
Value input_clamped = rewriter.create<SelectOp>(
|
||||
loc, smaller_than_nine, input,
|
||||
rewriter.create<CopySignOp>(loc, plus_nine, input));
|
||||
|
||||
static constexpr std::array<float, 7> numerator_coeffs{
|
||||
-2.76076847742355e-16f, 2.00018790482477e-13f, -8.60467152213735e-11f,
|
||||
5.12229709037114e-08f, 1.48572235717979e-05f, 6.37261928875436e-04f,
|
||||
4.89352455891786e-03f};
|
||||
|
||||
static constexpr std::array<float, 4> denominator_coeffs{
|
||||
1.19825839466702e-06f, 1.18534705686654e-04f, 2.26843463243900e-03f,
|
||||
4.89352518554385e-03f};
|
||||
|
||||
Value input_squared =
|
||||
rewriter.create<MulFOp>(loc, input_clamped, input_clamped);
|
||||
Value numerator = rewriter.create<ConstantOp>(
|
||||
loc, rewriter.getF32FloatAttr(numerator_coeffs[0]));
|
||||
for (int i = 1; i < numerator_coeffs.size(); i++) {
|
||||
numerator = rewriter.create<AddFOp>(
|
||||
loc, rewriter.create<MulFOp>(loc, input_squared, numerator),
|
||||
rewriter.create<ConstantOp>(
|
||||
loc, rewriter.getF32FloatAttr(numerator_coeffs[i])));
|
||||
}
|
||||
|
||||
numerator = rewriter.create<MulFOp>(loc, input_clamped, numerator);
|
||||
|
||||
Value denominator = rewriter.create<ConstantOp>(
|
||||
loc, rewriter.getF32FloatAttr(denominator_coeffs[0]));
|
||||
for (int i = 1; i < denominator_coeffs.size(); i++) {
|
||||
denominator = rewriter.create<AddFOp>(
|
||||
loc, rewriter.create<MulFOp>(loc, input_squared, denominator),
|
||||
rewriter.create<ConstantOp>(
|
||||
loc, rewriter.getF32FloatAttr(denominator_coeffs[i])));
|
||||
}
|
||||
|
||||
Value approx = rewriter.create<DivFOp>(loc, numerator, denominator);
|
||||
|
||||
return rewriter.create<SelectOp>(loc, return_input, input, approx);
|
||||
}
|
||||
|
||||
class ApproximateTanhLowering : public OpRewritePattern<TanhOp> {
|
||||
public:
|
||||
explicit ApproximateTanhLowering(MLIRContext *ctx)
|
||||
: OpRewritePattern<TanhOp>(ctx, 100) {}
|
||||
|
||||
LogicalResult matchAndRewrite(TanhOp tanhOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Type operand_type = tanhOp.getType();
|
||||
|
||||
if (operand_type.isF64()) {
|
||||
// Similar to XLA, do not rewrite f64 as precision might matter.
|
||||
return failure();
|
||||
}
|
||||
|
||||
Location loc = tanhOp.getLoc();
|
||||
Value input = tanhOp.operand();
|
||||
if (operand_type.isF16()) {
|
||||
input = rewriter.create<FPExtOp>(loc, input, rewriter.getF32Type());
|
||||
}
|
||||
|
||||
// If we still do not have f32, fail.
|
||||
if (!input.getType().isF32()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// For |operand| > 20.0, we just return -1/1.
|
||||
constexpr double kMaxValue = 20.0;
|
||||
Value max_value =
|
||||
rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(kMaxValue));
|
||||
Value abs_value = rewriter.create<AbsFOp>(loc, input);
|
||||
|
||||
Value one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0));
|
||||
Value one_with_sign = rewriter.create<CopySignOp>(loc, one, input);
|
||||
|
||||
Value smaller_than_twenty =
|
||||
rewriter.create<CmpFOp>(loc, CmpFPredicate::ULT, abs_value, max_value);
|
||||
|
||||
// Otherwise, we use the approximation.
|
||||
Value approx = EmitTanhApproximation(input, abs_value, loc, rewriter);
|
||||
|
||||
Value result = rewriter.create<SelectOp>(loc, smaller_than_twenty, approx,
|
||||
one_with_sign);
|
||||
|
||||
// Truncate back if needed.
|
||||
if (operand_type.isF16()) {
|
||||
result = rewriter.create<FPTruncOp>(loc, result, rewriter.getF16Type());
|
||||
}
|
||||
|
||||
rewriter.replaceOp(tanhOp, {result});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct LegalizeTanhToApproximation
|
||||
: public PassWrapper<LegalizeTanhToApproximation, FunctionPass> {
|
||||
/// Perform the lowering of standard dialect operations to approximations.
|
||||
void runOnFunction() override {
|
||||
OwningRewritePatternList patterns;
|
||||
PopulateTanhToApproximationPatterns(&getContext(), &patterns);
|
||||
applyPatternsAndFoldGreedily(getFunction(), patterns);
|
||||
}
|
||||
};
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
std::unique_ptr<mlir::OperationPass<mlir::FuncOp>>
|
||||
createLegalizeTanhToApproximationPass() {
|
||||
return std::make_unique<LegalizeTanhToApproximation>();
|
||||
}
|
||||
|
||||
void PopulateTanhToApproximationPatterns(mlir::MLIRContext *context,
|
||||
OwningRewritePatternList *patterns) {
|
||||
patterns->insert<ApproximateTanhLowering>(context);
|
||||
}
|
||||
|
||||
static PassRegistration<LegalizeTanhToApproximation> legalize_pass(
|
||||
"xla-legalize-tanh-to-approximation",
|
||||
"Legalize tanh from standard dialect to an approximation");
|
||||
|
||||
} // namespace xla
|
||||
} // namespace mlir
|
@ -115,6 +115,13 @@ std::unique_ptr<Pass> createLhloCopyRemovalPass();
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToParallelLoopsPass();
|
||||
|
||||
} // namespace xla_lhlo
|
||||
|
||||
namespace xla {
|
||||
|
||||
/// Lowers the standard TanhOp to an approximation that does not use intrinsics.
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeTanhToApproximationPass();
|
||||
|
||||
} // namespace xla
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_PASSES_H_
|
||||
|
@ -91,6 +91,14 @@ void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
|
||||
|
||||
} // namespace xla_chlo
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Populates a pattern that translates the standard TanhOp to an approximation
|
||||
// that does not use intrinsics.
|
||||
void PopulateTanhToApproximationPatterns(MLIRContext *context,
|
||||
OwningRewritePatternList *patterns);
|
||||
|
||||
} // namespace xla
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_REWRITERS_H_
|
||||
|
Loading…
Reference in New Issue
Block a user