[MLIR][KernelGen] Add erf kernel for f32 arguments and missing lowerings

PiperOrigin-RevId: 352381016
Change-Id: Ib26ab051c37080fbd01d9f646e8244b2c65d2ec3
This commit is contained in:
A. Unique TensorFlower 2021-01-18 03:34:19 -08:00 committed by TensorFlower Gardener
parent a2766ce758
commit 3fcba829d8
9 changed files with 179 additions and 7 deletions

View File

@ -20,6 +20,7 @@ limitations under the License.
#define _USE_MATH_DEFINES
#include <cmath>
#include <numeric>
#include <vector>
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
@ -75,6 +76,63 @@ struct ConvertConstantLikeOp : public OpConversionPattern<ConstantLikeOp> {
}
};
Value MaterializePolynomialApproximation(
ConversionPatternRewriter &rewriter, Location loc, Value x,
const std::vector<float> &coefficients) {
Value poly = chlo::getConstantLike(rewriter, loc, 0.0, x);
for (float c : coefficients) {
poly = rewriter.create<mhlo::MulOp>(loc, x.getType(), poly, x);
poly = rewriter.create<mhlo::AddOp>(
loc, x.getType(), poly, chlo::getConstantLike(rewriter, loc, c, x));
}
return poly;
}
Value MaterializeErfApproximationF32(ConversionPatternRewriter &rewriter,
Location loc, Value operand) {
const std::vector<float> kAlpha{
-2.72614225801306e-10f, 2.77068142495902e-08f, -2.10102402082508e-06f,
-5.69250639462346e-05f, -7.34990630326855e-04f, -2.95459980854025e-03f,
-1.60960333262415e-02f,
};
const std::vector<float> kBeta{
-1.45660718464996e-05f, -2.13374055278905e-04f, -1.68282697438203e-03f,
-7.37332916720468e-03f, -1.42647390514189e-02f,
};
// Clamp argument between -4 and 4.
Value lb = chlo::getConstantLike(rewriter, loc, -4.0, operand);
Value ub = chlo::getConstantLike(rewriter, loc, 4.0, operand);
Value x =
rewriter.create<mhlo::ClampOp>(loc, operand.getType(), lb, operand, ub);
Value x_sq = rewriter.create<mhlo::MulOp>(loc, x, x);
// Materialize polynomial approximation for x in [-4, 4].
Value alpha_poly =
MaterializePolynomialApproximation(rewriter, loc, x_sq, kAlpha);
Value beta_poly =
MaterializePolynomialApproximation(rewriter, loc, x_sq, kBeta);
Value mul_x_alpha_poly = rewriter.create<mhlo::MulOp>(loc, x, alpha_poly);
return rewriter.create<mhlo::DivOp>(loc, mul_x_alpha_poly, beta_poly);
}
struct ConvertErfOp : public OpConversionPattern<ErfOp> {
using OpConversionPattern<ErfOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
ErfOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Type ty = getElementTypeOrSelf(op.getType());
// For now, we support only f32.
if (!ty.isF32()) return failure();
ErfOp::Adaptor transformed(operands);
rewriter.replaceOp(op, MaterializeErfApproximationF32(
rewriter, op.getLoc(), transformed.operand()));
return success();
}
};
// Converts binary ops that statically are determined to not broadcast directly
// to the corresponding mhlo non-broadcasting op.
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
@ -226,7 +284,7 @@ void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
context, patterns, 5);
// Other patterns.
patterns->insert<ConvertConstantLikeOp>(context);
patterns->insert<ConvertConstantLikeOp, ConvertErfOp>(context);
}
} // namespace chlo

View File

@ -35,3 +35,54 @@ func @conj(%arg0: tensor<3xcomplex<f32>>) -> tensor<3xcomplex<f32>> {
return %1 : tensor<3xcomplex<f32>>
}
// CHECK-LABEL: @erf_f32
// CHECK-SAME: %[[ARG:.*]]: tensor<f32>
func @erf_f32(%arg : tensor<f32>) -> tensor<f32> {
// CHECK: %[[TMP_0:.*]] = mhlo.constant dense<-4.000000e+00>
// CHECK: %[[TMP_1:.*]] = mhlo.constant dense<4.000000e+00>
// CHECK: %[[TMP_2:.*]] = "mhlo.clamp"(%[[TMP_0]], %[[ARG]], %[[TMP_1]])
// CHECK: %[[TMP_3:.*]] = mhlo.multiply %[[TMP_2]], %[[TMP_2]]
// CHECK: %[[TMP_4:.*]] = mhlo.constant dense<0.000000e+00>
// CHECK: %[[TMP_5:.*]] = mhlo.multiply %[[TMP_4]], %[[TMP_3]]
// CHECK: %[[TMP_6:.*]] = mhlo.constant dense<-2.72614237E-10>
// CHECK: %[[TMP_7:.*]] = mhlo.add %[[TMP_5]], %[[TMP_6]]
// CHECK: %[[TMP_8:.*]] = mhlo.multiply %[[TMP_7]], %[[TMP_3]]
// CHECK: %[[TMP_9:.*]] = mhlo.constant dense<2.77068146E-8>
// CHECK: %[[TMP_10:.*]] = mhlo.add %[[TMP_8]], %[[TMP_9]]
// CHECK: %[[TMP_11:.*]] = mhlo.multiply %[[TMP_10]], %[[TMP_3]]
// CHECK: %[[TMP_12:.*]] = mhlo.constant dense<-2.10102394E-6>
// CHECK: %[[TMP_13:.*]] = mhlo.add %[[TMP_11]], %[[TMP_12]]
// CHECK: %[[TMP_14:.*]] = mhlo.multiply %[[TMP_13]], %[[TMP_3]]
// CHECK: %[[TMP_15:.*]] = mhlo.constant dense<-5.69250624E-5>
// CHECK: %[[TMP_16:.*]] = mhlo.add %[[TMP_14]], %[[TMP_15]]
// CHECK: %[[TMP_17:.*]] = mhlo.multiply %[[TMP_16]], %[[TMP_3]]
// CHECK: %[[TMP_18:.*]] = mhlo.constant dense<-7.34990637E-4>
// CHECK: %[[TMP_19:.*]] = mhlo.add %[[TMP_17]], %[[TMP_18]]
// CHECK: %[[TMP_20:.*]] = mhlo.multiply %[[TMP_19]], %[[TMP_3]]
// CHECK: %[[TMP_21:.*]] = mhlo.constant dense<-2.954600e-03>
// CHECK: %[[TMP_22:.*]] = mhlo.add %[[TMP_20]], %[[TMP_21]]
// CHECK: %[[TMP_23:.*]] = mhlo.multiply %[[TMP_22]], %[[TMP_3]]
// CHECK: %[[TMP_24:.*]] = mhlo.constant dense<-0.0160960332>
// CHECK: %[[TMP_25:.*]] = mhlo.add %[[TMP_23]], %[[TMP_24]]
// CHECK: %[[TMP_26:.*]] = mhlo.constant dense<0.000000e+00>
// CHECK: %[[TMP_27:.*]] = mhlo.multiply %[[TMP_26]], %[[TMP_3]]
// CHECK: %[[TMP_28:.*]] = mhlo.constant dense<-1.45660715E-5>
// CHECK: %[[TMP_29:.*]] = mhlo.add %[[TMP_27]], %[[TMP_28]]
// CHECK: %[[TMP_30:.*]] = mhlo.multiply %[[TMP_29]], %[[TMP_3]]
// CHECK: %[[TMP_31:.*]] = mhlo.constant dense<-2.13374049E-4>
// CHECK: %[[TMP_32:.*]] = mhlo.add %[[TMP_30]], %[[TMP_31]]
// CHECK: %[[TMP_33:.*]] = mhlo.multiply %[[TMP_32]], %[[TMP_3]]
// CHECK: %[[TMP_34:.*]] = mhlo.constant dense<-0.00168282702>
// CHECK: %[[TMP_35:.*]] = mhlo.add %[[TMP_33]], %[[TMP_34]]
// CHECK: %[[TMP_36:.*]] = mhlo.multiply %[[TMP_35]], %[[TMP_3]]
// CHECK: %[[TMP_37:.*]] = mhlo.constant dense<-0.00737332925>
// CHECK: %[[TMP_38:.*]] = mhlo.add %[[TMP_36]], %[[TMP_37]]
// CHECK: %[[TMP_39:.*]] = mhlo.multiply %[[TMP_38]], %[[TMP_3]]
// CHECK: %[[TMP_40:.*]] = mhlo.constant dense<-0.0142647391>
// CHECK: %[[TMP_41:.*]] = mhlo.add %[[TMP_39]], %[[TMP_40]]
// CHECK: %[[TMP_42:.*]] = mhlo.multiply %[[TMP_2]], %[[TMP_25]]
// CHECK: %[[RESULT:.*]] = mhlo.divide %[[TMP_42]], %[[TMP_41]]
// CHECK: return %[[RESULT]]
%1 = "chlo.erf"(%arg) : (tensor<f32>) -> tensor<f32>
return %1 : tensor<f32>
}

View File

@ -16,8 +16,18 @@ limitations under the License.
#include "tensorflow/core/kernels/cwise_ops_common.h"
namespace tensorflow {
REGISTER3(UnaryOp, CPU, "Erf", functor::erf, float, Eigen::half, double);
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
REGISTER3(UnaryOp, GPU, "Erf", functor::erf, float, Eigen::half, double);
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
REGISTER2(UnaryOp, GPU, "Erf", functor::erf, Eigen::half, double);
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
!defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
REGISTER(UnaryOp, GPU, "Erf", functor::erf, float);
#endif
#endif
} // namespace tensorflow

View File

@ -55,6 +55,7 @@ filegroup(
"gpu_op_conj.cc",
"gpu_op_cos.cc",
"gpu_op_cosh.cc",
"gpu_op_erf.cc",
"gpu_op_exp.cc",
"gpu_op_expm1.cc",
"gpu_op_floor.cc",
@ -121,6 +122,7 @@ tf_kernel_library(
":conj_kernels",
":cos_kernels",
":cosh_kernels",
":erf_kernels",
":exp_kernels",
":expm1_kernels",
":floor_kernels",
@ -365,6 +367,15 @@ gen_kernel_library(
unroll_factors = "4",
)
gen_kernel_library(
name = "erf",
tile_size = "256",
types = [
"f32",
],
unroll_factors = "4",
)
gen_kernel_library(
name = "imag",
tile_size = "256",

View File

@ -134,6 +134,8 @@ class GpuBinaryOpTest : public OpsTestBase {
const test::GpuOpsTestConfig& config) {
// Prepare inputs.
int input_size = shape.num_elements();
CHECK(lhs_input.size() <= input_size && rhs_input.size() <= input_size &&
"expect input shape to hold all input values");
auto repeated_lhs_input =
test::RepeatInputToMatchShape(lhs_input, input_size);
auto repeated_rhs_input =
@ -165,6 +167,8 @@ class GpuBinaryOpTest : public OpsTestBase {
const test::GpuOpsTestConfig& config) {
// Prepare inputs.
TensorShape scalar_shape{};
CHECK(other_input.size() <= other_shape.num_elements() &&
"expect other input shape to hold all input values");
auto repeated_other_input =
test::RepeatInputToMatchShape(other_input, other_shape.num_elements());

View File

@ -0,0 +1,23 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
namespace tensorflow {
GENERATE_AND_REGISTER_UNARY_KERNEL(Erf, f32, DT_FLOAT, float);
} // namespace tensorflow

View File

@ -18,7 +18,7 @@ limitations under the License.
namespace tensorflow {
namespace test {
TensorShape DefaultInputShape() { return TensorShape{3, 4}; }
TensorShape DefaultInputShape() { return TensorShape{7, 13}; }
} // namespace test
} // namespace tensorflow

View File

@ -16,6 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_GPU_OPS_TEST_UTIL_H_
#define TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_GPU_OPS_TEST_UTIL_H_
#include <iostream>
#include "absl/container/inlined_vector.h"
#include "absl/strings/string_view.h"
#include "llvm/ADT/STLExtras.h"
@ -148,7 +150,7 @@ template <typename T,
bool> = true>
absl::InlinedVector<T, 10> DefaultInputLessThanBitwidth() {
auto max_shift = sizeof(T) * 8 - 1;
absl::InlinedVector<T, 10> v(max_shift);
absl::InlinedVector<T, 10> v;
for (auto i = 0; i < max_shift; ++i) v.push_back(i);
return v;
}
@ -166,8 +168,9 @@ template <typename T, std::enable_if_t<
llvm::is_one_of<T, Eigen::half, float, double>::value,
bool> = true>
absl::InlinedVector<T, 10> DefaultInput() {
return InputAsVector<T, double>({-18.0, -9.0, -1e-6, -0.0, 0.0, 1e-6, 0.1,
0.2, 0.3, 0.5, 0.7, 0.9, 9.0, 18.0});
return InputAsVector<T, double>({-18.0, -9.0, -0.7, -0.5, -0.3, -0.2, -0.1,
-1e-6, -0.0, 0.0, 1e-6, 0.1, 0.2, 0.3, 0.5,
0.7, 0.9, 18.0});
}
template <typename T,

View File

@ -96,6 +96,7 @@ class GpuUnaryOpTest : public OpsTestBase {
BaselineOutT (*baseline_callback)(BaselineT),
const test::GpuOpsTestConfig& config) {
// Prepare inputs and compute expected results.
CHECK(input.size() <= shape.num_elements());
auto repeated_input =
test::RepeatInputToMatchShape(input, shape.num_elements());
absl::InlinedVector<OutT, 10> expected_output =
@ -249,6 +250,17 @@ GENERATE_DEFAULT_TEST(Cosh, DT_FLOAT, DT_FLOAT, std::cosh,
GENERATE_DEFAULT_TEST(Cosh, DT_DOUBLE, DT_DOUBLE, std::cosh,
test::GpuOpsTestConfig())
/// Test `tf.Erf`.
GENERATE_DEFAULT_TEST(Erf, DT_FLOAT, DT_FLOAT, std::erf,
test::GpuOpsTestConfig())
GENERATE_DEFAULT_TEST(Erf, DT_DOUBLE, DT_DOUBLE, std::erf,
test::GpuOpsTestConfig())
GENERATE_DEFAULT_TEST_2(Erf, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::erf,
test::GpuOpsTestConfig())
/// Test `tf.Exp`.
GENERATE_DEFAULT_TEST(Exp, DT_FLOAT, DT_FLOAT, std::exp,