Merge branch 'master' of github.com:ashahba/tensorflow into ashahba/onednn-centos7

This commit is contained in:
Abolfazl Shahbazi 2021-02-02 09:47:48 -08:00
commit 5381666f28
187 changed files with 1972 additions and 703 deletions

View File

@ -555,6 +555,15 @@ def HLOClient_ConstantLikeOp : HLOClient_Op<"constant_like",
let hasCanonicalizer = 1;
}
def HLOClient_DigammaOp : HLOClient_UnaryElementwiseOp<"digamma",
[SameOperandsAndResultType], HLO_FpTensor, HLO_FpTensor> {
let summary = "Digamma function";
let description = [{
Returns `Digamma(operand)` element-wise.
}];
}
def HLOClient_ErfOp : HLOClient_UnaryElementwiseOp<"erf",
[SameOperandsAndResultType], HLO_FpTensor, HLO_FpTensor> {
let summary = "Erfc operator";

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/TypeUtilities.h"
namespace mlir {
@ -588,6 +589,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::PowOp>(Location loc,
ArrayRef<Value> args,
OpBuilder* b) {
lmhlo::PowOp::Adaptor adaptor(args);
auto lb = ImplicitLocOpBuilder(loc, *b);
// Floating point can use std::powf
auto result_type = result_types.front();
if (result_type.isa<::mlir::FloatType>())
@ -597,53 +599,66 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::PowOp>(Location loc,
assert(result_type.isa<::mlir::IntegerType>() &&
"only float and integer `pow` is supported right now");
// There is no powi, so lower to a simple product.
Value neg_one =
b->create<ConstantOp>(loc, b->getIntegerAttr(result_type, -1));
Value zero = b->create<ConstantOp>(loc, b->getIntegerAttr(result_type, 0));
Value one = b->create<ConstantOp>(loc, b->getIntegerAttr(result_type, 1));
Value two = b->create<ConstantOp>(loc, b->getIntegerAttr(result_type, 2));
// Exponentiation by squaring:
// https://en.wikipedia.org/wiki/Exponentiation_by_squaring;
Value neg_one = lb.create<ConstantOp>(lb.getIntegerAttr(result_type, -1));
Value zero = lb.create<ConstantOp>(lb.getIntegerAttr(result_type, 0));
Value one = lb.create<ConstantOp>(lb.getIntegerAttr(result_type, 1));
Value two = lb.create<ConstantOp>(lb.getIntegerAttr(result_type, 2));
Value step = lb.create<ConstantIndexOp>(1);
Value lowerBound = lb.create<ConstantIndexOp>(0);
// Everything else would overflow for any exponent > 1, as 2^64
// is the larget possible exponent for a 64-bit integer, and
// that's 1 << 6.
Value upperBound = lb.create<ConstantIndexOp>(6);
auto original_base = adaptor.lhs();
auto original_exponent = adaptor.rhs();
Value lowerBound = b->create<ConstantIndexOp>(loc, 0);
Value upperBound =
b->create<IndexCastOp>(loc, adaptor.rhs(), b->getIndexType());
Value step = b->create<ConstantIndexOp>(loc, 1);
Value for_result =
b->create<scf::ForOp>(
loc, lowerBound, upperBound, step, llvm::makeArrayRef(one),
[&](OpBuilder& b, Location l, Value v, ValueRange iters) {
Value prod =
b.create<::mlir::MulIOp>(l, adaptor.lhs(), iters.front());
b.create<scf::YieldOp>(l, prod);
})
Value accum =
lb.create<scf::ForOp>(
lowerBound, upperBound, step,
SmallVector<Value>({one, original_base, original_exponent}),
[&](OpBuilder& b, Location, Value v, ValueRange iters) {
Value accum = iters[0];
Value base = iters[1];
Value exponent = iters[2];
Value condition = b.create<CmpIOp>(
loc, CmpIPredicate::eq,
b.create<::mlir::AndOp>(loc, exponent, one), one);
Value multiplied = b.create<::mlir::MulIOp>(loc, accum, base);
accum =
b.create<::mlir::SelectOp>(loc, condition, multiplied, accum);
base = b.create<::mlir::MulIOp>(loc, base, base);
exponent =
b.create<::mlir::UnsignedShiftRightOp>(loc, exponent, one);
b.create<scf::YieldOp>(
loc, SmallVector<Value>({accum, base, exponent}));
})
.getResult(0);
Value rhs_is_even =
b->create<CmpIOp>(loc, CmpIPredicate::eq,
b->create<SignedRemIOp>(loc, adaptor.rhs(), two), zero);
Value rhs_is_even = lb.create<CmpIOp>(
CmpIPredicate::eq, lb.create<SignedRemIOp>(adaptor.rhs(), two), zero);
Value rhs_is_negative =
b->create<CmpIOp>(loc, CmpIPredicate::slt, adaptor.rhs(), zero);
Value lhs_is_one =
b->create<CmpIOp>(loc, CmpIPredicate::eq, adaptor.lhs(), one);
lb.create<CmpIOp>(CmpIPredicate::slt, adaptor.rhs(), zero);
Value lhs_is_one = lb.create<CmpIOp>(CmpIPredicate::eq, adaptor.lhs(), one);
Value lhs_is_neg_one =
b->create<CmpIOp>(loc, CmpIPredicate::eq, adaptor.lhs(), neg_one);
lb.create<CmpIOp>(CmpIPredicate::eq, adaptor.lhs(), neg_one);
// The for_result is correct when the rhs is non-negative. When rhs is
// The accum is correct when the rhs is non-negative. When rhs is
// negative, we return 0 for integer, with the exception of lhs values of 1
// and -1 which have integer results for negative exponents. Specifically, the
// calulation is the following:
//
// - Return for_result if the rhs is not negative.
// - Return accum if the rhs is not negative.
// - Return 1 or -1 depending on the parity of rhs when the lhs is -1.
// - Return 1 if lhs is 1.
// - Else return 0.
Value if_lhs_is_one = b->create<::mlir::SelectOp>(loc, lhs_is_one, one, zero);
Value if_lhs_is_neg_one = b->create<::mlir::SelectOp>(
loc, lhs_is_neg_one,
b->create<::mlir::SelectOp>(loc, rhs_is_even, one, neg_one),
Value if_lhs_is_one = lb.create<::mlir::SelectOp>(lhs_is_one, one, zero);
Value if_lhs_is_neg_one = lb.create<::mlir::SelectOp>(
lhs_is_neg_one, lb.create<::mlir::SelectOp>(rhs_is_even, one, neg_one),
if_lhs_is_one);
return b->create<::mlir::SelectOp>(loc, rhs_is_negative, if_lhs_is_neg_one,
for_result);
return lb.create<::mlir::SelectOp>(rhs_is_negative, if_lhs_is_neg_one, accum);
}
template <>

View File

@ -1310,6 +1310,40 @@ class DynamicReshapeOpNotActuallyDynamic
}
};
// Canonicalizes
// %0 = some_op(%tensor)
// %1 = "mhlo.dynamic_reshape"(%0, %shape)
// (tensor<?xT>, tensor<1xindex>) -> tensor<?xT>
// ... uses of %1.
//
// into
//
// ... uses of %0.
// This canonicalization is only correct if the input is correct!
// TODO(b/178779691): Use a more sophisticated canonicalization that preserves
// errors in input, and still allows us to get rid of redundant reshapes.
class RemoveRedundantRank1DynamicReshape
: public OpRewritePattern<DynamicReshapeOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(DynamicReshapeOp op,
PatternRewriter& rewriter) const override {
auto type = op.result().getType().dyn_cast<RankedTensorType>();
if (!type || type.getRank() != 1 || type.hasStaticShape()) {
return rewriter.notifyMatchFailure(
op, "requires rank 1 shape tensor with dynamic dimension");
}
auto operand_type = op.operand().getType().dyn_cast<RankedTensorType>();
if (!operand_type || operand_type.getRank() != 1 ||
operand_type.hasStaticShape()) {
return rewriter.notifyMatchFailure(
op, "requires rank 1 shape tensor with dynamic dimension");
}
rewriter.replaceOp(op, {op.operand()});
return success();
}
};
// Canonicalizes
// %0 = "mhlo.dynamic_reshape"(%tensor, %shape)
// %1 = same_operands_and_result_shape_op(%tensor)
@ -1354,6 +1388,7 @@ void DynamicReshapeOp::getCanonicalizationPatterns(
DynamicReshapeOpSameShapeOpResult,
RemoveRedundantDynamicBroadcast,
RemoveRedundantDynamicReshape,
RemoveRedundantRank1DynamicReshape,
ShapeOfDynamicReshape
>(context);
// clang-format on

View File

@ -625,6 +625,119 @@ Value MaterializeLgamma(ConversionPatternRewriter &rewriter, Location loc,
lgamma);
}
// Compute the Digamma function using Lanczos' approximation from "A Precision
// Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
// series B. Vol. 1:
// digamma(z + 1) = log(t(z)) + a'(z) / a(z) - kLanczosGamma / t(z)
// with t(z) = z + kLanczosGamma + 1/2
// a(z) = kBaseLanczosCoeff
// + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
// a'(z) = - sum(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k))
Value MaterializeDigamma(ConversionPatternRewriter &rewriter, Location loc,
Value x) {
// If the input is less than 0.5 use Euler's reflection formula.
// digamma(x) = digamma(1 - x) - pi * cot(pi * x)
// Let z be
// z = -x if x < 1/2
// z = x - 1 otheriwse
const StringAttr kLT = rewriter.getStringAttr(
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
Value half = getConstantLike(rewriter, loc, 0.5, x);
Value need_to_reflect = rewriter.create<mhlo::CompareOp>(loc, x, half, kLT);
Value neg_x = rewriter.create<mhlo::NegOp>(loc, x);
Value one = getConstantLike(rewriter, loc, 1, x);
Value x_sub_one = rewriter.create<mhlo::SubOp>(loc, x, one);
Value z =
rewriter.create<mhlo::SelectOp>(loc, need_to_reflect, neg_x, x_sub_one);
// Materialize
// a(z) = kBaseLanczosCoeff
// + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
// a'(z) = - sum(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k))
Value zero = getConstantLike(rewriter, loc, 0.0, x);
Value a = getConstantLike(rewriter, loc, kBaseLanczosCoeff, x);
Value a_prime = zero;
for (int i = 0, end = kLanczosCoefficients.size(); i < end; ++i) {
Value coeff = getConstantLike(rewriter, loc, kLanczosCoefficients[i], x);
Value one_based_index = getConstantLike(rewriter, loc, i + 1, x);
Value z_term = rewriter.create<mhlo::AddOp>(loc, z, one_based_index);
a_prime = rewriter.create<mhlo::SubOp>(
loc, a_prime,
rewriter.create<mhlo::DivOp>(
loc, coeff, rewriter.create<mhlo::MulOp>(loc, z_term, z_term)));
a = rewriter.create<mhlo::AddOp>(
loc, a, rewriter.create<mhlo::DivOp>(loc, coeff, z_term));
}
// To improve accuracy on platforms with less-precise log implementations,
// compute log(kLanczosGamma + 1/2) at compile time and use log1p on the
// device.
// Materialize as
// log(t) = log(kLanczosGamma + 1/2 + z)
// = log(kLanczosGamma + 1/2) + log1p(z / (kLanczosGamma + 1/2)).
Value lanczos_plus_half =
getConstantLike(rewriter, loc, kLanczosGamma + 0.5, x);
Value t = rewriter.create<mhlo::AddOp>(loc, lanczos_plus_half, z);
Value log_term =
getConstantLike(rewriter, loc, std::log(kLanczosGamma + 0.5), x);
Value log1p_term = rewriter.create<mhlo::Log1pOp>(
loc, rewriter.create<mhlo::DivOp>(loc, z, lanczos_plus_half));
Value log_t = rewriter.create<mhlo::AddOp>(loc, log_term, log1p_term);
// Materialize the final result (modulo reflection) as
// digamma(z + 1) = log(t(z)) + a'(z) / a(z) - kLanczosGamma / t(z).
Value a_prime_div_a = rewriter.create<mhlo::DivOp>(loc, a_prime, a);
Value lanczos_gamma_div_t = rewriter.create<mhlo::DivOp>(
loc, getConstantLike(rewriter, loc, kLanczosGamma, x), t);
Value digamma = rewriter.create<mhlo::SubOp>(
loc, rewriter.create<mhlo::AddOp>(loc, log_t, a_prime_div_a),
lanczos_gamma_div_t);
// We need to be careful how we compute cot(pi * input) below: For
// near-integral arguments, pi * input can lose precision.
//
// Input is already known to be less than 0.5 (otherwise we don't have to
// reflect). We shift values smaller than -0.5 into the range [-0.5, 0.5] to
// increase precision of pi * x and the resulting cotangent.
Value reduced_x = rewriter.create<mhlo::AddOp>(
loc, x,
rewriter.create<mhlo::AbsOp>(
loc, rewriter.create<mhlo::FloorOp>(
loc, rewriter.create<mhlo::AddOp>(
loc, x, getConstantLike(rewriter, loc, 0.5, x)))));
// Materialize reflection for inputs less than 0.5 as
// digamma(x) = digamma(1 - x) - pi * cot(pi * x)
// = digamma(1 - x) - pi * cos(pi * x) / sin(pi * x)
Value pi = getConstantLike(rewriter, loc, M_PI, x);
Value pi_mul_reduced_x = rewriter.create<mhlo::MulOp>(loc, pi, reduced_x);
Value cos = rewriter.create<mhlo::CosOp>(loc, pi_mul_reduced_x);
Value sin = rewriter.create<mhlo::SinOp>(loc, pi_mul_reduced_x);
Value reflection = rewriter.create<mhlo::SubOp>(
loc, digamma,
rewriter.create<mhlo::DivOp>(
loc, rewriter.create<mhlo::MulOp>(loc, pi, cos), sin));
// Select whether or not to rely on the reflection.
digamma = rewriter.create<mhlo::SelectOp>(loc, need_to_reflect, reflection,
digamma);
// Digamma has poles at negative integers and zero; return nan for those.
const StringAttr kLE = rewriter.getStringAttr(
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LE));
Value is_le_zero = rewriter.create<mhlo::CompareOp>(loc, x, zero, kLE);
const StringAttr kEQ = rewriter.getStringAttr(
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::EQ));
Value is_int = rewriter.create<mhlo::CompareOp>(
loc, x, rewriter.create<mhlo::FloorOp>(loc, x), kEQ);
Value is_pole = rewriter.create<mhlo::AndOp>(loc, is_le_zero, is_int);
return rewriter.create<mhlo::SelectOp>(
loc, is_pole,
getConstantLike(rewriter, loc, std::numeric_limits<double>::quiet_NaN(),
x),
digamma);
}
struct ConvertLgammaOp : public OpConversionPattern<LgammaOp> {
using OpConversionPattern<LgammaOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
@ -639,6 +752,20 @@ struct ConvertLgammaOp : public OpConversionPattern<LgammaOp> {
}
};
struct ConvertDigammaOp : public OpConversionPattern<DigammaOp> {
using OpConversionPattern<DigammaOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
DigammaOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
DigammaOp::Adaptor transformed(operands);
FloatType min_precision_ty = rewriter.getF32Type();
rewriter.replaceOp(
op, MaterializeWithUpcast(rewriter, op.getLoc(), transformed.operand(),
min_precision_ty, &MaterializeDigamma));
return success();
}
};
// Converts binary ops that statically are determined to not broadcast directly
// to the corresponding mhlo non-broadcasting op.
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
@ -790,8 +917,13 @@ void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
context, patterns, 5);
// Other patterns.
patterns->insert<ConvertConstantLikeOp, ConvertErfOp, ConvertErfcOp,
// clang-format off
patterns->insert<ConvertConstantLikeOp,
ConvertDigammaOp,
ConvertErfOp,
ConvertErfcOp,
ConvertLgammaOp>(context);
// clang-format on
}
} // namespace chlo

View File

@ -53,8 +53,9 @@ namespace {
// TODO(herhut): Generate these out of op definitions.
#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \
fn(AcosOp) sep fn(AcoshOp) sep fn(AsinOp) sep fn(AsinhOp) sep fn(AtanOp) \
sep fn(AtanhOp) sep fn(ConjOp) sep fn(CoshOp) sep fn(ErfOp) \
sep fn(ErfcOp) sep fn(LgammaOp) sep fn(SinhOp) sep fn(TanOp)
sep fn(AtanhOp) sep fn(ConjOp) sep fn(CoshOp) sep fn(DigammaOp) \
sep fn(ErfOp) sep fn(ErfcOp) sep fn(LgammaOp) sep fn(SinhOp) \
sep fn(TanOp)
template <typename OpTy>
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {

View File

@ -594,12 +594,26 @@ func @shape_of_dynamic_reshape(%arg0: tensor<*xf32>, %shape: tensor<2xindex>) ->
return %1 : tensor<2xindex>
}
// CHECK-LABEL: func @dynamic_reshape_rank_1_to_rank_1
// CHECK-SAME: [[ARG0:%[a-zA-Z0-9]+]]
func @dynamic_reshape_rank_1_to_rank_1(%arg0: tensor<?xcomplex<f32>>,
%shape: tensor<?xindex>) -> tensor<?xf32> {
// CHECK: [[RES:%[a-zA-Z0-9]+]] = "mhlo.real"([[ARG0]]) : (tensor<?xcomplex<f32>>) -> tensor<?xf32>
// CHECK: return [[RES]]
%0 = "mhlo.real"(%arg0): (tensor<?xcomplex<f32>>) -> tensor<?xf32>
%1 = shape.shape_of %arg0 : tensor<?xcomplex<f32>> -> tensor<1xindex>
%2 = shape.num_elements %1 : tensor<1xindex> -> index
%3 = tensor.from_elements %2 : tensor<1xindex>
%4 = "mhlo.dynamic_reshape"(%0, %3)
: (tensor<?xf32>, tensor<1xindex>) -> tensor<?xf32>
return %4 : tensor<?xf32>
}
// CHECK-LABEL: func @dynamic_reshape_of_dynamic_reshape
// CHECK-SAME: [[ARG0:%[a-zA-Z0-9]+]]
// CHECK-SAME: [[ARG1:%[a-zA-Z0-9]+]]
func @dynamic_reshape_of_dynamic_reshape(%arg0: tensor<?xf16>, %shape: tensor<?xindex>) -> tensor<?xf16> {
// CHECK: [[RES:%[a-zA-Z0-9]+]] = "mhlo.dynamic_reshape"([[ARG0]], %{{[a-zA-Z0-9]+}}) : (tensor<?xf16>, tensor<1xindex>) -> tensor<?xf16>
// CHECK: return [[RES]]
// CHECK: return [[ARG0]]
%0 = "mhlo.dynamic_reshape"(%arg0, %shape) : (tensor<?xf16>, tensor<?xindex>) -> tensor<*xf16>
%1 = shape.shape_of %0 : tensor<*xf16> -> tensor<?xindex>
%2 = shape.num_elements %1 : tensor<?xindex> -> index

View File

@ -875,3 +875,233 @@ func @lgamma_f16(%arg : tensor<f16>) -> tensor<f16> {
%1 = chlo.lgamma %arg : tensor<f16> -> tensor<f16>
return %1 : tensor<f16>
}
// CHECK-LABEL: @digamma_f64
// CHECK-SAME: (%[[ARG:.*]]: tensor<f64>)
func @digamma_f64(%arg : tensor<f64>) -> tensor<f64> {
// CHECK: %[[TMP_0:.*]] = mhlo.constant dense<5.000000e-01>
// CHECK: %[[TMP_1:.*]] = "mhlo.compare"(%arg0, %[[TMP_0]]) {comparison_direction = "LT"}
// CHECK: %[[TMP_2:.*]] = "mhlo.negate"(%arg0)
// CHECK: %[[TMP_3:.*]] = mhlo.constant dense<1.000000e+00>
// CHECK: %[[TMP_4:.*]] = mhlo.subtract %arg0, %[[TMP_3]]
// CHECK: %[[TMP_5:.*]] = "mhlo.select"(%[[TMP_1]], %[[TMP_2]], %[[TMP_4]])
// CHECK: %[[TMP_6:.*]] = mhlo.constant dense<0.000000e+00>
// CHECK: %[[TMP_7:.*]] = mhlo.constant dense<0.99999999999980993>
// CHECK: %[[TMP_8:.*]] = mhlo.constant dense<676.5203681218851>
// CHECK: %[[TMP_9:.*]] = mhlo.constant dense<1.000000e+00>
// CHECK: %[[TMP_10:.*]] = mhlo.add %[[TMP_5]], %[[TMP_9]]
// CHECK: %[[TMP_11:.*]] = mhlo.multiply %[[TMP_10]], %[[TMP_10]]
// CHECK: %[[TMP_12:.*]] = mhlo.divide %[[TMP_8]], %[[TMP_11]]
// CHECK: %[[TMP_13:.*]] = mhlo.subtract %[[TMP_6]], %[[TMP_12]]
// CHECK: %[[TMP_14:.*]] = mhlo.divide %[[TMP_8]], %[[TMP_10]]
// CHECK: %[[TMP_15:.*]] = mhlo.add %[[TMP_7]], %[[TMP_14]]
// CHECK: %[[TMP_16:.*]] = mhlo.constant dense<-1259.1392167224028>
// CHECK: %[[TMP_17:.*]] = mhlo.constant dense<2.000000e+00>
// CHECK: %[[TMP_18:.*]] = mhlo.add %[[TMP_5]], %[[TMP_17]]
// CHECK: %[[TMP_19:.*]] = mhlo.multiply %[[TMP_18]], %[[TMP_18]]
// CHECK: %[[TMP_20:.*]] = mhlo.divide %[[TMP_16]], %[[TMP_19]]
// CHECK: %[[TMP_21:.*]] = mhlo.subtract %[[TMP_13]], %[[TMP_20]]
// CHECK: %[[TMP_22:.*]] = mhlo.divide %[[TMP_16]], %[[TMP_18]]
// CHECK: %[[TMP_23:.*]] = mhlo.add %[[TMP_15]], %[[TMP_22]]
// CHECK: %[[TMP_24:.*]] = mhlo.constant dense<771.32342877765313>
// CHECK: %[[TMP_25:.*]] = mhlo.constant dense<3.000000e+00>
// CHECK: %[[TMP_26:.*]] = mhlo.add %[[TMP_5]], %[[TMP_25]]
// CHECK: %[[TMP_27:.*]] = mhlo.multiply %[[TMP_26]], %[[TMP_26]]
// CHECK: %[[TMP_28:.*]] = mhlo.divide %[[TMP_24]], %[[TMP_27]]
// CHECK: %[[TMP_29:.*]] = mhlo.subtract %[[TMP_21]], %[[TMP_28]]
// CHECK: %[[TMP_30:.*]] = mhlo.divide %[[TMP_24]], %[[TMP_26]]
// CHECK: %[[TMP_31:.*]] = mhlo.add %[[TMP_23]], %[[TMP_30]]
// CHECK: %[[TMP_32:.*]] = mhlo.constant dense<-176.61502916214059>
// CHECK: %[[TMP_33:.*]] = mhlo.constant dense<4.000000e+00>
// CHECK: %[[TMP_34:.*]] = mhlo.add %[[TMP_5]], %[[TMP_33]]
// CHECK: %[[TMP_35:.*]] = mhlo.multiply %[[TMP_34]], %[[TMP_34]]
// CHECK: %[[TMP_36:.*]] = mhlo.divide %[[TMP_32]], %[[TMP_35]]
// CHECK: %[[TMP_37:.*]] = mhlo.subtract %[[TMP_29]], %[[TMP_36]]
// CHECK: %[[TMP_38:.*]] = mhlo.divide %[[TMP_32]], %[[TMP_34]]
// CHECK: %[[TMP_39:.*]] = mhlo.add %[[TMP_31]], %[[TMP_38]]
// CHECK: %[[TMP_40:.*]] = mhlo.constant dense<12.507343278686905>
// CHECK: %[[TMP_41:.*]] = mhlo.constant dense<5.000000e+00>
// CHECK: %[[TMP_42:.*]] = mhlo.add %[[TMP_5]], %[[TMP_41]]
// CHECK: %[[TMP_43:.*]] = mhlo.multiply %[[TMP_42]], %[[TMP_42]]
// CHECK: %[[TMP_44:.*]] = mhlo.divide %[[TMP_40]], %[[TMP_43]]
// CHECK: %[[TMP_45:.*]] = mhlo.subtract %[[TMP_37]], %[[TMP_44]]
// CHECK: %[[TMP_46:.*]] = mhlo.divide %[[TMP_40]], %[[TMP_42]]
// CHECK: %[[TMP_47:.*]] = mhlo.add %[[TMP_39]], %[[TMP_46]]
// CHECK: %[[TMP_48:.*]] = mhlo.constant dense<-0.13857109526572012>
// CHECK: %[[TMP_49:.*]] = mhlo.constant dense<6.000000e+00>
// CHECK: %[[TMP_50:.*]] = mhlo.add %[[TMP_5]], %[[TMP_49]]
// CHECK: %[[TMP_51:.*]] = mhlo.multiply %[[TMP_50]], %[[TMP_50]]
// CHECK: %[[TMP_52:.*]] = mhlo.divide %[[TMP_48]], %[[TMP_51]]
// CHECK: %[[TMP_53:.*]] = mhlo.subtract %[[TMP_45]], %[[TMP_52]]
// CHECK: %[[TMP_54:.*]] = mhlo.divide %[[TMP_48]], %[[TMP_50]]
// CHECK: %[[TMP_55:.*]] = mhlo.add %[[TMP_47]], %[[TMP_54]]
// CHECK: %[[TMP_56:.*]] = mhlo.constant dense<9.9843695780195716E-6>
// CHECK: %[[TMP_57:.*]] = mhlo.constant dense<7.000000e+00>
// CHECK: %[[TMP_58:.*]] = mhlo.add %[[TMP_5]], %[[TMP_57]]
// CHECK: %[[TMP_59:.*]] = mhlo.multiply %[[TMP_58]], %[[TMP_58]]
// CHECK: %[[TMP_60:.*]] = mhlo.divide %[[TMP_56]], %[[TMP_59]]
// CHECK: %[[TMP_61:.*]] = mhlo.subtract %[[TMP_53]], %[[TMP_60]]
// CHECK: %[[TMP_62:.*]] = mhlo.divide %[[TMP_56]], %[[TMP_58]]
// CHECK: %[[TMP_63:.*]] = mhlo.add %[[TMP_55]], %[[TMP_62]]
// CHECK: %[[TMP_64:.*]] = mhlo.constant dense<1.5056327351493116E-7>
// CHECK: %[[TMP_65:.*]] = mhlo.constant dense<8.000000e+00>
// CHECK: %[[TMP_66:.*]] = mhlo.add %[[TMP_5]], %[[TMP_65]]
// CHECK: %[[TMP_67:.*]] = mhlo.multiply %[[TMP_66]], %[[TMP_66]]
// CHECK: %[[TMP_68:.*]] = mhlo.divide %[[TMP_64]], %[[TMP_67]]
// CHECK: %[[TMP_69:.*]] = mhlo.subtract %[[TMP_61]], %[[TMP_68]]
// CHECK: %[[TMP_70:.*]] = mhlo.divide %[[TMP_64]], %[[TMP_66]]
// CHECK: %[[TMP_71:.*]] = mhlo.add %[[TMP_63]], %[[TMP_70]]
// CHECK: %[[TMP_72:.*]] = mhlo.constant dense<7.500000e+00>
// CHECK: %[[TMP_73:.*]] = mhlo.add %[[TMP_72]], %[[TMP_5]]
// CHECK: %[[TMP_74:.*]] = mhlo.constant dense<2.0149030205422647>
// CHECK: %[[TMP_75:.*]] = mhlo.divide %[[TMP_5]], %[[TMP_72]]
// CHECK: %[[TMP_76:.*]] = "mhlo.log_plus_one"(%[[TMP_75]])
// CHECK: %[[TMP_77:.*]] = mhlo.add %[[TMP_74]], %[[TMP_76]]
// CHECK: %[[TMP_78:.*]] = mhlo.divide %[[TMP_69]], %[[TMP_71]]
// CHECK: %[[TMP_79:.*]] = mhlo.constant dense<7.000000e+00>
// CHECK: %[[TMP_80:.*]] = mhlo.divide %[[TMP_79]], %[[TMP_73]]
// CHECK: %[[TMP_81:.*]] = mhlo.add %[[TMP_77]], %[[TMP_78]]
// CHECK: %[[TMP_82:.*]] = mhlo.subtract %[[TMP_81]], %[[TMP_80]]
// CHECK: %[[TMP_83:.*]] = mhlo.constant dense<5.000000e-01>
// CHECK: %[[TMP_84:.*]] = mhlo.add %arg0, %[[TMP_83]]
// CHECK: %[[TMP_85:.*]] = "mhlo.floor"(%[[TMP_84]])
// CHECK: %[[TMP_86:.*]] = "mhlo.abs"(%[[TMP_85]])
// CHECK: %[[TMP_87:.*]] = mhlo.add %arg0, %[[TMP_86]]
// CHECK: %[[TMP_88:.*]] = mhlo.constant dense<3.1415926535897931>
// CHECK: %[[TMP_89:.*]] = mhlo.multiply %[[TMP_88]], %[[TMP_87]]
// CHECK: %[[TMP_90:.*]] = "mhlo.cosine"(%[[TMP_89]])
// CHECK: %[[TMP_92:.*]] = "mhlo.sine"(%[[TMP_89]])
// CHECK: %[[TMP_91:.*]] = mhlo.multiply %[[TMP_88]], %[[TMP_90]]
// CHECK: %[[TMP_93:.*]] = mhlo.divide %[[TMP_91]], %[[TMP_92]]
// CHECK: %[[TMP_94:.*]] = mhlo.subtract %[[TMP_82]], %[[TMP_93]]
// CHECK: %[[TMP_95:.*]] = "mhlo.select"(%[[TMP_1]], %[[TMP_94]], %[[TMP_82]])
// CHECK: %[[TMP_96:.*]] = "mhlo.compare"(%arg0, %[[TMP_6]]) {comparison_direction = "LE"}
// CHECK: %[[TMP_97:.*]] = "mhlo.floor"(%arg0)
// CHECK: %[[TMP_98:.*]] = "mhlo.compare"(%arg0, %[[TMP_97]]) {comparison_direction = "EQ"}
// CHECK: %[[TMP_99:.*]] = mhlo.and %[[TMP_96]], %[[TMP_98]]
// CHECK: %[[TMP_100:.*]] = mhlo.constant dense<0x7FF8000000000000>
// CHECK: %[[RES:.*]] = "mhlo.select"(%[[TMP_99]], %[[TMP_100]], %[[TMP_95]])
// CHECK: return %[[RES]]
%1 = chlo.digamma %arg : tensor<f64> -> tensor<f64>
return %1 : tensor<f64>
}
// CHECK-LABEL: @digamma_f32
// CHECK-SAME: (%[[ARG:.*]]: tensor<f32>)
func @digamma_f32(%arg : tensor<f32>) -> tensor<f32> {
// CHECK: %[[TMP_0:.*]] = mhlo.constant dense<5.000000e-01>
// CHECK: %[[TMP_1:.*]] = "mhlo.compare"(%arg0, %[[TMP_0]]) {comparison_direction = "LT"}
// CHECK: %[[TMP_2:.*]] = "mhlo.negate"(%arg0)
// CHECK: %[[TMP_3:.*]] = mhlo.constant dense<1.000000e+00>
// CHECK: %[[TMP_4:.*]] = mhlo.subtract %arg0, %[[TMP_3]]
// CHECK: %[[TMP_5:.*]] = "mhlo.select"(%[[TMP_1]], %[[TMP_2]], %[[TMP_4]])
// CHECK: %[[TMP_6:.*]] = mhlo.constant dense<0.000000e+00>
// CHECK: %[[TMP_7:.*]] = mhlo.constant dense<1.000000e+00>
// CHECK: %[[TMP_8:.*]] = mhlo.constant dense<676.520386>
// CHECK: %[[TMP_9:.*]] = mhlo.constant dense<1.000000e+00>
// CHECK: %[[TMP_10:.*]] = mhlo.add %[[TMP_5]], %[[TMP_9]]
// CHECK: %[[TMP_11:.*]] = mhlo.multiply %[[TMP_10]], %[[TMP_10]]
// CHECK: %[[TMP_12:.*]] = mhlo.divide %[[TMP_8]], %[[TMP_11]]
// CHECK: %[[TMP_13:.*]] = mhlo.subtract %[[TMP_6]], %[[TMP_12]]
// CHECK: %[[TMP_14:.*]] = mhlo.divide %[[TMP_8]], %[[TMP_10]]
// CHECK: %[[TMP_15:.*]] = mhlo.add %[[TMP_7]], %[[TMP_14]]
// CHECK: %[[TMP_16:.*]] = mhlo.constant dense<-1259.13916>
// CHECK: %[[TMP_17:.*]] = mhlo.constant dense<2.000000e+00>
// CHECK: %[[TMP_18:.*]] = mhlo.add %[[TMP_5]], %[[TMP_17]]
// CHECK: %[[TMP_19:.*]] = mhlo.multiply %[[TMP_18]], %[[TMP_18]]
// CHECK: %[[TMP_20:.*]] = mhlo.divide %[[TMP_16]], %[[TMP_19]]
// CHECK: %[[TMP_21:.*]] = mhlo.subtract %[[TMP_13]], %[[TMP_20]]
// CHECK: %[[TMP_22:.*]] = mhlo.divide %[[TMP_16]], %[[TMP_18]]
// CHECK: %[[TMP_23:.*]] = mhlo.add %[[TMP_15]], %[[TMP_22]]
// CHECK: %[[TMP_24:.*]] = mhlo.constant dense<771.323425>
// CHECK: %[[TMP_25:.*]] = mhlo.constant dense<3.000000e+00>
// CHECK: %[[TMP_26:.*]] = mhlo.add %[[TMP_5]], %[[TMP_25]]
// CHECK: %[[TMP_27:.*]] = mhlo.multiply %[[TMP_26]], %[[TMP_26]]
// CHECK: %[[TMP_28:.*]] = mhlo.divide %[[TMP_24]], %[[TMP_27]]
// CHECK: %[[TMP_29:.*]] = mhlo.subtract %[[TMP_21]], %[[TMP_28]]
// CHECK: %[[TMP_30:.*]] = mhlo.divide %[[TMP_24]], %[[TMP_26]]
// CHECK: %[[TMP_31:.*]] = mhlo.add %[[TMP_23]], %[[TMP_30]]
// CHECK: %[[TMP_32:.*]] = mhlo.constant dense<-176.615036>
// CHECK: %[[TMP_33:.*]] = mhlo.constant dense<4.000000e+00>
// CHECK: %[[TMP_34:.*]] = mhlo.add %[[TMP_5]], %[[TMP_33]]
// CHECK: %[[TMP_35:.*]] = mhlo.multiply %[[TMP_34]], %[[TMP_34]]
// CHECK: %[[TMP_36:.*]] = mhlo.divide %[[TMP_32]], %[[TMP_35]]
// CHECK: %[[TMP_37:.*]] = mhlo.subtract %[[TMP_29]], %[[TMP_36]]
// CHECK: %[[TMP_38:.*]] = mhlo.divide %[[TMP_32]], %[[TMP_34]]
// CHECK: %[[TMP_39:.*]] = mhlo.add %[[TMP_31]], %[[TMP_38]]
// CHECK: %[[TMP_40:.*]] = mhlo.constant dense<12.5073433>
// CHECK: %[[TMP_41:.*]] = mhlo.constant dense<5.000000e+00>
// CHECK: %[[TMP_42:.*]] = mhlo.add %[[TMP_5]], %[[TMP_41]]
// CHECK: %[[TMP_43:.*]] = mhlo.multiply %[[TMP_42]], %[[TMP_42]]
// CHECK: %[[TMP_44:.*]] = mhlo.divide %[[TMP_40]], %[[TMP_43]]
// CHECK: %[[TMP_45:.*]] = mhlo.subtract %[[TMP_37]], %[[TMP_44]]
// CHECK: %[[TMP_46:.*]] = mhlo.divide %[[TMP_40]], %[[TMP_42]]
// CHECK: %[[TMP_47:.*]] = mhlo.add %[[TMP_39]], %[[TMP_46]]
// CHECK: %[[TMP_48:.*]] = mhlo.constant dense<-0.138571098>
// CHECK: %[[TMP_49:.*]] = mhlo.constant dense<6.000000e+00>
// CHECK: %[[TMP_50:.*]] = mhlo.add %[[TMP_5]], %[[TMP_49]]
// CHECK: %[[TMP_51:.*]] = mhlo.multiply %[[TMP_50]], %[[TMP_50]]
// CHECK: %[[TMP_52:.*]] = mhlo.divide %[[TMP_48]], %[[TMP_51]]
// CHECK: %[[TMP_53:.*]] = mhlo.subtract %[[TMP_45]], %[[TMP_52]]
// CHECK: %[[TMP_54:.*]] = mhlo.divide %[[TMP_48]], %[[TMP_50]]
// CHECK: %[[TMP_55:.*]] = mhlo.add %[[TMP_47]], %[[TMP_54]]
// CHECK: %[[TMP_56:.*]] = mhlo.constant dense<9.98436917E-6>
// CHECK: %[[TMP_57:.*]] = mhlo.constant dense<7.000000e+00>
// CHECK: %[[TMP_58:.*]] = mhlo.add %[[TMP_5]], %[[TMP_57]]
// CHECK: %[[TMP_59:.*]] = mhlo.multiply %[[TMP_58]], %[[TMP_58]]
// CHECK: %[[TMP_60:.*]] = mhlo.divide %[[TMP_56]], %[[TMP_59]]
// CHECK: %[[TMP_61:.*]] = mhlo.subtract %[[TMP_53]], %[[TMP_60]]
// CHECK: %[[TMP_62:.*]] = mhlo.divide %[[TMP_56]], %[[TMP_58]]
// CHECK: %[[TMP_63:.*]] = mhlo.add %[[TMP_55]], %[[TMP_62]]
// CHECK: %[[TMP_64:.*]] = mhlo.constant dense<1.50563267E-7>
// CHECK: %[[TMP_65:.*]] = mhlo.constant dense<8.000000e+00>
// CHECK: %[[TMP_66:.*]] = mhlo.add %[[TMP_5]], %[[TMP_65]]
// CHECK: %[[TMP_67:.*]] = mhlo.multiply %[[TMP_66]], %[[TMP_66]]
// CHECK: %[[TMP_68:.*]] = mhlo.divide %[[TMP_64]], %[[TMP_67]]
// CHECK: %[[TMP_69:.*]] = mhlo.subtract %[[TMP_61]], %[[TMP_68]]
// CHECK: %[[TMP_70:.*]] = mhlo.divide %[[TMP_64]], %[[TMP_66]]
// CHECK: %[[TMP_71:.*]] = mhlo.add %[[TMP_63]], %[[TMP_70]]
// CHECK: %[[TMP_72:.*]] = mhlo.constant dense<7.500000e+00>
// CHECK: %[[TMP_73:.*]] = mhlo.add %[[TMP_72]], %[[TMP_5]]
// CHECK: %[[TMP_74:.*]] = mhlo.constant dense<2.01490307>
// CHECK: %[[TMP_75:.*]] = mhlo.divide %[[TMP_5]], %[[TMP_72]]
// CHECK: %[[TMP_76:.*]] = "mhlo.log_plus_one"(%[[TMP_75]])
// CHECK: %[[TMP_77:.*]] = mhlo.add %[[TMP_74]], %[[TMP_76]]
// CHECK: %[[TMP_78:.*]] = mhlo.divide %[[TMP_69]], %[[TMP_71]]
// CHECK: %[[TMP_79:.*]] = mhlo.constant dense<7.000000e+00>
// CHECK: %[[TMP_80:.*]] = mhlo.divide %[[TMP_79]], %[[TMP_73]]
// CHECK: %[[TMP_81:.*]] = mhlo.add %[[TMP_77]], %[[TMP_78]]
// CHECK: %[[TMP_82:.*]] = mhlo.subtract %[[TMP_81]], %[[TMP_80]]
// CHECK: %[[TMP_83:.*]] = mhlo.constant dense<5.000000e-01>
// CHECK: %[[TMP_84:.*]] = mhlo.add %arg0, %[[TMP_83]]
// CHECK: %[[TMP_85:.*]] = "mhlo.floor"(%[[TMP_84]])
// CHECK: %[[TMP_86:.*]] = "mhlo.abs"(%[[TMP_85]])
// CHECK: %[[TMP_87:.*]] = mhlo.add %arg0, %[[TMP_86]]
// CHECK: %[[TMP_88:.*]] = mhlo.constant dense<3.14159274>
// CHECK: %[[TMP_89:.*]] = mhlo.multiply %[[TMP_88]], %[[TMP_87]]
// CHECK: %[[TMP_90:.*]] = "mhlo.cosine"(%[[TMP_89]])
// CHECK: %[[TMP_92:.*]] = "mhlo.sine"(%[[TMP_89]])
// CHECK: %[[TMP_91:.*]] = mhlo.multiply %[[TMP_88]], %[[TMP_90]]
// CHECK: %[[TMP_93:.*]] = mhlo.divide %[[TMP_91]], %[[TMP_92]]
// CHECK: %[[TMP_94:.*]] = mhlo.subtract %[[TMP_82]], %[[TMP_93]]
// CHECK: %[[TMP_95:.*]] = "mhlo.select"(%[[TMP_1]], %[[TMP_94]], %[[TMP_82]])
// CHECK: %[[TMP_96:.*]] = "mhlo.compare"(%arg0, %[[TMP_6]]) {comparison_direction = "LE"}
// CHECK: %[[TMP_97:.*]] = "mhlo.floor"(%arg0)
// CHECK: %[[TMP_98:.*]] = "mhlo.compare"(%arg0, %[[TMP_97]]) {comparison_direction = "EQ"}
// CHECK: %[[TMP_99:.*]] = mhlo.and %[[TMP_96]], %[[TMP_98]]
// CHECK: %[[TMP_100:.*]] = mhlo.constant dense<0x7FC00000>
// CHECK: %[[RES:.*]] = "mhlo.select"(%[[TMP_99]], %[[TMP_100]], %[[TMP_95]])
// CHECK: return %[[RES]]
%1 = chlo.digamma %arg : tensor<f32> -> tensor<f32>
return %1 : tensor<f32>
}
// CHECK-LABEL: @digamma_f16
// CHECK-SAME: (%[[ARG:.*]]: tensor<f16>)
func @digamma_f16(%arg : tensor<f16>) -> tensor<f16> {
// CHECK: "mhlo.convert"(%[[ARG]]) : (tensor<f16>) -> tensor<f32>
// CHECK: %[[RES:.*]] = "mhlo.convert"(%{{.*}}) : (tensor<f32>) -> tensor<f16>
// CHECK: return %[[RES]]
%1 = chlo.digamma %arg : tensor<f16> -> tensor<f16>
return %1 : tensor<f16>
}

View File

@ -163,7 +163,7 @@ func @dyn_broadcast(%operand: tensor<?x?xf32>) -> tensor<?x?x?xf32> {
// CHECK: %[[EXPAND_2:.*]] = cmpi slt, %[[OPER_DIM_1]], %[[SIZE_2]] : index
// CHECK: %[[STRIDE_2:.*]] = select %[[EXPAND_2]], %[[C0]], %[[C1]] : index
// CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref_reinterpret_cast %[[OPERAND]] to offset: [0], sizes: {{\[}}%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]], strides: {{\[}}%[[C0]], %[[STRIDE_1]], %[[STRIDE_2]]]: memref<?x?xf32> to memref<?x?x?xf32, #map>
// CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref_reinterpret_cast %[[OPERAND]] to offset: [0], sizes: {{\[}}%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]], strides: {{\[}}%[[C0]], %[[STRIDE_1]], %[[STRIDE_2]]] : memref<?x?xf32> to memref<?x?x?xf32, #map>
// CHECK: %[[RESULT:.*]] = alloc(%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]) : memref<?x?x?xf32>

View File

@ -851,12 +851,19 @@ func @integer_pow(%lhs: tensor<2x2xi32>,
// CHECK: ^{{[a-z0-9_]*}}
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: i32
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: i32
// CHECK: %[[UPPER:.*]] = index_cast %[[ARG1]]
// CHECK: %[[FOR_RESULT:.*]] = scf.for {{.*}} to %[[UPPER]]
// CHECK-SAME: step %c1{{[a-zA-Z0-9_]*}}
// CHECK-SAME: iter_args(%[[ITER:.*]] = %c1{{.*}}) -> (i32) {
// CHECK: %[[ACCUM:[a-zA-Z0-9_]*]] = muli %[[ARG0]], %[[ITER]]
// CHECK: scf.yield %[[ACCUM]]
// CHECK: %[[FOR_RESULT:[a-zA-Z0-9_]*]]:3 = scf.for {{.*}} to %c6 step %c1
// CHECK-SAME: iter_args(
// CHECK-SAME: %[[ITER0:.*]] = %c1
// CHECK-SAME: %[[ITER1:.*]] = %[[ARG0]]
// CHECK-SAME: %[[ITER2:.*]] = %[[ARG1]]
// CHECK-SAME: ) -> (i32, i32, i32) {
// CHECK: %[[AND:[a-zA-Z0-9_]*]] = and %[[ITER2]], %c1
// CHECK: %[[COND:[a-zA-Z0-9_]*]] = cmpi eq, %[[AND]], %c1
// CHECK: %[[MUL:[a-zA-Z0-9_]*]] = muli %[[ITER0]], %[[ITER1]]
// CHECK: %[[ACCUM:[a-zA-Z0-9_]*]] = select %[[COND]], %[[MUL]], %[[ITER0]]
// CHECK: %[[BASE:[a-zA-Z0-9_]*]] = muli %[[ITER1]], %[[ITER1]]
// CHECK: %[[EXP:[a-zA-Z0-9_]*]] = shift_right_unsigned %[[ITER2]], %c1
// CHECK: scf.yield %[[ACCUM]], %[[BASE]], %[[EXP]]
// CHECK: %[[RHS_PARITY:.*]] = remi_signed %[[ARG1]], %c2
// CHECK: %[[RHS_EVEN:.*]] = cmpi eq, %[[RHS_PARITY]], %c0
// CHECK: %[[RHS_NEG:.*]] = cmpi slt, %[[ARG1]], %c0
@ -865,7 +872,7 @@ func @integer_pow(%lhs: tensor<2x2xi32>,
// CHECK: %[[VAL5:.*]] = select %[[LHS_ONE]], %c1_i32, %c0
// CHECK: %[[VAL6:.*]] = select %[[RHS_EVEN]], %c1{{.*}}, %c-1
// CHECK: %[[VAL7:.*]] = select %[[LHS_NEG_ONE]], %[[VAL6]], %[[VAL5]]
// CHECK: %[[RESULT:.*]] = select %[[RHS_NEG]], %[[VAL7]], %[[FOR_RESULT]]
// CHECK: %[[RESULT:.*]] = select %[[RHS_NEG]], %[[VAL7]], %[[FOR_RESULT]]#0
// CHECK: linalg.yield %[[RESULT]]
%0 = "mhlo.power"(%lhs, %rhs) : (tensor<2x2xi32>,
tensor<2x2xi32>) -> tensor<2x2xi32>

View File

@ -1526,5 +1526,57 @@ void PopulateLoweringTFPatterns(MLIRContext *context,
populateWithGenerated(context, *patterns);
}
void PopulateTFLoweringBeforeHLOPatterns(MLIRContext *context,
OwningRewritePatternList *patterns) {
// clang-format off
patterns->insert<
ConvertFakeQuantWithMinMaxVarsOp,
LowerAddNOp,
LowerBatchToSpaceND,
LowerDynamicStitchOp<DynamicStitchOp>,
LowerDynamicStitchOp<ParallelDynamicStitchOp>,
LowerInvertPermutationOp,
LowerPackOp,
LowerResizeNearestNeighbor,
LowerSpaceToBatchNDOp,
LowerSparseMatMulOp,
Lower_UnaryOpsComposition>(context);
// clang-format on
// Populate the relevant generated patterns.
// clang-format off
patterns->insert<
LowerBiasAddGradOp,
LowerDivNoNanOp,
LowerEmptyOp,
LowerExpm1Op,
LowerFakeQuantWithMinMaxArgs,
LowerFillOp,
LowerIsInfOp,
LowerIsNanOp,
LowerL2LossOp,
LowerMulNoNanOp,
LowerOnesLikeOp,
LowerPadOp,
LowerReciprocal,
LowerRintOp,
LowerRoundOpOnFloatTensor,
LowerRoundOpOnIntTensor,
LowerRsqrtGradOp,
LowerScatterNdOp,
LowerSizeOp,
LowerSoftmaxCrossEntropyWithLogitsOp,
LowerSparseSoftmaxCrossEntropyWithLogitsOp,
LowerSquareOp,
LowerSquaredDifferenceOpOnRealTensors,
LowerSquaredDifferenceOpOneComplexTensors,
LowerTanhGradOp,
LowerXdivyOp,
LowerXlog1pyOp,
LowerXlogyOp,
LowerZerosLikeOp>(context);
// clang-format on
}
} // namespace TF
} // namespace mlir

View File

@ -27,6 +27,14 @@ namespace TF {
void PopulateLoweringTFPatterns(MLIRContext *context,
OwningRewritePatternList *patterns);
// Populates TensorFlow lowering patterns to lower some of the TensorFlow
// operations that can be represented by means of other TensorFlow operations.
// This pattern collection preserves those TensorFlow operations that will later
// be lowered to equivalent operations in CHLO or MHLO. This allows for
// HLO-specific lowerings.
void PopulateTFLoweringBeforeHLOPatterns(MLIRContext *context,
OwningRewritePatternList *patterns);
} // namespace TF
} // namespace mlir

View File

@ -134,16 +134,18 @@ def LowerSparseSoftmaxCrossEntropyWithLogitsOp : Pattern<
// Difference op patterns.
//===----------------------------------------------------------------------===//
def ComplexTensor : TensorOf<[AnyComplex]>;
def RealTensor : TensorOf<[AnySignlessInteger, AnyFloat]>;
def ComplexTensor : TensorOf<[AnyComplex]>;
def RealTensor : TensorOf<[AnySignlessInteger, AnyFloat]>;
def : Pat<(TF_SquareOp $val), (TF_MulOp $val, $val)>;
def LowerSquareOp : Pat<(TF_SquareOp $val), (TF_MulOp $val, $val)>;
def : Pat<(TF_SquaredDifferenceOp RealTensor: $lhs, RealTensor:$rhs),
(TF_SquareOp (TF_SubOp $lhs, $rhs))>;
def LowerSquaredDifferenceOpOnRealTensors : Pat<
(TF_SquaredDifferenceOp RealTensor: $lhs, RealTensor:$rhs),
(TF_SquareOp (TF_SubOp $lhs, $rhs))>;
def : Pat<(TF_SquaredDifferenceOp ComplexTensor: $lhs, ComplexTensor:$rhs),
(TF_MulOp (TF_SubOp:$diff $lhs, $rhs), (TF_ConjOp $diff))>;
def LowerSquaredDifferenceOpOneComplexTensors : Pat<
(TF_SquaredDifferenceOp ComplexTensor: $lhs, ComplexTensor:$rhs),
(TF_MulOp (TF_SubOp:$diff $lhs, $rhs), (TF_ConjOp $diff))>;
//===----------------------------------------------------------------------===//
// DivNoNan and MulNonNan op patterns.
@ -156,9 +158,9 @@ class BinaryNoNanPat<Op FromOp, Op ToOp>
/*incompatible_shape_error*/ConstBoolAttrTrue),
$zero, (ToOp $l, $r))>;
foreach fromToBinPair = [[TF_DivNoNanOp, TF_DivOp],
[TF_MulNoNanOp, TF_MulOp]] in
def : BinaryNoNanPat<fromToBinPair[0], fromToBinPair[1]>;
def LowerDivNoNanOp : BinaryNoNanPat<TF_DivNoNanOp, TF_DivOp>;
def LowerMulNoNanOp : BinaryNoNanPat<TF_MulNoNanOp, TF_MulOp>;
//===----------------------------------------------------------------------===//
// Expm1 op patterns.
@ -226,9 +228,13 @@ def LowerL2LossOp :
// Pad op patterns.
//===----------------------------------------------------------------------===//
def : Pat<(TF_PadOp TensorOf<[AnySignlessInteger, AnyFloat]>:$input, $paddings),
(TF_PadV2Op $input, $paddings,
(TF_ConstOp (GetScalarOfType<0> $input)))>;
def LowerPadOp : Pat<
(TF_PadOp TensorOf<[AnySignlessInteger, AnyFloat]>:$input, $paddings),
(TF_PadV2Op $input, $paddings,
(TF_ConstOp
(GetScalarOfType<0> $input)
)
)>;
//===----------------------------------------------------------------------===//
// Reciprocal op patterns.
@ -243,45 +249,76 @@ def LowerReciprocal : Pat<(TF_ReciprocalOp $x),
// Rint is specified as RoundHalfToEven, which happens to be the same behavior
// as TF_RoundOp, so lower to TF_RoundOp.
def : Pat<(TF_RintOp:$res TF_FloatTensor:$input), (TF_RoundOp $input)>;
def LowerRintOp : Pat<(TF_RintOp:$res TF_FloatTensor:$input), (TF_RoundOp $input)>;
// Rounds on integers should just be bypassed.
def : Pat<(TF_RoundOp:$res TF_IntTensor:$input), (TF_IdentityOp $input)>;
def LowerRoundOpOnIntTensor : Pat<
(TF_RoundOp:$res TF_IntTensor:$input),
(TF_IdentityOp $input)>;
// Implements TF Round on floats using basic operations. TF Round is specified
// as RoundHalfToEven to be compatible with Numpy.
def : Pat<(TF_RoundOp:$res TF_FloatTensor:$input),
(TF_SelectOp
(TF_LessOp
(TF_SubOp $input, (TF_FloorOp:$floor $input)),
(TF_ConstOp (GetScalarOfFloatType<"0.5"> $input))),
$floor,
(TF_AddV2Op
(TF_ConstOp (GetScalarOfType<1> $input)), $floor))>;
def LowerRoundOpOnFloatTensor : Pat<
(TF_RoundOp:$res TF_FloatTensor:$input),
(TF_SelectOp
(TF_LessOp
(TF_SubOp
$input,
(TF_FloorOp:$floor $input)
),
(TF_ConstOp
(GetScalarOfFloatType<"0.5"> $input)
)
),
$floor,
(TF_AddV2Op
(TF_ConstOp
(GetScalarOfType<1> $input)
),
$floor
)
)>;
//===----------------------------------------------------------------------===//
// Rsqrt op patterns.
//===----------------------------------------------------------------------===//
// RsqrtGrad(lhs, rhs) = (lhs * lhs * lhs) * (rhs / -2)
def : Pat<(TF_RsqrtGradOp $lhs, $rhs),
(TF_MulOp (TF_MulOp (TF_MulOp $lhs, $lhs), $lhs),
(TF_DivOp $rhs,
(TF_ConstOp (GetScalarOfType<-2> $rhs))))>;
def LowerRsqrtGradOp : Pat<
(TF_RsqrtGradOp $lhs, $rhs),
(TF_MulOp
(TF_MulOp
(TF_MulOp $lhs, $lhs),
$lhs
),
(TF_DivOp
$rhs,
(TF_ConstOp
(GetScalarOfType<-2> $rhs)
)
)
)>;
//===----------------------------------------------------------------------===//
// Size op patterns.
//===----------------------------------------------------------------------===//
// Size(x) = Prod(Shape(x), reduction_indices=0, keep_dims=false)
def : Pat<(TF_SizeOp:$res $arg),
(TF_ProdOp
(CreateTFShapeOp $res, $arg, (IsI32 $res)),
/*reduction_indices=*/(TF_ConstOp (GetScalarOfType<0> $res)),
/*keep_dims=*/ConstBoolAttrFalse)>;
def LowerSizeOp : Pat<
(TF_SizeOp:$res $arg),
(TF_ProdOp
(CreateTFShapeOp
$res,
$arg,
(IsI32 $res)
),
/*reduction_indices=*/
(TF_ConstOp
(GetScalarOfType<0> $res)
),
/*keep_dims=*/
ConstBoolAttrFalse
)>;
//===----------------------------------------------------------------------===//
// TanhGrad op patterns.
@ -331,14 +368,28 @@ def LowerScatterNdOp :
// Xdivy, Xlog1p and Xlogy op patterns.
//===----------------------------------------------------------------------===//
class BinaryXopyPat<dag From, dag To>
: Pat<From,
(TF_SelectV2Op (TF_EqualOp $x,
(TF_ConstOp:$zero (GetScalarOfType<0> $x)),
/*incompatible_shape_error*/ConstBoolAttrTrue),
$zero, To)>;
class BinaryXopyPat<dag From, dag To> : Pat<
From,
(TF_SelectV2Op
(TF_EqualOp
$x,
(TF_ConstOp:$zero
(GetScalarOfType<0> $x)
),
/*incompatible_shape_error*/ConstBoolAttrTrue
),
$zero,
To
)>;
foreach fromToPair = [[(TF_XdivyOp $x, $y), (TF_DivOp $x, $y)],
[(TF_Xlog1pyOp $x, $y), (TF_MulOp $x, (TF_Log1pOp $y))],
[(TF_XlogyOp $x, $y), (TF_MulOp $x, (TF_LogOp $y))]] in
def : BinaryXopyPat<fromToPair[0], fromToPair[1]>;
def LowerXdivyOp : BinaryXopyPat<
(TF_XdivyOp $x, $y),
(TF_DivOp $x, $y)>;
def LowerXlog1pyOp : BinaryXopyPat<
(TF_Xlog1pyOp $x, $y),
(TF_MulOp $x, (TF_Log1pOp $y))>;
def LowerXlogyOp : BinaryXopyPat<
(TF_XlogyOp $x, $y),
(TF_MulOp $x, (TF_LogOp $y))>;

View File

@ -105,9 +105,29 @@ struct CollapseParallelLoopsTo1D
};
} // end anonymous namespace
Status LowerTFtoGPU(mlir::ModuleOp module, llvm::ArrayRef<uint32_t> tile_sizes,
llvm::ArrayRef<uint32_t> unroll_factors,
bool embed_memref_prints) {
struct TilingParams {
llvm::SmallVector<int64_t, 4> outer_tile;
llvm::SmallVector<int64_t, 4> inner_tile;
};
// We have to anticipate later unrolling in tiling to make sure that we get
// the requested tiling after unrolling. Compute the new tiling here if
// needed.
TilingParams ComputeTilingParas(llvm::ArrayRef<uint32_t> tile_sizes,
llvm::ArrayRef<uint32_t> unroll_factors) {
TilingParams params;
params.outer_tile.reserve(tile_sizes.size());
for (auto pair : llvm::zip(tile_sizes, unroll_factors)) {
params.outer_tile.push_back(std::get<0>(pair) * std::get<1>(pair));
params.inner_tile.push_back(std::get<1>(pair));
}
params.outer_tile.append(tile_sizes.drop_front(unroll_factors.size()).begin(),
tile_sizes.end());
return params;
}
Status LowerTFtoLoops(mlir::ModuleOp module,
const TilingParams& tiling_params) {
mlir::PassManager pm(module.getContext());
applyTensorflowAndCLOptions(pm);
@ -154,29 +174,31 @@ Status LowerTFtoGPU(mlir::ModuleOp module, llvm::ArrayRef<uint32_t> tile_sizes,
// Collapse and tile parallel loops.
pm.addNestedPass<mlir::FuncOp>(std::make_unique<CollapseParallelLoopsTo1D>());
// We have to anticipate later unrolling in tiling to make sure that we get
// the requested tiling after unrolling. Compute the new tiling here if
// needed.
llvm::SmallVector<int64_t, 4> tiling_for_unrolling, inner_tile;
tiling_for_unrolling.reserve(tile_sizes.size());
for (auto pair : llvm::zip(tile_sizes, unroll_factors)) {
tiling_for_unrolling.push_back(std::get<0>(pair) * std::get<1>(pair));
inner_tile.push_back(std::get<1>(pair));
}
tiling_for_unrolling.append(
tile_sizes.drop_front(unroll_factors.size()).begin(), tile_sizes.end());
pm.addNestedPass<mlir::FuncOp>(
::mlir::createParallelLoopTilingPass(tiling_for_unrolling));
if (!unroll_factors.empty()) {
::mlir::createParallelLoopTilingPass(tiling_params.outer_tile));
if (!tiling_params.inner_tile.empty()) {
pm.addNestedPass<mlir::FuncOp>(
::mlir::createParallelLoopTilingPass(inner_tile));
::mlir::createParallelLoopTilingPass(tiling_params.inner_tile));
}
pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass());
pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass());
if (failed(pm.run(module))) {
return InternalError("Lowering TF to loops failed.");
}
return Status::OK();
}
// Greedily map the remaining loop to GPU hardware dimensions.
pm.addNestedPass<::mlir::FuncOp>(
mlir::kernel_gen::transforms::CreateMapParallelLoopsPass());
Status LowerLoopsToGPUorCPU(mlir::ModuleOp module,
const TilingParams& tiling_params,
bool embed_memref_prints, bool cpu_codegen) {
mlir::PassManager pm(module.getContext());
applyTensorflowAndCLOptions(pm);
if (!cpu_codegen) {
// Greedily map the remaining loop to GPU hardware dimensions.
pm.addNestedPass<::mlir::FuncOp>(
mlir::kernel_gen::transforms::CreateMapParallelLoopsPass());
}
// Now lower the shape computations, bufferize all remaining ops and insert
// deallocs.
@ -213,21 +235,25 @@ Status LowerTFtoGPU(mlir::ModuleOp module, llvm::ArrayRef<uint32_t> tile_sizes,
// Apply the mapping and go to GPU. We cannot do this earlier due to missing
// interfaces on the GPU dialect.
// TODO(b/174830459): Move up once implemented.
pm.addNestedPass<::mlir::FuncOp>(mlir::createParallelLoopToGpuPass());
if (!cpu_codegen) {
pm.addNestedPass<::mlir::FuncOp>(mlir::createParallelLoopToGpuPass());
}
// Some basic cleanup.
pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass());
pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass());
// Make loops with min bounds into a conditional plus static bounds.
// Only do this if we unrolled in the first place.
if (!unroll_factors.empty()) {
if (!tiling_params.inner_tile.empty()) {
pm.addNestedPass<::mlir::FuncOp>(mlir::createForLoopSpecializationPass());
}
// Approximate Tanh using standard operations.
pm.addNestedPass<::mlir::FuncOp>(
::mlir::mhlo::createLegalizeTrigonometricToApproximationPass());
// Take launches to launches with kernels.
pm.addPass(::mlir::createGpuKernelOutliningPass());
if (!cpu_codegen) {
pm.addPass(::mlir::createGpuKernelOutliningPass());
}
pm.addPass(::mlir::createLowerAffinePass());
// Constraints are removed as late as possible and before lowering to CFG.
@ -244,6 +270,10 @@ Status LowerTFtoGPU(mlir::ModuleOp module, llvm::ArrayRef<uint32_t> tile_sizes,
if (failed(pm.run(module))) {
return InternalError("Lowering to GPU kernels failed.");
}
return Status::OK();
}
Status LowerKernelBodiesToLowLevelIr(mlir::ModuleOp module) {
auto gpu_modules = module.getOps<mlir::gpu::GPUModuleOp>();
auto num_modules = std::distance(gpu_modules.begin(), gpu_modules.end());
if (num_modules != 1) {
@ -252,10 +282,6 @@ Status LowerTFtoGPU(mlir::ModuleOp module, llvm::ArrayRef<uint32_t> tile_sizes,
<< ". Currently we leak memory if there is more than one "
"module, see https://bugs.llvm.org/show_bug.cgi?id=48385";
}
return Status::OK();
}
Status LowerKernelBodiesToLowLevelIr(mlir::ModuleOp module) {
#if !defined(TENSORFLOW_USE_ROCM) && !defined(GOOGLE_CUDA)
return InternalError(
"Neither TENSORFLOW_USE_ROCM nor GOOGLE_CUDA are defined."
@ -337,18 +363,23 @@ StatusOr<mlir::OwningModuleRef> GenerateKernelForTfCode(
llvm::ArrayRef<std::string> architectures,
llvm::ArrayRef<uint32_t> tile_sizes,
llvm::ArrayRef<uint32_t> unroll_factors, bool embed_memref_prints,
bool generate_fatbin, bool print_ptx, bool enable_ftz) {
bool generate_fatbin, bool print_ptx, bool enable_ftz, bool cpu_codegen) {
auto& registry = context.getDialectRegistry();
mlir::RegisterAllTensorFlowDialects(registry);
registry.insert<mlir::chlo::HloClientDialect, mlir::mhlo::MhloDialect>();
mlir::OwningModuleRef module = mlir::parseSourceString(tf_code, &context);
TF_RETURN_IF_ERROR(LowerTFtoGPU(module.get(), tile_sizes, unroll_factors,
embed_memref_prints));
TF_RETURN_IF_ERROR(LowerKernelBodiesToLowLevelIr(module.get()));
TF_RETURN_IF_ERROR(AmendKernelLLVMIRWithStaticKnowledge(module.get()));
TF_RETURN_IF_ERROR(GenerateDeviceCode(module.get(), kGpuBinaryAttrName,
architectures, generate_fatbin,
print_ptx, enable_ftz));
TilingParams tiling_params = ComputeTilingParas(tile_sizes, unroll_factors);
TF_RETURN_IF_ERROR(LowerTFtoLoops(module.get(), tiling_params));
TF_RETURN_IF_ERROR(LowerLoopsToGPUorCPU(module.get(), tiling_params,
embed_memref_prints, cpu_codegen));
if (!cpu_codegen) {
TF_RETURN_IF_ERROR(LowerKernelBodiesToLowLevelIr(module.get()));
TF_RETURN_IF_ERROR(AmendKernelLLVMIRWithStaticKnowledge(module.get()));
TF_RETURN_IF_ERROR(GenerateDeviceCode(module.get(), kGpuBinaryAttrName,
architectures, generate_fatbin,
print_ptx, enable_ftz));
}
TF_RETURN_IF_ERROR(LowerHostSideToFinalForm(module.get()));
return module;
}

View File

@ -33,14 +33,14 @@ limitations under the License.
namespace tensorflow {
namespace kernel_gen {
// Converts TF code to LLVM/NVVM. Lowers the host side to LLVM Dialect.
// Converts TF code to LLVM with or without GPU support.
xla::StatusOr<mlir::OwningModuleRef> GenerateKernelForTfCode(
mlir::MLIRContext& context, llvm::StringRef tf_code,
llvm::ArrayRef<std::string> architectures = {"sm_75"},
llvm::ArrayRef<uint32_t> tile_sizes = {16, 64},
llvm::ArrayRef<uint32_t> unroll_factors = {},
bool embed_memref_prints = false, bool generate_fatbin = true,
bool print_ptx = false, bool enable_ftz = false);
bool print_ptx = false, bool enable_ftz = false, bool cpu_codegen = false);
} // namespace kernel_gen
} // namespace tensorflow

View File

@ -106,7 +106,8 @@ xla::Status Run(llvm::StringRef input_file, llvm::StringRef output_file,
llvm::ArrayRef<std::string> architectures,
llvm::ArrayRef<uint32_t> tile_sizes,
llvm::ArrayRef<uint32_t> unroll_factors,
bool embed_memref_prints, bool print_ptx, bool enable_ftz) {
bool embed_memref_prints, bool print_ptx, bool enable_ftz,
bool cpu_codegen) {
// Read TF code.
std::string tf_code;
TF_RETURN_IF_ERROR(
@ -117,7 +118,8 @@ xla::Status Run(llvm::StringRef input_file, llvm::StringRef output_file,
mlir::OwningModuleRef module,
GenerateKernelForTfCode(context, tf_code, architectures, tile_sizes,
unroll_factors, embed_memref_prints,
/*generate_fatbin=*/true, print_ptx, enable_ftz));
/*generate_fatbin=*/true, print_ptx, enable_ftz,
cpu_codegen));
// Get binary.
TF_ASSIGN_OR_RETURN(std::string binary, EmitToBinary(*module));
@ -138,18 +140,21 @@ int main(int argc, char** argv) {
llvm::cl::opt<std::string> output_file(
"output", llvm::cl::desc("output file"), llvm::cl::value_desc("filename"),
llvm::cl::init("foo.bin"));
llvm::cl::opt<bool> cpu_codegen("cpu_codegen",
llvm::cl::desc("enable CPU code generation"),
llvm::cl::init(false));
llvm::cl::opt<bool> embed_memref_prints(
"embed_memref_prints",
llvm::cl::desc("embeds memref prints at the end of their lifetime"),
llvm::cl::desc("embed memref prints at the end of their lifetime"),
llvm::cl::init(false));
llvm::cl::opt<bool> print_ptx(
"print-ptx",
llvm::cl::desc("Print generated PTX code per target architecture."),
llvm::cl::desc("print generated PTX code per target architecture."),
llvm::cl::init(false));
llvm::cl::opt<bool> enable_ftz(
"enable_ftz",
llvm::cl::desc(
"Enable the denormal flush to zero mode when generating code."),
"enable the denormal flush to zero mode when generating code."),
llvm::cl::init(false));
llvm::cl::list<std::string> architectures(
"arch", llvm::cl::desc("target architectures (e.g. sm_70 or compute_75)"),
@ -166,11 +171,11 @@ int main(int argc, char** argv) {
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
mlir::registerPassManagerCLOptions();
llvm::cl::ParseCommandLineOptions(argc, argv, "TF op GPU kernel generator\n");
llvm::cl::ParseCommandLineOptions(argc, argv, "TF op kernel generator\n");
auto status = tensorflow::kernel_gen::Run(
input_file, output_file, architectures, tile_sizes, unroll_factors,
embed_memref_prints, print_ptx, enable_ftz);
embed_memref_prints, print_ptx, enable_ftz, cpu_codegen);
if (!status.ok()) {
LOG(ERROR) << status;
return 1;

View File

@ -49,7 +49,9 @@ class GpuKernelToNVVMPass
GPUModuleOp m = getOperation();
OwningRewritePatternList patterns;
LLVMTypeConverter converter(m.getContext());
mlir::LowerToLLVMOptions llvm_opts;
llvm_opts.indexBitwidth = 32;
LLVMTypeConverter converter(m.getContext(), llvm_opts);
populateStdToLLVMConversionPatterns(converter, patterns);
populateGpuToNVVMConversionPatterns(converter, patterns);
populateComplexToLLVMConversionPatterns(converter, patterns);

View File

@ -1075,7 +1075,7 @@ func @checkNumerics(%arg0: tensor<1xf32>) -> tensor<1xf32> {
// CHECK-LABEL: func @infeed_dequeue_tuple
func @infeed_dequeue_tuple() -> (tensor<3xi32>, tensor<4xf32>) {
// CHECK: [[TOKEN:%.*]] = "mhlo.create_token"() : () -> !mhlo.token
// CHECK: [[INFEED:%.*]] = "mhlo.infeed"([[TOKEN]]) {infeed_config = "", layout = [{{\[\[0], \[0]]}}, unit]} : (!mhlo.token) -> tuple<tuple<tensor<3xi32>, tensor<4xf32>>, !mhlo.token>
// CHECK: [[INFEED:%.*]] = "mhlo.infeed"([[TOKEN]]) {infeed_config = ""} : (!mhlo.token) -> tuple<tuple<tensor<3xi32>, tensor<4xf32>>, !mhlo.token>
// CHECK: [[INFEED_VAL:%.*]] = "mhlo.get_tuple_element"([[INFEED]]) {index = 0 : i32} : (tuple<tuple<tensor<3xi32>, tensor<4xf32>>, !mhlo.token>) -> tuple<tensor<3xi32>, tensor<4xf32>>
// CHECK: [[RES_1:%.*]] = "mhlo.get_tuple_element"([[INFEED_VAL]]) {index = 0 : i32} : (tuple<tensor<3xi32>, tensor<4xf32>>) -> tensor<3xi32>
// CHECK: [[RES_2:%.*]] = "mhlo.get_tuple_element"([[INFEED_VAL]]) {index = 1 : i32} : (tuple<tensor<3xi32>, tensor<4xf32>>) -> tensor<4xf32>
@ -5149,3 +5149,12 @@ func @replica_id() -> tensor<i32> {
return %0 : tensor<i32>
}
// CHECK: func @angle_c64
// CHECK-SAME: ([[ARG0:%.*]]: tensor<complex<f32>>)
func @angle_c64(%arg0: tensor<complex<f32>>) -> tensor<f32> {
// CHECK: [[IMAG:%.*]] = "mhlo.imag"([[ARG0]])
// CHECK: [[REAL:%.*]] = "mhlo.real"([[ARG0]])
// CHECK: [[ATAN2:%.*]] = mhlo.atan2 [[IMAG]], [[REAL]]
%0 = "tf.Angle"(%arg0): (tensor<complex<f32>>) -> tensor<f32>
return %0 : tensor<f32>
}

View File

@ -70,7 +70,6 @@ namespace mhlo {
namespace {
constexpr char kShardingAttr[] = "mhlo.sharding";
constexpr char kLayoutAttr[] = "layout";
class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
void getDependentDialects(DialectRegistry &registry) const override {
@ -4609,9 +4608,6 @@ class ConvertInfeedDequeueTupleOp
rewriter.create<InfeedOp>(op.getLoc(), data_and_token_type, token,
/*infeed_config=*/rewriter.getStringAttr(""));
data_and_token->setAttr(kLayoutAttr,
GetLayout(data_and_token_type, rewriter));
if (op._XlaSharding().hasValue()) {
// _XlaSharding attribute in TF is a serialized string of the OpSharding
// proto, so convert to a text form here.
@ -6144,7 +6140,7 @@ LogicalResult legalizeTF(
PopulateLegalizeTfPatterns(context, &patterns);
// Add TF->TF lowering patterns.
TF::PopulateLoweringTFPatterns(context, &patterns);
TF::PopulateTFLoweringBeforeHLOPatterns(context, &patterns);
// Add TF->HLO legalization patterns via TF2XLA fallback.
if (tf2xla_fallback_device_type.hasValue()) {

View File

@ -595,6 +595,7 @@ foreach Mapping = [
[TF_ComplexAbsOp, HLO_AbsOp],
[TF_ConjOp, HLOClient_ConjOp],
[TF_CosOp, HLO_CosOp],
[TF_DigammaOp, HLOClient_DigammaOp],
[TF_ExpOp, HLO_ExpOp],
[TF_ErfOp, HLOClient_ErfOp],
[TF_ErfcOp, HLOClient_ErfcOp],
@ -620,6 +621,8 @@ foreach Mapping = [
(Mapping[1] $input)>;
}
def : Pat<(TF_AngleOp $x), (HLO_Atan2Op (HLO_ImagOp $x), (HLO_RealOp $x))>;
// TODO(bixia): Lower Cast with a Complex type source operand or with
// Truncate=True for floating point value conversions.
def : Pat<(TF_CastOp HLO_PredIntOrFpTensor:$arg, ConstBoolAttrFalse),

View File

@ -184,7 +184,7 @@ class XlaBuilder {
// Similar to SetOpMetadata, but only set the metadata for the next op.
void SetOneShotOpMetadata(OpMetadata metadata) {
metadata_ = std::move(metadata);
one_shot_metadata_ = std::move(metadata);
}
// Clears the HloMetadata state.

View File

@ -335,10 +335,15 @@ cc_library(
visibility = ["//visibility:private"],
deps = [
":absl_casters",
":jax_jit",
":py_client",
":types",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/core/platform:logging",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
"@com_google_absl//absl/types:variant",
"@pybind11",
],

View File

@ -855,7 +855,7 @@ class CompiledFunction {
py::object Call(py::args args, py::kwargs kwargs);
// This allows `inspect.signature(cpp_jitted_f)` from Python.
py::object __signature__() {
py::object PythonSignature() {
static const auto* inspect = new py::module(py::module::import("inspect"));
return inspect->attr("signature")(fun_);
}
@ -1212,7 +1212,8 @@ void BuildJaxjitSubmodule(pybind11::module& m) {
py::class_<CompiledFunction, std::unique_ptr<CompiledFunction>> cfun(
jitlib, "CompiledFunction");
cfun.def("__call__", &CompiledFunction::Call);
cfun.def_property_readonly("__signature__", &CompiledFunction::__signature__);
cfun.def_property_readonly("__signature__",
&CompiledFunction::PythonSignature);
jitlib.def("set_disable_jit", &SetDisableJit);
jitlib.def("get_disable_jit", &GetDisableJit);

View File

@ -21,11 +21,17 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/synchronization/notification.h"
#include "absl/types/span.h"
#include "absl/types/variant.h"
#include "pybind11/cast.h"
#include "pybind11/pybind11.h"
#include "pybind11/pytypes.h"
#include "tensorflow/compiler/xla/python/absl_casters.h"
#include "tensorflow/compiler/xla/python/jax_jit.h"
#include "tensorflow/compiler/xla/python/py_executable.h"
#include "tensorflow/compiler/xla/python/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/logging.h"
namespace jax {
@ -90,6 +96,246 @@ pybind11::tuple CppMeshMappingToPy(
return result;
}
namespace {
struct PmapCacheEntry {
// To get a first version running, we use extensively Python here for the
// handling of the arguments and outputs.
// TODO(jblespiau): Move more to C++.
std::shared_ptr<xla::PyExecutable> executable;
// See _cpp_pmap in api.py.
py::object backend;
// A function taking as argument a list of arguments and returns a list of
// list of buffers `[num_devices x num_args]`.
py::function handle_args;
// A function taking as argument the output of `ExecuteOnLocalDevices` and
// returning a list of ShardedDeviceArray objects.
py::function out_handler;
xla::PyTreeDef out_pytree_def;
// Ensures a single thread performs the compilation for a given executable.
//
// The first thread (holding the GIL) will create the CacheEntry associated to
// a signature and if the object has been insterted already, other threads
// will wait for the notification.
absl::Notification compilation_complete;
absl::optional<xla::Status> compilation_error = absl::nullopt;
bool fall_back_to_python = false;
};
} // namespace
// A `PmapFunction` is associated to a `jax.pmap(f)` and takes care of the
// bookkeeping of the different signatures used and the dispatch of calls to
// the correct underlying `PyExecutable`. This class is thread-safe.
class PmapFunction {
public:
PmapFunction(py::function fun, py::function cache_miss,
py::function get_jax_enable_x64, std::vector<int> static_argnums)
: fun_(std::move(fun)),
cache_miss_(std::move(cache_miss)),
static_argnums_(std::move(static_argnums)),
get_jax_enable_x64_(get_jax_enable_x64) {
std::sort(static_argnums_.begin(), static_argnums_.end());
}
~PmapFunction() {
for (const auto& entry : executables_) {
entry.first.DecRef();
}
}
// This function will:
// (a) flatten the inputs using pytree
// (b) get buffer objects from the arguments
// (c) call the executable
// (d) construct `ShardedDeviceArray` objects from the outputs
// (e) reconstruct the `PyTree`.
py::object Call(py::args args, py::kwargs kwargs);
py::object PythonSignature() {
static const auto* inspect = new py::module(py::module::import("inspect"));
return inspect->attr("signature")(fun_);
}
int cache_size() const { return executables_.size(); }
private:
// Returns nullptr if not present in the cache.
PmapCacheEntry* GetCacheEntryIfPresent(const CallSignature& signature);
// Should never return nullptr.
PmapCacheEntry* AddCacheEntry(const py::args& args, const py::kwargs& kwargs,
const CallSignature& signature,
py::object out_and_fastpath_data);
bool always_fallback_to_python_ = false;
const py::function fun_; // The Python function to pmap.
// See JAX _cpp_pmap in api.py for documentation.
const py::function cache_miss_;
// We need to know the static arguments to remove them from the arguments
// passed to the underlying PyExecutable. In sorted order.
std::vector<int> static_argnums_;
// We need a `unique_ptr` here to ensure value pointer stability.
absl::flat_hash_map<CallSignature, std::unique_ptr<PmapCacheEntry>>
executables_;
const py::function get_jax_enable_x64_;
absl::optional<bool> jax_enable_x64_ = absl::nullopt;
// A vector of size `num_outputs`, specifying the sharding of each output
std::vector<ShardingSpec> sharding_specs_;
};
PmapCacheEntry* PmapFunction::GetCacheEntryIfPresent(
const CallSignature& signature) {
auto found_iterator = executables_.find(signature);
if (found_iterator != executables_.end()) { // Cache hit!
if (!found_iterator->second->compilation_complete.HasBeenNotified()) {
py::gil_scoped_release gil_release;
found_iterator->second->compilation_complete.WaitForNotification();
}
if (found_iterator->second->compilation_error) {
throw std::invalid_argument(
found_iterator->second->compilation_error.value().error_message());
}
return found_iterator->second.get();
}
return nullptr;
}
PmapCacheEntry* PmapFunction::AddCacheEntry(const py::args& args,
const py::kwargs& kwargs,
const CallSignature& signature,
py::object out_and_fastpath_data) {
// We need to insert the element.
auto result =
executables_.emplace(signature, std::make_unique<PmapCacheEntry>());
auto it = result.first;
PmapCacheEntry* cache_entry = it->second.get();
// CallSignatures in the cache own their keyword argument reference.
result.first->first.IncRef();
py::tuple tuple = py::cast<py::tuple>(out_and_fastpath_data);
CHECK_EQ(tuple.size(), 2);
if (tuple[1].is_none()) {
cache_entry->fall_back_to_python = true;
cache_entry->compilation_complete.Notify();
return cache_entry;
}
py::dict pmap_data = py::cast<py::dict>(tuple[1]);
if (py::cast<int>(pmap_data["version"]) != 1) {
throw std::runtime_error(absl::StrCat(
"The versions of jaxlib and Jax are incompatible (pmap cpp version 1 "
"expected, but got ",
py::cast<int>(pmap_data["version"]),
"Upgrade jaxlib and jax. Provided data was:",
py::cast<std::string>(py::str(py::repr(pmap_data)))));
}
// { "version": 1,
// "xla_executable": xla_executable,
// "in_handler": in_handler,
// "out_handler": out_handler,
// "out_pytree_def": out_pytree_def }
auto executable =
py::cast<std::shared_ptr<xla::PyExecutable>>(pmap_data["xla_executable"]);
cache_entry->executable = std::move(executable);
cache_entry->handle_args = py::cast<py::function>(pmap_data["in_handler"]);
cache_entry->out_handler = py::cast<py::function>(pmap_data["out_handler"]);
auto out_tree = py::cast<xla::PyTreeDef>(pmap_data["out_pytree_def"]);
cache_entry->out_pytree_def = std::move(out_tree);
cache_entry->compilation_complete.Notify();
return cache_entry;
}
py::object PmapFunction::Call(py::args args, py::kwargs kwargs) {
if (always_fallback_to_python_) {
return py::cast<py::tuple>(cache_miss_(*args, **kwargs))[0];
}
// Delayed values are retrieved on the first call to `Call`.
if (jax_enable_x64_ == absl::nullopt) {
jax_enable_x64_ = py::cast<bool>(get_jax_enable_x64_());
}
ParsedArgumentsAsBuffers arguments;
if (!ParseArguments(args, kwargs, static_argnums_, arguments).ok()) {
return py::cast<py::tuple>(cache_miss_(*args, **kwargs))[0];
}
// Get dynamic argument signatures.
for (py::handle arg : arguments.flat_dynamic_args) {
auto signature_or_error = ArgSignatureOfValue(arg, jax_enable_x64_.value());
if (!signature_or_error.ok()) {
return py::cast<py::tuple>(cache_miss_(*args, **kwargs))[0];
}
arguments.signature.dynamic_args_signatures.push_back(
std::move(signature_or_error).ValueOrDie());
}
// Retrieve/Maybe add the executable to the cache.
PmapCacheEntry* cache_entry = GetCacheEntryIfPresent(arguments.signature);
if (!cache_entry) {
py::object out_and_fastpath_data = cache_miss_(*args, **kwargs);
cache_entry = GetCacheEntryIfPresent(arguments.signature);
if (!cache_entry) {
cache_entry = AddCacheEntry(args, kwargs, arguments.signature,
out_and_fastpath_data);
}
CHECK(cache_entry);
if (cache_entry->fall_back_to_python) {
return py::cast<py::tuple>(out_and_fastpath_data)[0];
}
// As we have already computed the results, we can return it.
// It's even *required* e.g. if there are donated arguments, because
// otherwise the buffer which has been donated already will be invalid.
return py::cast<py::tuple>(out_and_fastpath_data)[0];
}
CHECK(cache_entry);
if (cache_entry->fall_back_to_python) {
return py::cast<py::tuple>(cache_miss_(*args, **kwargs))[0];
}
// TODO(jblespiau): Use C++ only for this.
py::list arg_list;
for (auto& v : arguments.flat_dynamic_args) {
arg_list.append(v);
}
py::object handled_args = cache_entry->handle_args(arg_list);
py::list list_of_list_of_buffers = py::cast<py::list>(handled_args);
arguments.keep_alive_objects.push_back(
py::cast<py::object>(list_of_list_of_buffers));
// Should be `[num_devices x num_args]`.
std::vector<std::vector<xla::PyBuffer*>> arg_buffers;
arg_buffers.reserve(list_of_list_of_buffers.size());
for (int i = 0; i < list_of_list_of_buffers.size(); ++i) {
std::vector<xla::PyBuffer*> buffers;
buffers.reserve(py::cast<py::list>(list_of_list_of_buffers[i]).size());
for (auto& buf : list_of_list_of_buffers[i]) {
buffers.push_back(py::cast<xla::PyBuffer*>(buf));
}
arg_buffers.push_back(std::move(buffers));
}
std::vector<std::vector<std::unique_ptr<xla::PyBuffer>>> outputs =
ValueOrThrow(cache_entry->executable->ExecuteOnLocalDevices(arg_buffers));
// TODO(jblespiau): Move this to C++.
py::list outputs_as_python_objects;
for (int i = 0; i < outputs.size(); ++i) {
outputs_as_python_objects.append(py::cast(std::move(outputs[i])));
}
py::list flat_sharded_device_arrays =
cache_entry->out_handler(outputs_as_python_objects);
return cache_entry->out_pytree_def.Unflatten(flat_sharded_device_arrays);
}
void BuildPmapSubmodule(pybind11::module& m) {
py::module pmap_lib = m.def_submodule("pmap_lib", "Jax C++ pmap library");
@ -167,6 +413,21 @@ void BuildPmapSubmodule(pybind11::module& m) {
&ShardedDeviceArray::GetShardingSpec)
.def_property_readonly("device_buffers",
&ShardedDeviceArray::GetDeviceBuffers);
py::class_<PmapFunction, std::unique_ptr<PmapFunction>> cfun(pmap_lib,
"PmapFunction");
cfun.def("__call__", &PmapFunction::Call);
cfun.def_property_readonly("__signature__", &PmapFunction::PythonSignature);
pmap_lib.def(
"pmap",
[](py::function fun, py::function cache_miss,
py::function get_jax_enable_x64,
std::vector<int> static_argnums) -> std::unique_ptr<PmapFunction> {
return std::make_unique<PmapFunction>(
std::move(fun), std::move(cache_miss),
std::move(get_jax_enable_x64), std::move(static_argnums));
});
}
} // namespace jax

View File

@ -70,12 +70,12 @@ struct Chunked {
};
// `Unstacked` means that the dimension is split into chunks of size 1, and
// doesn't appear inside the map. `size` is alwyays the dimension size.
// doesn't appear inside the map. `size` is always the dimension size.
// For example, a Tensor t of shape [N] will be sharded into N tensors of shape
// [], when using `Unstacked(N)`.
struct Unstacked {
public:
explicit Unstacked(int size_) : size(size_) {}
explicit Unstacked(int sz) : size(sz) {}
const int size;
bool operator==(const Unstacked& other) const { return size == other.size; }
@ -121,7 +121,6 @@ pybind11::tuple CppMeshMappingToPy(std::vector<MeshDimAssignment> mesh_mapping);
// Describes how each axis is sharded (if it is), and how it'smapped to the
// devices mesh.
// See `AvalDimSharding` and `MeshDimAssignment`.
class ShardingSpec {
public:
ShardingSpec(std::vector<AvalDimSharding> sharding,

View File

@ -1498,11 +1498,11 @@ void HloInstruction::set_single_sharding(const HloSharding& sharding) {
void HloInstruction::SetupDerivedInstruction(
HloInstruction* derived_instruction) const {
if (sharding_ != nullptr && ShapeUtil::CompatibleIgnoringElementType(
shape_, derived_instruction->shape())) {
// Only copy sharding if the shape of the two instruction is compatible
// because copying it between differently shaped instructions can produce
// invalid shardings.
if (sharding_ != nullptr &&
ShapeUtil::CompatibleKind(shape_, derived_instruction->shape())) {
// Only copy sharding if the tuple tree shape of the two instruction is
// compatible because copying it between differently shaped instructions
// can produce invalid shardings.
derived_instruction->set_sharding(*sharding_);
} else {
derived_instruction->clear_sharding();

View File

@ -749,6 +749,64 @@ TEST_F(HloInstructionTest, PreserveTupleShapeThroughClone) {
EXPECT_TRUE(ShapeUtil::Equal(tuple_clone->shape(), tuple->shape()));
}
TEST_F(HloInstructionTest, PreserveShardingThroughCompatibleClone) {
HloSharding sharding = HloSharding::AssignDevice(5);
HloComputation::Builder builder(TestName());
auto* constant = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR2<float>({
{1, 2},
{3, 4},
})));
auto* tuple =
builder.AddInstruction(HloInstruction::CreateTuple({constant, constant}));
tuple->set_sharding(sharding);
// Compatible with original shape as tuple tree structure and leaf ranks are
// identical
auto clone_shape = ShapeUtil::MakeShape(F32, {3, 3});
clone_shape = ShapeUtil::MakeTupleShape({clone_shape, clone_shape});
auto tuple_clone = tuple->CloneWithNewOperands(clone_shape, {});
EXPECT_EQ(tuple_clone->sharding(), sharding);
}
TEST_F(HloInstructionTest,
DoNotPreserveShardingThroughTupleTreeIncompatibleClone) {
HloSharding sharding = HloSharding::AssignDevice(5);
HloComputation::Builder builder(TestName());
auto* constant = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR2<float>({
{1, 2},
{3, 4},
})));
auto* tuple =
builder.AddInstruction(HloInstruction::CreateTuple({constant, constant}));
tuple->set_sharding(sharding);
// Incompatible with original shape as tuple tree structure is different
auto clone_shape = ShapeUtil::MakeShape(F32, {2, 2});
clone_shape =
ShapeUtil::MakeTupleShape({clone_shape, clone_shape, clone_shape});
auto tuple_clone = tuple->CloneWithNewOperands(clone_shape, {});
EXPECT_FALSE(tuple_clone->has_sharding());
}
TEST_F(HloInstructionTest,
DoNotPreserveShardingThroughLeafRankIncompatibleClone) {
HloSharding sharding = HloSharding::AssignDevice(5);
HloComputation::Builder builder(TestName());
auto* constant = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR2<float>({
{1, 2},
{3, 4},
})));
auto* tuple =
builder.AddInstruction(HloInstruction::CreateTuple({constant, constant}));
tuple->set_sharding(sharding);
// Incompatible with original shape as tuple tree structure is different
auto clone_shape = ShapeUtil::MakeShape(F32, {1, 2, 3});
clone_shape = ShapeUtil::MakeTupleShape({clone_shape, clone_shape});
auto tuple_clone = tuple->CloneWithNewOperands(clone_shape, {});
EXPECT_FALSE(tuple_clone->has_sharding());
}
TEST_F(HloInstructionTest, FusionOpWithCalledComputations) {
// Create a fusion instruction containing a single unary operation.
const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});

View File

@ -141,9 +141,16 @@ bool Shape::Equal::operator()(const Shape& lhs, const Shape& rhs) {
}
}
if (!ShapeUtil::SameDimensions(lhs, rhs)) {
VLOG(3) << "CompareShapes: lhs dimensions != rhs dimensions";
return false;
if (!ignore_dimensions_) {
if (!ShapeUtil::SameDimensions(lhs, rhs)) {
VLOG(3) << "CompareShapes: lhs dimensions != rhs dimensions";
return false;
}
} else {
if (!ShapeUtil::SameRank(lhs, rhs)) {
VLOG(3) << "CompareShapes: lhs rank != rhs rank";
return false;
}
}
if (!ignore_layout_) {

View File

@ -220,6 +220,10 @@ class Shape {
ignore_dynamic_dimension_ = true;
return *this;
}
Equal& IgnoreDimensions() {
ignore_dimensions_ = true;
return *this;
}
private:
bool ignore_layout_ = false;
@ -229,6 +233,7 @@ class Shape {
bool ignore_element_type_ = false;
bool ignore_fp_precision_ = false;
bool ignore_dynamic_dimension_ = false;
bool ignore_dimensions_ = false;
};
// Test that all fields of the shape are the same, equivalent to Equal().

View File

@ -644,6 +644,12 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
return absl::c_equal(lhs.dimensions(), rhs.dimensions());
}
/* static */ bool ShapeUtil::SameRank(const Shape& lhs, const Shape& rhs) {
CHECK(lhs.IsArray());
CHECK(rhs.IsArray());
return lhs.rank() == rhs.rank();
}
/* static */ bool ShapeUtil::Compatible(const Shape& lhs, const Shape& rhs) {
return Shape::Equal().IgnoreDynamicDimension().IgnoreLayout()(lhs, rhs);
}
@ -656,6 +662,15 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
.IgnoreLayout()(lhs, rhs);
}
/* static */ bool ShapeUtil::CompatibleKind(const Shape& lhs,
const Shape& rhs) {
return Shape::Equal()
.IgnoreElementType()
.IgnoreLayout()
.IgnoreDimensions()
.IgnoreDynamicDimension()(lhs, rhs);
}
/* static */ bool ShapeUtil::CompatibleIgnoringFpPrecision(const Shape& lhs,
const Shape& rhs) {
return Shape::Equal()

View File

@ -247,6 +247,11 @@ class ShapeUtil {
// Precondition: IsArray(lhs) && IsArray(rhs)
static bool SameDimensions(const Shape& lhs, const Shape& rhs);
// Returns whether the LHS and RHS shapes have the same rank; note: does
// not check element type.
// Precondition: IsArray(lhs) && IsArray(rhs)
static bool SameRank(const Shape& lhs, const Shape& rhs);
// Returns whether the lhs and rhs shapes have the same element type.
static bool SameElementType(const Shape& lhs, const Shape& rhs) {
return lhs.element_type() == rhs.element_type();
@ -308,6 +313,11 @@ class ShapeUtil {
// compatibility.
static bool CompatibleIgnoringElementType(const Shape& lhs, const Shape& rhs);
// Returns true if the tuple tree shapes and leaf ranks are identical.
// Leaf dimensions, element type, and layout are ignored. Tuple elements are
// compared recursively for compatibility.
static bool CompatibleKind(const Shape& lhs, const Shape& rhs);
// As Compatible, but allow one of lhs and rhs to be BF16 while the other
// being F32. Tuple elements are compared recursively for compatibility.
static bool CompatibleIgnoringFpPrecision(const Shape& lhs, const Shape& rhs);

View File

@ -27,9 +27,12 @@ REGISTER_COMPLEX(CPU, float, complex64);
REGISTER_COMPLEX(CPU, double, complex128);
#if GOOGLE_CUDA
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
!defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
REGISTER_COMPLEX(GPU, float, complex64);
REGISTER_COMPLEX(GPU, double, complex128);
#endif
#endif
#undef REGISTER_COMPLEX
} // namespace tensorflow

View File

@ -16,10 +16,16 @@ limitations under the License.
#include "tensorflow/core/kernels/cwise_ops_common.h"
namespace tensorflow {
REGISTER3(UnaryOp, CPU, "Digamma", functor::digamma, float, Eigen::half,
double);
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
!defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
REGISTER3(UnaryOp, GPU, "Digamma", functor::digamma, float, Eigen::half,
double);
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#endif
#endif
} // namespace tensorflow

View File

@ -20,8 +20,11 @@ limitations under the License.
namespace tensorflow {
namespace functor {
#if GOOGLE_CUDA
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
!defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
DEFINE_UNARY2(get_angle, complex64, complex128);
#endif
#endif
} // namespace functor
} // namespace tensorflow

View File

@ -19,8 +19,7 @@ limitations under the License.
namespace tensorflow {
namespace functor {
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
!defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED)
DEFINE_UNARY3(isfinite, Eigen::half, float, double);
#endif
} // namespace functor

View File

@ -19,8 +19,7 @@ limitations under the License.
namespace tensorflow {
namespace functor {
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
!defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED)
DEFINE_UNARY3(isinf, Eigen::half, float, double);
#endif
} // namespace functor

View File

@ -19,8 +19,7 @@ limitations under the License.
namespace tensorflow {
namespace functor {
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
!defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED)
DEFINE_UNARY3(isnan, Eigen::half, float, double);
#endif
} // namespace functor

View File

@ -20,8 +20,7 @@ REGISTER4(UnaryOp, CPU, "IsFinite", functor::isfinite, float, Eigen::half,
bfloat16, double);
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
!defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED)
REGISTER3(UnaryOp, GPU, "IsFinite", functor::isfinite, float, Eigen::half,
double);
#endif

View File

@ -20,8 +20,7 @@ REGISTER4(UnaryOp, CPU, "IsInf", functor::isinf, float, Eigen::half, bfloat16,
double);
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
!defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED)
REGISTER3(UnaryOp, GPU, "IsInf", functor::isinf, float, Eigen::half, double);
#endif
#endif

View File

@ -20,8 +20,7 @@ REGISTER4(UnaryOp, CPU, "IsNan", functor::isnan, float, Eigen::half, double,
bfloat16);
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
!defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED)
REGISTER3(UnaryOp, GPU, "IsNan", functor::isnan, float, Eigen::half, double);
#endif
#endif

View File

@ -52,6 +52,9 @@ filegroup(
"gpu_op_floor.cc",
"gpu_op_imag.cc",
"gpu_op_invert.cc",
"gpu_op_is_finite.cc",
"gpu_op_is_inf.cc",
"gpu_op_is_nan.cc",
"gpu_op_log.cc",
"gpu_op_log1p.cc",
"gpu_op_logical_not.cc",
@ -71,13 +74,12 @@ filegroup(
srcs = [
"gpu_op_acos.cc",
"gpu_op_acosh.cc",
"gpu_op_angle.cc",
"gpu_op_asin.cc",
"gpu_op_asinh.cc",
"gpu_op_digamma.cc",
"gpu_op_exp.cc",
"gpu_op_expm1.cc",
"gpu_op_is_finite.cc",
"gpu_op_is_inf.cc",
"gpu_op_is_nan.cc",
"gpu_op_lgamma.cc",
"gpu_op_sign.cc",
"gpu_op_sin.cc",
@ -97,9 +99,9 @@ filegroup(
)
cc_library(
name = "gpu_ops_base",
srcs = ["gpu_ops_base.cc"],
hdrs = ["gpu_ops_base.h"],
name = "base_op",
srcs = ["base_op.cc"],
hdrs = ["base_op.h"],
compatible_with = get_compatible_with_cloud(),
deps = [
"//tensorflow/compiler/mlir/tools/kernel_gen:tf_framework_c_interface",
@ -113,6 +115,13 @@ cc_library(
],
)
cc_library(
name = "base_gpu_op",
hdrs = ["base_gpu_op.h"],
compatible_with = get_compatible_with_cloud(),
deps = [":base_op"],
)
tf_kernel_library(
name = "cwise_unary_op",
srcs = [":unary_kernel_srcs"],
@ -129,6 +138,7 @@ tf_kernel_library(
":square_kernels",
":acos_kernels",
":acosh_kernels",
":angle_kernels",
":asin_kernels",
":asinh_kernels",
":atan_kernels",
@ -139,6 +149,7 @@ tf_kernel_library(
":conj_kernels",
":cos_kernels",
":cosh_kernels",
":digamma_kernels",
":erf_kernels",
":erfc_kernels",
":exp_kernels",
@ -162,7 +173,7 @@ tf_kernel_library(
":sqrt_kernels",
":tan_kernels",
":tanh_kernels",
":gpu_ops_base",
":base_gpu_op",
"//third_party/eigen3",
],
)
@ -200,13 +211,13 @@ tf_kernel_library(
deps = [
":add_v2_kernels",
":atan2_kernels",
":base_gpu_op",
":bitwise_and_kernels",
":bitwise_or_kernels",
":bitwise_xor_kernels",
":div_kernels",
":equal_kernels",
":floor_div_kernels",
":gpu_ops_base",
":greater_equal_kernels",
":greater_kernels",
":left_shift_kernels",
@ -354,7 +365,6 @@ gen_kernel_library(
"f32",
"f64",
],
unroll_factors = "4",
)
gen_kernel_library(
@ -364,7 +374,20 @@ gen_kernel_library(
"f32",
"f64",
],
unroll_factors = "4",
)
gen_kernel_library(
name = "angle",
output_types = [
"f32",
"f64",
],
tile_size = "256",
types = [
"c64",
"c128",
],
unroll_factors = "2",
)
gen_kernel_library(
@ -374,7 +397,6 @@ gen_kernel_library(
"f32",
"f64",
],
unroll_factors = "4",
)
gen_kernel_library(
@ -384,7 +406,6 @@ gen_kernel_library(
"f32",
"f64",
],
unroll_factors = "4",
)
gen_kernel_library(
@ -514,7 +535,6 @@ gen_kernel_library(
"f32",
"f64",
],
unroll_factors = "4",
)
[
@ -769,6 +789,7 @@ gen_kernel_library(
)
for name in [
"ceil",
"digamma",
"exp",
"expm1",
"floor",

View File

@ -0,0 +1,77 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_BASE_GPU_OP_H_
#define TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_BASE_GPU_OP_H_
#include "tensorflow/core/kernels/mlir_generated/base_op.h"
namespace tensorflow {
#define GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(tf_op, mlir_type, tf_data_type, \
data_type) \
GENERATE_AND_REGISTER_UNARY_KERNEL(tf_op, GPU, mlir_type, tf_data_type, \
data_type)
#define GENERATE_UNARY_GPU_KERNEL(tf_op, mlir_type, tf_data_type, data_type) \
GENERATE_UNARY_KERNEL(tf_op, GPU, mlir_type, tf_data_type, data_type)
#define GENERATE_UNARY_GPU_KERNEL2(tf_op, mlir_type, mlir_output_type, \
tf_data_type, result_data_type, \
input_data_type) \
GENERATE_UNARY_KERNEL2(tf_op, GPU, mlir_type, mlir_output_type, \
tf_data_type, result_data_type, input_data_type)
#define REGISTER_ALIASED_GPU_KERNEL(tf_op, mlir_op, mlir_type, \
mlir_output_type, data_type) \
REGISTER_ALIASED_KERNEL(tf_op, mlir_op, GPU, mlir_type, mlir_output_type, \
data_type)
#define REGISTER_GPU_KERNEL(tf_op, mlir_type, mlir_output_type, data_type) \
REGISTER_KERNEL(tf_op, GPU, mlir_type, mlir_output_type, data_type)
#define REGISTER_COMPLEX_GPU_KERNEL(tf_op, mlir_type, mlir_output_type, \
data_type, input_data_type) \
REGISTER_COMPLEX_KERNEL(tf_op, GPU, mlir_type, mlir_output_type, data_type, \
input_data_type)
#define REGISTER_GPU_KERNEL_NO_TYPE_CONSTRAINT(tf_op, mlir_type, \
mlir_output_type) \
REGISTER_KERNEL_NO_TYPE_CONSTRAINT(tf_op, GPU, mlir_type, mlir_output_type)
#define GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(tf_op, mlir_type, \
tf_data_type, data_type) \
GENERATE_AND_REGISTER_BINARY_KERNEL(tf_op, GPU, mlir_type, tf_data_type, \
data_type)
#define GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2( \
tf_op, mlir_type, mlir_output_type, tf_data_type, result_data_type, \
input_data_type) \
GENERATE_AND_REGISTER_BINARY_KERNEL2(tf_op, GPU, mlir_type, \
mlir_output_type, tf_data_type, \
result_data_type, input_data_type)
#define GENERATE_BINARY_GPU_KERNEL(tf_op, mlir_type, tf_data_type, data_type) \
GENERATE_BINARY_KERNEL(tf_op, GPU, mlir_type, tf_data_type, data_type)
#define GENERATE_BINARY_GPU_KERNEL2(tf_op, mlir_type, mlir_output_type, \
tf_data_type, result_data_type, \
input_data_type) \
GENERATE_BINARY_KERNEL2(tf_op, GPU, mlir_type, mlir_output_type, \
tf_data_type, result_data_type, input_data_type)
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_BASE_GPU_OP_H_

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
#include "tensorflow/core/kernels/mlir_generated/base_op.h"
#include "tensorflow/core/framework/allocation_description.pb.h"
#include "tensorflow/core/framework/tensor.h"

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_GPU_OPS_BASE_H_
#define TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_GPU_OPS_BASE_H_
#ifndef TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_BASE_OP_H_
#define TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_BASE_OP_H_
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
@ -97,9 +97,9 @@ Tensor ConvertDescriptorToTensor(
template <DataType TfDataType, typename OutputDataType, typename Kernel,
typename InputDataType = OutputDataType>
class MlirUnrankedOp : public OpKernel {
class MlirOp : public OpKernel {
public:
explicit MlirUnrankedOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
explicit MlirOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
llvm::SmallVector<::UnrankedMemRefType<InputDataType>, 2> input_descs;
@ -142,114 +142,118 @@ class MlirUnrankedOp : public OpKernel {
}
};
#define MLIR_FUNCTION(tf_op, mlir_type, mlir_output_type) \
_mlir_ciface_##tf_op##_##mlir_type##_##mlir_output_type
#define MLIR_FUNCTION(tf_op, platform, mlir_type, mlir_output_type) \
_mlir_ciface_##tf_op##_##platform##_##mlir_type##_##mlir_output_type
#define REGISTER_ALIASED_KERNEL(tf_op, mlir_op, mlir_type, mlir_output_type, \
data_type) \
#define REGISTER_ALIASED_KERNEL(tf_op, mlir_op, platform, mlir_type, \
mlir_output_type, data_type) \
REGISTER_KERNEL_BUILDER( \
Name(#tf_op).Device(DEVICE_GPU).TypeConstraint<data_type>("T"), \
MlirUnranked##mlir_op##mlir_type##mlir_output_type##Op);
Name(#tf_op).Device(DEVICE_##platform).TypeConstraint<data_type>("T"), \
Mlir##mlir_op##platform##mlir_type##mlir_output_type##Op);
#define REGISTER_KERNEL(tf_op, mlir_type, mlir_output_type, data_type) \
REGISTER_ALIASED_KERNEL(tf_op, tf_op, mlir_type, mlir_output_type, data_type)
#define REGISTER_KERNEL(tf_op, platform, mlir_type, mlir_output_type, \
data_type) \
REGISTER_ALIASED_KERNEL(tf_op, tf_op, platform, mlir_type, mlir_output_type, \
data_type)
#define REGISTER_COMPLEX_KERNEL(tf_op, mlir_type, mlir_output_type, data_type, \
input_data_type) \
REGISTER_KERNEL_BUILDER( \
Name(#tf_op) \
.Device(DEVICE_GPU) \
.TypeConstraint<input_data_type>("T") \
.TypeConstraint<data_type>("Tout"), \
MlirUnranked##tf_op##mlir_type##mlir_output_type##Op);
#define REGISTER_COMPLEX_KERNEL(tf_op, platform, mlir_type, mlir_output_type, \
data_type, input_data_type) \
REGISTER_KERNEL_BUILDER( \
Name(#tf_op) \
.Device(DEVICE_##platform) \
.TypeConstraint<input_data_type>("T") \
.TypeConstraint<data_type>("Tout"), \
Mlir##tf_op##platform##mlir_type##mlir_output_type##Op);
#define REGISTER_KERNEL_NO_TYPE_CONSTRAINT(tf_op, mlir_type, mlir_output_type) \
REGISTER_KERNEL_BUILDER( \
Name(#tf_op).Device(DEVICE_GPU), \
MlirUnranked##tf_op##mlir_type##mlir_output_type##Op);
#define REGISTER_KERNEL_NO_TYPE_CONSTRAINT(tf_op, platform, mlir_type, \
mlir_output_type) \
REGISTER_KERNEL_BUILDER( \
Name(#tf_op).Device(DEVICE_##platform), \
Mlir##tf_op##platform##mlir_type##mlir_output_type##Op);
// OpKernel with Compute function that converts input tensors to unranked
// memref descriptors and calls mlir-generated unranked kernel. The outputs
// are converted back to tensors using MlirTensorBuffer to take ownership of
// pre-allocated memory.
#define GENERATE_AND_REGISTER_BINARY_KERNEL(tf_op, mlir_type, tf_data_type, \
data_type) \
GENERATE_BINARY_KERNEL(tf_op, mlir_type, tf_data_type, data_type) \
REGISTER_KERNEL(tf_op, mlir_type, mlir_type, data_type)
#define GENERATE_AND_REGISTER_BINARY_KERNEL(tf_op, platform, mlir_type, \
tf_data_type, data_type) \
GENERATE_BINARY_KERNEL(tf_op, platform, mlir_type, tf_data_type, data_type) \
REGISTER_KERNEL(tf_op, platform, mlir_type, mlir_type, data_type)
#define GENERATE_AND_REGISTER_BINARY_KERNEL2( \
tf_op, mlir_type, mlir_output_type, tf_data_type, result_data_type, \
input_data_type) \
GENERATE_BINARY_KERNEL2(tf_op, mlir_type, mlir_output_type, tf_data_type, \
result_data_type, input_data_type) \
REGISTER_KERNEL(tf_op, mlir_type, mlir_output_type, input_data_type)
#define GENERATE_AND_REGISTER_BINARY_KERNEL2( \
tf_op, platform, mlir_type, mlir_output_type, tf_data_type, \
result_data_type, input_data_type) \
GENERATE_BINARY_KERNEL2(tf_op, platform, mlir_type, mlir_output_type, \
tf_data_type, result_data_type, input_data_type) \
REGISTER_KERNEL(tf_op, platform, mlir_type, mlir_output_type, input_data_type)
#define GENERATE_BINARY_KERNEL(tf_op, mlir_type, tf_data_type, data_type) \
GENERATE_BINARY_KERNEL2(tf_op, mlir_type, mlir_type, tf_data_type, \
#define GENERATE_BINARY_KERNEL(tf_op, platform, mlir_type, tf_data_type, \
data_type) \
GENERATE_BINARY_KERNEL2(tf_op, platform, mlir_type, mlir_type, tf_data_type, \
data_type, data_type)
#define GENERATE_BINARY_KERNEL2(tf_op, mlir_type, mlir_output_type, \
tf_data_type, result_data_type, \
input_data_type) \
extern "C" UntypedUnrankedMemRefType MLIR_FUNCTION(tf_op, mlir_type, \
mlir_output_type)( \
tensorflow::OpKernelContext * ctx, \
const ::UnrankedMemRefType<input_data_type>* arg1, \
const ::UnrankedMemRefType<input_data_type>* arg2); \
\
namespace { \
class MlirUnranked##tf_op##mlir_type##mlir_output_type##Op \
: public MlirUnrankedOp< \
tf_data_type, result_data_type, \
MlirUnranked##tf_op##mlir_type##mlir_output_type##Op, \
input_data_type> { \
public: \
using MlirUnrankedOp::MlirUnrankedOp; \
\
static ::UnrankedMemRefType<result_data_type> Invoke( \
OpKernelContext* ctx, \
llvm::ArrayRef<::UnrankedMemRefType<input_data_type>> args) { \
return ConvertToTyped<result_data_type>(MLIR_FUNCTION( \
tf_op, mlir_type, mlir_output_type)(ctx, &args[0], &args[1])); \
} \
}; \
#define GENERATE_BINARY_KERNEL2(tf_op, platform, mlir_type, mlir_output_type, \
tf_data_type, result_data_type, \
input_data_type) \
extern "C" UntypedUnrankedMemRefType MLIR_FUNCTION( \
tf_op, platform, mlir_type, mlir_output_type)( \
tensorflow::OpKernelContext * ctx, \
const ::UnrankedMemRefType<input_data_type>* arg1, \
const ::UnrankedMemRefType<input_data_type>* arg2); \
\
namespace { \
class Mlir##tf_op##platform##mlir_type##mlir_output_type##Op \
: public MlirOp<tf_data_type, result_data_type, \
Mlir##tf_op##platform##mlir_type##mlir_output_type##Op, \
input_data_type> { \
public: \
using MlirOp::MlirOp; \
\
static ::UnrankedMemRefType<result_data_type> Invoke( \
OpKernelContext* ctx, \
llvm::ArrayRef<::UnrankedMemRefType<input_data_type>> args) { \
return ConvertToTyped<result_data_type>( \
MLIR_FUNCTION(tf_op, platform, mlir_type, mlir_output_type)( \
ctx, &args[0], &args[1])); \
} \
}; \
}
#define GENERATE_AND_REGISTER_UNARY_KERNEL(tf_op, mlir_type, tf_data_type, \
data_type) \
GENERATE_UNARY_KERNEL(tf_op, mlir_type, tf_data_type, data_type) \
REGISTER_KERNEL(tf_op, mlir_type, mlir_type, data_type)
#define GENERATE_AND_REGISTER_UNARY_KERNEL(tf_op, platform, mlir_type, \
tf_data_type, data_type) \
GENERATE_UNARY_KERNEL(tf_op, platform, mlir_type, tf_data_type, data_type) \
REGISTER_KERNEL(tf_op, platform, mlir_type, mlir_type, data_type)
#define GENERATE_UNARY_KERNEL(tf_op, mlir_type, tf_data_type, data_type) \
GENERATE_UNARY_KERNEL2(tf_op, mlir_type, mlir_type, tf_data_type, data_type, \
data_type)
#define GENERATE_UNARY_KERNEL(tf_op, platform, mlir_type, tf_data_type, \
data_type) \
GENERATE_UNARY_KERNEL2(tf_op, platform, mlir_type, mlir_type, tf_data_type, \
data_type, data_type)
#define GENERATE_UNARY_KERNEL2(tf_op, mlir_type, mlir_output_type, \
tf_data_type, result_data_type, \
input_data_type) \
extern "C" UntypedUnrankedMemRefType MLIR_FUNCTION(tf_op, mlir_type, \
mlir_output_type)( \
tensorflow::OpKernelContext * ctx, \
const ::UnrankedMemRefType<input_data_type>* arg); \
\
namespace { \
class MlirUnranked##tf_op##mlir_type##mlir_output_type##Op \
: public MlirUnrankedOp< \
tf_data_type, result_data_type, \
MlirUnranked##tf_op##mlir_type##mlir_output_type##Op, \
input_data_type> { \
public: \
using MlirUnrankedOp::MlirUnrankedOp; \
\
static ::UnrankedMemRefType<result_data_type> Invoke( \
OpKernelContext* ctx, \
llvm::ArrayRef<::UnrankedMemRefType<input_data_type>> args) { \
return ConvertToTyped<result_data_type>( \
MLIR_FUNCTION(tf_op, mlir_type, mlir_output_type)(ctx, &args[0])); \
} \
}; \
#define GENERATE_UNARY_KERNEL2(tf_op, platform, mlir_type, mlir_output_type, \
tf_data_type, result_data_type, \
input_data_type) \
extern "C" UntypedUnrankedMemRefType MLIR_FUNCTION( \
tf_op, platform, mlir_type, mlir_output_type)( \
tensorflow::OpKernelContext * ctx, \
const ::UnrankedMemRefType<input_data_type>* arg); \
\
namespace { \
class Mlir##tf_op##platform##mlir_type##mlir_output_type##Op \
: public MlirOp<tf_data_type, result_data_type, \
Mlir##tf_op##platform##mlir_type##mlir_output_type##Op, \
input_data_type> { \
public: \
using MlirOp::MlirOp; \
\
static ::UnrankedMemRefType<result_data_type> Invoke( \
OpKernelContext* ctx, \
llvm::ArrayRef<::UnrankedMemRefType<input_data_type>> args) { \
return ConvertToTyped<result_data_type>(MLIR_FUNCTION( \
tf_op, platform, mlir_type, mlir_output_type)(ctx, &args[0])); \
} \
}; \
}
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_GPU_OPS_BASE_H_
#endif // TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_BASE_OP_H_

View File

@ -49,9 +49,11 @@ def _gen_mlir_op_impl(ctx):
inputs = [ctx.file.template],
outputs = [ctx.outputs.out],
command = (
(("cat %s | sed 's/_elem_type/_%s/g' | sed 's/elem_type/%s/g' | " +
"sed 's/_output_type/_%s/g' | sed 's/output_type/%s/g' > %s")) % (
(("cat %s | sed 's/platform/%s/g' | sed 's/_elem_type/_%s/g' | " +
"sed 's/elem_type/%s/g' | " + "sed 's/_output_type/_%s/g' | " +
"sed 's/output_type/%s/g' > %s")) % (
ctx.file.template.path,
ctx.attr.platform.upper(),
ctx.attr.type,
mlir_type,
ctx.attr.output_type,
@ -68,22 +70,26 @@ _gen_mlir_op_rule = rule(
"template": attr.label(mandatory = True, allow_single_file = True),
"type": attr.string(mandatory = True),
"output_type": attr.string(mandatory = True),
"platform": attr.string(mandatory = True),
"out": attr.output(mandatory = True),
},
)
def _gen_mlir_op(name, type, output_type):
def _gen_mlir_op(name, type, platform, output_type):
_gen_mlir_op_rule(
name = "generate_{name}_{type}_{output_type}_mlir".format(
name = "generate_{name}_{platform}_{type}_{output_type}_mlir".format(
name = name,
platform = platform,
type = type,
output_type = output_type,
),
template = "op_definitions/{name}.mlir.tmpl".format(name = name),
platform = platform,
type = type,
output_type = output_type,
out = "{name}_{type}_{output_type}.mlir".format(
out = "{name}_{platform}_{type}_{output_type}.mlir".format(
name = name,
platform = platform,
type = type,
output_type = output_type,
),
@ -177,6 +183,7 @@ def gen_kernel_library(
tile_size,
output_types = None,
tags = [],
platform = "gpu",
unroll_factors = None,
extra_args = []):
""" Generate a library with kernels for a specific tensorflow op.
@ -190,6 +197,7 @@ def gen_kernel_library(
entry in output_types. By default, output_types = types is
assumed.
tags: The tags which should be added to the library.
platform: Platform for which to compile, i.e. "cpu" or "gpu"
unroll_factors: The unrolling specification, e.g. "4,4"
extra_args: Extra arguments to pass to the generator tool.
"""
@ -200,17 +208,20 @@ def gen_kernel_library(
for (type, output_type) in zip(types, output_types):
_gen_mlir_op(
name = name,
platform = platform,
type = type,
output_type = output_type,
)
_gen_kernel_fatbin_rule(
name = "{name}_{type}_{output_type}_kernel_generator".format(
name = "{name}_{platform}_{type}_{output_type}_kernel_generator".format(
name = name,
platform = platform,
type = type,
output_type = output_type,
),
mlir_op = "{name}_{type}_{output_type}.mlir".format(
mlir_op = "{name}_{platform}_{type}_{output_type}.mlir".format(
name = name,
platform = platform,
type = type,
output_type = output_type,
),
@ -223,8 +234,9 @@ def gen_kernel_library(
# We have to use a sh_test instead of build_test because it doesn't properly find the dependent targets.
native.sh_test(
name = "{name}_{type}_{output_type}_gen_test".format(
name = "{name}_{platform}_{type}_{output_type}_gen_test".format(
name = name,
platform = platform,
type = type,
output_type = output_type,
),
@ -232,16 +244,18 @@ def gen_kernel_library(
tags = ["no_rocm"],
args = [
"$(location //tensorflow/compiler/mlir/tools/kernel_gen:tf_to_kernel)",
"$(location {name}_{type}_{output_type}.mlir)".format(
"$(location {name}_{platform}_{type}_{output_type}.mlir)".format(
name = name,
platform = platform,
type = type,
output_type = output_type,
),
],
size = "medium",
data = [
":{name}_{type}_{output_type}.mlir".format(
":{name}_{platform}_{type}_{output_type}.mlir".format(
name = name,
platform = platform,
type = type,
output_type = output_type,
),
@ -252,8 +266,9 @@ def gen_kernel_library(
native.cc_library(
name = name + "_kernels",
compatible_with = get_compatible_with_cloud(),
deps = if_gpu_is_configured([":{name}_{type}_{output_type}_kernel_generator".format(
deps = if_gpu_is_configured([":{name}_{platform}_{type}_{output_type}_kernel_generator".format(
name = name,
platform = platform,
type = type,
output_type = output_type,
) for (type, output_type) in zip(types, output_types)]),

View File

@ -14,14 +14,14 @@ limitations under the License.
==============================================================================*/
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
namespace tensorflow {
GENERATE_AND_REGISTER_UNARY_KERNEL(Abs, f16, DT_HALF, Eigen::half);
GENERATE_AND_REGISTER_UNARY_KERNEL(Abs, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_UNARY_KERNEL(Abs, f64, DT_DOUBLE, double);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Abs, f16, DT_HALF, Eigen::half);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Abs, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Abs, f64, DT_DOUBLE, double);
// TODO(b/25387198): Add an int32 kernel.
GENERATE_AND_REGISTER_UNARY_KERNEL(Abs, i64, DT_INT64, int64);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Abs, i64, DT_INT64, int64);
} // namespace tensorflow

View File

@ -14,11 +14,11 @@ limitations under the License.
==============================================================================*/
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
namespace tensorflow {
GENERATE_AND_REGISTER_UNARY_KERNEL(Acos, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_UNARY_KERNEL(Acos, f64, DT_DOUBLE, double);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Acos, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Acos, f64, DT_DOUBLE, double);
} // namespace tensorflow

View File

@ -14,11 +14,11 @@ limitations under the License.
==============================================================================*/
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
namespace tensorflow {
GENERATE_AND_REGISTER_UNARY_KERNEL(Acosh, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_UNARY_KERNEL(Acosh, f64, DT_DOUBLE, double);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Acosh, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Acosh, f64, DT_DOUBLE, double);
} // namespace tensorflow

View File

@ -13,13 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
namespace tensorflow {
GENERATE_AND_REGISTER_BINARY_KERNEL(AddV2, f16, DT_HALF, Eigen::half);
GENERATE_AND_REGISTER_BINARY_KERNEL(AddV2, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_BINARY_KERNEL(AddV2, f64, DT_DOUBLE, double);
GENERATE_AND_REGISTER_BINARY_KERNEL(AddV2, i64, DT_INT64, int64);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(AddV2, f16, DT_HALF, Eigen::half);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(AddV2, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(AddV2, f64, DT_DOUBLE, double);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(AddV2, i64, DT_INT64, int64);
} // namespace tensorflow

View File

@ -0,0 +1,30 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <complex>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
namespace tensorflow {
GENERATE_UNARY_GPU_KERNEL2(Angle, c64, f32, DT_FLOAT, float,
std::complex<float>);
REGISTER_COMPLEX_GPU_KERNEL(Angle, c64, f32, float, std::complex<float>);
GENERATE_UNARY_GPU_KERNEL2(Angle, c128, f64, DT_DOUBLE, double,
std::complex<double>);
REGISTER_COMPLEX_GPU_KERNEL(Angle, c128, f64, double, std::complex<double>);
} // namespace tensorflow

View File

@ -14,11 +14,11 @@ limitations under the License.
==============================================================================*/
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
namespace tensorflow {
GENERATE_AND_REGISTER_UNARY_KERNEL(Asin, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_UNARY_KERNEL(Asin, f64, DT_DOUBLE, double);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Asin, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Asin, f64, DT_DOUBLE, double);
} // namespace tensorflow

View File

@ -14,11 +14,11 @@ limitations under the License.
==============================================================================*/
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
namespace tensorflow {
GENERATE_AND_REGISTER_UNARY_KERNEL(Asinh, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_UNARY_KERNEL(Asinh, f64, DT_DOUBLE, double);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Asinh, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Asinh, f64, DT_DOUBLE, double);
} // namespace tensorflow

View File

@ -14,11 +14,11 @@ limitations under the License.
==============================================================================*/
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
namespace tensorflow {
GENERATE_AND_REGISTER_UNARY_KERNEL(Atan, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_UNARY_KERNEL(Atan, f64, DT_DOUBLE, double);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Atan, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Atan, f64, DT_DOUBLE, double);
} // namespace tensorflow

View File

@ -13,11 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
namespace tensorflow {
GENERATE_AND_REGISTER_BINARY_KERNEL(Atan2, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_BINARY_KERNEL(Atan2, f64, DT_DOUBLE, double);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Atan2, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Atan2, f64, DT_DOUBLE, double);
} // namespace tensorflow

View File

@ -14,11 +14,11 @@ limitations under the License.
==============================================================================*/
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
namespace tensorflow {
GENERATE_AND_REGISTER_UNARY_KERNEL(Atanh, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_UNARY_KERNEL(Atanh, f64, DT_DOUBLE, double);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Atanh, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Atanh, f64, DT_DOUBLE, double);
} // namespace tensorflow

View File

@ -13,19 +13,19 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
namespace tensorflow {
GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseAnd, i8, DT_INT8, int8);
GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseAnd, i16, DT_INT16, int16);
GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseAnd, i32, DT_INT32, int32);
GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseAnd, i64, DT_INT64, int64);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseAnd, i8, DT_INT8, int8);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseAnd, i16, DT_INT16, int16);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseAnd, i32, DT_INT32, int32);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseAnd, i64, DT_INT64, int64);
// TODO(b/172804967): Enable once fixed.
// GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseAnd, ui8, DT_UINT8, uint8);
// GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseAnd, ui16, DT_UINT16, uint16);
// GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseAnd, ui32, DT_UINT32, uint32);
// GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseAnd, ui64, DT_UINT64, uint64);
// GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseAnd, ui8, DT_UINT8, uint8);
// GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseAnd, ui16, DT_UINT16, uint16);
// GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseAnd, ui32, DT_UINT32, uint32);
// GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseAnd, ui64, DT_UINT64, uint64);
} // namespace tensorflow

View File

@ -13,19 +13,19 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
namespace tensorflow {
GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseOr, i8, DT_INT8, int8);
GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseOr, i16, DT_INT16, int16);
GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseOr, i32, DT_INT32, int32);
GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseOr, i64, DT_INT64, int64);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseOr, i8, DT_INT8, int8);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseOr, i16, DT_INT16, int16);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseOr, i32, DT_INT32, int32);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseOr, i64, DT_INT64, int64);
// TODO(b/172804967): Enable once fixed.
// GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseOr, ui8, DT_UINT8, uint8);
// GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseOr, ui16, DT_UINT16, uint16);
// GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseOr, ui32, DT_UINT32, uint32);
// GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseOr, ui64, DT_UINT64, uint64);
// GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseOr, ui8, DT_UINT8, uint8);
// GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseOr, ui16, DT_UINT16, uint16);
// GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseOr, ui32, DT_UINT32, uint32);
// GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseOr, ui64, DT_UINT64, uint64);
} // namespace tensorflow

View File

@ -13,19 +13,19 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
namespace tensorflow {
GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseXor, i8, DT_INT8, int8);
GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseXor, i16, DT_INT16, int16);
GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseXor, i32, DT_INT32, int32);
GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseXor, i64, DT_INT64, int64);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseXor, i8, DT_INT8, int8);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseXor, i16, DT_INT16, int16);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseXor, i32, DT_INT32, int32);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseXor, i64, DT_INT64, int64);
// TODO(b/172804967): Enable once fixed.
// GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseXor, ui8, DT_UINT8, uint8);
// GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseXor, ui16, DT_UINT16, uint16);
// GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseXor, ui32, DT_UINT32, uint32);
// GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseXor, ui64, DT_UINT64, uint64);
// GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseXor, ui8, DT_UINT8, uint8);
// GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseXor, ui16, DT_UINT16, uint16);
// GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseXor, ui32, DT_UINT32, uint32);
// GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseXor, ui64, DT_UINT64, uint64);
} // namespace tensorflow

View File

@ -14,12 +14,12 @@ limitations under the License.
==============================================================================*/
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
namespace tensorflow {
GENERATE_AND_REGISTER_UNARY_KERNEL(Ceil, f16, DT_HALF, Eigen::half);
GENERATE_AND_REGISTER_UNARY_KERNEL(Ceil, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_UNARY_KERNEL(Ceil, f64, DT_DOUBLE, double);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Ceil, f16, DT_HALF, Eigen::half);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Ceil, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Ceil, f64, DT_DOUBLE, double);
} // namespace tensorflow

View File

@ -16,15 +16,15 @@ limitations under the License.
#include <complex>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
namespace tensorflow {
GENERATE_BINARY_KERNEL2(Complex, f32, c64, DT_COMPLEX64, std::complex<float>,
float);
REGISTER_COMPLEX_KERNEL(Complex, f32, c64, std::complex<float>, float);
GENERATE_BINARY_KERNEL2(Complex, f64, c128, DT_COMPLEX128, std::complex<double>,
double);
REGISTER_COMPLEX_KERNEL(Complex, f64, c128, std::complex<double>, double);
GENERATE_BINARY_GPU_KERNEL2(Complex, f32, c64, DT_COMPLEX64,
std::complex<float>, float);
REGISTER_COMPLEX_GPU_KERNEL(Complex, f32, c64, std::complex<float>, float);
GENERATE_BINARY_GPU_KERNEL2(Complex, f64, c128, DT_COMPLEX128,
std::complex<double>, double);
REGISTER_COMPLEX_GPU_KERNEL(Complex, f64, c128, std::complex<double>, double);
} // namespace tensorflow

View File

@ -16,15 +16,16 @@ limitations under the License.
#include <complex>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
namespace tensorflow {
GENERATE_UNARY_KERNEL2(ComplexAbs, c64, f32, DT_FLOAT, float,
std::complex<float>);
REGISTER_COMPLEX_KERNEL(ComplexAbs, c64, f32, float, std::complex<float>);
GENERATE_UNARY_KERNEL2(ComplexAbs, c128, f64, DT_DOUBLE, double,
std::complex<double>);
REGISTER_COMPLEX_KERNEL(ComplexAbs, c128, f64, double, std::complex<double>);
GENERATE_UNARY_GPU_KERNEL2(ComplexAbs, c64, f32, DT_FLOAT, float,
std::complex<float>);
REGISTER_COMPLEX_GPU_KERNEL(ComplexAbs, c64, f32, float, std::complex<float>);
GENERATE_UNARY_GPU_KERNEL2(ComplexAbs, c128, f64, DT_DOUBLE, double,
std::complex<double>);
REGISTER_COMPLEX_GPU_KERNEL(ComplexAbs, c128, f64, double,
std::complex<double>);
} // namespace tensorflow

View File

@ -16,13 +16,13 @@ limitations under the License.
#include <complex>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
namespace tensorflow {
GENERATE_AND_REGISTER_UNARY_KERNEL(Conj, c64, DT_COMPLEX64,
std::complex<float>);
GENERATE_AND_REGISTER_UNARY_KERNEL(Conj, c128, DT_COMPLEX128,
std::complex<double>);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Conj, c64, DT_COMPLEX64,
std::complex<float>);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Conj, c128, DT_COMPLEX128,
std::complex<double>);
} // namespace tensorflow

View File

@ -14,12 +14,12 @@ limitations under the License.
==============================================================================*/
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
namespace tensorflow {
GENERATE_AND_REGISTER_UNARY_KERNEL(Cos, f16, DT_HALF, Eigen::half);
GENERATE_AND_REGISTER_UNARY_KERNEL(Cos, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_UNARY_KERNEL(Cos, f64, DT_DOUBLE, double);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Cos, f16, DT_HALF, Eigen::half);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Cos, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Cos, f64, DT_DOUBLE, double);
} // namespace tensorflow

View File

@ -14,11 +14,11 @@ limitations under the License.
==============================================================================*/
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
namespace tensorflow {
GENERATE_AND_REGISTER_UNARY_KERNEL(Cosh, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_UNARY_KERNEL(Cosh, f64, DT_DOUBLE, double);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Cosh, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Cosh, f64, DT_DOUBLE, double);
} // namespace tensorflow

View File

@ -0,0 +1,25 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
namespace tensorflow {
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Digamma, f16, DT_HALF, Eigen::half);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Digamma, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Digamma, f64, DT_DOUBLE, double);
} // namespace tensorflow

View File

@ -14,21 +14,21 @@ limitations under the License.
==============================================================================*/
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
namespace tensorflow {
GENERATE_AND_REGISTER_BINARY_KERNEL(Div, f16, DT_HALF, Eigen::half);
GENERATE_AND_REGISTER_BINARY_KERNEL(Div, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_BINARY_KERNEL(Div, f64, DT_DOUBLE, double);
GENERATE_AND_REGISTER_BINARY_KERNEL(Div, i16, DT_INT16, int16);
GENERATE_AND_REGISTER_BINARY_KERNEL(Div, i64, DT_INT64, int64);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Div, f16, DT_HALF, Eigen::half);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Div, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Div, f64, DT_DOUBLE, double);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Div, i16, DT_INT16, int16);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Div, i64, DT_INT64, int64);
REGISTER_ALIASED_KERNEL(RealDiv, Div, f16, f16, Eigen::half)
REGISTER_ALIASED_KERNEL(RealDiv, Div, f32, f32, float)
REGISTER_ALIASED_KERNEL(RealDiv, Div, f64, f64, double)
REGISTER_ALIASED_GPU_KERNEL(RealDiv, Div, f16, f16, Eigen::half)
REGISTER_ALIASED_GPU_KERNEL(RealDiv, Div, f32, f32, float)
REGISTER_ALIASED_GPU_KERNEL(RealDiv, Div, f64, f64, double)
REGISTER_ALIASED_KERNEL(TruncateDiv, Div, i16, i16, int16)
REGISTER_ALIASED_KERNEL(TruncateDiv, Div, i64, i64, int64)
REGISTER_ALIASED_GPU_KERNEL(TruncateDiv, Div, i16, i16, int16)
REGISTER_ALIASED_GPU_KERNEL(TruncateDiv, Div, i64, i64, int64)
} // namespace tensorflow

View File

@ -16,18 +16,18 @@ limitations under the License.
#include <complex>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
namespace tensorflow {
GENERATE_AND_REGISTER_BINARY_KERNEL2(Equal, f16, i1, DT_BOOL, bool,
Eigen::half);
GENERATE_AND_REGISTER_BINARY_KERNEL2(Equal, f32, i1, DT_BOOL, bool, float);
GENERATE_AND_REGISTER_BINARY_KERNEL2(Equal, f64, i1, DT_BOOL, bool, double);
GENERATE_AND_REGISTER_BINARY_KERNEL2(Equal, i1, i1, DT_BOOL, bool, bool);
GENERATE_AND_REGISTER_BINARY_KERNEL2(Equal, i8, i1, DT_BOOL, bool, int8);
GENERATE_AND_REGISTER_BINARY_KERNEL2(Equal, i16, i1, DT_BOOL, bool, int16);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Equal, f16, i1, DT_BOOL, bool,
Eigen::half);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Equal, f32, i1, DT_BOOL, bool, float);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Equal, f64, i1, DT_BOOL, bool, double);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Equal, i1, i1, DT_BOOL, bool, bool);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Equal, i8, i1, DT_BOOL, bool, int8);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Equal, i16, i1, DT_BOOL, bool, int16);
// TODO(b/25387198): We cannot use a regular GPU kernel for int32.
GENERATE_AND_REGISTER_BINARY_KERNEL2(Equal, i64, i1, DT_BOOL, bool, int64);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Equal, i64, i1, DT_BOOL, bool, int64);
} // namespace tensorflow

View File

@ -14,12 +14,12 @@ limitations under the License.
==============================================================================*/
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
namespace tensorflow {
GENERATE_AND_REGISTER_UNARY_KERNEL(Erf, f16, DT_HALF, Eigen::half);
GENERATE_AND_REGISTER_UNARY_KERNEL(Erf, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_UNARY_KERNEL(Erf, f64, DT_DOUBLE, double);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Erf, f16, DT_HALF, Eigen::half);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Erf, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Erf, f64, DT_DOUBLE, double);
} // namespace tensorflow

View File

@ -14,12 +14,12 @@ limitations under the License.
==============================================================================*/
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
namespace tensorflow {
GENERATE_AND_REGISTER_UNARY_KERNEL(Erfc, f64, DT_DOUBLE, double);
GENERATE_AND_REGISTER_UNARY_KERNEL(Erfc, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_UNARY_KERNEL(Erfc, f16, DT_HALF, Eigen::half);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Erfc, f64, DT_DOUBLE, double);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Erfc, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Erfc, f16, DT_HALF, Eigen::half);
} // namespace tensorflow

View File

@ -14,12 +14,12 @@ limitations under the License.
==============================================================================*/
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
namespace tensorflow {
GENERATE_AND_REGISTER_UNARY_KERNEL(Exp, f16, DT_HALF, Eigen::half);
GENERATE_AND_REGISTER_UNARY_KERNEL(Exp, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_UNARY_KERNEL(Exp, f64, DT_DOUBLE, double);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Exp, f16, DT_HALF, Eigen::half);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Exp, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Exp, f64, DT_DOUBLE, double);
} // namespace tensorflow

View File

@ -14,12 +14,12 @@ limitations under the License.
==============================================================================*/
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
namespace tensorflow {
GENERATE_AND_REGISTER_UNARY_KERNEL(Expm1, f16, DT_HALF, Eigen::half);
GENERATE_AND_REGISTER_UNARY_KERNEL(Expm1, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_UNARY_KERNEL(Expm1, f64, DT_DOUBLE, double);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Expm1, f16, DT_HALF, Eigen::half);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Expm1, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Expm1, f64, DT_DOUBLE, double);
} // namespace tensorflow

View File

@ -14,12 +14,12 @@ limitations under the License.
==============================================================================*/
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
namespace tensorflow {
GENERATE_AND_REGISTER_UNARY_KERNEL(Floor, f16, DT_HALF, Eigen::half);
GENERATE_AND_REGISTER_UNARY_KERNEL(Floor, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_UNARY_KERNEL(Floor, f64, DT_DOUBLE, double);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Floor, f16, DT_HALF, Eigen::half);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Floor, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Floor, f64, DT_DOUBLE, double);
} // namespace tensorflow

View File

@ -13,12 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
namespace tensorflow {
GENERATE_AND_REGISTER_BINARY_KERNEL(FloorDiv, f16, DT_HALF, Eigen::half);
GENERATE_AND_REGISTER_BINARY_KERNEL(FloorDiv, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_BINARY_KERNEL(FloorDiv, f64, DT_DOUBLE, double);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(FloorDiv, f16, DT_HALF, Eigen::half);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(FloorDiv, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(FloorDiv, f64, DT_DOUBLE, double);
} // namespace tensorflow

View File

@ -16,17 +16,21 @@ limitations under the License.
#include <complex>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
namespace tensorflow {
GENERATE_AND_REGISTER_BINARY_KERNEL2(Greater, f16, i1, DT_BOOL, bool,
Eigen::half);
GENERATE_AND_REGISTER_BINARY_KERNEL2(Greater, f32, i1, DT_BOOL, bool, float);
GENERATE_AND_REGISTER_BINARY_KERNEL2(Greater, f64, i1, DT_BOOL, bool, double);
GENERATE_AND_REGISTER_BINARY_KERNEL2(Greater, i8, i1, DT_BOOL, bool, int8);
GENERATE_AND_REGISTER_BINARY_KERNEL2(Greater, i16, i1, DT_BOOL, bool, int16);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Greater, f16, i1, DT_BOOL, bool,
Eigen::half);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Greater, f32, i1, DT_BOOL, bool,
float);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Greater, f64, i1, DT_BOOL, bool,
double);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Greater, i8, i1, DT_BOOL, bool, int8);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Greater, i16, i1, DT_BOOL, bool,
int16);
// TODO(b/25387198): We cannot use a regular GPU kernel for int32.
GENERATE_AND_REGISTER_BINARY_KERNEL2(Greater, i64, i1, DT_BOOL, bool, int64);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Greater, i64, i1, DT_BOOL, bool,
int64);
} // namespace tensorflow

View File

@ -16,21 +16,22 @@ limitations under the License.
#include <complex>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
namespace tensorflow {
GENERATE_AND_REGISTER_BINARY_KERNEL2(GreaterEqual, f16, i1, DT_BOOL, bool,
Eigen::half);
GENERATE_AND_REGISTER_BINARY_KERNEL2(GreaterEqual, f32, i1, DT_BOOL, bool,
float);
GENERATE_AND_REGISTER_BINARY_KERNEL2(GreaterEqual, f64, i1, DT_BOOL, bool,
double);
GENERATE_AND_REGISTER_BINARY_KERNEL2(GreaterEqual, i8, i1, DT_BOOL, bool, int8);
GENERATE_AND_REGISTER_BINARY_KERNEL2(GreaterEqual, i16, i1, DT_BOOL, bool,
int16);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(GreaterEqual, f16, i1, DT_BOOL, bool,
Eigen::half);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(GreaterEqual, f32, i1, DT_BOOL, bool,
float);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(GreaterEqual, f64, i1, DT_BOOL, bool,
double);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(GreaterEqual, i8, i1, DT_BOOL, bool,
int8);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(GreaterEqual, i16, i1, DT_BOOL, bool,
int16);
// TODO(b/25387198): We cannot use a regular GPU kernel for int32.
GENERATE_AND_REGISTER_BINARY_KERNEL2(GreaterEqual, i64, i1, DT_BOOL, bool,
int64);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(GreaterEqual, i64, i1, DT_BOOL, bool,
int64);
} // namespace tensorflow

View File

@ -16,14 +16,15 @@ limitations under the License.
#include <complex>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
namespace tensorflow {
GENERATE_UNARY_KERNEL2(Imag, c64, f32, DT_FLOAT, float, std::complex<float>);
REGISTER_COMPLEX_KERNEL(Imag, c64, f32, float, std::complex<float>);
GENERATE_UNARY_KERNEL2(Imag, c128, f64, DT_DOUBLE, double,
std::complex<double>);
REGISTER_COMPLEX_KERNEL(Imag, c128, f64, double, std::complex<double>);
GENERATE_UNARY_GPU_KERNEL2(Imag, c64, f32, DT_FLOAT, float,
std::complex<float>);
REGISTER_COMPLEX_GPU_KERNEL(Imag, c64, f32, float, std::complex<float>);
GENERATE_UNARY_GPU_KERNEL2(Imag, c128, f64, DT_DOUBLE, double,
std::complex<double>);
REGISTER_COMPLEX_GPU_KERNEL(Imag, c128, f64, double, std::complex<double>);
} // namespace tensorflow

View File

@ -16,13 +16,13 @@ limitations under the License.
#include <complex>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
namespace tensorflow {
GENERATE_AND_REGISTER_UNARY_KERNEL(Invert, i8, DT_INT8, int8);
GENERATE_AND_REGISTER_UNARY_KERNEL(Invert, i16, DT_INT16, int16);
GENERATE_AND_REGISTER_UNARY_KERNEL(Invert, i32, DT_INT32, int32);
GENERATE_AND_REGISTER_UNARY_KERNEL(Invert, i64, DT_INT64, int64);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Invert, i8, DT_INT8, int8);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Invert, i16, DT_INT16, int16);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Invert, i32, DT_INT32, int32);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Invert, i64, DT_INT64, int64);
} // namespace tensorflow

View File

@ -16,15 +16,15 @@ limitations under the License.
#include <complex>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
namespace tensorflow {
GENERATE_UNARY_KERNEL2(IsFinite, f16, i1, DT_BOOL, bool, Eigen::half);
REGISTER_KERNEL(IsFinite, f16, i1, Eigen::half);
GENERATE_UNARY_KERNEL2(IsFinite, f32, i1, DT_BOOL, bool, float);
REGISTER_KERNEL(IsFinite, f32, i1, float);
GENERATE_UNARY_KERNEL2(IsFinite, f64, i1, DT_BOOL, bool, double);
REGISTER_KERNEL(IsFinite, f64, i1, double);
GENERATE_UNARY_GPU_KERNEL2(IsFinite, f16, i1, DT_BOOL, bool, Eigen::half);
REGISTER_GPU_KERNEL(IsFinite, f16, i1, Eigen::half);
GENERATE_UNARY_GPU_KERNEL2(IsFinite, f32, i1, DT_BOOL, bool, float);
REGISTER_GPU_KERNEL(IsFinite, f32, i1, float);
GENERATE_UNARY_GPU_KERNEL2(IsFinite, f64, i1, DT_BOOL, bool, double);
REGISTER_GPU_KERNEL(IsFinite, f64, i1, double);
} // namespace tensorflow

View File

@ -16,15 +16,15 @@ limitations under the License.
#include <complex>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
namespace tensorflow {
GENERATE_UNARY_KERNEL2(IsInf, f16, i1, DT_BOOL, bool, Eigen::half);
REGISTER_KERNEL(IsInf, f16, i1, Eigen::half);
GENERATE_UNARY_KERNEL2(IsInf, f32, i1, DT_BOOL, bool, float);
REGISTER_KERNEL(IsInf, f32, i1, float);
GENERATE_UNARY_KERNEL2(IsInf, f64, i1, DT_BOOL, bool, double);
REGISTER_KERNEL(IsInf, f64, i1, double);
GENERATE_UNARY_GPU_KERNEL2(IsInf, f16, i1, DT_BOOL, bool, Eigen::half);
REGISTER_GPU_KERNEL(IsInf, f16, i1, Eigen::half);
GENERATE_UNARY_GPU_KERNEL2(IsInf, f32, i1, DT_BOOL, bool, float);
REGISTER_GPU_KERNEL(IsInf, f32, i1, float);
GENERATE_UNARY_GPU_KERNEL2(IsInf, f64, i1, DT_BOOL, bool, double);
REGISTER_GPU_KERNEL(IsInf, f64, i1, double);
} // namespace tensorflow

View File

@ -16,15 +16,15 @@ limitations under the License.
#include <complex>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
namespace tensorflow {
GENERATE_UNARY_KERNEL2(IsNan, f16, i1, DT_BOOL, bool, Eigen::half);
REGISTER_KERNEL(IsNan, f16, i1, Eigen::half);
GENERATE_UNARY_KERNEL2(IsNan, f32, i1, DT_BOOL, bool, float);
REGISTER_KERNEL(IsNan, f32, i1, float);
GENERATE_UNARY_KERNEL2(IsNan, f64, i1, DT_BOOL, bool, double);
REGISTER_KERNEL(IsNan, f64, i1, double);
GENERATE_UNARY_GPU_KERNEL2(IsNan, f16, i1, DT_BOOL, bool, Eigen::half);
REGISTER_GPU_KERNEL(IsNan, f16, i1, Eigen::half);
GENERATE_UNARY_GPU_KERNEL2(IsNan, f32, i1, DT_BOOL, bool, float);
REGISTER_GPU_KERNEL(IsNan, f32, i1, float);
GENERATE_UNARY_GPU_KERNEL2(IsNan, f64, i1, DT_BOOL, bool, double);
REGISTER_GPU_KERNEL(IsNan, f64, i1, double);
} // namespace tensorflow

View File

@ -13,13 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
namespace tensorflow {
GENERATE_AND_REGISTER_BINARY_KERNEL(LeftShift, i8, DT_INT8, int8);
GENERATE_AND_REGISTER_BINARY_KERNEL(LeftShift, i16, DT_INT16, int16);
GENERATE_AND_REGISTER_BINARY_KERNEL(LeftShift, i32, DT_INT32, int32);
GENERATE_AND_REGISTER_BINARY_KERNEL(LeftShift, i64, DT_INT64, int64);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(LeftShift, i8, DT_INT8, int8);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(LeftShift, i16, DT_INT16, int16);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(LeftShift, i32, DT_INT32, int32);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(LeftShift, i64, DT_INT64, int64);
} // namespace tensorflow

View File

@ -16,16 +16,17 @@ limitations under the License.
#include <complex>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
namespace tensorflow {
GENERATE_AND_REGISTER_BINARY_KERNEL2(Less, f16, i1, DT_BOOL, bool, Eigen::half);
GENERATE_AND_REGISTER_BINARY_KERNEL2(Less, f32, i1, DT_BOOL, bool, float);
GENERATE_AND_REGISTER_BINARY_KERNEL2(Less, f64, i1, DT_BOOL, bool, double);
GENERATE_AND_REGISTER_BINARY_KERNEL2(Less, i8, i1, DT_BOOL, bool, int8);
GENERATE_AND_REGISTER_BINARY_KERNEL2(Less, i16, i1, DT_BOOL, bool, int16);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Less, f16, i1, DT_BOOL, bool,
Eigen::half);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Less, f32, i1, DT_BOOL, bool, float);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Less, f64, i1, DT_BOOL, bool, double);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Less, i8, i1, DT_BOOL, bool, int8);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Less, i16, i1, DT_BOOL, bool, int16);
// TODO(b/25387198): We cannot use a regular GPU kernel for int32.
GENERATE_AND_REGISTER_BINARY_KERNEL2(Less, i64, i1, DT_BOOL, bool, int64);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Less, i64, i1, DT_BOOL, bool, int64);
} // namespace tensorflow

View File

@ -16,17 +16,22 @@ limitations under the License.
#include <complex>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
namespace tensorflow {
GENERATE_AND_REGISTER_BINARY_KERNEL2(LessEqual, f16, i1, DT_BOOL, bool,
Eigen::half);
GENERATE_AND_REGISTER_BINARY_KERNEL2(LessEqual, f32, i1, DT_BOOL, bool, float);
GENERATE_AND_REGISTER_BINARY_KERNEL2(LessEqual, f64, i1, DT_BOOL, bool, double);
GENERATE_AND_REGISTER_BINARY_KERNEL2(LessEqual, i8, i1, DT_BOOL, bool, int8);
GENERATE_AND_REGISTER_BINARY_KERNEL2(LessEqual, i16, i1, DT_BOOL, bool, int16);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(LessEqual, f16, i1, DT_BOOL, bool,
Eigen::half);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(LessEqual, f32, i1, DT_BOOL, bool,
float);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(LessEqual, f64, i1, DT_BOOL, bool,
double);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(LessEqual, i8, i1, DT_BOOL, bool,
int8);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(LessEqual, i16, i1, DT_BOOL, bool,
int16);
// TODO(b/25387198): We cannot use a regular GPU kernel for int32.
GENERATE_AND_REGISTER_BINARY_KERNEL2(LessEqual, i64, i1, DT_BOOL, bool, int64);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(LessEqual, i64, i1, DT_BOOL, bool,
int64);
} // namespace tensorflow

View File

@ -14,12 +14,12 @@ limitations under the License.
==============================================================================*/
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
namespace tensorflow {
GENERATE_AND_REGISTER_UNARY_KERNEL(Lgamma, f16, DT_HALF, Eigen::half);
GENERATE_AND_REGISTER_UNARY_KERNEL(Lgamma, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_UNARY_KERNEL(Lgamma, f64, DT_DOUBLE, double);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Lgamma, f16, DT_HALF, Eigen::half);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Lgamma, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Lgamma, f64, DT_DOUBLE, double);
} // namespace tensorflow

View File

@ -14,12 +14,12 @@ limitations under the License.
==============================================================================*/
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
namespace tensorflow {
GENERATE_AND_REGISTER_UNARY_KERNEL(Log, f16, DT_HALF, Eigen::half);
GENERATE_AND_REGISTER_UNARY_KERNEL(Log, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_UNARY_KERNEL(Log, f64, DT_DOUBLE, double);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Log, f16, DT_HALF, Eigen::half);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Log, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Log, f64, DT_DOUBLE, double);
} // namespace tensorflow

View File

@ -14,12 +14,12 @@ limitations under the License.
==============================================================================*/
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
namespace tensorflow {
GENERATE_AND_REGISTER_UNARY_KERNEL(Log1p, f16, DT_HALF, Eigen::half);
GENERATE_AND_REGISTER_UNARY_KERNEL(Log1p, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_UNARY_KERNEL(Log1p, f64, DT_DOUBLE, double);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Log1p, f16, DT_HALF, Eigen::half);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Log1p, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Log1p, f64, DT_DOUBLE, double);
} // namespace tensorflow

View File

@ -13,13 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
namespace tensorflow {
GENERATE_BINARY_KERNEL(LogicalAnd, i1, DT_BOOL, bool);
GENERATE_BINARY_GPU_KERNEL(LogicalAnd, i1, DT_BOOL, bool);
// LogicalAnd does not have a "T" attribute because it only works with type
// bool. So we need to register it without TypeConstraint<bool>("T").
REGISTER_KERNEL_NO_TYPE_CONSTRAINT(LogicalAnd, i1, i1);
REGISTER_GPU_KERNEL_NO_TYPE_CONSTRAINT(LogicalAnd, i1, i1);
} // namespace tensorflow

View File

@ -13,13 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
namespace tensorflow {
GENERATE_UNARY_KERNEL(LogicalNot, i1, DT_BOOL, bool);
GENERATE_UNARY_GPU_KERNEL(LogicalNot, i1, DT_BOOL, bool);
// LogicalNot does not have a "T" attribute because it only works with type
// bool. So we need to register it without TypeConstraint<bool>("T").
REGISTER_KERNEL_NO_TYPE_CONSTRAINT(LogicalNot, i1, i1);
REGISTER_GPU_KERNEL_NO_TYPE_CONSTRAINT(LogicalNot, i1, i1);
} // namespace tensorflow

View File

@ -13,13 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
namespace tensorflow {
GENERATE_BINARY_KERNEL(LogicalOr, i1, DT_BOOL, bool);
GENERATE_BINARY_GPU_KERNEL(LogicalOr, i1, DT_BOOL, bool);
// LogicalOr does not have a "T" attribute because it only works with type
// bool. So we need to register it without TypeConstraint<bool>("T").
REGISTER_KERNEL_NO_TYPE_CONSTRAINT(LogicalOr, i1, i1);
REGISTER_GPU_KERNEL_NO_TYPE_CONSTRAINT(LogicalOr, i1, i1);
} // namespace tensorflow

View File

@ -14,15 +14,15 @@ limitations under the License.
==============================================================================*/
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
namespace tensorflow {
GENERATE_AND_REGISTER_BINARY_KERNEL(Maximum, f16, DT_HALF, Eigen::half);
GENERATE_AND_REGISTER_BINARY_KERNEL(Maximum, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_BINARY_KERNEL(Maximum, f64, DT_DOUBLE, double);
GENERATE_AND_REGISTER_BINARY_KERNEL(Maximum, i16, DT_INT16, int16);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Maximum, f16, DT_HALF, Eigen::half);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Maximum, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Maximum, f64, DT_DOUBLE, double);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Maximum, i16, DT_INT16, int16);
// TODO(b/25387198): We cannot use a regular GPU kernel for int32.
GENERATE_AND_REGISTER_BINARY_KERNEL(Maximum, i64, DT_INT64, int64);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Maximum, i64, DT_INT64, int64);
} // namespace tensorflow

View File

@ -14,15 +14,15 @@ limitations under the License.
==============================================================================*/
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
namespace tensorflow {
GENERATE_AND_REGISTER_BINARY_KERNEL(Minimum, f16, DT_HALF, Eigen::half);
GENERATE_AND_REGISTER_BINARY_KERNEL(Minimum, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_BINARY_KERNEL(Minimum, f64, DT_DOUBLE, double);
GENERATE_AND_REGISTER_BINARY_KERNEL(Minimum, i16, DT_INT16, int16);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Minimum, f16, DT_HALF, Eigen::half);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Minimum, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Minimum, f64, DT_DOUBLE, double);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Minimum, i16, DT_INT16, int16);
// TODO(b/25387198): We cannot use a regular GPU kernel for int32.
GENERATE_AND_REGISTER_BINARY_KERNEL(Minimum, i64, DT_INT64, int64);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Minimum, i64, DT_INT64, int64);
} // namespace tensorflow

View File

@ -14,16 +14,16 @@ limitations under the License.
==============================================================================*/
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
namespace tensorflow {
GENERATE_AND_REGISTER_BINARY_KERNEL(Mul, f16, DT_HALF, Eigen::half);
GENERATE_AND_REGISTER_BINARY_KERNEL(Mul, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_BINARY_KERNEL(Mul, f64, DT_DOUBLE, double);
GENERATE_AND_REGISTER_BINARY_KERNEL(Mul, i8, DT_INT8, int8);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Mul, f16, DT_HALF, Eigen::half);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Mul, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Mul, f64, DT_DOUBLE, double);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Mul, i8, DT_INT8, int8);
// TODO(b/25387198): We cannot use a regular GPU kernel for int32.
GENERATE_AND_REGISTER_BINARY_KERNEL(Mul, i16, DT_INT16, int16);
GENERATE_AND_REGISTER_BINARY_KERNEL(Mul, i64, DT_INT64, int64);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Mul, i16, DT_INT16, int16);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Mul, i64, DT_INT64, int64);
} // namespace tensorflow

View File

@ -14,16 +14,16 @@ limitations under the License.
==============================================================================*/
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
namespace tensorflow {
GENERATE_AND_REGISTER_UNARY_KERNEL(Neg, f16, DT_HALF, Eigen::half);
GENERATE_AND_REGISTER_UNARY_KERNEL(Neg, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_UNARY_KERNEL(Neg, f64, DT_DOUBLE, double);
GENERATE_AND_REGISTER_UNARY_KERNEL(Neg, i8, DT_INT8, int8);
GENERATE_AND_REGISTER_UNARY_KERNEL(Neg, i16, DT_INT16, int16);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Neg, f16, DT_HALF, Eigen::half);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Neg, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Neg, f64, DT_DOUBLE, double);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Neg, i8, DT_INT8, int8);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Neg, i16, DT_INT16, int16);
// TODO(b/25387198): We cannot use a regular GPU kernel for int32.
GENERATE_AND_REGISTER_UNARY_KERNEL(Neg, i64, DT_INT64, int64);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Neg, i64, DT_INT64, int64);
} // namespace tensorflow

View File

@ -16,18 +16,22 @@ limitations under the License.
#include <complex>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
namespace tensorflow {
GENERATE_AND_REGISTER_BINARY_KERNEL2(NotEqual, f16, i1, DT_BOOL, bool,
Eigen::half);
GENERATE_AND_REGISTER_BINARY_KERNEL2(NotEqual, f32, i1, DT_BOOL, bool, float);
GENERATE_AND_REGISTER_BINARY_KERNEL2(NotEqual, f64, i1, DT_BOOL, bool, double);
GENERATE_AND_REGISTER_BINARY_KERNEL2(NotEqual, i1, i1, DT_BOOL, bool, bool);
GENERATE_AND_REGISTER_BINARY_KERNEL2(NotEqual, i8, i1, DT_BOOL, bool, int8);
GENERATE_AND_REGISTER_BINARY_KERNEL2(NotEqual, i16, i1, DT_BOOL, bool, int16);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(NotEqual, f16, i1, DT_BOOL, bool,
Eigen::half);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(NotEqual, f32, i1, DT_BOOL, bool,
float);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(NotEqual, f64, i1, DT_BOOL, bool,
double);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(NotEqual, i1, i1, DT_BOOL, bool, bool);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(NotEqual, i8, i1, DT_BOOL, bool, int8);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(NotEqual, i16, i1, DT_BOOL, bool,
int16);
// TODO(b/25387198): We cannot use a regular GPU kernel for int32.
GENERATE_AND_REGISTER_BINARY_KERNEL2(NotEqual, i64, i1, DT_BOOL, bool, int64);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(NotEqual, i64, i1, DT_BOOL, bool,
int64);
} // namespace tensorflow

View File

@ -13,13 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
namespace tensorflow {
GENERATE_AND_REGISTER_BINARY_KERNEL(Pow, f16, DT_HALF, Eigen::half);
GENERATE_AND_REGISTER_BINARY_KERNEL(Pow, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_BINARY_KERNEL(Pow, f64, DT_DOUBLE, double);
GENERATE_AND_REGISTER_BINARY_KERNEL(Pow, i64, DT_INT64, int64);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Pow, f16, DT_HALF, Eigen::half);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Pow, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Pow, f64, DT_DOUBLE, double);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Pow, i64, DT_INT64, int64);
} // namespace tensorflow

View File

@ -16,14 +16,15 @@ limitations under the License.
#include <complex>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
namespace tensorflow {
GENERATE_UNARY_KERNEL2(Real, c64, f32, DT_FLOAT, float, std::complex<float>);
REGISTER_COMPLEX_KERNEL(Real, c64, f32, float, std::complex<float>);
GENERATE_UNARY_KERNEL2(Real, c128, f64, DT_DOUBLE, double,
std::complex<double>);
REGISTER_COMPLEX_KERNEL(Real, c128, f64, double, std::complex<double>);
GENERATE_UNARY_GPU_KERNEL2(Real, c64, f32, DT_FLOAT, float,
std::complex<float>);
REGISTER_COMPLEX_GPU_KERNEL(Real, c64, f32, float, std::complex<float>);
GENERATE_UNARY_GPU_KERNEL2(Real, c128, f64, DT_DOUBLE, double,
std::complex<double>);
REGISTER_COMPLEX_GPU_KERNEL(Real, c128, f64, double, std::complex<double>);
} // namespace tensorflow

View File

@ -13,13 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
namespace tensorflow {
GENERATE_AND_REGISTER_BINARY_KERNEL(RightShift, i8, DT_INT8, int8);
GENERATE_AND_REGISTER_BINARY_KERNEL(RightShift, i16, DT_INT16, int16);
GENERATE_AND_REGISTER_BINARY_KERNEL(RightShift, i32, DT_INT32, int32);
GENERATE_AND_REGISTER_BINARY_KERNEL(RightShift, i64, DT_INT64, int64);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(RightShift, i8, DT_INT8, int8);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(RightShift, i16, DT_INT16, int16);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(RightShift, i32, DT_INT32, int32);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(RightShift, i64, DT_INT64, int64);
} // namespace tensorflow

View File

@ -14,12 +14,12 @@ limitations under the License.
==============================================================================*/
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
namespace tensorflow {
GENERATE_AND_REGISTER_UNARY_KERNEL(Rsqrt, f16, DT_HALF, Eigen::half);
GENERATE_AND_REGISTER_UNARY_KERNEL(Rsqrt, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_UNARY_KERNEL(Rsqrt, f64, DT_DOUBLE, double);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Rsqrt, f16, DT_HALF, Eigen::half);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Rsqrt, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Rsqrt, f64, DT_DOUBLE, double);
} // namespace tensorflow

View File

@ -14,15 +14,15 @@ limitations under the License.
==============================================================================*/
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
namespace tensorflow {
GENERATE_AND_REGISTER_UNARY_KERNEL(Sign, f16, DT_HALF, Eigen::half);
GENERATE_AND_REGISTER_UNARY_KERNEL(Sign, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_UNARY_KERNEL(Sign, f64, DT_DOUBLE, double);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Sign, f16, DT_HALF, Eigen::half);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Sign, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Sign, f64, DT_DOUBLE, double);
// TODO(b/25387198): We cannot use a regular GPU kernel for int32.
GENERATE_AND_REGISTER_UNARY_KERNEL(Sign, i64, DT_INT64, int64);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Sign, i64, DT_INT64, int64);
// TODO(b/162577610): Register the kernel for complex types and bfloat.
} // namespace tensorflow

View File

@ -14,12 +14,12 @@ limitations under the License.
==============================================================================*/
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
namespace tensorflow {
GENERATE_AND_REGISTER_UNARY_KERNEL(Sin, f16, DT_HALF, Eigen::half);
GENERATE_AND_REGISTER_UNARY_KERNEL(Sin, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_UNARY_KERNEL(Sin, f64, DT_DOUBLE, double);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Sin, f16, DT_HALF, Eigen::half);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Sin, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Sin, f64, DT_DOUBLE, double);
} // namespace tensorflow

View File

@ -14,11 +14,11 @@ limitations under the License.
==============================================================================*/
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
namespace tensorflow {
GENERATE_AND_REGISTER_UNARY_KERNEL(Sinh, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_UNARY_KERNEL(Sinh, f64, DT_DOUBLE, double);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Sinh, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Sinh, f64, DT_DOUBLE, double);
} // namespace tensorflow

Some files were not shown because too many files have changed in this diff Show More