[MLIR][KernelGen] Add erf kernel for f32 arguments and missing lowerings
PiperOrigin-RevId: 352381016 Change-Id: Ib26ab051c37080fbd01d9f646e8244b2c65d2ec3
This commit is contained in:
parent
a2766ce758
commit
3fcba829d8
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#define _USE_MATH_DEFINES
|
||||
#include <cmath>
|
||||
#include <numeric>
|
||||
#include <vector>
|
||||
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
@ -75,6 +76,63 @@ struct ConvertConstantLikeOp : public OpConversionPattern<ConstantLikeOp> {
|
||||
}
|
||||
};
|
||||
|
||||
Value MaterializePolynomialApproximation(
|
||||
ConversionPatternRewriter &rewriter, Location loc, Value x,
|
||||
const std::vector<float> &coefficients) {
|
||||
Value poly = chlo::getConstantLike(rewriter, loc, 0.0, x);
|
||||
for (float c : coefficients) {
|
||||
poly = rewriter.create<mhlo::MulOp>(loc, x.getType(), poly, x);
|
||||
poly = rewriter.create<mhlo::AddOp>(
|
||||
loc, x.getType(), poly, chlo::getConstantLike(rewriter, loc, c, x));
|
||||
}
|
||||
return poly;
|
||||
}
|
||||
|
||||
Value MaterializeErfApproximationF32(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Value operand) {
|
||||
const std::vector<float> kAlpha{
|
||||
-2.72614225801306e-10f, 2.77068142495902e-08f, -2.10102402082508e-06f,
|
||||
-5.69250639462346e-05f, -7.34990630326855e-04f, -2.95459980854025e-03f,
|
||||
-1.60960333262415e-02f,
|
||||
};
|
||||
const std::vector<float> kBeta{
|
||||
-1.45660718464996e-05f, -2.13374055278905e-04f, -1.68282697438203e-03f,
|
||||
-7.37332916720468e-03f, -1.42647390514189e-02f,
|
||||
};
|
||||
|
||||
// Clamp argument between -4 and 4.
|
||||
Value lb = chlo::getConstantLike(rewriter, loc, -4.0, operand);
|
||||
Value ub = chlo::getConstantLike(rewriter, loc, 4.0, operand);
|
||||
Value x =
|
||||
rewriter.create<mhlo::ClampOp>(loc, operand.getType(), lb, operand, ub);
|
||||
Value x_sq = rewriter.create<mhlo::MulOp>(loc, x, x);
|
||||
|
||||
// Materialize polynomial approximation for x in [-4, 4].
|
||||
Value alpha_poly =
|
||||
MaterializePolynomialApproximation(rewriter, loc, x_sq, kAlpha);
|
||||
Value beta_poly =
|
||||
MaterializePolynomialApproximation(rewriter, loc, x_sq, kBeta);
|
||||
Value mul_x_alpha_poly = rewriter.create<mhlo::MulOp>(loc, x, alpha_poly);
|
||||
return rewriter.create<mhlo::DivOp>(loc, mul_x_alpha_poly, beta_poly);
|
||||
}
|
||||
|
||||
struct ConvertErfOp : public OpConversionPattern<ErfOp> {
|
||||
using OpConversionPattern<ErfOp>::OpConversionPattern;
|
||||
LogicalResult matchAndRewrite(
|
||||
ErfOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type ty = getElementTypeOrSelf(op.getType());
|
||||
|
||||
// For now, we support only f32.
|
||||
if (!ty.isF32()) return failure();
|
||||
|
||||
ErfOp::Adaptor transformed(operands);
|
||||
rewriter.replaceOp(op, MaterializeErfApproximationF32(
|
||||
rewriter, op.getLoc(), transformed.operand()));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Converts binary ops that statically are determined to not broadcast directly
|
||||
// to the corresponding mhlo non-broadcasting op.
|
||||
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
|
||||
@ -226,7 +284,7 @@ void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
|
||||
context, patterns, 5);
|
||||
|
||||
// Other patterns.
|
||||
patterns->insert<ConvertConstantLikeOp>(context);
|
||||
patterns->insert<ConvertConstantLikeOp, ConvertErfOp>(context);
|
||||
}
|
||||
|
||||
} // namespace chlo
|
||||
|
||||
@ -35,3 +35,54 @@ func @conj(%arg0: tensor<3xcomplex<f32>>) -> tensor<3xcomplex<f32>> {
|
||||
return %1 : tensor<3xcomplex<f32>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @erf_f32
|
||||
// CHECK-SAME: %[[ARG:.*]]: tensor<f32>
|
||||
func @erf_f32(%arg : tensor<f32>) -> tensor<f32> {
|
||||
// CHECK: %[[TMP_0:.*]] = mhlo.constant dense<-4.000000e+00>
|
||||
// CHECK: %[[TMP_1:.*]] = mhlo.constant dense<4.000000e+00>
|
||||
// CHECK: %[[TMP_2:.*]] = "mhlo.clamp"(%[[TMP_0]], %[[ARG]], %[[TMP_1]])
|
||||
// CHECK: %[[TMP_3:.*]] = mhlo.multiply %[[TMP_2]], %[[TMP_2]]
|
||||
// CHECK: %[[TMP_4:.*]] = mhlo.constant dense<0.000000e+00>
|
||||
// CHECK: %[[TMP_5:.*]] = mhlo.multiply %[[TMP_4]], %[[TMP_3]]
|
||||
// CHECK: %[[TMP_6:.*]] = mhlo.constant dense<-2.72614237E-10>
|
||||
// CHECK: %[[TMP_7:.*]] = mhlo.add %[[TMP_5]], %[[TMP_6]]
|
||||
// CHECK: %[[TMP_8:.*]] = mhlo.multiply %[[TMP_7]], %[[TMP_3]]
|
||||
// CHECK: %[[TMP_9:.*]] = mhlo.constant dense<2.77068146E-8>
|
||||
// CHECK: %[[TMP_10:.*]] = mhlo.add %[[TMP_8]], %[[TMP_9]]
|
||||
// CHECK: %[[TMP_11:.*]] = mhlo.multiply %[[TMP_10]], %[[TMP_3]]
|
||||
// CHECK: %[[TMP_12:.*]] = mhlo.constant dense<-2.10102394E-6>
|
||||
// CHECK: %[[TMP_13:.*]] = mhlo.add %[[TMP_11]], %[[TMP_12]]
|
||||
// CHECK: %[[TMP_14:.*]] = mhlo.multiply %[[TMP_13]], %[[TMP_3]]
|
||||
// CHECK: %[[TMP_15:.*]] = mhlo.constant dense<-5.69250624E-5>
|
||||
// CHECK: %[[TMP_16:.*]] = mhlo.add %[[TMP_14]], %[[TMP_15]]
|
||||
// CHECK: %[[TMP_17:.*]] = mhlo.multiply %[[TMP_16]], %[[TMP_3]]
|
||||
// CHECK: %[[TMP_18:.*]] = mhlo.constant dense<-7.34990637E-4>
|
||||
// CHECK: %[[TMP_19:.*]] = mhlo.add %[[TMP_17]], %[[TMP_18]]
|
||||
// CHECK: %[[TMP_20:.*]] = mhlo.multiply %[[TMP_19]], %[[TMP_3]]
|
||||
// CHECK: %[[TMP_21:.*]] = mhlo.constant dense<-2.954600e-03>
|
||||
// CHECK: %[[TMP_22:.*]] = mhlo.add %[[TMP_20]], %[[TMP_21]]
|
||||
// CHECK: %[[TMP_23:.*]] = mhlo.multiply %[[TMP_22]], %[[TMP_3]]
|
||||
// CHECK: %[[TMP_24:.*]] = mhlo.constant dense<-0.0160960332>
|
||||
// CHECK: %[[TMP_25:.*]] = mhlo.add %[[TMP_23]], %[[TMP_24]]
|
||||
// CHECK: %[[TMP_26:.*]] = mhlo.constant dense<0.000000e+00>
|
||||
// CHECK: %[[TMP_27:.*]] = mhlo.multiply %[[TMP_26]], %[[TMP_3]]
|
||||
// CHECK: %[[TMP_28:.*]] = mhlo.constant dense<-1.45660715E-5>
|
||||
// CHECK: %[[TMP_29:.*]] = mhlo.add %[[TMP_27]], %[[TMP_28]]
|
||||
// CHECK: %[[TMP_30:.*]] = mhlo.multiply %[[TMP_29]], %[[TMP_3]]
|
||||
// CHECK: %[[TMP_31:.*]] = mhlo.constant dense<-2.13374049E-4>
|
||||
// CHECK: %[[TMP_32:.*]] = mhlo.add %[[TMP_30]], %[[TMP_31]]
|
||||
// CHECK: %[[TMP_33:.*]] = mhlo.multiply %[[TMP_32]], %[[TMP_3]]
|
||||
// CHECK: %[[TMP_34:.*]] = mhlo.constant dense<-0.00168282702>
|
||||
// CHECK: %[[TMP_35:.*]] = mhlo.add %[[TMP_33]], %[[TMP_34]]
|
||||
// CHECK: %[[TMP_36:.*]] = mhlo.multiply %[[TMP_35]], %[[TMP_3]]
|
||||
// CHECK: %[[TMP_37:.*]] = mhlo.constant dense<-0.00737332925>
|
||||
// CHECK: %[[TMP_38:.*]] = mhlo.add %[[TMP_36]], %[[TMP_37]]
|
||||
// CHECK: %[[TMP_39:.*]] = mhlo.multiply %[[TMP_38]], %[[TMP_3]]
|
||||
// CHECK: %[[TMP_40:.*]] = mhlo.constant dense<-0.0142647391>
|
||||
// CHECK: %[[TMP_41:.*]] = mhlo.add %[[TMP_39]], %[[TMP_40]]
|
||||
// CHECK: %[[TMP_42:.*]] = mhlo.multiply %[[TMP_2]], %[[TMP_25]]
|
||||
// CHECK: %[[RESULT:.*]] = mhlo.divide %[[TMP_42]], %[[TMP_41]]
|
||||
// CHECK: return %[[RESULT]]
|
||||
%1 = "chlo.erf"(%arg) : (tensor<f32>) -> tensor<f32>
|
||||
return %1 : tensor<f32>
|
||||
}
|
||||
|
||||
@ -16,8 +16,18 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/cwise_ops_common.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
REGISTER3(UnaryOp, CPU, "Erf", functor::erf, float, Eigen::half, double);
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
REGISTER3(UnaryOp, GPU, "Erf", functor::erf, float, Eigen::half, double);
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
REGISTER2(UnaryOp, GPU, "Erf", functor::erf, Eigen::half, double);
|
||||
|
||||
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
|
||||
!defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
|
||||
REGISTER(UnaryOp, GPU, "Erf", functor::erf, float);
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
@ -55,6 +55,7 @@ filegroup(
|
||||
"gpu_op_conj.cc",
|
||||
"gpu_op_cos.cc",
|
||||
"gpu_op_cosh.cc",
|
||||
"gpu_op_erf.cc",
|
||||
"gpu_op_exp.cc",
|
||||
"gpu_op_expm1.cc",
|
||||
"gpu_op_floor.cc",
|
||||
@ -121,6 +122,7 @@ tf_kernel_library(
|
||||
":conj_kernels",
|
||||
":cos_kernels",
|
||||
":cosh_kernels",
|
||||
":erf_kernels",
|
||||
":exp_kernels",
|
||||
":expm1_kernels",
|
||||
":floor_kernels",
|
||||
@ -365,6 +367,15 @@ gen_kernel_library(
|
||||
unroll_factors = "4",
|
||||
)
|
||||
|
||||
gen_kernel_library(
|
||||
name = "erf",
|
||||
tile_size = "256",
|
||||
types = [
|
||||
"f32",
|
||||
],
|
||||
unroll_factors = "4",
|
||||
)
|
||||
|
||||
gen_kernel_library(
|
||||
name = "imag",
|
||||
tile_size = "256",
|
||||
|
||||
@ -134,6 +134,8 @@ class GpuBinaryOpTest : public OpsTestBase {
|
||||
const test::GpuOpsTestConfig& config) {
|
||||
// Prepare inputs.
|
||||
int input_size = shape.num_elements();
|
||||
CHECK(lhs_input.size() <= input_size && rhs_input.size() <= input_size &&
|
||||
"expect input shape to hold all input values");
|
||||
auto repeated_lhs_input =
|
||||
test::RepeatInputToMatchShape(lhs_input, input_size);
|
||||
auto repeated_rhs_input =
|
||||
@ -165,6 +167,8 @@ class GpuBinaryOpTest : public OpsTestBase {
|
||||
const test::GpuOpsTestConfig& config) {
|
||||
// Prepare inputs.
|
||||
TensorShape scalar_shape{};
|
||||
CHECK(other_input.size() <= other_shape.num_elements() &&
|
||||
"expect other input shape to hold all input values");
|
||||
auto repeated_other_input =
|
||||
test::RepeatInputToMatchShape(other_input, other_shape.num_elements());
|
||||
|
||||
|
||||
23
tensorflow/core/kernels/mlir_generated/gpu_op_erf.cc
Normal file
23
tensorflow/core/kernels/mlir_generated/gpu_op_erf.cc
Normal file
@ -0,0 +1,23 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Erf, f32, DT_FLOAT, float);
|
||||
|
||||
} // namespace tensorflow
|
||||
@ -18,7 +18,7 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace test {
|
||||
|
||||
TensorShape DefaultInputShape() { return TensorShape{3, 4}; }
|
||||
TensorShape DefaultInputShape() { return TensorShape{7, 13}; }
|
||||
|
||||
} // namespace test
|
||||
} // namespace tensorflow
|
||||
|
||||
@ -16,6 +16,8 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_GPU_OPS_TEST_UTIL_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_GPU_OPS_TEST_UTIL_H_
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
@ -148,7 +150,7 @@ template <typename T,
|
||||
bool> = true>
|
||||
absl::InlinedVector<T, 10> DefaultInputLessThanBitwidth() {
|
||||
auto max_shift = sizeof(T) * 8 - 1;
|
||||
absl::InlinedVector<T, 10> v(max_shift);
|
||||
absl::InlinedVector<T, 10> v;
|
||||
for (auto i = 0; i < max_shift; ++i) v.push_back(i);
|
||||
return v;
|
||||
}
|
||||
@ -166,8 +168,9 @@ template <typename T, std::enable_if_t<
|
||||
llvm::is_one_of<T, Eigen::half, float, double>::value,
|
||||
bool> = true>
|
||||
absl::InlinedVector<T, 10> DefaultInput() {
|
||||
return InputAsVector<T, double>({-18.0, -9.0, -1e-6, -0.0, 0.0, 1e-6, 0.1,
|
||||
0.2, 0.3, 0.5, 0.7, 0.9, 9.0, 18.0});
|
||||
return InputAsVector<T, double>({-18.0, -9.0, -0.7, -0.5, -0.3, -0.2, -0.1,
|
||||
-1e-6, -0.0, 0.0, 1e-6, 0.1, 0.2, 0.3, 0.5,
|
||||
0.7, 0.9, 18.0});
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
|
||||
@ -96,6 +96,7 @@ class GpuUnaryOpTest : public OpsTestBase {
|
||||
BaselineOutT (*baseline_callback)(BaselineT),
|
||||
const test::GpuOpsTestConfig& config) {
|
||||
// Prepare inputs and compute expected results.
|
||||
CHECK(input.size() <= shape.num_elements());
|
||||
auto repeated_input =
|
||||
test::RepeatInputToMatchShape(input, shape.num_elements());
|
||||
absl::InlinedVector<OutT, 10> expected_output =
|
||||
@ -249,6 +250,17 @@ GENERATE_DEFAULT_TEST(Cosh, DT_FLOAT, DT_FLOAT, std::cosh,
|
||||
GENERATE_DEFAULT_TEST(Cosh, DT_DOUBLE, DT_DOUBLE, std::cosh,
|
||||
test::GpuOpsTestConfig())
|
||||
|
||||
/// Test `tf.Erf`.
|
||||
|
||||
GENERATE_DEFAULT_TEST(Erf, DT_FLOAT, DT_FLOAT, std::erf,
|
||||
test::GpuOpsTestConfig())
|
||||
|
||||
GENERATE_DEFAULT_TEST(Erf, DT_DOUBLE, DT_DOUBLE, std::erf,
|
||||
test::GpuOpsTestConfig())
|
||||
|
||||
GENERATE_DEFAULT_TEST_2(Erf, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::erf,
|
||||
test::GpuOpsTestConfig())
|
||||
|
||||
/// Test `tf.Exp`.
|
||||
|
||||
GENERATE_DEFAULT_TEST(Exp, DT_FLOAT, DT_FLOAT, std::exp,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user