From b0199879f09e4bd0068e627e1a43384dcff34983 Mon Sep 17 00:00:00 2001 From: Stephan Herhut <herhut@google.com> Date: Wed, 17 Jun 2020 01:26:51 -0700 Subject: [PATCH] Add lowering for TanhOp that uses an approximation instead of lowering to intrinsics. The same approximation is used by the XLA compiler. PiperOrigin-RevId: 316844625 Change-Id: I1a909bd063509491a6a58ae0acc3bfb919cb34d5 --- tensorflow/compiler/mlir/xla/BUILD | 19 ++ .../tests/legalize_tanh_to_approximation.mlir | 134 ++++++++++++++ .../legalize_tanh_to_approximation.cc | 167 ++++++++++++++++++ .../compiler/mlir/xla/transforms/passes.h | 7 + .../compiler/mlir/xla/transforms/rewriters.h | 8 + 5 files changed, 335 insertions(+) create mode 100644 tensorflow/compiler/mlir/xla/tests/legalize_tanh_to_approximation.mlir create mode 100644 tensorflow/compiler/mlir/xla/transforms/legalize_tanh_to_approximation.cc diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index 43458aab2d3..d089f80d571 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -515,6 +515,24 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "xla_legalize_tanh_to_approximation", + srcs = ["transforms/legalize_tanh_to_approximation.cc"], + hdrs = [ + "transforms/passes.h", + "transforms/rewriters.h", + ], + deps = [ + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", + ], + alwayslink = 1, +) + gentbl( name = "xla_lower_complex_inc_gen", tbl_outs = [ @@ -946,6 +964,7 @@ cc_library( ":xla_hlo_fusion", ":xla_hlo_to_lhlo_with_xla", ":xla_legalize_control_flow", + ":xla_legalize_tanh_to_approximation", ":xla_legalize_tf", ":xla_legalize_tf_with_tf2xla", ":xla_legalize_to_linalg", diff --git a/tensorflow/compiler/mlir/xla/tests/legalize_tanh_to_approximation.mlir b/tensorflow/compiler/mlir/xla/tests/legalize_tanh_to_approximation.mlir new file mode 100644 index 00000000000..a8286c9b5a9 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/legalize_tanh_to_approximation.mlir @@ -0,0 +1,134 @@ +// RUN: xla-opt -xla-legalize-tanh-to-approximation -split-input-file %s | FileCheck %s + +func @tanh_f64(%arg0 : f64) -> f64 { + %res = tanh %arg0 : f64 + return %res : f64 +} + +// CHECK-LABEL: @tanh_f64 +// CHECK: tanh + +// ----- + +func @tanh_f32(%arg0 : f32) -> f32 { + %res = tanh %arg0 : f32 + return %res : f32 +} + +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py +// CHECK: module { + +// CHECK-LABEL: func @tanh_f32( +// CHECK-SAME: %[[VAL_0:.*]]: f32) -> f32 { +// CHECK: %[[VAL_1:.*]] = constant 2.000000e+01 : f32 +// CHECK: %[[VAL_2:.*]] = constant 1.000000e+00 : f32 +// CHECK: %[[VAL_3:.*]] = constant 4.000000e-04 : f32 +// CHECK: %[[VAL_4:.*]] = constant 9.000000e+00 : f32 +// CHECK: %[[VAL_5:.*]] = constant -2.76076837E-16 : f32 +// CHECK: %[[VAL_6:.*]] = constant 2.00018794E-13 : f32 +// CHECK: %[[VAL_7:.*]] = constant -8.60467184E-11 : f32 +// CHECK: %[[VAL_8:.*]] = constant 5.12229725E-8 : f32 +// CHECK: %[[VAL_9:.*]] = constant 1.48572235E-5 : f32 +// CHECK: %[[VAL_10:.*]] = constant 6.37261954E-4 : f32 +// CHECK: %[[VAL_11:.*]] = constant 0.00489352457 : f32 +// CHECK: %[[VAL_12:.*]] = constant 1.19825836E-6 : f32 +// CHECK: %[[VAL_13:.*]] = constant 1.18534706E-4 : f32 +// CHECK: %[[VAL_14:.*]] = constant 0.00226843474 : f32 +// CHECK: %[[VAL_15:.*]] = constant 0.00489352504 : f32 +// CHECK: %[[VAL_16:.*]] = absf %[[VAL_0]] : f32 +// CHECK: %[[VAL_17:.*]] = copysign %[[VAL_2]], %[[VAL_0]] : f32 +// CHECK: %[[VAL_18:.*]] = cmpf "ult", %[[VAL_16]], %[[VAL_1]] : f32 +// CHECK: %[[VAL_19:.*]] = cmpf "olt", %[[VAL_16]], %[[VAL_3]] : f32 +// CHECK: %[[VAL_20:.*]] = cmpf "ule", %[[VAL_16]], %[[VAL_4]] : f32 +// CHECK: %[[VAL_21:.*]] = copysign %[[VAL_4]], %[[VAL_0]] : f32 +// CHECK: %[[VAL_22:.*]] = select %[[VAL_20]], %[[VAL_0]], %[[VAL_21]] : f32 +// CHECK: %[[VAL_23:.*]] = mulf %[[VAL_22]], %[[VAL_22]] : f32 +// CHECK: %[[VAL_24:.*]] = mulf %[[VAL_23]], %[[VAL_5]] : f32 +// CHECK: %[[VAL_25:.*]] = addf %[[VAL_24]], %[[VAL_6]] : f32 +// CHECK: %[[VAL_26:.*]] = mulf %[[VAL_23]], %[[VAL_25]] : f32 +// CHECK: %[[VAL_27:.*]] = addf %[[VAL_26]], %[[VAL_7]] : f32 +// CHECK: %[[VAL_28:.*]] = mulf %[[VAL_23]], %[[VAL_27]] : f32 +// CHECK: %[[VAL_29:.*]] = addf %[[VAL_28]], %[[VAL_8]] : f32 +// CHECK: %[[VAL_30:.*]] = mulf %[[VAL_23]], %[[VAL_29]] : f32 +// CHECK: %[[VAL_31:.*]] = addf %[[VAL_30]], %[[VAL_9]] : f32 +// CHECK: %[[VAL_32:.*]] = mulf %[[VAL_23]], %[[VAL_31]] : f32 +// CHECK: %[[VAL_33:.*]] = addf %[[VAL_32]], %[[VAL_10]] : f32 +// CHECK: %[[VAL_34:.*]] = mulf %[[VAL_23]], %[[VAL_33]] : f32 +// CHECK: %[[VAL_35:.*]] = addf %[[VAL_34]], %[[VAL_11]] : f32 +// CHECK: %[[VAL_36:.*]] = mulf %[[VAL_22]], %[[VAL_35]] : f32 +// CHECK: %[[VAL_37:.*]] = mulf %[[VAL_23]], %[[VAL_12]] : f32 +// CHECK: %[[VAL_38:.*]] = addf %[[VAL_37]], %[[VAL_13]] : f32 +// CHECK: %[[VAL_39:.*]] = mulf %[[VAL_23]], %[[VAL_38]] : f32 +// CHECK: %[[VAL_40:.*]] = addf %[[VAL_39]], %[[VAL_14]] : f32 +// CHECK: %[[VAL_41:.*]] = mulf %[[VAL_23]], %[[VAL_40]] : f32 +// CHECK: %[[VAL_42:.*]] = addf %[[VAL_41]], %[[VAL_15]] : f32 +// CHECK: %[[VAL_43:.*]] = divf %[[VAL_36]], %[[VAL_42]] : f32 +// CHECK: %[[VAL_44:.*]] = select %[[VAL_19]], %[[VAL_0]], %[[VAL_43]] : f32 +// CHECK: %[[VAL_45:.*]] = select %[[VAL_18]], %[[VAL_44]], %[[VAL_17]] : f32 +// CHECK: return %[[VAL_45]] : f32 +// CHECK: } +// CHECK: } + +// ----- + +func @tanh_f16(%arg0 : f16) -> f16 { + %res = tanh %arg0 : f16 + return %res : f16 +} + +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py +// CHECK: module { + +// CHECK-LABEL: func @tanh_f16( +// CHECK-SAME: %[[VAL_0:.*]]: f16) -> f16 { +// CHECK: %[[VAL_1:.*]] = constant 2.000000e+01 : f32 +// CHECK: %[[VAL_2:.*]] = constant 1.000000e+00 : f32 +// CHECK: %[[VAL_3:.*]] = constant 4.000000e-04 : f32 +// CHECK: %[[VAL_4:.*]] = constant 9.000000e+00 : f32 +// CHECK: %[[VAL_5:.*]] = constant -2.76076837E-16 : f32 +// CHECK: %[[VAL_6:.*]] = constant 2.00018794E-13 : f32 +// CHECK: %[[VAL_7:.*]] = constant -8.60467184E-11 : f32 +// CHECK: %[[VAL_8:.*]] = constant 5.12229725E-8 : f32 +// CHECK: %[[VAL_9:.*]] = constant 1.48572235E-5 : f32 +// CHECK: %[[VAL_10:.*]] = constant 6.37261954E-4 : f32 +// CHECK: %[[VAL_11:.*]] = constant 0.00489352457 : f32 +// CHECK: %[[VAL_12:.*]] = constant 1.19825836E-6 : f32 +// CHECK: %[[VAL_13:.*]] = constant 1.18534706E-4 : f32 +// CHECK: %[[VAL_14:.*]] = constant 0.00226843474 : f32 +// CHECK: %[[VAL_15:.*]] = constant 0.00489352504 : f32 +// CHECK: %[[VAL_16:.*]] = fpext %[[VAL_0]] : f16 to f32 +// CHECK: %[[VAL_17:.*]] = absf %[[VAL_16]] : f32 +// CHECK: %[[VAL_18:.*]] = copysign %[[VAL_2]], %[[VAL_16]] : f32 +// CHECK: %[[VAL_19:.*]] = cmpf "ult", %[[VAL_17]], %[[VAL_1]] : f32 +// CHECK: %[[VAL_20:.*]] = cmpf "olt", %[[VAL_17]], %[[VAL_3]] : f32 +// CHECK: %[[VAL_21:.*]] = cmpf "ule", %[[VAL_17]], %[[VAL_4]] : f32 +// CHECK: %[[VAL_22:.*]] = copysign %[[VAL_4]], %[[VAL_16]] : f32 +// CHECK: %[[VAL_23:.*]] = select %[[VAL_21]], %[[VAL_16]], %[[VAL_22]] : f32 +// CHECK: %[[VAL_24:.*]] = mulf %[[VAL_23]], %[[VAL_23]] : f32 +// CHECK: %[[VAL_25:.*]] = mulf %[[VAL_24]], %[[VAL_5]] : f32 +// CHECK: %[[VAL_26:.*]] = addf %[[VAL_25]], %[[VAL_6]] : f32 +// CHECK: %[[VAL_27:.*]] = mulf %[[VAL_24]], %[[VAL_26]] : f32 +// CHECK: %[[VAL_28:.*]] = addf %[[VAL_27]], %[[VAL_7]] : f32 +// CHECK: %[[VAL_29:.*]] = mulf %[[VAL_24]], %[[VAL_28]] : f32 +// CHECK: %[[VAL_30:.*]] = addf %[[VAL_29]], %[[VAL_8]] : f32 +// CHECK: %[[VAL_31:.*]] = mulf %[[VAL_24]], %[[VAL_30]] : f32 +// CHECK: %[[VAL_32:.*]] = addf %[[VAL_31]], %[[VAL_9]] : f32 +// CHECK: %[[VAL_33:.*]] = mulf %[[VAL_24]], %[[VAL_32]] : f32 +// CHECK: %[[VAL_34:.*]] = addf %[[VAL_33]], %[[VAL_10]] : f32 +// CHECK: %[[VAL_35:.*]] = mulf %[[VAL_24]], %[[VAL_34]] : f32 +// CHECK: %[[VAL_36:.*]] = addf %[[VAL_35]], %[[VAL_11]] : f32 +// CHECK: %[[VAL_37:.*]] = mulf %[[VAL_23]], %[[VAL_36]] : f32 +// CHECK: %[[VAL_38:.*]] = mulf %[[VAL_24]], %[[VAL_12]] : f32 +// CHECK: %[[VAL_39:.*]] = addf %[[VAL_38]], %[[VAL_13]] : f32 +// CHECK: %[[VAL_40:.*]] = mulf %[[VAL_24]], %[[VAL_39]] : f32 +// CHECK: %[[VAL_41:.*]] = addf %[[VAL_40]], %[[VAL_14]] : f32 +// CHECK: %[[VAL_42:.*]] = mulf %[[VAL_24]], %[[VAL_41]] : f32 +// CHECK: %[[VAL_43:.*]] = addf %[[VAL_42]], %[[VAL_15]] : f32 +// CHECK: %[[VAL_44:.*]] = divf %[[VAL_37]], %[[VAL_43]] : f32 +// CHECK: %[[VAL_45:.*]] = select %[[VAL_20]], %[[VAL_16]], %[[VAL_44]] : f32 +// CHECK: %[[VAL_46:.*]] = select %[[VAL_19]], %[[VAL_45]], %[[VAL_18]] : f32 +// CHECK: %[[VAL_47:.*]] = fptrunc %[[VAL_46]] : f32 to f16 +// CHECK: return %[[VAL_47]] : f16 +// CHECK: } +// CHECK: } + diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tanh_to_approximation.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tanh_to_approximation.cc new file mode 100644 index 00000000000..9696db377da --- /dev/null +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tanh_to_approximation.cc @@ -0,0 +1,167 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file implements logic for lowering the tanh standard ops to an +// approximation. + +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "tensorflow/compiler/mlir/xla/transforms/passes.h" +#include "tensorflow/compiler/mlir/xla/transforms/rewriters.h" + +namespace mlir { +namespace xla { +namespace { + +/// Emits the fast tanh approximation that is also used by XLA. +static Value EmitTanhApproximation(Value input, Value abs_value, Location loc, + PatternRewriter &rewriter) { + // For small values of x, we can approximate tanh(x)=x. For extremely small + // values of x (|x| < 1e-37), the other approximation would evaluate + // tanh(x) = 0. + constexpr float kCanUseApprox = 0.0004; + Value can_use_approx = + rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(kCanUseApprox)); + Value return_input = rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, + abs_value, can_use_approx); + + // Clamp the input to [-9, 9]. + Value plus_nine = + rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(9.0)); + Value smaller_than_nine = + rewriter.create<CmpFOp>(loc, CmpFPredicate::ULE, abs_value, plus_nine); + Value input_clamped = rewriter.create<SelectOp>( + loc, smaller_than_nine, input, + rewriter.create<CopySignOp>(loc, plus_nine, input)); + + static constexpr std::array<float, 7> numerator_coeffs{ + -2.76076847742355e-16f, 2.00018790482477e-13f, -8.60467152213735e-11f, + 5.12229709037114e-08f, 1.48572235717979e-05f, 6.37261928875436e-04f, + 4.89352455891786e-03f}; + + static constexpr std::array<float, 4> denominator_coeffs{ + 1.19825839466702e-06f, 1.18534705686654e-04f, 2.26843463243900e-03f, + 4.89352518554385e-03f}; + + Value input_squared = + rewriter.create<MulFOp>(loc, input_clamped, input_clamped); + Value numerator = rewriter.create<ConstantOp>( + loc, rewriter.getF32FloatAttr(numerator_coeffs[0])); + for (int i = 1; i < numerator_coeffs.size(); i++) { + numerator = rewriter.create<AddFOp>( + loc, rewriter.create<MulFOp>(loc, input_squared, numerator), + rewriter.create<ConstantOp>( + loc, rewriter.getF32FloatAttr(numerator_coeffs[i]))); + } + + numerator = rewriter.create<MulFOp>(loc, input_clamped, numerator); + + Value denominator = rewriter.create<ConstantOp>( + loc, rewriter.getF32FloatAttr(denominator_coeffs[0])); + for (int i = 1; i < denominator_coeffs.size(); i++) { + denominator = rewriter.create<AddFOp>( + loc, rewriter.create<MulFOp>(loc, input_squared, denominator), + rewriter.create<ConstantOp>( + loc, rewriter.getF32FloatAttr(denominator_coeffs[i]))); + } + + Value approx = rewriter.create<DivFOp>(loc, numerator, denominator); + + return rewriter.create<SelectOp>(loc, return_input, input, approx); +} + +class ApproximateTanhLowering : public OpRewritePattern<TanhOp> { + public: + explicit ApproximateTanhLowering(MLIRContext *ctx) + : OpRewritePattern<TanhOp>(ctx, 100) {} + + LogicalResult matchAndRewrite(TanhOp tanhOp, + PatternRewriter &rewriter) const override { + Type operand_type = tanhOp.getType(); + + if (operand_type.isF64()) { + // Similar to XLA, do not rewrite f64 as precision might matter. + return failure(); + } + + Location loc = tanhOp.getLoc(); + Value input = tanhOp.operand(); + if (operand_type.isF16()) { + input = rewriter.create<FPExtOp>(loc, input, rewriter.getF32Type()); + } + + // If we still do not have f32, fail. + if (!input.getType().isF32()) { + return failure(); + } + + // For |operand| > 20.0, we just return -1/1. + constexpr double kMaxValue = 20.0; + Value max_value = + rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(kMaxValue)); + Value abs_value = rewriter.create<AbsFOp>(loc, input); + + Value one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0)); + Value one_with_sign = rewriter.create<CopySignOp>(loc, one, input); + + Value smaller_than_twenty = + rewriter.create<CmpFOp>(loc, CmpFPredicate::ULT, abs_value, max_value); + + // Otherwise, we use the approximation. + Value approx = EmitTanhApproximation(input, abs_value, loc, rewriter); + + Value result = rewriter.create<SelectOp>(loc, smaller_than_twenty, approx, + one_with_sign); + + // Truncate back if needed. + if (operand_type.isF16()) { + result = rewriter.create<FPTruncOp>(loc, result, rewriter.getF16Type()); + } + + rewriter.replaceOp(tanhOp, {result}); + return success(); + } +}; + +struct LegalizeTanhToApproximation + : public PassWrapper<LegalizeTanhToApproximation, FunctionPass> { + /// Perform the lowering of standard dialect operations to approximations. + void runOnFunction() override { + OwningRewritePatternList patterns; + PopulateTanhToApproximationPatterns(&getContext(), &patterns); + applyPatternsAndFoldGreedily(getFunction(), patterns); + } +}; + +} // anonymous namespace + +std::unique_ptr<mlir::OperationPass<mlir::FuncOp>> +createLegalizeTanhToApproximationPass() { + return std::make_unique<LegalizeTanhToApproximation>(); +} + +void PopulateTanhToApproximationPatterns(mlir::MLIRContext *context, + OwningRewritePatternList *patterns) { + patterns->insert<ApproximateTanhLowering>(context); +} + +static PassRegistration<LegalizeTanhToApproximation> legalize_pass( + "xla-legalize-tanh-to-approximation", + "Legalize tanh from standard dialect to an approximation"); + +} // namespace xla +} // namespace mlir diff --git a/tensorflow/compiler/mlir/xla/transforms/passes.h b/tensorflow/compiler/mlir/xla/transforms/passes.h index a2af8124786..3db0bc3b474 100644 --- a/tensorflow/compiler/mlir/xla/transforms/passes.h +++ b/tensorflow/compiler/mlir/xla/transforms/passes.h @@ -115,6 +115,13 @@ std::unique_ptr<Pass> createLhloCopyRemovalPass(); std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToParallelLoopsPass(); } // namespace xla_lhlo + +namespace xla { + +/// Lowers the standard TanhOp to an approximation that does not use intrinsics. +std::unique_ptr<OperationPass<FuncOp>> createLegalizeTanhToApproximationPass(); + +} // namespace xla } // namespace mlir #endif // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_PASSES_H_ diff --git a/tensorflow/compiler/mlir/xla/transforms/rewriters.h b/tensorflow/compiler/mlir/xla/transforms/rewriters.h index 59347198fe4..7303b87be75 100644 --- a/tensorflow/compiler/mlir/xla/transforms/rewriters.h +++ b/tensorflow/compiler/mlir/xla/transforms/rewriters.h @@ -91,6 +91,14 @@ void PopulateLegalizeChloToHloPatterns(MLIRContext *context, } // namespace xla_chlo +namespace xla { + +// Populates a pattern that translates the standard TanhOp to an approximation +// that does not use intrinsics. +void PopulateTanhToApproximationPatterns(MLIRContext *context, + OwningRewritePatternList *patterns); + +} // namespace xla } // namespace mlir #endif // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_REWRITERS_H_