Merge branch 'master' of github.com:ashahba/tensorflow into ashahba/onednn-centos7
This commit is contained in:
commit
5381666f28
@ -555,6 +555,15 @@ def HLOClient_ConstantLikeOp : HLOClient_Op<"constant_like",
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def HLOClient_DigammaOp : HLOClient_UnaryElementwiseOp<"digamma",
|
||||
[SameOperandsAndResultType], HLO_FpTensor, HLO_FpTensor> {
|
||||
let summary = "Digamma function";
|
||||
|
||||
let description = [{
|
||||
Returns `Digamma(operand)` element-wise.
|
||||
}];
|
||||
}
|
||||
|
||||
def HLOClient_ErfOp : HLOClient_UnaryElementwiseOp<"erf",
|
||||
[SameOperandsAndResultType], HLO_FpTensor, HLO_FpTensor> {
|
||||
let summary = "Erfc operator";
|
||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/ImplicitLocOpBuilder.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
|
||||
namespace mlir {
|
||||
@ -588,6 +589,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::PowOp>(Location loc,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
lmhlo::PowOp::Adaptor adaptor(args);
|
||||
auto lb = ImplicitLocOpBuilder(loc, *b);
|
||||
// Floating point can use std::powf
|
||||
auto result_type = result_types.front();
|
||||
if (result_type.isa<::mlir::FloatType>())
|
||||
@ -597,53 +599,66 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::PowOp>(Location loc,
|
||||
assert(result_type.isa<::mlir::IntegerType>() &&
|
||||
"only float and integer `pow` is supported right now");
|
||||
|
||||
// There is no powi, so lower to a simple product.
|
||||
Value neg_one =
|
||||
b->create<ConstantOp>(loc, b->getIntegerAttr(result_type, -1));
|
||||
Value zero = b->create<ConstantOp>(loc, b->getIntegerAttr(result_type, 0));
|
||||
Value one = b->create<ConstantOp>(loc, b->getIntegerAttr(result_type, 1));
|
||||
Value two = b->create<ConstantOp>(loc, b->getIntegerAttr(result_type, 2));
|
||||
// Exponentiation by squaring:
|
||||
// https://en.wikipedia.org/wiki/Exponentiation_by_squaring;
|
||||
Value neg_one = lb.create<ConstantOp>(lb.getIntegerAttr(result_type, -1));
|
||||
Value zero = lb.create<ConstantOp>(lb.getIntegerAttr(result_type, 0));
|
||||
Value one = lb.create<ConstantOp>(lb.getIntegerAttr(result_type, 1));
|
||||
Value two = lb.create<ConstantOp>(lb.getIntegerAttr(result_type, 2));
|
||||
Value step = lb.create<ConstantIndexOp>(1);
|
||||
Value lowerBound = lb.create<ConstantIndexOp>(0);
|
||||
// Everything else would overflow for any exponent > 1, as 2^64
|
||||
// is the larget possible exponent for a 64-bit integer, and
|
||||
// that's 1 << 6.
|
||||
Value upperBound = lb.create<ConstantIndexOp>(6);
|
||||
auto original_base = adaptor.lhs();
|
||||
auto original_exponent = adaptor.rhs();
|
||||
|
||||
Value lowerBound = b->create<ConstantIndexOp>(loc, 0);
|
||||
Value upperBound =
|
||||
b->create<IndexCastOp>(loc, adaptor.rhs(), b->getIndexType());
|
||||
Value step = b->create<ConstantIndexOp>(loc, 1);
|
||||
Value for_result =
|
||||
b->create<scf::ForOp>(
|
||||
loc, lowerBound, upperBound, step, llvm::makeArrayRef(one),
|
||||
[&](OpBuilder& b, Location l, Value v, ValueRange iters) {
|
||||
Value prod =
|
||||
b.create<::mlir::MulIOp>(l, adaptor.lhs(), iters.front());
|
||||
b.create<scf::YieldOp>(l, prod);
|
||||
})
|
||||
Value accum =
|
||||
lb.create<scf::ForOp>(
|
||||
lowerBound, upperBound, step,
|
||||
SmallVector<Value>({one, original_base, original_exponent}),
|
||||
[&](OpBuilder& b, Location, Value v, ValueRange iters) {
|
||||
Value accum = iters[0];
|
||||
Value base = iters[1];
|
||||
Value exponent = iters[2];
|
||||
|
||||
Value condition = b.create<CmpIOp>(
|
||||
loc, CmpIPredicate::eq,
|
||||
b.create<::mlir::AndOp>(loc, exponent, one), one);
|
||||
Value multiplied = b.create<::mlir::MulIOp>(loc, accum, base);
|
||||
accum =
|
||||
b.create<::mlir::SelectOp>(loc, condition, multiplied, accum);
|
||||
base = b.create<::mlir::MulIOp>(loc, base, base);
|
||||
exponent =
|
||||
b.create<::mlir::UnsignedShiftRightOp>(loc, exponent, one);
|
||||
b.create<scf::YieldOp>(
|
||||
loc, SmallVector<Value>({accum, base, exponent}));
|
||||
})
|
||||
.getResult(0);
|
||||
|
||||
Value rhs_is_even =
|
||||
b->create<CmpIOp>(loc, CmpIPredicate::eq,
|
||||
b->create<SignedRemIOp>(loc, adaptor.rhs(), two), zero);
|
||||
Value rhs_is_even = lb.create<CmpIOp>(
|
||||
CmpIPredicate::eq, lb.create<SignedRemIOp>(adaptor.rhs(), two), zero);
|
||||
Value rhs_is_negative =
|
||||
b->create<CmpIOp>(loc, CmpIPredicate::slt, adaptor.rhs(), zero);
|
||||
Value lhs_is_one =
|
||||
b->create<CmpIOp>(loc, CmpIPredicate::eq, adaptor.lhs(), one);
|
||||
lb.create<CmpIOp>(CmpIPredicate::slt, adaptor.rhs(), zero);
|
||||
Value lhs_is_one = lb.create<CmpIOp>(CmpIPredicate::eq, adaptor.lhs(), one);
|
||||
Value lhs_is_neg_one =
|
||||
b->create<CmpIOp>(loc, CmpIPredicate::eq, adaptor.lhs(), neg_one);
|
||||
lb.create<CmpIOp>(CmpIPredicate::eq, adaptor.lhs(), neg_one);
|
||||
|
||||
// The for_result is correct when the rhs is non-negative. When rhs is
|
||||
// The accum is correct when the rhs is non-negative. When rhs is
|
||||
// negative, we return 0 for integer, with the exception of lhs values of 1
|
||||
// and -1 which have integer results for negative exponents. Specifically, the
|
||||
// calulation is the following:
|
||||
//
|
||||
// - Return for_result if the rhs is not negative.
|
||||
// - Return accum if the rhs is not negative.
|
||||
// - Return 1 or -1 depending on the parity of rhs when the lhs is -1.
|
||||
// - Return 1 if lhs is 1.
|
||||
// - Else return 0.
|
||||
Value if_lhs_is_one = b->create<::mlir::SelectOp>(loc, lhs_is_one, one, zero);
|
||||
Value if_lhs_is_neg_one = b->create<::mlir::SelectOp>(
|
||||
loc, lhs_is_neg_one,
|
||||
b->create<::mlir::SelectOp>(loc, rhs_is_even, one, neg_one),
|
||||
Value if_lhs_is_one = lb.create<::mlir::SelectOp>(lhs_is_one, one, zero);
|
||||
Value if_lhs_is_neg_one = lb.create<::mlir::SelectOp>(
|
||||
lhs_is_neg_one, lb.create<::mlir::SelectOp>(rhs_is_even, one, neg_one),
|
||||
if_lhs_is_one);
|
||||
return b->create<::mlir::SelectOp>(loc, rhs_is_negative, if_lhs_is_neg_one,
|
||||
for_result);
|
||||
return lb.create<::mlir::SelectOp>(rhs_is_negative, if_lhs_is_neg_one, accum);
|
||||
}
|
||||
|
||||
template <>
|
||||
|
@ -1310,6 +1310,40 @@ class DynamicReshapeOpNotActuallyDynamic
|
||||
}
|
||||
};
|
||||
|
||||
// Canonicalizes
|
||||
// %0 = some_op(%tensor)
|
||||
// %1 = "mhlo.dynamic_reshape"(%0, %shape)
|
||||
// (tensor<?xT>, tensor<1xindex>) -> tensor<?xT>
|
||||
// ... uses of %1.
|
||||
//
|
||||
// into
|
||||
//
|
||||
// ... uses of %0.
|
||||
// This canonicalization is only correct if the input is correct!
|
||||
// TODO(b/178779691): Use a more sophisticated canonicalization that preserves
|
||||
// errors in input, and still allows us to get rid of redundant reshapes.
|
||||
class RemoveRedundantRank1DynamicReshape
|
||||
: public OpRewritePattern<DynamicReshapeOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(DynamicReshapeOp op,
|
||||
PatternRewriter& rewriter) const override {
|
||||
auto type = op.result().getType().dyn_cast<RankedTensorType>();
|
||||
if (!type || type.getRank() != 1 || type.hasStaticShape()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "requires rank 1 shape tensor with dynamic dimension");
|
||||
}
|
||||
auto operand_type = op.operand().getType().dyn_cast<RankedTensorType>();
|
||||
if (!operand_type || operand_type.getRank() != 1 ||
|
||||
operand_type.hasStaticShape()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "requires rank 1 shape tensor with dynamic dimension");
|
||||
}
|
||||
rewriter.replaceOp(op, {op.operand()});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Canonicalizes
|
||||
// %0 = "mhlo.dynamic_reshape"(%tensor, %shape)
|
||||
// %1 = same_operands_and_result_shape_op(%tensor)
|
||||
@ -1354,6 +1388,7 @@ void DynamicReshapeOp::getCanonicalizationPatterns(
|
||||
DynamicReshapeOpSameShapeOpResult,
|
||||
RemoveRedundantDynamicBroadcast,
|
||||
RemoveRedundantDynamicReshape,
|
||||
RemoveRedundantRank1DynamicReshape,
|
||||
ShapeOfDynamicReshape
|
||||
>(context);
|
||||
// clang-format on
|
||||
|
@ -625,6 +625,119 @@ Value MaterializeLgamma(ConversionPatternRewriter &rewriter, Location loc,
|
||||
lgamma);
|
||||
}
|
||||
|
||||
// Compute the Digamma function using Lanczos' approximation from "A Precision
|
||||
// Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
|
||||
// series B. Vol. 1:
|
||||
// digamma(z + 1) = log(t(z)) + a'(z) / a(z) - kLanczosGamma / t(z)
|
||||
// with t(z) = z + kLanczosGamma + 1/2
|
||||
// a(z) = kBaseLanczosCoeff
|
||||
// + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
|
||||
// a'(z) = - sum(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k))
|
||||
Value MaterializeDigamma(ConversionPatternRewriter &rewriter, Location loc,
|
||||
Value x) {
|
||||
// If the input is less than 0.5 use Euler's reflection formula.
|
||||
// digamma(x) = digamma(1 - x) - pi * cot(pi * x)
|
||||
// Let z be
|
||||
// z = -x if x < 1/2
|
||||
// z = x - 1 otheriwse
|
||||
const StringAttr kLT = rewriter.getStringAttr(
|
||||
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
|
||||
Value half = getConstantLike(rewriter, loc, 0.5, x);
|
||||
Value need_to_reflect = rewriter.create<mhlo::CompareOp>(loc, x, half, kLT);
|
||||
Value neg_x = rewriter.create<mhlo::NegOp>(loc, x);
|
||||
Value one = getConstantLike(rewriter, loc, 1, x);
|
||||
Value x_sub_one = rewriter.create<mhlo::SubOp>(loc, x, one);
|
||||
Value z =
|
||||
rewriter.create<mhlo::SelectOp>(loc, need_to_reflect, neg_x, x_sub_one);
|
||||
|
||||
// Materialize
|
||||
// a(z) = kBaseLanczosCoeff
|
||||
// + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
|
||||
// a'(z) = - sum(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k))
|
||||
Value zero = getConstantLike(rewriter, loc, 0.0, x);
|
||||
Value a = getConstantLike(rewriter, loc, kBaseLanczosCoeff, x);
|
||||
Value a_prime = zero;
|
||||
for (int i = 0, end = kLanczosCoefficients.size(); i < end; ++i) {
|
||||
Value coeff = getConstantLike(rewriter, loc, kLanczosCoefficients[i], x);
|
||||
Value one_based_index = getConstantLike(rewriter, loc, i + 1, x);
|
||||
Value z_term = rewriter.create<mhlo::AddOp>(loc, z, one_based_index);
|
||||
a_prime = rewriter.create<mhlo::SubOp>(
|
||||
loc, a_prime,
|
||||
rewriter.create<mhlo::DivOp>(
|
||||
loc, coeff, rewriter.create<mhlo::MulOp>(loc, z_term, z_term)));
|
||||
a = rewriter.create<mhlo::AddOp>(
|
||||
loc, a, rewriter.create<mhlo::DivOp>(loc, coeff, z_term));
|
||||
}
|
||||
|
||||
// To improve accuracy on platforms with less-precise log implementations,
|
||||
// compute log(kLanczosGamma + 1/2) at compile time and use log1p on the
|
||||
// device.
|
||||
// Materialize as
|
||||
// log(t) = log(kLanczosGamma + 1/2 + z)
|
||||
// = log(kLanczosGamma + 1/2) + log1p(z / (kLanczosGamma + 1/2)).
|
||||
Value lanczos_plus_half =
|
||||
getConstantLike(rewriter, loc, kLanczosGamma + 0.5, x);
|
||||
Value t = rewriter.create<mhlo::AddOp>(loc, lanczos_plus_half, z);
|
||||
Value log_term =
|
||||
getConstantLike(rewriter, loc, std::log(kLanczosGamma + 0.5), x);
|
||||
Value log1p_term = rewriter.create<mhlo::Log1pOp>(
|
||||
loc, rewriter.create<mhlo::DivOp>(loc, z, lanczos_plus_half));
|
||||
Value log_t = rewriter.create<mhlo::AddOp>(loc, log_term, log1p_term);
|
||||
|
||||
// Materialize the final result (modulo reflection) as
|
||||
// digamma(z + 1) = log(t(z)) + a'(z) / a(z) - kLanczosGamma / t(z).
|
||||
Value a_prime_div_a = rewriter.create<mhlo::DivOp>(loc, a_prime, a);
|
||||
Value lanczos_gamma_div_t = rewriter.create<mhlo::DivOp>(
|
||||
loc, getConstantLike(rewriter, loc, kLanczosGamma, x), t);
|
||||
Value digamma = rewriter.create<mhlo::SubOp>(
|
||||
loc, rewriter.create<mhlo::AddOp>(loc, log_t, a_prime_div_a),
|
||||
lanczos_gamma_div_t);
|
||||
|
||||
// We need to be careful how we compute cot(pi * input) below: For
|
||||
// near-integral arguments, pi * input can lose precision.
|
||||
//
|
||||
// Input is already known to be less than 0.5 (otherwise we don't have to
|
||||
// reflect). We shift values smaller than -0.5 into the range [-0.5, 0.5] to
|
||||
// increase precision of pi * x and the resulting cotangent.
|
||||
Value reduced_x = rewriter.create<mhlo::AddOp>(
|
||||
loc, x,
|
||||
rewriter.create<mhlo::AbsOp>(
|
||||
loc, rewriter.create<mhlo::FloorOp>(
|
||||
loc, rewriter.create<mhlo::AddOp>(
|
||||
loc, x, getConstantLike(rewriter, loc, 0.5, x)))));
|
||||
|
||||
// Materialize reflection for inputs less than 0.5 as
|
||||
// digamma(x) = digamma(1 - x) - pi * cot(pi * x)
|
||||
// = digamma(1 - x) - pi * cos(pi * x) / sin(pi * x)
|
||||
Value pi = getConstantLike(rewriter, loc, M_PI, x);
|
||||
Value pi_mul_reduced_x = rewriter.create<mhlo::MulOp>(loc, pi, reduced_x);
|
||||
Value cos = rewriter.create<mhlo::CosOp>(loc, pi_mul_reduced_x);
|
||||
Value sin = rewriter.create<mhlo::SinOp>(loc, pi_mul_reduced_x);
|
||||
Value reflection = rewriter.create<mhlo::SubOp>(
|
||||
loc, digamma,
|
||||
rewriter.create<mhlo::DivOp>(
|
||||
loc, rewriter.create<mhlo::MulOp>(loc, pi, cos), sin));
|
||||
|
||||
// Select whether or not to rely on the reflection.
|
||||
digamma = rewriter.create<mhlo::SelectOp>(loc, need_to_reflect, reflection,
|
||||
digamma);
|
||||
|
||||
// Digamma has poles at negative integers and zero; return nan for those.
|
||||
const StringAttr kLE = rewriter.getStringAttr(
|
||||
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LE));
|
||||
Value is_le_zero = rewriter.create<mhlo::CompareOp>(loc, x, zero, kLE);
|
||||
const StringAttr kEQ = rewriter.getStringAttr(
|
||||
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::EQ));
|
||||
Value is_int = rewriter.create<mhlo::CompareOp>(
|
||||
loc, x, rewriter.create<mhlo::FloorOp>(loc, x), kEQ);
|
||||
Value is_pole = rewriter.create<mhlo::AndOp>(loc, is_le_zero, is_int);
|
||||
return rewriter.create<mhlo::SelectOp>(
|
||||
loc, is_pole,
|
||||
getConstantLike(rewriter, loc, std::numeric_limits<double>::quiet_NaN(),
|
||||
x),
|
||||
digamma);
|
||||
}
|
||||
|
||||
struct ConvertLgammaOp : public OpConversionPattern<LgammaOp> {
|
||||
using OpConversionPattern<LgammaOp>::OpConversionPattern;
|
||||
LogicalResult matchAndRewrite(
|
||||
@ -639,6 +752,20 @@ struct ConvertLgammaOp : public OpConversionPattern<LgammaOp> {
|
||||
}
|
||||
};
|
||||
|
||||
struct ConvertDigammaOp : public OpConversionPattern<DigammaOp> {
|
||||
using OpConversionPattern<DigammaOp>::OpConversionPattern;
|
||||
LogicalResult matchAndRewrite(
|
||||
DigammaOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
DigammaOp::Adaptor transformed(operands);
|
||||
FloatType min_precision_ty = rewriter.getF32Type();
|
||||
rewriter.replaceOp(
|
||||
op, MaterializeWithUpcast(rewriter, op.getLoc(), transformed.operand(),
|
||||
min_precision_ty, &MaterializeDigamma));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Converts binary ops that statically are determined to not broadcast directly
|
||||
// to the corresponding mhlo non-broadcasting op.
|
||||
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
|
||||
@ -790,8 +917,13 @@ void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
|
||||
context, patterns, 5);
|
||||
|
||||
// Other patterns.
|
||||
patterns->insert<ConvertConstantLikeOp, ConvertErfOp, ConvertErfcOp,
|
||||
// clang-format off
|
||||
patterns->insert<ConvertConstantLikeOp,
|
||||
ConvertDigammaOp,
|
||||
ConvertErfOp,
|
||||
ConvertErfcOp,
|
||||
ConvertLgammaOp>(context);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
} // namespace chlo
|
||||
|
@ -53,8 +53,9 @@ namespace {
|
||||
// TODO(herhut): Generate these out of op definitions.
|
||||
#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \
|
||||
fn(AcosOp) sep fn(AcoshOp) sep fn(AsinOp) sep fn(AsinhOp) sep fn(AtanOp) \
|
||||
sep fn(AtanhOp) sep fn(ConjOp) sep fn(CoshOp) sep fn(ErfOp) \
|
||||
sep fn(ErfcOp) sep fn(LgammaOp) sep fn(SinhOp) sep fn(TanOp)
|
||||
sep fn(AtanhOp) sep fn(ConjOp) sep fn(CoshOp) sep fn(DigammaOp) \
|
||||
sep fn(ErfOp) sep fn(ErfcOp) sep fn(LgammaOp) sep fn(SinhOp) \
|
||||
sep fn(TanOp)
|
||||
|
||||
template <typename OpTy>
|
||||
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {
|
||||
|
@ -594,12 +594,26 @@ func @shape_of_dynamic_reshape(%arg0: tensor<*xf32>, %shape: tensor<2xindex>) ->
|
||||
return %1 : tensor<2xindex>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @dynamic_reshape_rank_1_to_rank_1
|
||||
// CHECK-SAME: [[ARG0:%[a-zA-Z0-9]+]]
|
||||
func @dynamic_reshape_rank_1_to_rank_1(%arg0: tensor<?xcomplex<f32>>,
|
||||
%shape: tensor<?xindex>) -> tensor<?xf32> {
|
||||
// CHECK: [[RES:%[a-zA-Z0-9]+]] = "mhlo.real"([[ARG0]]) : (tensor<?xcomplex<f32>>) -> tensor<?xf32>
|
||||
// CHECK: return [[RES]]
|
||||
%0 = "mhlo.real"(%arg0): (tensor<?xcomplex<f32>>) -> tensor<?xf32>
|
||||
%1 = shape.shape_of %arg0 : tensor<?xcomplex<f32>> -> tensor<1xindex>
|
||||
%2 = shape.num_elements %1 : tensor<1xindex> -> index
|
||||
%3 = tensor.from_elements %2 : tensor<1xindex>
|
||||
%4 = "mhlo.dynamic_reshape"(%0, %3)
|
||||
: (tensor<?xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
return %4 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @dynamic_reshape_of_dynamic_reshape
|
||||
// CHECK-SAME: [[ARG0:%[a-zA-Z0-9]+]]
|
||||
// CHECK-SAME: [[ARG1:%[a-zA-Z0-9]+]]
|
||||
func @dynamic_reshape_of_dynamic_reshape(%arg0: tensor<?xf16>, %shape: tensor<?xindex>) -> tensor<?xf16> {
|
||||
// CHECK: [[RES:%[a-zA-Z0-9]+]] = "mhlo.dynamic_reshape"([[ARG0]], %{{[a-zA-Z0-9]+}}) : (tensor<?xf16>, tensor<1xindex>) -> tensor<?xf16>
|
||||
// CHECK: return [[RES]]
|
||||
// CHECK: return [[ARG0]]
|
||||
%0 = "mhlo.dynamic_reshape"(%arg0, %shape) : (tensor<?xf16>, tensor<?xindex>) -> tensor<*xf16>
|
||||
%1 = shape.shape_of %0 : tensor<*xf16> -> tensor<?xindex>
|
||||
%2 = shape.num_elements %1 : tensor<?xindex> -> index
|
||||
|
@ -875,3 +875,233 @@ func @lgamma_f16(%arg : tensor<f16>) -> tensor<f16> {
|
||||
%1 = chlo.lgamma %arg : tensor<f16> -> tensor<f16>
|
||||
return %1 : tensor<f16>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @digamma_f64
|
||||
// CHECK-SAME: (%[[ARG:.*]]: tensor<f64>)
|
||||
func @digamma_f64(%arg : tensor<f64>) -> tensor<f64> {
|
||||
// CHECK: %[[TMP_0:.*]] = mhlo.constant dense<5.000000e-01>
|
||||
// CHECK: %[[TMP_1:.*]] = "mhlo.compare"(%arg0, %[[TMP_0]]) {comparison_direction = "LT"}
|
||||
// CHECK: %[[TMP_2:.*]] = "mhlo.negate"(%arg0)
|
||||
// CHECK: %[[TMP_3:.*]] = mhlo.constant dense<1.000000e+00>
|
||||
// CHECK: %[[TMP_4:.*]] = mhlo.subtract %arg0, %[[TMP_3]]
|
||||
// CHECK: %[[TMP_5:.*]] = "mhlo.select"(%[[TMP_1]], %[[TMP_2]], %[[TMP_4]])
|
||||
// CHECK: %[[TMP_6:.*]] = mhlo.constant dense<0.000000e+00>
|
||||
// CHECK: %[[TMP_7:.*]] = mhlo.constant dense<0.99999999999980993>
|
||||
// CHECK: %[[TMP_8:.*]] = mhlo.constant dense<676.5203681218851>
|
||||
// CHECK: %[[TMP_9:.*]] = mhlo.constant dense<1.000000e+00>
|
||||
// CHECK: %[[TMP_10:.*]] = mhlo.add %[[TMP_5]], %[[TMP_9]]
|
||||
// CHECK: %[[TMP_11:.*]] = mhlo.multiply %[[TMP_10]], %[[TMP_10]]
|
||||
// CHECK: %[[TMP_12:.*]] = mhlo.divide %[[TMP_8]], %[[TMP_11]]
|
||||
// CHECK: %[[TMP_13:.*]] = mhlo.subtract %[[TMP_6]], %[[TMP_12]]
|
||||
// CHECK: %[[TMP_14:.*]] = mhlo.divide %[[TMP_8]], %[[TMP_10]]
|
||||
// CHECK: %[[TMP_15:.*]] = mhlo.add %[[TMP_7]], %[[TMP_14]]
|
||||
// CHECK: %[[TMP_16:.*]] = mhlo.constant dense<-1259.1392167224028>
|
||||
// CHECK: %[[TMP_17:.*]] = mhlo.constant dense<2.000000e+00>
|
||||
// CHECK: %[[TMP_18:.*]] = mhlo.add %[[TMP_5]], %[[TMP_17]]
|
||||
// CHECK: %[[TMP_19:.*]] = mhlo.multiply %[[TMP_18]], %[[TMP_18]]
|
||||
// CHECK: %[[TMP_20:.*]] = mhlo.divide %[[TMP_16]], %[[TMP_19]]
|
||||
// CHECK: %[[TMP_21:.*]] = mhlo.subtract %[[TMP_13]], %[[TMP_20]]
|
||||
// CHECK: %[[TMP_22:.*]] = mhlo.divide %[[TMP_16]], %[[TMP_18]]
|
||||
// CHECK: %[[TMP_23:.*]] = mhlo.add %[[TMP_15]], %[[TMP_22]]
|
||||
// CHECK: %[[TMP_24:.*]] = mhlo.constant dense<771.32342877765313>
|
||||
// CHECK: %[[TMP_25:.*]] = mhlo.constant dense<3.000000e+00>
|
||||
// CHECK: %[[TMP_26:.*]] = mhlo.add %[[TMP_5]], %[[TMP_25]]
|
||||
// CHECK: %[[TMP_27:.*]] = mhlo.multiply %[[TMP_26]], %[[TMP_26]]
|
||||
// CHECK: %[[TMP_28:.*]] = mhlo.divide %[[TMP_24]], %[[TMP_27]]
|
||||
// CHECK: %[[TMP_29:.*]] = mhlo.subtract %[[TMP_21]], %[[TMP_28]]
|
||||
// CHECK: %[[TMP_30:.*]] = mhlo.divide %[[TMP_24]], %[[TMP_26]]
|
||||
// CHECK: %[[TMP_31:.*]] = mhlo.add %[[TMP_23]], %[[TMP_30]]
|
||||
// CHECK: %[[TMP_32:.*]] = mhlo.constant dense<-176.61502916214059>
|
||||
// CHECK: %[[TMP_33:.*]] = mhlo.constant dense<4.000000e+00>
|
||||
// CHECK: %[[TMP_34:.*]] = mhlo.add %[[TMP_5]], %[[TMP_33]]
|
||||
// CHECK: %[[TMP_35:.*]] = mhlo.multiply %[[TMP_34]], %[[TMP_34]]
|
||||
// CHECK: %[[TMP_36:.*]] = mhlo.divide %[[TMP_32]], %[[TMP_35]]
|
||||
// CHECK: %[[TMP_37:.*]] = mhlo.subtract %[[TMP_29]], %[[TMP_36]]
|
||||
// CHECK: %[[TMP_38:.*]] = mhlo.divide %[[TMP_32]], %[[TMP_34]]
|
||||
// CHECK: %[[TMP_39:.*]] = mhlo.add %[[TMP_31]], %[[TMP_38]]
|
||||
// CHECK: %[[TMP_40:.*]] = mhlo.constant dense<12.507343278686905>
|
||||
// CHECK: %[[TMP_41:.*]] = mhlo.constant dense<5.000000e+00>
|
||||
// CHECK: %[[TMP_42:.*]] = mhlo.add %[[TMP_5]], %[[TMP_41]]
|
||||
// CHECK: %[[TMP_43:.*]] = mhlo.multiply %[[TMP_42]], %[[TMP_42]]
|
||||
// CHECK: %[[TMP_44:.*]] = mhlo.divide %[[TMP_40]], %[[TMP_43]]
|
||||
// CHECK: %[[TMP_45:.*]] = mhlo.subtract %[[TMP_37]], %[[TMP_44]]
|
||||
// CHECK: %[[TMP_46:.*]] = mhlo.divide %[[TMP_40]], %[[TMP_42]]
|
||||
// CHECK: %[[TMP_47:.*]] = mhlo.add %[[TMP_39]], %[[TMP_46]]
|
||||
// CHECK: %[[TMP_48:.*]] = mhlo.constant dense<-0.13857109526572012>
|
||||
// CHECK: %[[TMP_49:.*]] = mhlo.constant dense<6.000000e+00>
|
||||
// CHECK: %[[TMP_50:.*]] = mhlo.add %[[TMP_5]], %[[TMP_49]]
|
||||
// CHECK: %[[TMP_51:.*]] = mhlo.multiply %[[TMP_50]], %[[TMP_50]]
|
||||
// CHECK: %[[TMP_52:.*]] = mhlo.divide %[[TMP_48]], %[[TMP_51]]
|
||||
// CHECK: %[[TMP_53:.*]] = mhlo.subtract %[[TMP_45]], %[[TMP_52]]
|
||||
// CHECK: %[[TMP_54:.*]] = mhlo.divide %[[TMP_48]], %[[TMP_50]]
|
||||
// CHECK: %[[TMP_55:.*]] = mhlo.add %[[TMP_47]], %[[TMP_54]]
|
||||
// CHECK: %[[TMP_56:.*]] = mhlo.constant dense<9.9843695780195716E-6>
|
||||
// CHECK: %[[TMP_57:.*]] = mhlo.constant dense<7.000000e+00>
|
||||
// CHECK: %[[TMP_58:.*]] = mhlo.add %[[TMP_5]], %[[TMP_57]]
|
||||
// CHECK: %[[TMP_59:.*]] = mhlo.multiply %[[TMP_58]], %[[TMP_58]]
|
||||
// CHECK: %[[TMP_60:.*]] = mhlo.divide %[[TMP_56]], %[[TMP_59]]
|
||||
// CHECK: %[[TMP_61:.*]] = mhlo.subtract %[[TMP_53]], %[[TMP_60]]
|
||||
// CHECK: %[[TMP_62:.*]] = mhlo.divide %[[TMP_56]], %[[TMP_58]]
|
||||
// CHECK: %[[TMP_63:.*]] = mhlo.add %[[TMP_55]], %[[TMP_62]]
|
||||
// CHECK: %[[TMP_64:.*]] = mhlo.constant dense<1.5056327351493116E-7>
|
||||
// CHECK: %[[TMP_65:.*]] = mhlo.constant dense<8.000000e+00>
|
||||
// CHECK: %[[TMP_66:.*]] = mhlo.add %[[TMP_5]], %[[TMP_65]]
|
||||
// CHECK: %[[TMP_67:.*]] = mhlo.multiply %[[TMP_66]], %[[TMP_66]]
|
||||
// CHECK: %[[TMP_68:.*]] = mhlo.divide %[[TMP_64]], %[[TMP_67]]
|
||||
// CHECK: %[[TMP_69:.*]] = mhlo.subtract %[[TMP_61]], %[[TMP_68]]
|
||||
// CHECK: %[[TMP_70:.*]] = mhlo.divide %[[TMP_64]], %[[TMP_66]]
|
||||
// CHECK: %[[TMP_71:.*]] = mhlo.add %[[TMP_63]], %[[TMP_70]]
|
||||
// CHECK: %[[TMP_72:.*]] = mhlo.constant dense<7.500000e+00>
|
||||
// CHECK: %[[TMP_73:.*]] = mhlo.add %[[TMP_72]], %[[TMP_5]]
|
||||
// CHECK: %[[TMP_74:.*]] = mhlo.constant dense<2.0149030205422647>
|
||||
// CHECK: %[[TMP_75:.*]] = mhlo.divide %[[TMP_5]], %[[TMP_72]]
|
||||
// CHECK: %[[TMP_76:.*]] = "mhlo.log_plus_one"(%[[TMP_75]])
|
||||
// CHECK: %[[TMP_77:.*]] = mhlo.add %[[TMP_74]], %[[TMP_76]]
|
||||
// CHECK: %[[TMP_78:.*]] = mhlo.divide %[[TMP_69]], %[[TMP_71]]
|
||||
// CHECK: %[[TMP_79:.*]] = mhlo.constant dense<7.000000e+00>
|
||||
// CHECK: %[[TMP_80:.*]] = mhlo.divide %[[TMP_79]], %[[TMP_73]]
|
||||
// CHECK: %[[TMP_81:.*]] = mhlo.add %[[TMP_77]], %[[TMP_78]]
|
||||
// CHECK: %[[TMP_82:.*]] = mhlo.subtract %[[TMP_81]], %[[TMP_80]]
|
||||
// CHECK: %[[TMP_83:.*]] = mhlo.constant dense<5.000000e-01>
|
||||
// CHECK: %[[TMP_84:.*]] = mhlo.add %arg0, %[[TMP_83]]
|
||||
// CHECK: %[[TMP_85:.*]] = "mhlo.floor"(%[[TMP_84]])
|
||||
// CHECK: %[[TMP_86:.*]] = "mhlo.abs"(%[[TMP_85]])
|
||||
// CHECK: %[[TMP_87:.*]] = mhlo.add %arg0, %[[TMP_86]]
|
||||
// CHECK: %[[TMP_88:.*]] = mhlo.constant dense<3.1415926535897931>
|
||||
// CHECK: %[[TMP_89:.*]] = mhlo.multiply %[[TMP_88]], %[[TMP_87]]
|
||||
// CHECK: %[[TMP_90:.*]] = "mhlo.cosine"(%[[TMP_89]])
|
||||
// CHECK: %[[TMP_92:.*]] = "mhlo.sine"(%[[TMP_89]])
|
||||
// CHECK: %[[TMP_91:.*]] = mhlo.multiply %[[TMP_88]], %[[TMP_90]]
|
||||
// CHECK: %[[TMP_93:.*]] = mhlo.divide %[[TMP_91]], %[[TMP_92]]
|
||||
// CHECK: %[[TMP_94:.*]] = mhlo.subtract %[[TMP_82]], %[[TMP_93]]
|
||||
// CHECK: %[[TMP_95:.*]] = "mhlo.select"(%[[TMP_1]], %[[TMP_94]], %[[TMP_82]])
|
||||
// CHECK: %[[TMP_96:.*]] = "mhlo.compare"(%arg0, %[[TMP_6]]) {comparison_direction = "LE"}
|
||||
// CHECK: %[[TMP_97:.*]] = "mhlo.floor"(%arg0)
|
||||
// CHECK: %[[TMP_98:.*]] = "mhlo.compare"(%arg0, %[[TMP_97]]) {comparison_direction = "EQ"}
|
||||
// CHECK: %[[TMP_99:.*]] = mhlo.and %[[TMP_96]], %[[TMP_98]]
|
||||
// CHECK: %[[TMP_100:.*]] = mhlo.constant dense<0x7FF8000000000000>
|
||||
// CHECK: %[[RES:.*]] = "mhlo.select"(%[[TMP_99]], %[[TMP_100]], %[[TMP_95]])
|
||||
// CHECK: return %[[RES]]
|
||||
%1 = chlo.digamma %arg : tensor<f64> -> tensor<f64>
|
||||
return %1 : tensor<f64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @digamma_f32
|
||||
// CHECK-SAME: (%[[ARG:.*]]: tensor<f32>)
|
||||
func @digamma_f32(%arg : tensor<f32>) -> tensor<f32> {
|
||||
// CHECK: %[[TMP_0:.*]] = mhlo.constant dense<5.000000e-01>
|
||||
// CHECK: %[[TMP_1:.*]] = "mhlo.compare"(%arg0, %[[TMP_0]]) {comparison_direction = "LT"}
|
||||
// CHECK: %[[TMP_2:.*]] = "mhlo.negate"(%arg0)
|
||||
// CHECK: %[[TMP_3:.*]] = mhlo.constant dense<1.000000e+00>
|
||||
// CHECK: %[[TMP_4:.*]] = mhlo.subtract %arg0, %[[TMP_3]]
|
||||
// CHECK: %[[TMP_5:.*]] = "mhlo.select"(%[[TMP_1]], %[[TMP_2]], %[[TMP_4]])
|
||||
// CHECK: %[[TMP_6:.*]] = mhlo.constant dense<0.000000e+00>
|
||||
// CHECK: %[[TMP_7:.*]] = mhlo.constant dense<1.000000e+00>
|
||||
// CHECK: %[[TMP_8:.*]] = mhlo.constant dense<676.520386>
|
||||
// CHECK: %[[TMP_9:.*]] = mhlo.constant dense<1.000000e+00>
|
||||
// CHECK: %[[TMP_10:.*]] = mhlo.add %[[TMP_5]], %[[TMP_9]]
|
||||
// CHECK: %[[TMP_11:.*]] = mhlo.multiply %[[TMP_10]], %[[TMP_10]]
|
||||
// CHECK: %[[TMP_12:.*]] = mhlo.divide %[[TMP_8]], %[[TMP_11]]
|
||||
// CHECK: %[[TMP_13:.*]] = mhlo.subtract %[[TMP_6]], %[[TMP_12]]
|
||||
// CHECK: %[[TMP_14:.*]] = mhlo.divide %[[TMP_8]], %[[TMP_10]]
|
||||
// CHECK: %[[TMP_15:.*]] = mhlo.add %[[TMP_7]], %[[TMP_14]]
|
||||
// CHECK: %[[TMP_16:.*]] = mhlo.constant dense<-1259.13916>
|
||||
// CHECK: %[[TMP_17:.*]] = mhlo.constant dense<2.000000e+00>
|
||||
// CHECK: %[[TMP_18:.*]] = mhlo.add %[[TMP_5]], %[[TMP_17]]
|
||||
// CHECK: %[[TMP_19:.*]] = mhlo.multiply %[[TMP_18]], %[[TMP_18]]
|
||||
// CHECK: %[[TMP_20:.*]] = mhlo.divide %[[TMP_16]], %[[TMP_19]]
|
||||
// CHECK: %[[TMP_21:.*]] = mhlo.subtract %[[TMP_13]], %[[TMP_20]]
|
||||
// CHECK: %[[TMP_22:.*]] = mhlo.divide %[[TMP_16]], %[[TMP_18]]
|
||||
// CHECK: %[[TMP_23:.*]] = mhlo.add %[[TMP_15]], %[[TMP_22]]
|
||||
// CHECK: %[[TMP_24:.*]] = mhlo.constant dense<771.323425>
|
||||
// CHECK: %[[TMP_25:.*]] = mhlo.constant dense<3.000000e+00>
|
||||
// CHECK: %[[TMP_26:.*]] = mhlo.add %[[TMP_5]], %[[TMP_25]]
|
||||
// CHECK: %[[TMP_27:.*]] = mhlo.multiply %[[TMP_26]], %[[TMP_26]]
|
||||
// CHECK: %[[TMP_28:.*]] = mhlo.divide %[[TMP_24]], %[[TMP_27]]
|
||||
// CHECK: %[[TMP_29:.*]] = mhlo.subtract %[[TMP_21]], %[[TMP_28]]
|
||||
// CHECK: %[[TMP_30:.*]] = mhlo.divide %[[TMP_24]], %[[TMP_26]]
|
||||
// CHECK: %[[TMP_31:.*]] = mhlo.add %[[TMP_23]], %[[TMP_30]]
|
||||
// CHECK: %[[TMP_32:.*]] = mhlo.constant dense<-176.615036>
|
||||
// CHECK: %[[TMP_33:.*]] = mhlo.constant dense<4.000000e+00>
|
||||
// CHECK: %[[TMP_34:.*]] = mhlo.add %[[TMP_5]], %[[TMP_33]]
|
||||
// CHECK: %[[TMP_35:.*]] = mhlo.multiply %[[TMP_34]], %[[TMP_34]]
|
||||
// CHECK: %[[TMP_36:.*]] = mhlo.divide %[[TMP_32]], %[[TMP_35]]
|
||||
// CHECK: %[[TMP_37:.*]] = mhlo.subtract %[[TMP_29]], %[[TMP_36]]
|
||||
// CHECK: %[[TMP_38:.*]] = mhlo.divide %[[TMP_32]], %[[TMP_34]]
|
||||
// CHECK: %[[TMP_39:.*]] = mhlo.add %[[TMP_31]], %[[TMP_38]]
|
||||
// CHECK: %[[TMP_40:.*]] = mhlo.constant dense<12.5073433>
|
||||
// CHECK: %[[TMP_41:.*]] = mhlo.constant dense<5.000000e+00>
|
||||
// CHECK: %[[TMP_42:.*]] = mhlo.add %[[TMP_5]], %[[TMP_41]]
|
||||
// CHECK: %[[TMP_43:.*]] = mhlo.multiply %[[TMP_42]], %[[TMP_42]]
|
||||
// CHECK: %[[TMP_44:.*]] = mhlo.divide %[[TMP_40]], %[[TMP_43]]
|
||||
// CHECK: %[[TMP_45:.*]] = mhlo.subtract %[[TMP_37]], %[[TMP_44]]
|
||||
// CHECK: %[[TMP_46:.*]] = mhlo.divide %[[TMP_40]], %[[TMP_42]]
|
||||
// CHECK: %[[TMP_47:.*]] = mhlo.add %[[TMP_39]], %[[TMP_46]]
|
||||
// CHECK: %[[TMP_48:.*]] = mhlo.constant dense<-0.138571098>
|
||||
// CHECK: %[[TMP_49:.*]] = mhlo.constant dense<6.000000e+00>
|
||||
// CHECK: %[[TMP_50:.*]] = mhlo.add %[[TMP_5]], %[[TMP_49]]
|
||||
// CHECK: %[[TMP_51:.*]] = mhlo.multiply %[[TMP_50]], %[[TMP_50]]
|
||||
// CHECK: %[[TMP_52:.*]] = mhlo.divide %[[TMP_48]], %[[TMP_51]]
|
||||
// CHECK: %[[TMP_53:.*]] = mhlo.subtract %[[TMP_45]], %[[TMP_52]]
|
||||
// CHECK: %[[TMP_54:.*]] = mhlo.divide %[[TMP_48]], %[[TMP_50]]
|
||||
// CHECK: %[[TMP_55:.*]] = mhlo.add %[[TMP_47]], %[[TMP_54]]
|
||||
// CHECK: %[[TMP_56:.*]] = mhlo.constant dense<9.98436917E-6>
|
||||
// CHECK: %[[TMP_57:.*]] = mhlo.constant dense<7.000000e+00>
|
||||
// CHECK: %[[TMP_58:.*]] = mhlo.add %[[TMP_5]], %[[TMP_57]]
|
||||
// CHECK: %[[TMP_59:.*]] = mhlo.multiply %[[TMP_58]], %[[TMP_58]]
|
||||
// CHECK: %[[TMP_60:.*]] = mhlo.divide %[[TMP_56]], %[[TMP_59]]
|
||||
// CHECK: %[[TMP_61:.*]] = mhlo.subtract %[[TMP_53]], %[[TMP_60]]
|
||||
// CHECK: %[[TMP_62:.*]] = mhlo.divide %[[TMP_56]], %[[TMP_58]]
|
||||
// CHECK: %[[TMP_63:.*]] = mhlo.add %[[TMP_55]], %[[TMP_62]]
|
||||
// CHECK: %[[TMP_64:.*]] = mhlo.constant dense<1.50563267E-7>
|
||||
// CHECK: %[[TMP_65:.*]] = mhlo.constant dense<8.000000e+00>
|
||||
// CHECK: %[[TMP_66:.*]] = mhlo.add %[[TMP_5]], %[[TMP_65]]
|
||||
// CHECK: %[[TMP_67:.*]] = mhlo.multiply %[[TMP_66]], %[[TMP_66]]
|
||||
// CHECK: %[[TMP_68:.*]] = mhlo.divide %[[TMP_64]], %[[TMP_67]]
|
||||
// CHECK: %[[TMP_69:.*]] = mhlo.subtract %[[TMP_61]], %[[TMP_68]]
|
||||
// CHECK: %[[TMP_70:.*]] = mhlo.divide %[[TMP_64]], %[[TMP_66]]
|
||||
// CHECK: %[[TMP_71:.*]] = mhlo.add %[[TMP_63]], %[[TMP_70]]
|
||||
// CHECK: %[[TMP_72:.*]] = mhlo.constant dense<7.500000e+00>
|
||||
// CHECK: %[[TMP_73:.*]] = mhlo.add %[[TMP_72]], %[[TMP_5]]
|
||||
// CHECK: %[[TMP_74:.*]] = mhlo.constant dense<2.01490307>
|
||||
// CHECK: %[[TMP_75:.*]] = mhlo.divide %[[TMP_5]], %[[TMP_72]]
|
||||
// CHECK: %[[TMP_76:.*]] = "mhlo.log_plus_one"(%[[TMP_75]])
|
||||
// CHECK: %[[TMP_77:.*]] = mhlo.add %[[TMP_74]], %[[TMP_76]]
|
||||
// CHECK: %[[TMP_78:.*]] = mhlo.divide %[[TMP_69]], %[[TMP_71]]
|
||||
// CHECK: %[[TMP_79:.*]] = mhlo.constant dense<7.000000e+00>
|
||||
// CHECK: %[[TMP_80:.*]] = mhlo.divide %[[TMP_79]], %[[TMP_73]]
|
||||
// CHECK: %[[TMP_81:.*]] = mhlo.add %[[TMP_77]], %[[TMP_78]]
|
||||
// CHECK: %[[TMP_82:.*]] = mhlo.subtract %[[TMP_81]], %[[TMP_80]]
|
||||
// CHECK: %[[TMP_83:.*]] = mhlo.constant dense<5.000000e-01>
|
||||
// CHECK: %[[TMP_84:.*]] = mhlo.add %arg0, %[[TMP_83]]
|
||||
// CHECK: %[[TMP_85:.*]] = "mhlo.floor"(%[[TMP_84]])
|
||||
// CHECK: %[[TMP_86:.*]] = "mhlo.abs"(%[[TMP_85]])
|
||||
// CHECK: %[[TMP_87:.*]] = mhlo.add %arg0, %[[TMP_86]]
|
||||
// CHECK: %[[TMP_88:.*]] = mhlo.constant dense<3.14159274>
|
||||
// CHECK: %[[TMP_89:.*]] = mhlo.multiply %[[TMP_88]], %[[TMP_87]]
|
||||
// CHECK: %[[TMP_90:.*]] = "mhlo.cosine"(%[[TMP_89]])
|
||||
// CHECK: %[[TMP_92:.*]] = "mhlo.sine"(%[[TMP_89]])
|
||||
// CHECK: %[[TMP_91:.*]] = mhlo.multiply %[[TMP_88]], %[[TMP_90]]
|
||||
// CHECK: %[[TMP_93:.*]] = mhlo.divide %[[TMP_91]], %[[TMP_92]]
|
||||
// CHECK: %[[TMP_94:.*]] = mhlo.subtract %[[TMP_82]], %[[TMP_93]]
|
||||
// CHECK: %[[TMP_95:.*]] = "mhlo.select"(%[[TMP_1]], %[[TMP_94]], %[[TMP_82]])
|
||||
// CHECK: %[[TMP_96:.*]] = "mhlo.compare"(%arg0, %[[TMP_6]]) {comparison_direction = "LE"}
|
||||
// CHECK: %[[TMP_97:.*]] = "mhlo.floor"(%arg0)
|
||||
// CHECK: %[[TMP_98:.*]] = "mhlo.compare"(%arg0, %[[TMP_97]]) {comparison_direction = "EQ"}
|
||||
// CHECK: %[[TMP_99:.*]] = mhlo.and %[[TMP_96]], %[[TMP_98]]
|
||||
// CHECK: %[[TMP_100:.*]] = mhlo.constant dense<0x7FC00000>
|
||||
// CHECK: %[[RES:.*]] = "mhlo.select"(%[[TMP_99]], %[[TMP_100]], %[[TMP_95]])
|
||||
// CHECK: return %[[RES]]
|
||||
%1 = chlo.digamma %arg : tensor<f32> -> tensor<f32>
|
||||
return %1 : tensor<f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @digamma_f16
|
||||
// CHECK-SAME: (%[[ARG:.*]]: tensor<f16>)
|
||||
func @digamma_f16(%arg : tensor<f16>) -> tensor<f16> {
|
||||
// CHECK: "mhlo.convert"(%[[ARG]]) : (tensor<f16>) -> tensor<f32>
|
||||
// CHECK: %[[RES:.*]] = "mhlo.convert"(%{{.*}}) : (tensor<f32>) -> tensor<f16>
|
||||
// CHECK: return %[[RES]]
|
||||
%1 = chlo.digamma %arg : tensor<f16> -> tensor<f16>
|
||||
return %1 : tensor<f16>
|
||||
}
|
||||
|
@ -163,7 +163,7 @@ func @dyn_broadcast(%operand: tensor<?x?xf32>) -> tensor<?x?x?xf32> {
|
||||
// CHECK: %[[EXPAND_2:.*]] = cmpi slt, %[[OPER_DIM_1]], %[[SIZE_2]] : index
|
||||
// CHECK: %[[STRIDE_2:.*]] = select %[[EXPAND_2]], %[[C0]], %[[C1]] : index
|
||||
|
||||
// CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref_reinterpret_cast %[[OPERAND]] to offset: [0], sizes: {{\[}}%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]], strides: {{\[}}%[[C0]], %[[STRIDE_1]], %[[STRIDE_2]]]: memref<?x?xf32> to memref<?x?x?xf32, #map>
|
||||
// CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref_reinterpret_cast %[[OPERAND]] to offset: [0], sizes: {{\[}}%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]], strides: {{\[}}%[[C0]], %[[STRIDE_1]], %[[STRIDE_2]]] : memref<?x?xf32> to memref<?x?x?xf32, #map>
|
||||
|
||||
// CHECK: %[[RESULT:.*]] = alloc(%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]) : memref<?x?x?xf32>
|
||||
|
||||
|
@ -851,12 +851,19 @@ func @integer_pow(%lhs: tensor<2x2xi32>,
|
||||
// CHECK: ^{{[a-z0-9_]*}}
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: i32
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: i32
|
||||
// CHECK: %[[UPPER:.*]] = index_cast %[[ARG1]]
|
||||
// CHECK: %[[FOR_RESULT:.*]] = scf.for {{.*}} to %[[UPPER]]
|
||||
// CHECK-SAME: step %c1{{[a-zA-Z0-9_]*}}
|
||||
// CHECK-SAME: iter_args(%[[ITER:.*]] = %c1{{.*}}) -> (i32) {
|
||||
// CHECK: %[[ACCUM:[a-zA-Z0-9_]*]] = muli %[[ARG0]], %[[ITER]]
|
||||
// CHECK: scf.yield %[[ACCUM]]
|
||||
// CHECK: %[[FOR_RESULT:[a-zA-Z0-9_]*]]:3 = scf.for {{.*}} to %c6 step %c1
|
||||
// CHECK-SAME: iter_args(
|
||||
// CHECK-SAME: %[[ITER0:.*]] = %c1
|
||||
// CHECK-SAME: %[[ITER1:.*]] = %[[ARG0]]
|
||||
// CHECK-SAME: %[[ITER2:.*]] = %[[ARG1]]
|
||||
// CHECK-SAME: ) -> (i32, i32, i32) {
|
||||
// CHECK: %[[AND:[a-zA-Z0-9_]*]] = and %[[ITER2]], %c1
|
||||
// CHECK: %[[COND:[a-zA-Z0-9_]*]] = cmpi eq, %[[AND]], %c1
|
||||
// CHECK: %[[MUL:[a-zA-Z0-9_]*]] = muli %[[ITER0]], %[[ITER1]]
|
||||
// CHECK: %[[ACCUM:[a-zA-Z0-9_]*]] = select %[[COND]], %[[MUL]], %[[ITER0]]
|
||||
// CHECK: %[[BASE:[a-zA-Z0-9_]*]] = muli %[[ITER1]], %[[ITER1]]
|
||||
// CHECK: %[[EXP:[a-zA-Z0-9_]*]] = shift_right_unsigned %[[ITER2]], %c1
|
||||
// CHECK: scf.yield %[[ACCUM]], %[[BASE]], %[[EXP]]
|
||||
// CHECK: %[[RHS_PARITY:.*]] = remi_signed %[[ARG1]], %c2
|
||||
// CHECK: %[[RHS_EVEN:.*]] = cmpi eq, %[[RHS_PARITY]], %c0
|
||||
// CHECK: %[[RHS_NEG:.*]] = cmpi slt, %[[ARG1]], %c0
|
||||
@ -865,7 +872,7 @@ func @integer_pow(%lhs: tensor<2x2xi32>,
|
||||
// CHECK: %[[VAL5:.*]] = select %[[LHS_ONE]], %c1_i32, %c0
|
||||
// CHECK: %[[VAL6:.*]] = select %[[RHS_EVEN]], %c1{{.*}}, %c-1
|
||||
// CHECK: %[[VAL7:.*]] = select %[[LHS_NEG_ONE]], %[[VAL6]], %[[VAL5]]
|
||||
// CHECK: %[[RESULT:.*]] = select %[[RHS_NEG]], %[[VAL7]], %[[FOR_RESULT]]
|
||||
// CHECK: %[[RESULT:.*]] = select %[[RHS_NEG]], %[[VAL7]], %[[FOR_RESULT]]#0
|
||||
// CHECK: linalg.yield %[[RESULT]]
|
||||
%0 = "mhlo.power"(%lhs, %rhs) : (tensor<2x2xi32>,
|
||||
tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||
|
@ -1526,5 +1526,57 @@ void PopulateLoweringTFPatterns(MLIRContext *context,
|
||||
populateWithGenerated(context, *patterns);
|
||||
}
|
||||
|
||||
void PopulateTFLoweringBeforeHLOPatterns(MLIRContext *context,
|
||||
OwningRewritePatternList *patterns) {
|
||||
// clang-format off
|
||||
patterns->insert<
|
||||
ConvertFakeQuantWithMinMaxVarsOp,
|
||||
LowerAddNOp,
|
||||
LowerBatchToSpaceND,
|
||||
LowerDynamicStitchOp<DynamicStitchOp>,
|
||||
LowerDynamicStitchOp<ParallelDynamicStitchOp>,
|
||||
LowerInvertPermutationOp,
|
||||
LowerPackOp,
|
||||
LowerResizeNearestNeighbor,
|
||||
LowerSpaceToBatchNDOp,
|
||||
LowerSparseMatMulOp,
|
||||
Lower_UnaryOpsComposition>(context);
|
||||
// clang-format on
|
||||
|
||||
// Populate the relevant generated patterns.
|
||||
// clang-format off
|
||||
patterns->insert<
|
||||
LowerBiasAddGradOp,
|
||||
LowerDivNoNanOp,
|
||||
LowerEmptyOp,
|
||||
LowerExpm1Op,
|
||||
LowerFakeQuantWithMinMaxArgs,
|
||||
LowerFillOp,
|
||||
LowerIsInfOp,
|
||||
LowerIsNanOp,
|
||||
LowerL2LossOp,
|
||||
LowerMulNoNanOp,
|
||||
LowerOnesLikeOp,
|
||||
LowerPadOp,
|
||||
LowerReciprocal,
|
||||
LowerRintOp,
|
||||
LowerRoundOpOnFloatTensor,
|
||||
LowerRoundOpOnIntTensor,
|
||||
LowerRsqrtGradOp,
|
||||
LowerScatterNdOp,
|
||||
LowerSizeOp,
|
||||
LowerSoftmaxCrossEntropyWithLogitsOp,
|
||||
LowerSparseSoftmaxCrossEntropyWithLogitsOp,
|
||||
LowerSquareOp,
|
||||
LowerSquaredDifferenceOpOnRealTensors,
|
||||
LowerSquaredDifferenceOpOneComplexTensors,
|
||||
LowerTanhGradOp,
|
||||
LowerXdivyOp,
|
||||
LowerXlog1pyOp,
|
||||
LowerXlogyOp,
|
||||
LowerZerosLikeOp>(context);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
||||
|
@ -27,6 +27,14 @@ namespace TF {
|
||||
void PopulateLoweringTFPatterns(MLIRContext *context,
|
||||
OwningRewritePatternList *patterns);
|
||||
|
||||
// Populates TensorFlow lowering patterns to lower some of the TensorFlow
|
||||
// operations that can be represented by means of other TensorFlow operations.
|
||||
// This pattern collection preserves those TensorFlow operations that will later
|
||||
// be lowered to equivalent operations in CHLO or MHLO. This allows for
|
||||
// HLO-specific lowerings.
|
||||
void PopulateTFLoweringBeforeHLOPatterns(MLIRContext *context,
|
||||
OwningRewritePatternList *patterns);
|
||||
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
||||
|
||||
|
@ -134,16 +134,18 @@ def LowerSparseSoftmaxCrossEntropyWithLogitsOp : Pattern<
|
||||
// Difference op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def ComplexTensor : TensorOf<[AnyComplex]>;
|
||||
def RealTensor : TensorOf<[AnySignlessInteger, AnyFloat]>;
|
||||
def ComplexTensor : TensorOf<[AnyComplex]>;
|
||||
def RealTensor : TensorOf<[AnySignlessInteger, AnyFloat]>;
|
||||
|
||||
def : Pat<(TF_SquareOp $val), (TF_MulOp $val, $val)>;
|
||||
def LowerSquareOp : Pat<(TF_SquareOp $val), (TF_MulOp $val, $val)>;
|
||||
|
||||
def : Pat<(TF_SquaredDifferenceOp RealTensor: $lhs, RealTensor:$rhs),
|
||||
(TF_SquareOp (TF_SubOp $lhs, $rhs))>;
|
||||
def LowerSquaredDifferenceOpOnRealTensors : Pat<
|
||||
(TF_SquaredDifferenceOp RealTensor: $lhs, RealTensor:$rhs),
|
||||
(TF_SquareOp (TF_SubOp $lhs, $rhs))>;
|
||||
|
||||
def : Pat<(TF_SquaredDifferenceOp ComplexTensor: $lhs, ComplexTensor:$rhs),
|
||||
(TF_MulOp (TF_SubOp:$diff $lhs, $rhs), (TF_ConjOp $diff))>;
|
||||
def LowerSquaredDifferenceOpOneComplexTensors : Pat<
|
||||
(TF_SquaredDifferenceOp ComplexTensor: $lhs, ComplexTensor:$rhs),
|
||||
(TF_MulOp (TF_SubOp:$diff $lhs, $rhs), (TF_ConjOp $diff))>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DivNoNan and MulNonNan op patterns.
|
||||
@ -156,9 +158,9 @@ class BinaryNoNanPat<Op FromOp, Op ToOp>
|
||||
/*incompatible_shape_error*/ConstBoolAttrTrue),
|
||||
$zero, (ToOp $l, $r))>;
|
||||
|
||||
foreach fromToBinPair = [[TF_DivNoNanOp, TF_DivOp],
|
||||
[TF_MulNoNanOp, TF_MulOp]] in
|
||||
def : BinaryNoNanPat<fromToBinPair[0], fromToBinPair[1]>;
|
||||
def LowerDivNoNanOp : BinaryNoNanPat<TF_DivNoNanOp, TF_DivOp>;
|
||||
|
||||
def LowerMulNoNanOp : BinaryNoNanPat<TF_MulNoNanOp, TF_MulOp>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Expm1 op patterns.
|
||||
@ -226,9 +228,13 @@ def LowerL2LossOp :
|
||||
// Pad op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def : Pat<(TF_PadOp TensorOf<[AnySignlessInteger, AnyFloat]>:$input, $paddings),
|
||||
(TF_PadV2Op $input, $paddings,
|
||||
(TF_ConstOp (GetScalarOfType<0> $input)))>;
|
||||
def LowerPadOp : Pat<
|
||||
(TF_PadOp TensorOf<[AnySignlessInteger, AnyFloat]>:$input, $paddings),
|
||||
(TF_PadV2Op $input, $paddings,
|
||||
(TF_ConstOp
|
||||
(GetScalarOfType<0> $input)
|
||||
)
|
||||
)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Reciprocal op patterns.
|
||||
@ -243,45 +249,76 @@ def LowerReciprocal : Pat<(TF_ReciprocalOp $x),
|
||||
|
||||
// Rint is specified as RoundHalfToEven, which happens to be the same behavior
|
||||
// as TF_RoundOp, so lower to TF_RoundOp.
|
||||
def : Pat<(TF_RintOp:$res TF_FloatTensor:$input), (TF_RoundOp $input)>;
|
||||
def LowerRintOp : Pat<(TF_RintOp:$res TF_FloatTensor:$input), (TF_RoundOp $input)>;
|
||||
|
||||
// Rounds on integers should just be bypassed.
|
||||
def : Pat<(TF_RoundOp:$res TF_IntTensor:$input), (TF_IdentityOp $input)>;
|
||||
def LowerRoundOpOnIntTensor : Pat<
|
||||
(TF_RoundOp:$res TF_IntTensor:$input),
|
||||
(TF_IdentityOp $input)>;
|
||||
|
||||
// Implements TF Round on floats using basic operations. TF Round is specified
|
||||
// as RoundHalfToEven to be compatible with Numpy.
|
||||
def : Pat<(TF_RoundOp:$res TF_FloatTensor:$input),
|
||||
(TF_SelectOp
|
||||
(TF_LessOp
|
||||
(TF_SubOp $input, (TF_FloorOp:$floor $input)),
|
||||
(TF_ConstOp (GetScalarOfFloatType<"0.5"> $input))),
|
||||
$floor,
|
||||
(TF_AddV2Op
|
||||
(TF_ConstOp (GetScalarOfType<1> $input)), $floor))>;
|
||||
|
||||
def LowerRoundOpOnFloatTensor : Pat<
|
||||
(TF_RoundOp:$res TF_FloatTensor:$input),
|
||||
(TF_SelectOp
|
||||
(TF_LessOp
|
||||
(TF_SubOp
|
||||
$input,
|
||||
(TF_FloorOp:$floor $input)
|
||||
),
|
||||
(TF_ConstOp
|
||||
(GetScalarOfFloatType<"0.5"> $input)
|
||||
)
|
||||
),
|
||||
$floor,
|
||||
(TF_AddV2Op
|
||||
(TF_ConstOp
|
||||
(GetScalarOfType<1> $input)
|
||||
),
|
||||
$floor
|
||||
)
|
||||
)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Rsqrt op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// RsqrtGrad(lhs, rhs) = (lhs * lhs * lhs) * (rhs / -2)
|
||||
def : Pat<(TF_RsqrtGradOp $lhs, $rhs),
|
||||
(TF_MulOp (TF_MulOp (TF_MulOp $lhs, $lhs), $lhs),
|
||||
(TF_DivOp $rhs,
|
||||
(TF_ConstOp (GetScalarOfType<-2> $rhs))))>;
|
||||
|
||||
def LowerRsqrtGradOp : Pat<
|
||||
(TF_RsqrtGradOp $lhs, $rhs),
|
||||
(TF_MulOp
|
||||
(TF_MulOp
|
||||
(TF_MulOp $lhs, $lhs),
|
||||
$lhs
|
||||
),
|
||||
(TF_DivOp
|
||||
$rhs,
|
||||
(TF_ConstOp
|
||||
(GetScalarOfType<-2> $rhs)
|
||||
)
|
||||
)
|
||||
)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Size op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Size(x) = Prod(Shape(x), reduction_indices=0, keep_dims=false)
|
||||
def : Pat<(TF_SizeOp:$res $arg),
|
||||
(TF_ProdOp
|
||||
(CreateTFShapeOp $res, $arg, (IsI32 $res)),
|
||||
/*reduction_indices=*/(TF_ConstOp (GetScalarOfType<0> $res)),
|
||||
/*keep_dims=*/ConstBoolAttrFalse)>;
|
||||
|
||||
def LowerSizeOp : Pat<
|
||||
(TF_SizeOp:$res $arg),
|
||||
(TF_ProdOp
|
||||
(CreateTFShapeOp
|
||||
$res,
|
||||
$arg,
|
||||
(IsI32 $res)
|
||||
),
|
||||
/*reduction_indices=*/
|
||||
(TF_ConstOp
|
||||
(GetScalarOfType<0> $res)
|
||||
),
|
||||
/*keep_dims=*/
|
||||
ConstBoolAttrFalse
|
||||
)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TanhGrad op patterns.
|
||||
@ -331,14 +368,28 @@ def LowerScatterNdOp :
|
||||
// Xdivy, Xlog1p and Xlogy op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class BinaryXopyPat<dag From, dag To>
|
||||
: Pat<From,
|
||||
(TF_SelectV2Op (TF_EqualOp $x,
|
||||
(TF_ConstOp:$zero (GetScalarOfType<0> $x)),
|
||||
/*incompatible_shape_error*/ConstBoolAttrTrue),
|
||||
$zero, To)>;
|
||||
class BinaryXopyPat<dag From, dag To> : Pat<
|
||||
From,
|
||||
(TF_SelectV2Op
|
||||
(TF_EqualOp
|
||||
$x,
|
||||
(TF_ConstOp:$zero
|
||||
(GetScalarOfType<0> $x)
|
||||
),
|
||||
/*incompatible_shape_error*/ConstBoolAttrTrue
|
||||
),
|
||||
$zero,
|
||||
To
|
||||
)>;
|
||||
|
||||
foreach fromToPair = [[(TF_XdivyOp $x, $y), (TF_DivOp $x, $y)],
|
||||
[(TF_Xlog1pyOp $x, $y), (TF_MulOp $x, (TF_Log1pOp $y))],
|
||||
[(TF_XlogyOp $x, $y), (TF_MulOp $x, (TF_LogOp $y))]] in
|
||||
def : BinaryXopyPat<fromToPair[0], fromToPair[1]>;
|
||||
def LowerXdivyOp : BinaryXopyPat<
|
||||
(TF_XdivyOp $x, $y),
|
||||
(TF_DivOp $x, $y)>;
|
||||
|
||||
def LowerXlog1pyOp : BinaryXopyPat<
|
||||
(TF_Xlog1pyOp $x, $y),
|
||||
(TF_MulOp $x, (TF_Log1pOp $y))>;
|
||||
|
||||
def LowerXlogyOp : BinaryXopyPat<
|
||||
(TF_XlogyOp $x, $y),
|
||||
(TF_MulOp $x, (TF_LogOp $y))>;
|
||||
|
@ -105,9 +105,29 @@ struct CollapseParallelLoopsTo1D
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
Status LowerTFtoGPU(mlir::ModuleOp module, llvm::ArrayRef<uint32_t> tile_sizes,
|
||||
llvm::ArrayRef<uint32_t> unroll_factors,
|
||||
bool embed_memref_prints) {
|
||||
struct TilingParams {
|
||||
llvm::SmallVector<int64_t, 4> outer_tile;
|
||||
llvm::SmallVector<int64_t, 4> inner_tile;
|
||||
};
|
||||
|
||||
// We have to anticipate later unrolling in tiling to make sure that we get
|
||||
// the requested tiling after unrolling. Compute the new tiling here if
|
||||
// needed.
|
||||
TilingParams ComputeTilingParas(llvm::ArrayRef<uint32_t> tile_sizes,
|
||||
llvm::ArrayRef<uint32_t> unroll_factors) {
|
||||
TilingParams params;
|
||||
params.outer_tile.reserve(tile_sizes.size());
|
||||
for (auto pair : llvm::zip(tile_sizes, unroll_factors)) {
|
||||
params.outer_tile.push_back(std::get<0>(pair) * std::get<1>(pair));
|
||||
params.inner_tile.push_back(std::get<1>(pair));
|
||||
}
|
||||
params.outer_tile.append(tile_sizes.drop_front(unroll_factors.size()).begin(),
|
||||
tile_sizes.end());
|
||||
return params;
|
||||
}
|
||||
|
||||
Status LowerTFtoLoops(mlir::ModuleOp module,
|
||||
const TilingParams& tiling_params) {
|
||||
mlir::PassManager pm(module.getContext());
|
||||
applyTensorflowAndCLOptions(pm);
|
||||
|
||||
@ -154,29 +174,31 @@ Status LowerTFtoGPU(mlir::ModuleOp module, llvm::ArrayRef<uint32_t> tile_sizes,
|
||||
|
||||
// Collapse and tile parallel loops.
|
||||
pm.addNestedPass<mlir::FuncOp>(std::make_unique<CollapseParallelLoopsTo1D>());
|
||||
// We have to anticipate later unrolling in tiling to make sure that we get
|
||||
// the requested tiling after unrolling. Compute the new tiling here if
|
||||
// needed.
|
||||
llvm::SmallVector<int64_t, 4> tiling_for_unrolling, inner_tile;
|
||||
tiling_for_unrolling.reserve(tile_sizes.size());
|
||||
for (auto pair : llvm::zip(tile_sizes, unroll_factors)) {
|
||||
tiling_for_unrolling.push_back(std::get<0>(pair) * std::get<1>(pair));
|
||||
inner_tile.push_back(std::get<1>(pair));
|
||||
}
|
||||
tiling_for_unrolling.append(
|
||||
tile_sizes.drop_front(unroll_factors.size()).begin(), tile_sizes.end());
|
||||
pm.addNestedPass<mlir::FuncOp>(
|
||||
::mlir::createParallelLoopTilingPass(tiling_for_unrolling));
|
||||
if (!unroll_factors.empty()) {
|
||||
::mlir::createParallelLoopTilingPass(tiling_params.outer_tile));
|
||||
if (!tiling_params.inner_tile.empty()) {
|
||||
pm.addNestedPass<mlir::FuncOp>(
|
||||
::mlir::createParallelLoopTilingPass(inner_tile));
|
||||
::mlir::createParallelLoopTilingPass(tiling_params.inner_tile));
|
||||
}
|
||||
pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass());
|
||||
pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass());
|
||||
if (failed(pm.run(module))) {
|
||||
return InternalError("Lowering TF to loops failed.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Greedily map the remaining loop to GPU hardware dimensions.
|
||||
pm.addNestedPass<::mlir::FuncOp>(
|
||||
mlir::kernel_gen::transforms::CreateMapParallelLoopsPass());
|
||||
Status LowerLoopsToGPUorCPU(mlir::ModuleOp module,
|
||||
const TilingParams& tiling_params,
|
||||
bool embed_memref_prints, bool cpu_codegen) {
|
||||
mlir::PassManager pm(module.getContext());
|
||||
applyTensorflowAndCLOptions(pm);
|
||||
|
||||
if (!cpu_codegen) {
|
||||
// Greedily map the remaining loop to GPU hardware dimensions.
|
||||
pm.addNestedPass<::mlir::FuncOp>(
|
||||
mlir::kernel_gen::transforms::CreateMapParallelLoopsPass());
|
||||
}
|
||||
|
||||
// Now lower the shape computations, bufferize all remaining ops and insert
|
||||
// deallocs.
|
||||
@ -213,21 +235,25 @@ Status LowerTFtoGPU(mlir::ModuleOp module, llvm::ArrayRef<uint32_t> tile_sizes,
|
||||
// Apply the mapping and go to GPU. We cannot do this earlier due to missing
|
||||
// interfaces on the GPU dialect.
|
||||
// TODO(b/174830459): Move up once implemented.
|
||||
pm.addNestedPass<::mlir::FuncOp>(mlir::createParallelLoopToGpuPass());
|
||||
if (!cpu_codegen) {
|
||||
pm.addNestedPass<::mlir::FuncOp>(mlir::createParallelLoopToGpuPass());
|
||||
}
|
||||
|
||||
// Some basic cleanup.
|
||||
pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass());
|
||||
pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass());
|
||||
// Make loops with min bounds into a conditional plus static bounds.
|
||||
// Only do this if we unrolled in the first place.
|
||||
if (!unroll_factors.empty()) {
|
||||
if (!tiling_params.inner_tile.empty()) {
|
||||
pm.addNestedPass<::mlir::FuncOp>(mlir::createForLoopSpecializationPass());
|
||||
}
|
||||
// Approximate Tanh using standard operations.
|
||||
pm.addNestedPass<::mlir::FuncOp>(
|
||||
::mlir::mhlo::createLegalizeTrigonometricToApproximationPass());
|
||||
// Take launches to launches with kernels.
|
||||
pm.addPass(::mlir::createGpuKernelOutliningPass());
|
||||
if (!cpu_codegen) {
|
||||
pm.addPass(::mlir::createGpuKernelOutliningPass());
|
||||
}
|
||||
|
||||
pm.addPass(::mlir::createLowerAffinePass());
|
||||
// Constraints are removed as late as possible and before lowering to CFG.
|
||||
@ -244,6 +270,10 @@ Status LowerTFtoGPU(mlir::ModuleOp module, llvm::ArrayRef<uint32_t> tile_sizes,
|
||||
if (failed(pm.run(module))) {
|
||||
return InternalError("Lowering to GPU kernels failed.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status LowerKernelBodiesToLowLevelIr(mlir::ModuleOp module) {
|
||||
auto gpu_modules = module.getOps<mlir::gpu::GPUModuleOp>();
|
||||
auto num_modules = std::distance(gpu_modules.begin(), gpu_modules.end());
|
||||
if (num_modules != 1) {
|
||||
@ -252,10 +282,6 @@ Status LowerTFtoGPU(mlir::ModuleOp module, llvm::ArrayRef<uint32_t> tile_sizes,
|
||||
<< ". Currently we leak memory if there is more than one "
|
||||
"module, see https://bugs.llvm.org/show_bug.cgi?id=48385";
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status LowerKernelBodiesToLowLevelIr(mlir::ModuleOp module) {
|
||||
#if !defined(TENSORFLOW_USE_ROCM) && !defined(GOOGLE_CUDA)
|
||||
return InternalError(
|
||||
"Neither TENSORFLOW_USE_ROCM nor GOOGLE_CUDA are defined."
|
||||
@ -337,18 +363,23 @@ StatusOr<mlir::OwningModuleRef> GenerateKernelForTfCode(
|
||||
llvm::ArrayRef<std::string> architectures,
|
||||
llvm::ArrayRef<uint32_t> tile_sizes,
|
||||
llvm::ArrayRef<uint32_t> unroll_factors, bool embed_memref_prints,
|
||||
bool generate_fatbin, bool print_ptx, bool enable_ftz) {
|
||||
bool generate_fatbin, bool print_ptx, bool enable_ftz, bool cpu_codegen) {
|
||||
auto& registry = context.getDialectRegistry();
|
||||
mlir::RegisterAllTensorFlowDialects(registry);
|
||||
registry.insert<mlir::chlo::HloClientDialect, mlir::mhlo::MhloDialect>();
|
||||
mlir::OwningModuleRef module = mlir::parseSourceString(tf_code, &context);
|
||||
TF_RETURN_IF_ERROR(LowerTFtoGPU(module.get(), tile_sizes, unroll_factors,
|
||||
embed_memref_prints));
|
||||
TF_RETURN_IF_ERROR(LowerKernelBodiesToLowLevelIr(module.get()));
|
||||
TF_RETURN_IF_ERROR(AmendKernelLLVMIRWithStaticKnowledge(module.get()));
|
||||
TF_RETURN_IF_ERROR(GenerateDeviceCode(module.get(), kGpuBinaryAttrName,
|
||||
architectures, generate_fatbin,
|
||||
print_ptx, enable_ftz));
|
||||
|
||||
TilingParams tiling_params = ComputeTilingParas(tile_sizes, unroll_factors);
|
||||
TF_RETURN_IF_ERROR(LowerTFtoLoops(module.get(), tiling_params));
|
||||
TF_RETURN_IF_ERROR(LowerLoopsToGPUorCPU(module.get(), tiling_params,
|
||||
embed_memref_prints, cpu_codegen));
|
||||
if (!cpu_codegen) {
|
||||
TF_RETURN_IF_ERROR(LowerKernelBodiesToLowLevelIr(module.get()));
|
||||
TF_RETURN_IF_ERROR(AmendKernelLLVMIRWithStaticKnowledge(module.get()));
|
||||
TF_RETURN_IF_ERROR(GenerateDeviceCode(module.get(), kGpuBinaryAttrName,
|
||||
architectures, generate_fatbin,
|
||||
print_ptx, enable_ftz));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(LowerHostSideToFinalForm(module.get()));
|
||||
return module;
|
||||
}
|
||||
|
@ -33,14 +33,14 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace kernel_gen {
|
||||
|
||||
// Converts TF code to LLVM/NVVM. Lowers the host side to LLVM Dialect.
|
||||
// Converts TF code to LLVM with or without GPU support.
|
||||
xla::StatusOr<mlir::OwningModuleRef> GenerateKernelForTfCode(
|
||||
mlir::MLIRContext& context, llvm::StringRef tf_code,
|
||||
llvm::ArrayRef<std::string> architectures = {"sm_75"},
|
||||
llvm::ArrayRef<uint32_t> tile_sizes = {16, 64},
|
||||
llvm::ArrayRef<uint32_t> unroll_factors = {},
|
||||
bool embed_memref_prints = false, bool generate_fatbin = true,
|
||||
bool print_ptx = false, bool enable_ftz = false);
|
||||
bool print_ptx = false, bool enable_ftz = false, bool cpu_codegen = false);
|
||||
|
||||
} // namespace kernel_gen
|
||||
} // namespace tensorflow
|
||||
|
@ -106,7 +106,8 @@ xla::Status Run(llvm::StringRef input_file, llvm::StringRef output_file,
|
||||
llvm::ArrayRef<std::string> architectures,
|
||||
llvm::ArrayRef<uint32_t> tile_sizes,
|
||||
llvm::ArrayRef<uint32_t> unroll_factors,
|
||||
bool embed_memref_prints, bool print_ptx, bool enable_ftz) {
|
||||
bool embed_memref_prints, bool print_ptx, bool enable_ftz,
|
||||
bool cpu_codegen) {
|
||||
// Read TF code.
|
||||
std::string tf_code;
|
||||
TF_RETURN_IF_ERROR(
|
||||
@ -117,7 +118,8 @@ xla::Status Run(llvm::StringRef input_file, llvm::StringRef output_file,
|
||||
mlir::OwningModuleRef module,
|
||||
GenerateKernelForTfCode(context, tf_code, architectures, tile_sizes,
|
||||
unroll_factors, embed_memref_prints,
|
||||
/*generate_fatbin=*/true, print_ptx, enable_ftz));
|
||||
/*generate_fatbin=*/true, print_ptx, enable_ftz,
|
||||
cpu_codegen));
|
||||
// Get binary.
|
||||
TF_ASSIGN_OR_RETURN(std::string binary, EmitToBinary(*module));
|
||||
|
||||
@ -138,18 +140,21 @@ int main(int argc, char** argv) {
|
||||
llvm::cl::opt<std::string> output_file(
|
||||
"output", llvm::cl::desc("output file"), llvm::cl::value_desc("filename"),
|
||||
llvm::cl::init("foo.bin"));
|
||||
llvm::cl::opt<bool> cpu_codegen("cpu_codegen",
|
||||
llvm::cl::desc("enable CPU code generation"),
|
||||
llvm::cl::init(false));
|
||||
llvm::cl::opt<bool> embed_memref_prints(
|
||||
"embed_memref_prints",
|
||||
llvm::cl::desc("embeds memref prints at the end of their lifetime"),
|
||||
llvm::cl::desc("embed memref prints at the end of their lifetime"),
|
||||
llvm::cl::init(false));
|
||||
llvm::cl::opt<bool> print_ptx(
|
||||
"print-ptx",
|
||||
llvm::cl::desc("Print generated PTX code per target architecture."),
|
||||
llvm::cl::desc("print generated PTX code per target architecture."),
|
||||
llvm::cl::init(false));
|
||||
llvm::cl::opt<bool> enable_ftz(
|
||||
"enable_ftz",
|
||||
llvm::cl::desc(
|
||||
"Enable the denormal flush to zero mode when generating code."),
|
||||
"enable the denormal flush to zero mode when generating code."),
|
||||
llvm::cl::init(false));
|
||||
llvm::cl::list<std::string> architectures(
|
||||
"arch", llvm::cl::desc("target architectures (e.g. sm_70 or compute_75)"),
|
||||
@ -166,11 +171,11 @@ int main(int argc, char** argv) {
|
||||
llvm::InitializeNativeTarget();
|
||||
llvm::InitializeNativeTargetAsmPrinter();
|
||||
mlir::registerPassManagerCLOptions();
|
||||
llvm::cl::ParseCommandLineOptions(argc, argv, "TF op GPU kernel generator\n");
|
||||
llvm::cl::ParseCommandLineOptions(argc, argv, "TF op kernel generator\n");
|
||||
|
||||
auto status = tensorflow::kernel_gen::Run(
|
||||
input_file, output_file, architectures, tile_sizes, unroll_factors,
|
||||
embed_memref_prints, print_ptx, enable_ftz);
|
||||
embed_memref_prints, print_ptx, enable_ftz, cpu_codegen);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << status;
|
||||
return 1;
|
||||
|
@ -49,7 +49,9 @@ class GpuKernelToNVVMPass
|
||||
GPUModuleOp m = getOperation();
|
||||
|
||||
OwningRewritePatternList patterns;
|
||||
LLVMTypeConverter converter(m.getContext());
|
||||
mlir::LowerToLLVMOptions llvm_opts;
|
||||
llvm_opts.indexBitwidth = 32;
|
||||
LLVMTypeConverter converter(m.getContext(), llvm_opts);
|
||||
populateStdToLLVMConversionPatterns(converter, patterns);
|
||||
populateGpuToNVVMConversionPatterns(converter, patterns);
|
||||
populateComplexToLLVMConversionPatterns(converter, patterns);
|
||||
|
@ -1075,7 +1075,7 @@ func @checkNumerics(%arg0: tensor<1xf32>) -> tensor<1xf32> {
|
||||
// CHECK-LABEL: func @infeed_dequeue_tuple
|
||||
func @infeed_dequeue_tuple() -> (tensor<3xi32>, tensor<4xf32>) {
|
||||
// CHECK: [[TOKEN:%.*]] = "mhlo.create_token"() : () -> !mhlo.token
|
||||
// CHECK: [[INFEED:%.*]] = "mhlo.infeed"([[TOKEN]]) {infeed_config = "", layout = [{{\[\[0], \[0]]}}, unit]} : (!mhlo.token) -> tuple<tuple<tensor<3xi32>, tensor<4xf32>>, !mhlo.token>
|
||||
// CHECK: [[INFEED:%.*]] = "mhlo.infeed"([[TOKEN]]) {infeed_config = ""} : (!mhlo.token) -> tuple<tuple<tensor<3xi32>, tensor<4xf32>>, !mhlo.token>
|
||||
// CHECK: [[INFEED_VAL:%.*]] = "mhlo.get_tuple_element"([[INFEED]]) {index = 0 : i32} : (tuple<tuple<tensor<3xi32>, tensor<4xf32>>, !mhlo.token>) -> tuple<tensor<3xi32>, tensor<4xf32>>
|
||||
// CHECK: [[RES_1:%.*]] = "mhlo.get_tuple_element"([[INFEED_VAL]]) {index = 0 : i32} : (tuple<tensor<3xi32>, tensor<4xf32>>) -> tensor<3xi32>
|
||||
// CHECK: [[RES_2:%.*]] = "mhlo.get_tuple_element"([[INFEED_VAL]]) {index = 1 : i32} : (tuple<tensor<3xi32>, tensor<4xf32>>) -> tensor<4xf32>
|
||||
@ -5149,3 +5149,12 @@ func @replica_id() -> tensor<i32> {
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
|
||||
// CHECK: func @angle_c64
|
||||
// CHECK-SAME: ([[ARG0:%.*]]: tensor<complex<f32>>)
|
||||
func @angle_c64(%arg0: tensor<complex<f32>>) -> tensor<f32> {
|
||||
// CHECK: [[IMAG:%.*]] = "mhlo.imag"([[ARG0]])
|
||||
// CHECK: [[REAL:%.*]] = "mhlo.real"([[ARG0]])
|
||||
// CHECK: [[ATAN2:%.*]] = mhlo.atan2 [[IMAG]], [[REAL]]
|
||||
%0 = "tf.Angle"(%arg0): (tensor<complex<f32>>) -> tensor<f32>
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
|
@ -70,7 +70,6 @@ namespace mhlo {
|
||||
namespace {
|
||||
|
||||
constexpr char kShardingAttr[] = "mhlo.sharding";
|
||||
constexpr char kLayoutAttr[] = "layout";
|
||||
|
||||
class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
@ -4609,9 +4608,6 @@ class ConvertInfeedDequeueTupleOp
|
||||
rewriter.create<InfeedOp>(op.getLoc(), data_and_token_type, token,
|
||||
/*infeed_config=*/rewriter.getStringAttr(""));
|
||||
|
||||
data_and_token->setAttr(kLayoutAttr,
|
||||
GetLayout(data_and_token_type, rewriter));
|
||||
|
||||
if (op._XlaSharding().hasValue()) {
|
||||
// _XlaSharding attribute in TF is a serialized string of the OpSharding
|
||||
// proto, so convert to a text form here.
|
||||
@ -6144,7 +6140,7 @@ LogicalResult legalizeTF(
|
||||
PopulateLegalizeTfPatterns(context, &patterns);
|
||||
|
||||
// Add TF->TF lowering patterns.
|
||||
TF::PopulateLoweringTFPatterns(context, &patterns);
|
||||
TF::PopulateTFLoweringBeforeHLOPatterns(context, &patterns);
|
||||
|
||||
// Add TF->HLO legalization patterns via TF2XLA fallback.
|
||||
if (tf2xla_fallback_device_type.hasValue()) {
|
||||
|
@ -595,6 +595,7 @@ foreach Mapping = [
|
||||
[TF_ComplexAbsOp, HLO_AbsOp],
|
||||
[TF_ConjOp, HLOClient_ConjOp],
|
||||
[TF_CosOp, HLO_CosOp],
|
||||
[TF_DigammaOp, HLOClient_DigammaOp],
|
||||
[TF_ExpOp, HLO_ExpOp],
|
||||
[TF_ErfOp, HLOClient_ErfOp],
|
||||
[TF_ErfcOp, HLOClient_ErfcOp],
|
||||
@ -620,6 +621,8 @@ foreach Mapping = [
|
||||
(Mapping[1] $input)>;
|
||||
}
|
||||
|
||||
def : Pat<(TF_AngleOp $x), (HLO_Atan2Op (HLO_ImagOp $x), (HLO_RealOp $x))>;
|
||||
|
||||
// TODO(bixia): Lower Cast with a Complex type source operand or with
|
||||
// Truncate=True for floating point value conversions.
|
||||
def : Pat<(TF_CastOp HLO_PredIntOrFpTensor:$arg, ConstBoolAttrFalse),
|
||||
|
@ -184,7 +184,7 @@ class XlaBuilder {
|
||||
|
||||
// Similar to SetOpMetadata, but only set the metadata for the next op.
|
||||
void SetOneShotOpMetadata(OpMetadata metadata) {
|
||||
metadata_ = std::move(metadata);
|
||||
one_shot_metadata_ = std::move(metadata);
|
||||
}
|
||||
|
||||
// Clears the HloMetadata state.
|
||||
|
@ -335,10 +335,15 @@ cc_library(
|
||||
visibility = ["//visibility:private"],
|
||||
deps = [
|
||||
":absl_casters",
|
||||
":jax_jit",
|
||||
":py_client",
|
||||
":types",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@com_google_absl//absl/types:variant",
|
||||
"@pybind11",
|
||||
],
|
||||
|
@ -855,7 +855,7 @@ class CompiledFunction {
|
||||
py::object Call(py::args args, py::kwargs kwargs);
|
||||
|
||||
// This allows `inspect.signature(cpp_jitted_f)` from Python.
|
||||
py::object __signature__() {
|
||||
py::object PythonSignature() {
|
||||
static const auto* inspect = new py::module(py::module::import("inspect"));
|
||||
return inspect->attr("signature")(fun_);
|
||||
}
|
||||
@ -1212,7 +1212,8 @@ void BuildJaxjitSubmodule(pybind11::module& m) {
|
||||
py::class_<CompiledFunction, std::unique_ptr<CompiledFunction>> cfun(
|
||||
jitlib, "CompiledFunction");
|
||||
cfun.def("__call__", &CompiledFunction::Call);
|
||||
cfun.def_property_readonly("__signature__", &CompiledFunction::__signature__);
|
||||
cfun.def_property_readonly("__signature__",
|
||||
&CompiledFunction::PythonSignature);
|
||||
|
||||
jitlib.def("set_disable_jit", &SetDisableJit);
|
||||
jitlib.def("get_disable_jit", &GetDisableJit);
|
||||
|
@ -21,11 +21,17 @@ limitations under the License.
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "absl/synchronization/notification.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "absl/types/variant.h"
|
||||
#include "pybind11/cast.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/pytypes.h"
|
||||
#include "tensorflow/compiler/xla/python/absl_casters.h"
|
||||
#include "tensorflow/compiler/xla/python/jax_jit.h"
|
||||
#include "tensorflow/compiler/xla/python/py_executable.h"
|
||||
#include "tensorflow/compiler/xla/python/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace jax {
|
||||
@ -90,6 +96,246 @@ pybind11::tuple CppMeshMappingToPy(
|
||||
return result;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
struct PmapCacheEntry {
|
||||
// To get a first version running, we use extensively Python here for the
|
||||
// handling of the arguments and outputs.
|
||||
// TODO(jblespiau): Move more to C++.
|
||||
std::shared_ptr<xla::PyExecutable> executable;
|
||||
// See _cpp_pmap in api.py.
|
||||
py::object backend;
|
||||
// A function taking as argument a list of arguments and returns a list of
|
||||
// list of buffers `[num_devices x num_args]`.
|
||||
py::function handle_args;
|
||||
// A function taking as argument the output of `ExecuteOnLocalDevices` and
|
||||
// returning a list of ShardedDeviceArray objects.
|
||||
py::function out_handler;
|
||||
xla::PyTreeDef out_pytree_def;
|
||||
|
||||
// Ensures a single thread performs the compilation for a given executable.
|
||||
//
|
||||
// The first thread (holding the GIL) will create the CacheEntry associated to
|
||||
// a signature and if the object has been insterted already, other threads
|
||||
// will wait for the notification.
|
||||
absl::Notification compilation_complete;
|
||||
absl::optional<xla::Status> compilation_error = absl::nullopt;
|
||||
|
||||
bool fall_back_to_python = false;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
// A `PmapFunction` is associated to a `jax.pmap(f)` and takes care of the
|
||||
// bookkeeping of the different signatures used and the dispatch of calls to
|
||||
// the correct underlying `PyExecutable`. This class is thread-safe.
|
||||
class PmapFunction {
|
||||
public:
|
||||
PmapFunction(py::function fun, py::function cache_miss,
|
||||
py::function get_jax_enable_x64, std::vector<int> static_argnums)
|
||||
: fun_(std::move(fun)),
|
||||
cache_miss_(std::move(cache_miss)),
|
||||
static_argnums_(std::move(static_argnums)),
|
||||
get_jax_enable_x64_(get_jax_enable_x64) {
|
||||
std::sort(static_argnums_.begin(), static_argnums_.end());
|
||||
}
|
||||
|
||||
~PmapFunction() {
|
||||
for (const auto& entry : executables_) {
|
||||
entry.first.DecRef();
|
||||
}
|
||||
}
|
||||
|
||||
// This function will:
|
||||
// (a) flatten the inputs using pytree
|
||||
// (b) get buffer objects from the arguments
|
||||
// (c) call the executable
|
||||
// (d) construct `ShardedDeviceArray` objects from the outputs
|
||||
// (e) reconstruct the `PyTree`.
|
||||
py::object Call(py::args args, py::kwargs kwargs);
|
||||
|
||||
py::object PythonSignature() {
|
||||
static const auto* inspect = new py::module(py::module::import("inspect"));
|
||||
return inspect->attr("signature")(fun_);
|
||||
}
|
||||
|
||||
int cache_size() const { return executables_.size(); }
|
||||
|
||||
private:
|
||||
// Returns nullptr if not present in the cache.
|
||||
PmapCacheEntry* GetCacheEntryIfPresent(const CallSignature& signature);
|
||||
// Should never return nullptr.
|
||||
PmapCacheEntry* AddCacheEntry(const py::args& args, const py::kwargs& kwargs,
|
||||
const CallSignature& signature,
|
||||
py::object out_and_fastpath_data);
|
||||
|
||||
bool always_fallback_to_python_ = false;
|
||||
|
||||
const py::function fun_; // The Python function to pmap.
|
||||
// See JAX _cpp_pmap in api.py for documentation.
|
||||
const py::function cache_miss_;
|
||||
|
||||
// We need to know the static arguments to remove them from the arguments
|
||||
// passed to the underlying PyExecutable. In sorted order.
|
||||
std::vector<int> static_argnums_;
|
||||
// We need a `unique_ptr` here to ensure value pointer stability.
|
||||
absl::flat_hash_map<CallSignature, std::unique_ptr<PmapCacheEntry>>
|
||||
executables_;
|
||||
|
||||
const py::function get_jax_enable_x64_;
|
||||
absl::optional<bool> jax_enable_x64_ = absl::nullopt;
|
||||
|
||||
// A vector of size `num_outputs`, specifying the sharding of each output
|
||||
std::vector<ShardingSpec> sharding_specs_;
|
||||
};
|
||||
|
||||
PmapCacheEntry* PmapFunction::GetCacheEntryIfPresent(
|
||||
const CallSignature& signature) {
|
||||
auto found_iterator = executables_.find(signature);
|
||||
if (found_iterator != executables_.end()) { // Cache hit!
|
||||
if (!found_iterator->second->compilation_complete.HasBeenNotified()) {
|
||||
py::gil_scoped_release gil_release;
|
||||
found_iterator->second->compilation_complete.WaitForNotification();
|
||||
}
|
||||
if (found_iterator->second->compilation_error) {
|
||||
throw std::invalid_argument(
|
||||
found_iterator->second->compilation_error.value().error_message());
|
||||
}
|
||||
return found_iterator->second.get();
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
PmapCacheEntry* PmapFunction::AddCacheEntry(const py::args& args,
|
||||
const py::kwargs& kwargs,
|
||||
const CallSignature& signature,
|
||||
py::object out_and_fastpath_data) {
|
||||
// We need to insert the element.
|
||||
auto result =
|
||||
executables_.emplace(signature, std::make_unique<PmapCacheEntry>());
|
||||
auto it = result.first;
|
||||
PmapCacheEntry* cache_entry = it->second.get();
|
||||
// CallSignatures in the cache own their keyword argument reference.
|
||||
result.first->first.IncRef();
|
||||
|
||||
py::tuple tuple = py::cast<py::tuple>(out_and_fastpath_data);
|
||||
CHECK_EQ(tuple.size(), 2);
|
||||
if (tuple[1].is_none()) {
|
||||
cache_entry->fall_back_to_python = true;
|
||||
cache_entry->compilation_complete.Notify();
|
||||
return cache_entry;
|
||||
}
|
||||
|
||||
py::dict pmap_data = py::cast<py::dict>(tuple[1]);
|
||||
if (py::cast<int>(pmap_data["version"]) != 1) {
|
||||
throw std::runtime_error(absl::StrCat(
|
||||
"The versions of jaxlib and Jax are incompatible (pmap cpp version 1 "
|
||||
"expected, but got ",
|
||||
py::cast<int>(pmap_data["version"]),
|
||||
"Upgrade jaxlib and jax. Provided data was:",
|
||||
py::cast<std::string>(py::str(py::repr(pmap_data)))));
|
||||
}
|
||||
// { "version": 1,
|
||||
// "xla_executable": xla_executable,
|
||||
// "in_handler": in_handler,
|
||||
// "out_handler": out_handler,
|
||||
// "out_pytree_def": out_pytree_def }
|
||||
auto executable =
|
||||
py::cast<std::shared_ptr<xla::PyExecutable>>(pmap_data["xla_executable"]);
|
||||
cache_entry->executable = std::move(executable);
|
||||
cache_entry->handle_args = py::cast<py::function>(pmap_data["in_handler"]);
|
||||
cache_entry->out_handler = py::cast<py::function>(pmap_data["out_handler"]);
|
||||
auto out_tree = py::cast<xla::PyTreeDef>(pmap_data["out_pytree_def"]);
|
||||
cache_entry->out_pytree_def = std::move(out_tree);
|
||||
|
||||
cache_entry->compilation_complete.Notify();
|
||||
return cache_entry;
|
||||
}
|
||||
|
||||
py::object PmapFunction::Call(py::args args, py::kwargs kwargs) {
|
||||
if (always_fallback_to_python_) {
|
||||
return py::cast<py::tuple>(cache_miss_(*args, **kwargs))[0];
|
||||
}
|
||||
// Delayed values are retrieved on the first call to `Call`.
|
||||
if (jax_enable_x64_ == absl::nullopt) {
|
||||
jax_enable_x64_ = py::cast<bool>(get_jax_enable_x64_());
|
||||
}
|
||||
|
||||
ParsedArgumentsAsBuffers arguments;
|
||||
if (!ParseArguments(args, kwargs, static_argnums_, arguments).ok()) {
|
||||
return py::cast<py::tuple>(cache_miss_(*args, **kwargs))[0];
|
||||
}
|
||||
|
||||
// Get dynamic argument signatures.
|
||||
for (py::handle arg : arguments.flat_dynamic_args) {
|
||||
auto signature_or_error = ArgSignatureOfValue(arg, jax_enable_x64_.value());
|
||||
if (!signature_or_error.ok()) {
|
||||
return py::cast<py::tuple>(cache_miss_(*args, **kwargs))[0];
|
||||
}
|
||||
arguments.signature.dynamic_args_signatures.push_back(
|
||||
std::move(signature_or_error).ValueOrDie());
|
||||
}
|
||||
|
||||
// Retrieve/Maybe add the executable to the cache.
|
||||
PmapCacheEntry* cache_entry = GetCacheEntryIfPresent(arguments.signature);
|
||||
if (!cache_entry) {
|
||||
py::object out_and_fastpath_data = cache_miss_(*args, **kwargs);
|
||||
cache_entry = GetCacheEntryIfPresent(arguments.signature);
|
||||
if (!cache_entry) {
|
||||
cache_entry = AddCacheEntry(args, kwargs, arguments.signature,
|
||||
out_and_fastpath_data);
|
||||
}
|
||||
CHECK(cache_entry);
|
||||
if (cache_entry->fall_back_to_python) {
|
||||
return py::cast<py::tuple>(out_and_fastpath_data)[0];
|
||||
}
|
||||
// As we have already computed the results, we can return it.
|
||||
// It's even *required* e.g. if there are donated arguments, because
|
||||
// otherwise the buffer which has been donated already will be invalid.
|
||||
return py::cast<py::tuple>(out_and_fastpath_data)[0];
|
||||
}
|
||||
|
||||
CHECK(cache_entry);
|
||||
if (cache_entry->fall_back_to_python) {
|
||||
return py::cast<py::tuple>(cache_miss_(*args, **kwargs))[0];
|
||||
}
|
||||
|
||||
// TODO(jblespiau): Use C++ only for this.
|
||||
py::list arg_list;
|
||||
for (auto& v : arguments.flat_dynamic_args) {
|
||||
arg_list.append(v);
|
||||
}
|
||||
|
||||
py::object handled_args = cache_entry->handle_args(arg_list);
|
||||
py::list list_of_list_of_buffers = py::cast<py::list>(handled_args);
|
||||
|
||||
arguments.keep_alive_objects.push_back(
|
||||
py::cast<py::object>(list_of_list_of_buffers));
|
||||
// Should be `[num_devices x num_args]`.
|
||||
std::vector<std::vector<xla::PyBuffer*>> arg_buffers;
|
||||
arg_buffers.reserve(list_of_list_of_buffers.size());
|
||||
for (int i = 0; i < list_of_list_of_buffers.size(); ++i) {
|
||||
std::vector<xla::PyBuffer*> buffers;
|
||||
buffers.reserve(py::cast<py::list>(list_of_list_of_buffers[i]).size());
|
||||
for (auto& buf : list_of_list_of_buffers[i]) {
|
||||
buffers.push_back(py::cast<xla::PyBuffer*>(buf));
|
||||
}
|
||||
arg_buffers.push_back(std::move(buffers));
|
||||
}
|
||||
|
||||
std::vector<std::vector<std::unique_ptr<xla::PyBuffer>>> outputs =
|
||||
ValueOrThrow(cache_entry->executable->ExecuteOnLocalDevices(arg_buffers));
|
||||
|
||||
// TODO(jblespiau): Move this to C++.
|
||||
py::list outputs_as_python_objects;
|
||||
for (int i = 0; i < outputs.size(); ++i) {
|
||||
outputs_as_python_objects.append(py::cast(std::move(outputs[i])));
|
||||
}
|
||||
py::list flat_sharded_device_arrays =
|
||||
cache_entry->out_handler(outputs_as_python_objects);
|
||||
return cache_entry->out_pytree_def.Unflatten(flat_sharded_device_arrays);
|
||||
}
|
||||
|
||||
void BuildPmapSubmodule(pybind11::module& m) {
|
||||
py::module pmap_lib = m.def_submodule("pmap_lib", "Jax C++ pmap library");
|
||||
|
||||
@ -167,6 +413,21 @@ void BuildPmapSubmodule(pybind11::module& m) {
|
||||
&ShardedDeviceArray::GetShardingSpec)
|
||||
.def_property_readonly("device_buffers",
|
||||
&ShardedDeviceArray::GetDeviceBuffers);
|
||||
|
||||
py::class_<PmapFunction, std::unique_ptr<PmapFunction>> cfun(pmap_lib,
|
||||
"PmapFunction");
|
||||
cfun.def("__call__", &PmapFunction::Call);
|
||||
cfun.def_property_readonly("__signature__", &PmapFunction::PythonSignature);
|
||||
|
||||
pmap_lib.def(
|
||||
"pmap",
|
||||
[](py::function fun, py::function cache_miss,
|
||||
py::function get_jax_enable_x64,
|
||||
std::vector<int> static_argnums) -> std::unique_ptr<PmapFunction> {
|
||||
return std::make_unique<PmapFunction>(
|
||||
std::move(fun), std::move(cache_miss),
|
||||
std::move(get_jax_enable_x64), std::move(static_argnums));
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace jax
|
||||
|
@ -70,12 +70,12 @@ struct Chunked {
|
||||
};
|
||||
|
||||
// `Unstacked` means that the dimension is split into chunks of size 1, and
|
||||
// doesn't appear inside the map. `size` is alwyays the dimension size.
|
||||
// doesn't appear inside the map. `size` is always the dimension size.
|
||||
// For example, a Tensor t of shape [N] will be sharded into N tensors of shape
|
||||
// [], when using `Unstacked(N)`.
|
||||
struct Unstacked {
|
||||
public:
|
||||
explicit Unstacked(int size_) : size(size_) {}
|
||||
explicit Unstacked(int sz) : size(sz) {}
|
||||
const int size;
|
||||
|
||||
bool operator==(const Unstacked& other) const { return size == other.size; }
|
||||
@ -121,7 +121,6 @@ pybind11::tuple CppMeshMappingToPy(std::vector<MeshDimAssignment> mesh_mapping);
|
||||
|
||||
// Describes how each axis is sharded (if it is), and how it'smapped to the
|
||||
// devices mesh.
|
||||
// See `AvalDimSharding` and `MeshDimAssignment`.
|
||||
class ShardingSpec {
|
||||
public:
|
||||
ShardingSpec(std::vector<AvalDimSharding> sharding,
|
||||
|
@ -1498,11 +1498,11 @@ void HloInstruction::set_single_sharding(const HloSharding& sharding) {
|
||||
|
||||
void HloInstruction::SetupDerivedInstruction(
|
||||
HloInstruction* derived_instruction) const {
|
||||
if (sharding_ != nullptr && ShapeUtil::CompatibleIgnoringElementType(
|
||||
shape_, derived_instruction->shape())) {
|
||||
// Only copy sharding if the shape of the two instruction is compatible
|
||||
// because copying it between differently shaped instructions can produce
|
||||
// invalid shardings.
|
||||
if (sharding_ != nullptr &&
|
||||
ShapeUtil::CompatibleKind(shape_, derived_instruction->shape())) {
|
||||
// Only copy sharding if the tuple tree shape of the two instruction is
|
||||
// compatible because copying it between differently shaped instructions
|
||||
// can produce invalid shardings.
|
||||
derived_instruction->set_sharding(*sharding_);
|
||||
} else {
|
||||
derived_instruction->clear_sharding();
|
||||
|
@ -749,6 +749,64 @@ TEST_F(HloInstructionTest, PreserveTupleShapeThroughClone) {
|
||||
EXPECT_TRUE(ShapeUtil::Equal(tuple_clone->shape(), tuple->shape()));
|
||||
}
|
||||
|
||||
TEST_F(HloInstructionTest, PreserveShardingThroughCompatibleClone) {
|
||||
HloSharding sharding = HloSharding::AssignDevice(5);
|
||||
HloComputation::Builder builder(TestName());
|
||||
auto* constant = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR2<float>({
|
||||
{1, 2},
|
||||
{3, 4},
|
||||
})));
|
||||
auto* tuple =
|
||||
builder.AddInstruction(HloInstruction::CreateTuple({constant, constant}));
|
||||
tuple->set_sharding(sharding);
|
||||
// Compatible with original shape as tuple tree structure and leaf ranks are
|
||||
// identical
|
||||
auto clone_shape = ShapeUtil::MakeShape(F32, {3, 3});
|
||||
clone_shape = ShapeUtil::MakeTupleShape({clone_shape, clone_shape});
|
||||
auto tuple_clone = tuple->CloneWithNewOperands(clone_shape, {});
|
||||
EXPECT_EQ(tuple_clone->sharding(), sharding);
|
||||
}
|
||||
|
||||
TEST_F(HloInstructionTest,
|
||||
DoNotPreserveShardingThroughTupleTreeIncompatibleClone) {
|
||||
HloSharding sharding = HloSharding::AssignDevice(5);
|
||||
HloComputation::Builder builder(TestName());
|
||||
auto* constant = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR2<float>({
|
||||
{1, 2},
|
||||
{3, 4},
|
||||
})));
|
||||
auto* tuple =
|
||||
builder.AddInstruction(HloInstruction::CreateTuple({constant, constant}));
|
||||
tuple->set_sharding(sharding);
|
||||
// Incompatible with original shape as tuple tree structure is different
|
||||
auto clone_shape = ShapeUtil::MakeShape(F32, {2, 2});
|
||||
clone_shape =
|
||||
ShapeUtil::MakeTupleShape({clone_shape, clone_shape, clone_shape});
|
||||
auto tuple_clone = tuple->CloneWithNewOperands(clone_shape, {});
|
||||
EXPECT_FALSE(tuple_clone->has_sharding());
|
||||
}
|
||||
|
||||
TEST_F(HloInstructionTest,
|
||||
DoNotPreserveShardingThroughLeafRankIncompatibleClone) {
|
||||
HloSharding sharding = HloSharding::AssignDevice(5);
|
||||
HloComputation::Builder builder(TestName());
|
||||
auto* constant = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR2<float>({
|
||||
{1, 2},
|
||||
{3, 4},
|
||||
})));
|
||||
auto* tuple =
|
||||
builder.AddInstruction(HloInstruction::CreateTuple({constant, constant}));
|
||||
tuple->set_sharding(sharding);
|
||||
// Incompatible with original shape as tuple tree structure is different
|
||||
auto clone_shape = ShapeUtil::MakeShape(F32, {1, 2, 3});
|
||||
clone_shape = ShapeUtil::MakeTupleShape({clone_shape, clone_shape});
|
||||
auto tuple_clone = tuple->CloneWithNewOperands(clone_shape, {});
|
||||
EXPECT_FALSE(tuple_clone->has_sharding());
|
||||
}
|
||||
|
||||
TEST_F(HloInstructionTest, FusionOpWithCalledComputations) {
|
||||
// Create a fusion instruction containing a single unary operation.
|
||||
const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
|
||||
|
@ -141,9 +141,16 @@ bool Shape::Equal::operator()(const Shape& lhs, const Shape& rhs) {
|
||||
}
|
||||
}
|
||||
|
||||
if (!ShapeUtil::SameDimensions(lhs, rhs)) {
|
||||
VLOG(3) << "CompareShapes: lhs dimensions != rhs dimensions";
|
||||
return false;
|
||||
if (!ignore_dimensions_) {
|
||||
if (!ShapeUtil::SameDimensions(lhs, rhs)) {
|
||||
VLOG(3) << "CompareShapes: lhs dimensions != rhs dimensions";
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
if (!ShapeUtil::SameRank(lhs, rhs)) {
|
||||
VLOG(3) << "CompareShapes: lhs rank != rhs rank";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (!ignore_layout_) {
|
||||
|
@ -220,6 +220,10 @@ class Shape {
|
||||
ignore_dynamic_dimension_ = true;
|
||||
return *this;
|
||||
}
|
||||
Equal& IgnoreDimensions() {
|
||||
ignore_dimensions_ = true;
|
||||
return *this;
|
||||
}
|
||||
|
||||
private:
|
||||
bool ignore_layout_ = false;
|
||||
@ -229,6 +233,7 @@ class Shape {
|
||||
bool ignore_element_type_ = false;
|
||||
bool ignore_fp_precision_ = false;
|
||||
bool ignore_dynamic_dimension_ = false;
|
||||
bool ignore_dimensions_ = false;
|
||||
};
|
||||
|
||||
// Test that all fields of the shape are the same, equivalent to Equal().
|
||||
|
@ -644,6 +644,12 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
|
||||
return absl::c_equal(lhs.dimensions(), rhs.dimensions());
|
||||
}
|
||||
|
||||
/* static */ bool ShapeUtil::SameRank(const Shape& lhs, const Shape& rhs) {
|
||||
CHECK(lhs.IsArray());
|
||||
CHECK(rhs.IsArray());
|
||||
return lhs.rank() == rhs.rank();
|
||||
}
|
||||
|
||||
/* static */ bool ShapeUtil::Compatible(const Shape& lhs, const Shape& rhs) {
|
||||
return Shape::Equal().IgnoreDynamicDimension().IgnoreLayout()(lhs, rhs);
|
||||
}
|
||||
@ -656,6 +662,15 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
|
||||
.IgnoreLayout()(lhs, rhs);
|
||||
}
|
||||
|
||||
/* static */ bool ShapeUtil::CompatibleKind(const Shape& lhs,
|
||||
const Shape& rhs) {
|
||||
return Shape::Equal()
|
||||
.IgnoreElementType()
|
||||
.IgnoreLayout()
|
||||
.IgnoreDimensions()
|
||||
.IgnoreDynamicDimension()(lhs, rhs);
|
||||
}
|
||||
|
||||
/* static */ bool ShapeUtil::CompatibleIgnoringFpPrecision(const Shape& lhs,
|
||||
const Shape& rhs) {
|
||||
return Shape::Equal()
|
||||
|
@ -247,6 +247,11 @@ class ShapeUtil {
|
||||
// Precondition: IsArray(lhs) && IsArray(rhs)
|
||||
static bool SameDimensions(const Shape& lhs, const Shape& rhs);
|
||||
|
||||
// Returns whether the LHS and RHS shapes have the same rank; note: does
|
||||
// not check element type.
|
||||
// Precondition: IsArray(lhs) && IsArray(rhs)
|
||||
static bool SameRank(const Shape& lhs, const Shape& rhs);
|
||||
|
||||
// Returns whether the lhs and rhs shapes have the same element type.
|
||||
static bool SameElementType(const Shape& lhs, const Shape& rhs) {
|
||||
return lhs.element_type() == rhs.element_type();
|
||||
@ -308,6 +313,11 @@ class ShapeUtil {
|
||||
// compatibility.
|
||||
static bool CompatibleIgnoringElementType(const Shape& lhs, const Shape& rhs);
|
||||
|
||||
// Returns true if the tuple tree shapes and leaf ranks are identical.
|
||||
// Leaf dimensions, element type, and layout are ignored. Tuple elements are
|
||||
// compared recursively for compatibility.
|
||||
static bool CompatibleKind(const Shape& lhs, const Shape& rhs);
|
||||
|
||||
// As Compatible, but allow one of lhs and rhs to be BF16 while the other
|
||||
// being F32. Tuple elements are compared recursively for compatibility.
|
||||
static bool CompatibleIgnoringFpPrecision(const Shape& lhs, const Shape& rhs);
|
||||
|
@ -27,9 +27,12 @@ REGISTER_COMPLEX(CPU, float, complex64);
|
||||
REGISTER_COMPLEX(CPU, double, complex128);
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
|
||||
!defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
|
||||
REGISTER_COMPLEX(GPU, float, complex64);
|
||||
REGISTER_COMPLEX(GPU, double, complex128);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#undef REGISTER_COMPLEX
|
||||
} // namespace tensorflow
|
||||
|
@ -16,10 +16,16 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/cwise_ops_common.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
REGISTER3(UnaryOp, CPU, "Digamma", functor::digamma, float, Eigen::half,
|
||||
double);
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
|
||||
!defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
|
||||
REGISTER3(UnaryOp, GPU, "Digamma", functor::digamma, float, Eigen::half,
|
||||
double);
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#endif
|
||||
#endif
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -20,8 +20,11 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace functor {
|
||||
#if GOOGLE_CUDA
|
||||
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
|
||||
!defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
|
||||
DEFINE_UNARY2(get_angle, complex64, complex128);
|
||||
#endif
|
||||
#endif
|
||||
} // namespace functor
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -19,8 +19,7 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace functor {
|
||||
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
|
||||
!defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
|
||||
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED)
|
||||
DEFINE_UNARY3(isfinite, Eigen::half, float, double);
|
||||
#endif
|
||||
} // namespace functor
|
||||
|
@ -19,8 +19,7 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace functor {
|
||||
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
|
||||
!defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
|
||||
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED)
|
||||
DEFINE_UNARY3(isinf, Eigen::half, float, double);
|
||||
#endif
|
||||
} // namespace functor
|
||||
|
@ -19,8 +19,7 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace functor {
|
||||
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
|
||||
!defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
|
||||
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED)
|
||||
DEFINE_UNARY3(isnan, Eigen::half, float, double);
|
||||
#endif
|
||||
} // namespace functor
|
||||
|
@ -20,8 +20,7 @@ REGISTER4(UnaryOp, CPU, "IsFinite", functor::isfinite, float, Eigen::half,
|
||||
bfloat16, double);
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
|
||||
!defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
|
||||
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED)
|
||||
REGISTER3(UnaryOp, GPU, "IsFinite", functor::isfinite, float, Eigen::half,
|
||||
double);
|
||||
#endif
|
||||
|
@ -20,8 +20,7 @@ REGISTER4(UnaryOp, CPU, "IsInf", functor::isinf, float, Eigen::half, bfloat16,
|
||||
double);
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
|
||||
!defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
|
||||
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED)
|
||||
REGISTER3(UnaryOp, GPU, "IsInf", functor::isinf, float, Eigen::half, double);
|
||||
#endif
|
||||
#endif
|
||||
|
@ -20,8 +20,7 @@ REGISTER4(UnaryOp, CPU, "IsNan", functor::isnan, float, Eigen::half, double,
|
||||
bfloat16);
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
|
||||
!defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
|
||||
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED)
|
||||
REGISTER3(UnaryOp, GPU, "IsNan", functor::isnan, float, Eigen::half, double);
|
||||
#endif
|
||||
#endif
|
||||
|
@ -52,6 +52,9 @@ filegroup(
|
||||
"gpu_op_floor.cc",
|
||||
"gpu_op_imag.cc",
|
||||
"gpu_op_invert.cc",
|
||||
"gpu_op_is_finite.cc",
|
||||
"gpu_op_is_inf.cc",
|
||||
"gpu_op_is_nan.cc",
|
||||
"gpu_op_log.cc",
|
||||
"gpu_op_log1p.cc",
|
||||
"gpu_op_logical_not.cc",
|
||||
@ -71,13 +74,12 @@ filegroup(
|
||||
srcs = [
|
||||
"gpu_op_acos.cc",
|
||||
"gpu_op_acosh.cc",
|
||||
"gpu_op_angle.cc",
|
||||
"gpu_op_asin.cc",
|
||||
"gpu_op_asinh.cc",
|
||||
"gpu_op_digamma.cc",
|
||||
"gpu_op_exp.cc",
|
||||
"gpu_op_expm1.cc",
|
||||
"gpu_op_is_finite.cc",
|
||||
"gpu_op_is_inf.cc",
|
||||
"gpu_op_is_nan.cc",
|
||||
"gpu_op_lgamma.cc",
|
||||
"gpu_op_sign.cc",
|
||||
"gpu_op_sin.cc",
|
||||
@ -97,9 +99,9 @@ filegroup(
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gpu_ops_base",
|
||||
srcs = ["gpu_ops_base.cc"],
|
||||
hdrs = ["gpu_ops_base.h"],
|
||||
name = "base_op",
|
||||
srcs = ["base_op.cc"],
|
||||
hdrs = ["base_op.h"],
|
||||
compatible_with = get_compatible_with_cloud(),
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/tools/kernel_gen:tf_framework_c_interface",
|
||||
@ -113,6 +115,13 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "base_gpu_op",
|
||||
hdrs = ["base_gpu_op.h"],
|
||||
compatible_with = get_compatible_with_cloud(),
|
||||
deps = [":base_op"],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "cwise_unary_op",
|
||||
srcs = [":unary_kernel_srcs"],
|
||||
@ -129,6 +138,7 @@ tf_kernel_library(
|
||||
":square_kernels",
|
||||
":acos_kernels",
|
||||
":acosh_kernels",
|
||||
":angle_kernels",
|
||||
":asin_kernels",
|
||||
":asinh_kernels",
|
||||
":atan_kernels",
|
||||
@ -139,6 +149,7 @@ tf_kernel_library(
|
||||
":conj_kernels",
|
||||
":cos_kernels",
|
||||
":cosh_kernels",
|
||||
":digamma_kernels",
|
||||
":erf_kernels",
|
||||
":erfc_kernels",
|
||||
":exp_kernels",
|
||||
@ -162,7 +173,7 @@ tf_kernel_library(
|
||||
":sqrt_kernels",
|
||||
":tan_kernels",
|
||||
":tanh_kernels",
|
||||
":gpu_ops_base",
|
||||
":base_gpu_op",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
@ -200,13 +211,13 @@ tf_kernel_library(
|
||||
deps = [
|
||||
":add_v2_kernels",
|
||||
":atan2_kernels",
|
||||
":base_gpu_op",
|
||||
":bitwise_and_kernels",
|
||||
":bitwise_or_kernels",
|
||||
":bitwise_xor_kernels",
|
||||
":div_kernels",
|
||||
":equal_kernels",
|
||||
":floor_div_kernels",
|
||||
":gpu_ops_base",
|
||||
":greater_equal_kernels",
|
||||
":greater_kernels",
|
||||
":left_shift_kernels",
|
||||
@ -354,7 +365,6 @@ gen_kernel_library(
|
||||
"f32",
|
||||
"f64",
|
||||
],
|
||||
unroll_factors = "4",
|
||||
)
|
||||
|
||||
gen_kernel_library(
|
||||
@ -364,7 +374,20 @@ gen_kernel_library(
|
||||
"f32",
|
||||
"f64",
|
||||
],
|
||||
unroll_factors = "4",
|
||||
)
|
||||
|
||||
gen_kernel_library(
|
||||
name = "angle",
|
||||
output_types = [
|
||||
"f32",
|
||||
"f64",
|
||||
],
|
||||
tile_size = "256",
|
||||
types = [
|
||||
"c64",
|
||||
"c128",
|
||||
],
|
||||
unroll_factors = "2",
|
||||
)
|
||||
|
||||
gen_kernel_library(
|
||||
@ -374,7 +397,6 @@ gen_kernel_library(
|
||||
"f32",
|
||||
"f64",
|
||||
],
|
||||
unroll_factors = "4",
|
||||
)
|
||||
|
||||
gen_kernel_library(
|
||||
@ -384,7 +406,6 @@ gen_kernel_library(
|
||||
"f32",
|
||||
"f64",
|
||||
],
|
||||
unroll_factors = "4",
|
||||
)
|
||||
|
||||
gen_kernel_library(
|
||||
@ -514,7 +535,6 @@ gen_kernel_library(
|
||||
"f32",
|
||||
"f64",
|
||||
],
|
||||
unroll_factors = "4",
|
||||
)
|
||||
|
||||
[
|
||||
@ -769,6 +789,7 @@ gen_kernel_library(
|
||||
)
|
||||
for name in [
|
||||
"ceil",
|
||||
"digamma",
|
||||
"exp",
|
||||
"expm1",
|
||||
"floor",
|
||||
|
77
tensorflow/core/kernels/mlir_generated/base_gpu_op.h
Normal file
77
tensorflow/core/kernels/mlir_generated/base_gpu_op.h
Normal file
@ -0,0 +1,77 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_BASE_GPU_OP_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_BASE_GPU_OP_H_
|
||||
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
#define GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(tf_op, mlir_type, tf_data_type, \
|
||||
data_type) \
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(tf_op, GPU, mlir_type, tf_data_type, \
|
||||
data_type)
|
||||
|
||||
#define GENERATE_UNARY_GPU_KERNEL(tf_op, mlir_type, tf_data_type, data_type) \
|
||||
GENERATE_UNARY_KERNEL(tf_op, GPU, mlir_type, tf_data_type, data_type)
|
||||
|
||||
#define GENERATE_UNARY_GPU_KERNEL2(tf_op, mlir_type, mlir_output_type, \
|
||||
tf_data_type, result_data_type, \
|
||||
input_data_type) \
|
||||
GENERATE_UNARY_KERNEL2(tf_op, GPU, mlir_type, mlir_output_type, \
|
||||
tf_data_type, result_data_type, input_data_type)
|
||||
|
||||
#define REGISTER_ALIASED_GPU_KERNEL(tf_op, mlir_op, mlir_type, \
|
||||
mlir_output_type, data_type) \
|
||||
REGISTER_ALIASED_KERNEL(tf_op, mlir_op, GPU, mlir_type, mlir_output_type, \
|
||||
data_type)
|
||||
|
||||
#define REGISTER_GPU_KERNEL(tf_op, mlir_type, mlir_output_type, data_type) \
|
||||
REGISTER_KERNEL(tf_op, GPU, mlir_type, mlir_output_type, data_type)
|
||||
|
||||
#define REGISTER_COMPLEX_GPU_KERNEL(tf_op, mlir_type, mlir_output_type, \
|
||||
data_type, input_data_type) \
|
||||
REGISTER_COMPLEX_KERNEL(tf_op, GPU, mlir_type, mlir_output_type, data_type, \
|
||||
input_data_type)
|
||||
|
||||
#define REGISTER_GPU_KERNEL_NO_TYPE_CONSTRAINT(tf_op, mlir_type, \
|
||||
mlir_output_type) \
|
||||
REGISTER_KERNEL_NO_TYPE_CONSTRAINT(tf_op, GPU, mlir_type, mlir_output_type)
|
||||
|
||||
#define GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(tf_op, mlir_type, \
|
||||
tf_data_type, data_type) \
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(tf_op, GPU, mlir_type, tf_data_type, \
|
||||
data_type)
|
||||
|
||||
#define GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2( \
|
||||
tf_op, mlir_type, mlir_output_type, tf_data_type, result_data_type, \
|
||||
input_data_type) \
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL2(tf_op, GPU, mlir_type, \
|
||||
mlir_output_type, tf_data_type, \
|
||||
result_data_type, input_data_type)
|
||||
|
||||
#define GENERATE_BINARY_GPU_KERNEL(tf_op, mlir_type, tf_data_type, data_type) \
|
||||
GENERATE_BINARY_KERNEL(tf_op, GPU, mlir_type, tf_data_type, data_type)
|
||||
|
||||
#define GENERATE_BINARY_GPU_KERNEL2(tf_op, mlir_type, mlir_output_type, \
|
||||
tf_data_type, result_data_type, \
|
||||
input_data_type) \
|
||||
GENERATE_BINARY_KERNEL2(tf_op, GPU, mlir_type, mlir_output_type, \
|
||||
tf_data_type, result_data_type, input_data_type)
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_BASE_GPU_OP_H_
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_op.h"
|
||||
|
||||
#include "tensorflow/core/framework/allocation_description.pb.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_GPU_OPS_BASE_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_GPU_OPS_BASE_H_
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_BASE_OP_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_BASE_OP_H_
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
@ -97,9 +97,9 @@ Tensor ConvertDescriptorToTensor(
|
||||
|
||||
template <DataType TfDataType, typename OutputDataType, typename Kernel,
|
||||
typename InputDataType = OutputDataType>
|
||||
class MlirUnrankedOp : public OpKernel {
|
||||
class MlirOp : public OpKernel {
|
||||
public:
|
||||
explicit MlirUnrankedOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
explicit MlirOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
llvm::SmallVector<::UnrankedMemRefType<InputDataType>, 2> input_descs;
|
||||
@ -142,114 +142,118 @@ class MlirUnrankedOp : public OpKernel {
|
||||
}
|
||||
};
|
||||
|
||||
#define MLIR_FUNCTION(tf_op, mlir_type, mlir_output_type) \
|
||||
_mlir_ciface_##tf_op##_##mlir_type##_##mlir_output_type
|
||||
#define MLIR_FUNCTION(tf_op, platform, mlir_type, mlir_output_type) \
|
||||
_mlir_ciface_##tf_op##_##platform##_##mlir_type##_##mlir_output_type
|
||||
|
||||
#define REGISTER_ALIASED_KERNEL(tf_op, mlir_op, mlir_type, mlir_output_type, \
|
||||
data_type) \
|
||||
#define REGISTER_ALIASED_KERNEL(tf_op, mlir_op, platform, mlir_type, \
|
||||
mlir_output_type, data_type) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name(#tf_op).Device(DEVICE_GPU).TypeConstraint<data_type>("T"), \
|
||||
MlirUnranked##mlir_op##mlir_type##mlir_output_type##Op);
|
||||
Name(#tf_op).Device(DEVICE_##platform).TypeConstraint<data_type>("T"), \
|
||||
Mlir##mlir_op##platform##mlir_type##mlir_output_type##Op);
|
||||
|
||||
#define REGISTER_KERNEL(tf_op, mlir_type, mlir_output_type, data_type) \
|
||||
REGISTER_ALIASED_KERNEL(tf_op, tf_op, mlir_type, mlir_output_type, data_type)
|
||||
#define REGISTER_KERNEL(tf_op, platform, mlir_type, mlir_output_type, \
|
||||
data_type) \
|
||||
REGISTER_ALIASED_KERNEL(tf_op, tf_op, platform, mlir_type, mlir_output_type, \
|
||||
data_type)
|
||||
|
||||
#define REGISTER_COMPLEX_KERNEL(tf_op, mlir_type, mlir_output_type, data_type, \
|
||||
input_data_type) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name(#tf_op) \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<input_data_type>("T") \
|
||||
.TypeConstraint<data_type>("Tout"), \
|
||||
MlirUnranked##tf_op##mlir_type##mlir_output_type##Op);
|
||||
#define REGISTER_COMPLEX_KERNEL(tf_op, platform, mlir_type, mlir_output_type, \
|
||||
data_type, input_data_type) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name(#tf_op) \
|
||||
.Device(DEVICE_##platform) \
|
||||
.TypeConstraint<input_data_type>("T") \
|
||||
.TypeConstraint<data_type>("Tout"), \
|
||||
Mlir##tf_op##platform##mlir_type##mlir_output_type##Op);
|
||||
|
||||
#define REGISTER_KERNEL_NO_TYPE_CONSTRAINT(tf_op, mlir_type, mlir_output_type) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name(#tf_op).Device(DEVICE_GPU), \
|
||||
MlirUnranked##tf_op##mlir_type##mlir_output_type##Op);
|
||||
#define REGISTER_KERNEL_NO_TYPE_CONSTRAINT(tf_op, platform, mlir_type, \
|
||||
mlir_output_type) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name(#tf_op).Device(DEVICE_##platform), \
|
||||
Mlir##tf_op##platform##mlir_type##mlir_output_type##Op);
|
||||
|
||||
// OpKernel with Compute function that converts input tensors to unranked
|
||||
// memref descriptors and calls mlir-generated unranked kernel. The outputs
|
||||
// are converted back to tensors using MlirTensorBuffer to take ownership of
|
||||
// pre-allocated memory.
|
||||
#define GENERATE_AND_REGISTER_BINARY_KERNEL(tf_op, mlir_type, tf_data_type, \
|
||||
data_type) \
|
||||
GENERATE_BINARY_KERNEL(tf_op, mlir_type, tf_data_type, data_type) \
|
||||
REGISTER_KERNEL(tf_op, mlir_type, mlir_type, data_type)
|
||||
#define GENERATE_AND_REGISTER_BINARY_KERNEL(tf_op, platform, mlir_type, \
|
||||
tf_data_type, data_type) \
|
||||
GENERATE_BINARY_KERNEL(tf_op, platform, mlir_type, tf_data_type, data_type) \
|
||||
REGISTER_KERNEL(tf_op, platform, mlir_type, mlir_type, data_type)
|
||||
|
||||
#define GENERATE_AND_REGISTER_BINARY_KERNEL2( \
|
||||
tf_op, mlir_type, mlir_output_type, tf_data_type, result_data_type, \
|
||||
input_data_type) \
|
||||
GENERATE_BINARY_KERNEL2(tf_op, mlir_type, mlir_output_type, tf_data_type, \
|
||||
result_data_type, input_data_type) \
|
||||
REGISTER_KERNEL(tf_op, mlir_type, mlir_output_type, input_data_type)
|
||||
#define GENERATE_AND_REGISTER_BINARY_KERNEL2( \
|
||||
tf_op, platform, mlir_type, mlir_output_type, tf_data_type, \
|
||||
result_data_type, input_data_type) \
|
||||
GENERATE_BINARY_KERNEL2(tf_op, platform, mlir_type, mlir_output_type, \
|
||||
tf_data_type, result_data_type, input_data_type) \
|
||||
REGISTER_KERNEL(tf_op, platform, mlir_type, mlir_output_type, input_data_type)
|
||||
|
||||
#define GENERATE_BINARY_KERNEL(tf_op, mlir_type, tf_data_type, data_type) \
|
||||
GENERATE_BINARY_KERNEL2(tf_op, mlir_type, mlir_type, tf_data_type, \
|
||||
#define GENERATE_BINARY_KERNEL(tf_op, platform, mlir_type, tf_data_type, \
|
||||
data_type) \
|
||||
GENERATE_BINARY_KERNEL2(tf_op, platform, mlir_type, mlir_type, tf_data_type, \
|
||||
data_type, data_type)
|
||||
|
||||
#define GENERATE_BINARY_KERNEL2(tf_op, mlir_type, mlir_output_type, \
|
||||
tf_data_type, result_data_type, \
|
||||
input_data_type) \
|
||||
extern "C" UntypedUnrankedMemRefType MLIR_FUNCTION(tf_op, mlir_type, \
|
||||
mlir_output_type)( \
|
||||
tensorflow::OpKernelContext * ctx, \
|
||||
const ::UnrankedMemRefType<input_data_type>* arg1, \
|
||||
const ::UnrankedMemRefType<input_data_type>* arg2); \
|
||||
\
|
||||
namespace { \
|
||||
class MlirUnranked##tf_op##mlir_type##mlir_output_type##Op \
|
||||
: public MlirUnrankedOp< \
|
||||
tf_data_type, result_data_type, \
|
||||
MlirUnranked##tf_op##mlir_type##mlir_output_type##Op, \
|
||||
input_data_type> { \
|
||||
public: \
|
||||
using MlirUnrankedOp::MlirUnrankedOp; \
|
||||
\
|
||||
static ::UnrankedMemRefType<result_data_type> Invoke( \
|
||||
OpKernelContext* ctx, \
|
||||
llvm::ArrayRef<::UnrankedMemRefType<input_data_type>> args) { \
|
||||
return ConvertToTyped<result_data_type>(MLIR_FUNCTION( \
|
||||
tf_op, mlir_type, mlir_output_type)(ctx, &args[0], &args[1])); \
|
||||
} \
|
||||
}; \
|
||||
#define GENERATE_BINARY_KERNEL2(tf_op, platform, mlir_type, mlir_output_type, \
|
||||
tf_data_type, result_data_type, \
|
||||
input_data_type) \
|
||||
extern "C" UntypedUnrankedMemRefType MLIR_FUNCTION( \
|
||||
tf_op, platform, mlir_type, mlir_output_type)( \
|
||||
tensorflow::OpKernelContext * ctx, \
|
||||
const ::UnrankedMemRefType<input_data_type>* arg1, \
|
||||
const ::UnrankedMemRefType<input_data_type>* arg2); \
|
||||
\
|
||||
namespace { \
|
||||
class Mlir##tf_op##platform##mlir_type##mlir_output_type##Op \
|
||||
: public MlirOp<tf_data_type, result_data_type, \
|
||||
Mlir##tf_op##platform##mlir_type##mlir_output_type##Op, \
|
||||
input_data_type> { \
|
||||
public: \
|
||||
using MlirOp::MlirOp; \
|
||||
\
|
||||
static ::UnrankedMemRefType<result_data_type> Invoke( \
|
||||
OpKernelContext* ctx, \
|
||||
llvm::ArrayRef<::UnrankedMemRefType<input_data_type>> args) { \
|
||||
return ConvertToTyped<result_data_type>( \
|
||||
MLIR_FUNCTION(tf_op, platform, mlir_type, mlir_output_type)( \
|
||||
ctx, &args[0], &args[1])); \
|
||||
} \
|
||||
}; \
|
||||
}
|
||||
|
||||
#define GENERATE_AND_REGISTER_UNARY_KERNEL(tf_op, mlir_type, tf_data_type, \
|
||||
data_type) \
|
||||
GENERATE_UNARY_KERNEL(tf_op, mlir_type, tf_data_type, data_type) \
|
||||
REGISTER_KERNEL(tf_op, mlir_type, mlir_type, data_type)
|
||||
#define GENERATE_AND_REGISTER_UNARY_KERNEL(tf_op, platform, mlir_type, \
|
||||
tf_data_type, data_type) \
|
||||
GENERATE_UNARY_KERNEL(tf_op, platform, mlir_type, tf_data_type, data_type) \
|
||||
REGISTER_KERNEL(tf_op, platform, mlir_type, mlir_type, data_type)
|
||||
|
||||
#define GENERATE_UNARY_KERNEL(tf_op, mlir_type, tf_data_type, data_type) \
|
||||
GENERATE_UNARY_KERNEL2(tf_op, mlir_type, mlir_type, tf_data_type, data_type, \
|
||||
data_type)
|
||||
#define GENERATE_UNARY_KERNEL(tf_op, platform, mlir_type, tf_data_type, \
|
||||
data_type) \
|
||||
GENERATE_UNARY_KERNEL2(tf_op, platform, mlir_type, mlir_type, tf_data_type, \
|
||||
data_type, data_type)
|
||||
|
||||
#define GENERATE_UNARY_KERNEL2(tf_op, mlir_type, mlir_output_type, \
|
||||
tf_data_type, result_data_type, \
|
||||
input_data_type) \
|
||||
extern "C" UntypedUnrankedMemRefType MLIR_FUNCTION(tf_op, mlir_type, \
|
||||
mlir_output_type)( \
|
||||
tensorflow::OpKernelContext * ctx, \
|
||||
const ::UnrankedMemRefType<input_data_type>* arg); \
|
||||
\
|
||||
namespace { \
|
||||
class MlirUnranked##tf_op##mlir_type##mlir_output_type##Op \
|
||||
: public MlirUnrankedOp< \
|
||||
tf_data_type, result_data_type, \
|
||||
MlirUnranked##tf_op##mlir_type##mlir_output_type##Op, \
|
||||
input_data_type> { \
|
||||
public: \
|
||||
using MlirUnrankedOp::MlirUnrankedOp; \
|
||||
\
|
||||
static ::UnrankedMemRefType<result_data_type> Invoke( \
|
||||
OpKernelContext* ctx, \
|
||||
llvm::ArrayRef<::UnrankedMemRefType<input_data_type>> args) { \
|
||||
return ConvertToTyped<result_data_type>( \
|
||||
MLIR_FUNCTION(tf_op, mlir_type, mlir_output_type)(ctx, &args[0])); \
|
||||
} \
|
||||
}; \
|
||||
#define GENERATE_UNARY_KERNEL2(tf_op, platform, mlir_type, mlir_output_type, \
|
||||
tf_data_type, result_data_type, \
|
||||
input_data_type) \
|
||||
extern "C" UntypedUnrankedMemRefType MLIR_FUNCTION( \
|
||||
tf_op, platform, mlir_type, mlir_output_type)( \
|
||||
tensorflow::OpKernelContext * ctx, \
|
||||
const ::UnrankedMemRefType<input_data_type>* arg); \
|
||||
\
|
||||
namespace { \
|
||||
class Mlir##tf_op##platform##mlir_type##mlir_output_type##Op \
|
||||
: public MlirOp<tf_data_type, result_data_type, \
|
||||
Mlir##tf_op##platform##mlir_type##mlir_output_type##Op, \
|
||||
input_data_type> { \
|
||||
public: \
|
||||
using MlirOp::MlirOp; \
|
||||
\
|
||||
static ::UnrankedMemRefType<result_data_type> Invoke( \
|
||||
OpKernelContext* ctx, \
|
||||
llvm::ArrayRef<::UnrankedMemRefType<input_data_type>> args) { \
|
||||
return ConvertToTyped<result_data_type>(MLIR_FUNCTION( \
|
||||
tf_op, platform, mlir_type, mlir_output_type)(ctx, &args[0])); \
|
||||
} \
|
||||
}; \
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_GPU_OPS_BASE_H_
|
||||
#endif // TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_BASE_OP_H_
|
@ -49,9 +49,11 @@ def _gen_mlir_op_impl(ctx):
|
||||
inputs = [ctx.file.template],
|
||||
outputs = [ctx.outputs.out],
|
||||
command = (
|
||||
(("cat %s | sed 's/_elem_type/_%s/g' | sed 's/elem_type/%s/g' | " +
|
||||
"sed 's/_output_type/_%s/g' | sed 's/output_type/%s/g' > %s")) % (
|
||||
(("cat %s | sed 's/platform/%s/g' | sed 's/_elem_type/_%s/g' | " +
|
||||
"sed 's/elem_type/%s/g' | " + "sed 's/_output_type/_%s/g' | " +
|
||||
"sed 's/output_type/%s/g' > %s")) % (
|
||||
ctx.file.template.path,
|
||||
ctx.attr.platform.upper(),
|
||||
ctx.attr.type,
|
||||
mlir_type,
|
||||
ctx.attr.output_type,
|
||||
@ -68,22 +70,26 @@ _gen_mlir_op_rule = rule(
|
||||
"template": attr.label(mandatory = True, allow_single_file = True),
|
||||
"type": attr.string(mandatory = True),
|
||||
"output_type": attr.string(mandatory = True),
|
||||
"platform": attr.string(mandatory = True),
|
||||
"out": attr.output(mandatory = True),
|
||||
},
|
||||
)
|
||||
|
||||
def _gen_mlir_op(name, type, output_type):
|
||||
def _gen_mlir_op(name, type, platform, output_type):
|
||||
_gen_mlir_op_rule(
|
||||
name = "generate_{name}_{type}_{output_type}_mlir".format(
|
||||
name = "generate_{name}_{platform}_{type}_{output_type}_mlir".format(
|
||||
name = name,
|
||||
platform = platform,
|
||||
type = type,
|
||||
output_type = output_type,
|
||||
),
|
||||
template = "op_definitions/{name}.mlir.tmpl".format(name = name),
|
||||
platform = platform,
|
||||
type = type,
|
||||
output_type = output_type,
|
||||
out = "{name}_{type}_{output_type}.mlir".format(
|
||||
out = "{name}_{platform}_{type}_{output_type}.mlir".format(
|
||||
name = name,
|
||||
platform = platform,
|
||||
type = type,
|
||||
output_type = output_type,
|
||||
),
|
||||
@ -177,6 +183,7 @@ def gen_kernel_library(
|
||||
tile_size,
|
||||
output_types = None,
|
||||
tags = [],
|
||||
platform = "gpu",
|
||||
unroll_factors = None,
|
||||
extra_args = []):
|
||||
""" Generate a library with kernels for a specific tensorflow op.
|
||||
@ -190,6 +197,7 @@ def gen_kernel_library(
|
||||
entry in output_types. By default, output_types = types is
|
||||
assumed.
|
||||
tags: The tags which should be added to the library.
|
||||
platform: Platform for which to compile, i.e. "cpu" or "gpu"
|
||||
unroll_factors: The unrolling specification, e.g. "4,4"
|
||||
extra_args: Extra arguments to pass to the generator tool.
|
||||
"""
|
||||
@ -200,17 +208,20 @@ def gen_kernel_library(
|
||||
for (type, output_type) in zip(types, output_types):
|
||||
_gen_mlir_op(
|
||||
name = name,
|
||||
platform = platform,
|
||||
type = type,
|
||||
output_type = output_type,
|
||||
)
|
||||
_gen_kernel_fatbin_rule(
|
||||
name = "{name}_{type}_{output_type}_kernel_generator".format(
|
||||
name = "{name}_{platform}_{type}_{output_type}_kernel_generator".format(
|
||||
name = name,
|
||||
platform = platform,
|
||||
type = type,
|
||||
output_type = output_type,
|
||||
),
|
||||
mlir_op = "{name}_{type}_{output_type}.mlir".format(
|
||||
mlir_op = "{name}_{platform}_{type}_{output_type}.mlir".format(
|
||||
name = name,
|
||||
platform = platform,
|
||||
type = type,
|
||||
output_type = output_type,
|
||||
),
|
||||
@ -223,8 +234,9 @@ def gen_kernel_library(
|
||||
|
||||
# We have to use a sh_test instead of build_test because it doesn't properly find the dependent targets.
|
||||
native.sh_test(
|
||||
name = "{name}_{type}_{output_type}_gen_test".format(
|
||||
name = "{name}_{platform}_{type}_{output_type}_gen_test".format(
|
||||
name = name,
|
||||
platform = platform,
|
||||
type = type,
|
||||
output_type = output_type,
|
||||
),
|
||||
@ -232,16 +244,18 @@ def gen_kernel_library(
|
||||
tags = ["no_rocm"],
|
||||
args = [
|
||||
"$(location //tensorflow/compiler/mlir/tools/kernel_gen:tf_to_kernel)",
|
||||
"$(location {name}_{type}_{output_type}.mlir)".format(
|
||||
"$(location {name}_{platform}_{type}_{output_type}.mlir)".format(
|
||||
name = name,
|
||||
platform = platform,
|
||||
type = type,
|
||||
output_type = output_type,
|
||||
),
|
||||
],
|
||||
size = "medium",
|
||||
data = [
|
||||
":{name}_{type}_{output_type}.mlir".format(
|
||||
":{name}_{platform}_{type}_{output_type}.mlir".format(
|
||||
name = name,
|
||||
platform = platform,
|
||||
type = type,
|
||||
output_type = output_type,
|
||||
),
|
||||
@ -252,8 +266,9 @@ def gen_kernel_library(
|
||||
native.cc_library(
|
||||
name = name + "_kernels",
|
||||
compatible_with = get_compatible_with_cloud(),
|
||||
deps = if_gpu_is_configured([":{name}_{type}_{output_type}_kernel_generator".format(
|
||||
deps = if_gpu_is_configured([":{name}_{platform}_{type}_{output_type}_kernel_generator".format(
|
||||
name = name,
|
||||
platform = platform,
|
||||
type = type,
|
||||
output_type = output_type,
|
||||
) for (type, output_type) in zip(types, output_types)]),
|
||||
|
@ -14,14 +14,14 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Abs, f16, DT_HALF, Eigen::half);
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Abs, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Abs, f64, DT_DOUBLE, double);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Abs, f16, DT_HALF, Eigen::half);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Abs, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Abs, f64, DT_DOUBLE, double);
|
||||
// TODO(b/25387198): Add an int32 kernel.
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Abs, i64, DT_INT64, int64);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Abs, i64, DT_INT64, int64);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -14,11 +14,11 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Acos, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Acos, f64, DT_DOUBLE, double);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Acos, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Acos, f64, DT_DOUBLE, double);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -14,11 +14,11 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Acosh, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Acosh, f64, DT_DOUBLE, double);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Acosh, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Acosh, f64, DT_DOUBLE, double);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -13,13 +13,13 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(AddV2, f16, DT_HALF, Eigen::half);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(AddV2, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(AddV2, f64, DT_DOUBLE, double);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(AddV2, i64, DT_INT64, int64);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(AddV2, f16, DT_HALF, Eigen::half);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(AddV2, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(AddV2, f64, DT_DOUBLE, double);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(AddV2, i64, DT_INT64, int64);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
30
tensorflow/core/kernels/mlir_generated/gpu_op_angle.cc
Normal file
30
tensorflow/core/kernels/mlir_generated/gpu_op_angle.cc
Normal file
@ -0,0 +1,30 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <complex>
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_UNARY_GPU_KERNEL2(Angle, c64, f32, DT_FLOAT, float,
|
||||
std::complex<float>);
|
||||
REGISTER_COMPLEX_GPU_KERNEL(Angle, c64, f32, float, std::complex<float>);
|
||||
GENERATE_UNARY_GPU_KERNEL2(Angle, c128, f64, DT_DOUBLE, double,
|
||||
std::complex<double>);
|
||||
REGISTER_COMPLEX_GPU_KERNEL(Angle, c128, f64, double, std::complex<double>);
|
||||
|
||||
} // namespace tensorflow
|
@ -14,11 +14,11 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Asin, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Asin, f64, DT_DOUBLE, double);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Asin, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Asin, f64, DT_DOUBLE, double);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -14,11 +14,11 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Asinh, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Asinh, f64, DT_DOUBLE, double);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Asinh, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Asinh, f64, DT_DOUBLE, double);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -14,11 +14,11 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Atan, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Atan, f64, DT_DOUBLE, double);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Atan, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Atan, f64, DT_DOUBLE, double);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -13,11 +13,11 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(Atan2, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(Atan2, f64, DT_DOUBLE, double);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Atan2, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Atan2, f64, DT_DOUBLE, double);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -14,11 +14,11 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Atanh, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Atanh, f64, DT_DOUBLE, double);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Atanh, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Atanh, f64, DT_DOUBLE, double);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -13,19 +13,19 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseAnd, i8, DT_INT8, int8);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseAnd, i16, DT_INT16, int16);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseAnd, i32, DT_INT32, int32);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseAnd, i64, DT_INT64, int64);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseAnd, i8, DT_INT8, int8);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseAnd, i16, DT_INT16, int16);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseAnd, i32, DT_INT32, int32);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseAnd, i64, DT_INT64, int64);
|
||||
|
||||
// TODO(b/172804967): Enable once fixed.
|
||||
// GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseAnd, ui8, DT_UINT8, uint8);
|
||||
// GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseAnd, ui16, DT_UINT16, uint16);
|
||||
// GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseAnd, ui32, DT_UINT32, uint32);
|
||||
// GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseAnd, ui64, DT_UINT64, uint64);
|
||||
// GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseAnd, ui8, DT_UINT8, uint8);
|
||||
// GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseAnd, ui16, DT_UINT16, uint16);
|
||||
// GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseAnd, ui32, DT_UINT32, uint32);
|
||||
// GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseAnd, ui64, DT_UINT64, uint64);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -13,19 +13,19 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseOr, i8, DT_INT8, int8);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseOr, i16, DT_INT16, int16);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseOr, i32, DT_INT32, int32);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseOr, i64, DT_INT64, int64);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseOr, i8, DT_INT8, int8);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseOr, i16, DT_INT16, int16);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseOr, i32, DT_INT32, int32);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseOr, i64, DT_INT64, int64);
|
||||
|
||||
// TODO(b/172804967): Enable once fixed.
|
||||
// GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseOr, ui8, DT_UINT8, uint8);
|
||||
// GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseOr, ui16, DT_UINT16, uint16);
|
||||
// GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseOr, ui32, DT_UINT32, uint32);
|
||||
// GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseOr, ui64, DT_UINT64, uint64);
|
||||
// GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseOr, ui8, DT_UINT8, uint8);
|
||||
// GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseOr, ui16, DT_UINT16, uint16);
|
||||
// GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseOr, ui32, DT_UINT32, uint32);
|
||||
// GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseOr, ui64, DT_UINT64, uint64);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -13,19 +13,19 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseXor, i8, DT_INT8, int8);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseXor, i16, DT_INT16, int16);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseXor, i32, DT_INT32, int32);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseXor, i64, DT_INT64, int64);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseXor, i8, DT_INT8, int8);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseXor, i16, DT_INT16, int16);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseXor, i32, DT_INT32, int32);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseXor, i64, DT_INT64, int64);
|
||||
|
||||
// TODO(b/172804967): Enable once fixed.
|
||||
// GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseXor, ui8, DT_UINT8, uint8);
|
||||
// GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseXor, ui16, DT_UINT16, uint16);
|
||||
// GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseXor, ui32, DT_UINT32, uint32);
|
||||
// GENERATE_AND_REGISTER_BINARY_KERNEL(BitwiseXor, ui64, DT_UINT64, uint64);
|
||||
// GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseXor, ui8, DT_UINT8, uint8);
|
||||
// GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseXor, ui16, DT_UINT16, uint16);
|
||||
// GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseXor, ui32, DT_UINT32, uint32);
|
||||
// GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(BitwiseXor, ui64, DT_UINT64, uint64);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -14,12 +14,12 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Ceil, f16, DT_HALF, Eigen::half);
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Ceil, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Ceil, f64, DT_DOUBLE, double);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Ceil, f16, DT_HALF, Eigen::half);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Ceil, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Ceil, f64, DT_DOUBLE, double);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -16,15 +16,15 @@ limitations under the License.
|
||||
#include <complex>
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_BINARY_KERNEL2(Complex, f32, c64, DT_COMPLEX64, std::complex<float>,
|
||||
float);
|
||||
REGISTER_COMPLEX_KERNEL(Complex, f32, c64, std::complex<float>, float);
|
||||
GENERATE_BINARY_KERNEL2(Complex, f64, c128, DT_COMPLEX128, std::complex<double>,
|
||||
double);
|
||||
REGISTER_COMPLEX_KERNEL(Complex, f64, c128, std::complex<double>, double);
|
||||
GENERATE_BINARY_GPU_KERNEL2(Complex, f32, c64, DT_COMPLEX64,
|
||||
std::complex<float>, float);
|
||||
REGISTER_COMPLEX_GPU_KERNEL(Complex, f32, c64, std::complex<float>, float);
|
||||
GENERATE_BINARY_GPU_KERNEL2(Complex, f64, c128, DT_COMPLEX128,
|
||||
std::complex<double>, double);
|
||||
REGISTER_COMPLEX_GPU_KERNEL(Complex, f64, c128, std::complex<double>, double);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -16,15 +16,16 @@ limitations under the License.
|
||||
#include <complex>
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_UNARY_KERNEL2(ComplexAbs, c64, f32, DT_FLOAT, float,
|
||||
std::complex<float>);
|
||||
REGISTER_COMPLEX_KERNEL(ComplexAbs, c64, f32, float, std::complex<float>);
|
||||
GENERATE_UNARY_KERNEL2(ComplexAbs, c128, f64, DT_DOUBLE, double,
|
||||
std::complex<double>);
|
||||
REGISTER_COMPLEX_KERNEL(ComplexAbs, c128, f64, double, std::complex<double>);
|
||||
GENERATE_UNARY_GPU_KERNEL2(ComplexAbs, c64, f32, DT_FLOAT, float,
|
||||
std::complex<float>);
|
||||
REGISTER_COMPLEX_GPU_KERNEL(ComplexAbs, c64, f32, float, std::complex<float>);
|
||||
GENERATE_UNARY_GPU_KERNEL2(ComplexAbs, c128, f64, DT_DOUBLE, double,
|
||||
std::complex<double>);
|
||||
REGISTER_COMPLEX_GPU_KERNEL(ComplexAbs, c128, f64, double,
|
||||
std::complex<double>);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -16,13 +16,13 @@ limitations under the License.
|
||||
#include <complex>
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Conj, c64, DT_COMPLEX64,
|
||||
std::complex<float>);
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Conj, c128, DT_COMPLEX128,
|
||||
std::complex<double>);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Conj, c64, DT_COMPLEX64,
|
||||
std::complex<float>);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Conj, c128, DT_COMPLEX128,
|
||||
std::complex<double>);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -14,12 +14,12 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Cos, f16, DT_HALF, Eigen::half);
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Cos, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Cos, f64, DT_DOUBLE, double);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Cos, f16, DT_HALF, Eigen::half);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Cos, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Cos, f64, DT_DOUBLE, double);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -14,11 +14,11 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Cosh, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Cosh, f64, DT_DOUBLE, double);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Cosh, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Cosh, f64, DT_DOUBLE, double);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
25
tensorflow/core/kernels/mlir_generated/gpu_op_digamma.cc
Normal file
25
tensorflow/core/kernels/mlir_generated/gpu_op_digamma.cc
Normal file
@ -0,0 +1,25 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Digamma, f16, DT_HALF, Eigen::half);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Digamma, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Digamma, f64, DT_DOUBLE, double);
|
||||
|
||||
} // namespace tensorflow
|
@ -14,21 +14,21 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(Div, f16, DT_HALF, Eigen::half);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(Div, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(Div, f64, DT_DOUBLE, double);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(Div, i16, DT_INT16, int16);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(Div, i64, DT_INT64, int64);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Div, f16, DT_HALF, Eigen::half);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Div, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Div, f64, DT_DOUBLE, double);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Div, i16, DT_INT16, int16);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Div, i64, DT_INT64, int64);
|
||||
|
||||
REGISTER_ALIASED_KERNEL(RealDiv, Div, f16, f16, Eigen::half)
|
||||
REGISTER_ALIASED_KERNEL(RealDiv, Div, f32, f32, float)
|
||||
REGISTER_ALIASED_KERNEL(RealDiv, Div, f64, f64, double)
|
||||
REGISTER_ALIASED_GPU_KERNEL(RealDiv, Div, f16, f16, Eigen::half)
|
||||
REGISTER_ALIASED_GPU_KERNEL(RealDiv, Div, f32, f32, float)
|
||||
REGISTER_ALIASED_GPU_KERNEL(RealDiv, Div, f64, f64, double)
|
||||
|
||||
REGISTER_ALIASED_KERNEL(TruncateDiv, Div, i16, i16, int16)
|
||||
REGISTER_ALIASED_KERNEL(TruncateDiv, Div, i64, i64, int64)
|
||||
REGISTER_ALIASED_GPU_KERNEL(TruncateDiv, Div, i16, i16, int16)
|
||||
REGISTER_ALIASED_GPU_KERNEL(TruncateDiv, Div, i64, i64, int64)
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -16,18 +16,18 @@ limitations under the License.
|
||||
#include <complex>
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL2(Equal, f16, i1, DT_BOOL, bool,
|
||||
Eigen::half);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL2(Equal, f32, i1, DT_BOOL, bool, float);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL2(Equal, f64, i1, DT_BOOL, bool, double);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL2(Equal, i1, i1, DT_BOOL, bool, bool);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL2(Equal, i8, i1, DT_BOOL, bool, int8);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL2(Equal, i16, i1, DT_BOOL, bool, int16);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Equal, f16, i1, DT_BOOL, bool,
|
||||
Eigen::half);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Equal, f32, i1, DT_BOOL, bool, float);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Equal, f64, i1, DT_BOOL, bool, double);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Equal, i1, i1, DT_BOOL, bool, bool);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Equal, i8, i1, DT_BOOL, bool, int8);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Equal, i16, i1, DT_BOOL, bool, int16);
|
||||
// TODO(b/25387198): We cannot use a regular GPU kernel for int32.
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL2(Equal, i64, i1, DT_BOOL, bool, int64);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Equal, i64, i1, DT_BOOL, bool, int64);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -14,12 +14,12 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Erf, f16, DT_HALF, Eigen::half);
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Erf, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Erf, f64, DT_DOUBLE, double);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Erf, f16, DT_HALF, Eigen::half);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Erf, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Erf, f64, DT_DOUBLE, double);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -14,12 +14,12 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Erfc, f64, DT_DOUBLE, double);
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Erfc, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Erfc, f16, DT_HALF, Eigen::half);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Erfc, f64, DT_DOUBLE, double);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Erfc, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Erfc, f16, DT_HALF, Eigen::half);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -14,12 +14,12 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Exp, f16, DT_HALF, Eigen::half);
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Exp, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Exp, f64, DT_DOUBLE, double);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Exp, f16, DT_HALF, Eigen::half);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Exp, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Exp, f64, DT_DOUBLE, double);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -14,12 +14,12 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Expm1, f16, DT_HALF, Eigen::half);
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Expm1, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Expm1, f64, DT_DOUBLE, double);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Expm1, f16, DT_HALF, Eigen::half);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Expm1, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Expm1, f64, DT_DOUBLE, double);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -14,12 +14,12 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Floor, f16, DT_HALF, Eigen::half);
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Floor, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Floor, f64, DT_DOUBLE, double);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Floor, f16, DT_HALF, Eigen::half);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Floor, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Floor, f64, DT_DOUBLE, double);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -13,12 +13,12 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(FloorDiv, f16, DT_HALF, Eigen::half);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(FloorDiv, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(FloorDiv, f64, DT_DOUBLE, double);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(FloorDiv, f16, DT_HALF, Eigen::half);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(FloorDiv, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(FloorDiv, f64, DT_DOUBLE, double);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -16,17 +16,21 @@ limitations under the License.
|
||||
#include <complex>
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL2(Greater, f16, i1, DT_BOOL, bool,
|
||||
Eigen::half);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL2(Greater, f32, i1, DT_BOOL, bool, float);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL2(Greater, f64, i1, DT_BOOL, bool, double);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL2(Greater, i8, i1, DT_BOOL, bool, int8);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL2(Greater, i16, i1, DT_BOOL, bool, int16);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Greater, f16, i1, DT_BOOL, bool,
|
||||
Eigen::half);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Greater, f32, i1, DT_BOOL, bool,
|
||||
float);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Greater, f64, i1, DT_BOOL, bool,
|
||||
double);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Greater, i8, i1, DT_BOOL, bool, int8);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Greater, i16, i1, DT_BOOL, bool,
|
||||
int16);
|
||||
// TODO(b/25387198): We cannot use a regular GPU kernel for int32.
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL2(Greater, i64, i1, DT_BOOL, bool, int64);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Greater, i64, i1, DT_BOOL, bool,
|
||||
int64);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -16,21 +16,22 @@ limitations under the License.
|
||||
#include <complex>
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL2(GreaterEqual, f16, i1, DT_BOOL, bool,
|
||||
Eigen::half);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL2(GreaterEqual, f32, i1, DT_BOOL, bool,
|
||||
float);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL2(GreaterEqual, f64, i1, DT_BOOL, bool,
|
||||
double);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL2(GreaterEqual, i8, i1, DT_BOOL, bool, int8);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL2(GreaterEqual, i16, i1, DT_BOOL, bool,
|
||||
int16);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(GreaterEqual, f16, i1, DT_BOOL, bool,
|
||||
Eigen::half);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(GreaterEqual, f32, i1, DT_BOOL, bool,
|
||||
float);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(GreaterEqual, f64, i1, DT_BOOL, bool,
|
||||
double);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(GreaterEqual, i8, i1, DT_BOOL, bool,
|
||||
int8);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(GreaterEqual, i16, i1, DT_BOOL, bool,
|
||||
int16);
|
||||
// TODO(b/25387198): We cannot use a regular GPU kernel for int32.
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL2(GreaterEqual, i64, i1, DT_BOOL, bool,
|
||||
int64);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(GreaterEqual, i64, i1, DT_BOOL, bool,
|
||||
int64);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -16,14 +16,15 @@ limitations under the License.
|
||||
#include <complex>
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_UNARY_KERNEL2(Imag, c64, f32, DT_FLOAT, float, std::complex<float>);
|
||||
REGISTER_COMPLEX_KERNEL(Imag, c64, f32, float, std::complex<float>);
|
||||
GENERATE_UNARY_KERNEL2(Imag, c128, f64, DT_DOUBLE, double,
|
||||
std::complex<double>);
|
||||
REGISTER_COMPLEX_KERNEL(Imag, c128, f64, double, std::complex<double>);
|
||||
GENERATE_UNARY_GPU_KERNEL2(Imag, c64, f32, DT_FLOAT, float,
|
||||
std::complex<float>);
|
||||
REGISTER_COMPLEX_GPU_KERNEL(Imag, c64, f32, float, std::complex<float>);
|
||||
GENERATE_UNARY_GPU_KERNEL2(Imag, c128, f64, DT_DOUBLE, double,
|
||||
std::complex<double>);
|
||||
REGISTER_COMPLEX_GPU_KERNEL(Imag, c128, f64, double, std::complex<double>);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -16,13 +16,13 @@ limitations under the License.
|
||||
#include <complex>
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Invert, i8, DT_INT8, int8);
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Invert, i16, DT_INT16, int16);
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Invert, i32, DT_INT32, int32);
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Invert, i64, DT_INT64, int64);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Invert, i8, DT_INT8, int8);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Invert, i16, DT_INT16, int16);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Invert, i32, DT_INT32, int32);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Invert, i64, DT_INT64, int64);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -16,15 +16,15 @@ limitations under the License.
|
||||
#include <complex>
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_UNARY_KERNEL2(IsFinite, f16, i1, DT_BOOL, bool, Eigen::half);
|
||||
REGISTER_KERNEL(IsFinite, f16, i1, Eigen::half);
|
||||
GENERATE_UNARY_KERNEL2(IsFinite, f32, i1, DT_BOOL, bool, float);
|
||||
REGISTER_KERNEL(IsFinite, f32, i1, float);
|
||||
GENERATE_UNARY_KERNEL2(IsFinite, f64, i1, DT_BOOL, bool, double);
|
||||
REGISTER_KERNEL(IsFinite, f64, i1, double);
|
||||
GENERATE_UNARY_GPU_KERNEL2(IsFinite, f16, i1, DT_BOOL, bool, Eigen::half);
|
||||
REGISTER_GPU_KERNEL(IsFinite, f16, i1, Eigen::half);
|
||||
GENERATE_UNARY_GPU_KERNEL2(IsFinite, f32, i1, DT_BOOL, bool, float);
|
||||
REGISTER_GPU_KERNEL(IsFinite, f32, i1, float);
|
||||
GENERATE_UNARY_GPU_KERNEL2(IsFinite, f64, i1, DT_BOOL, bool, double);
|
||||
REGISTER_GPU_KERNEL(IsFinite, f64, i1, double);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -16,15 +16,15 @@ limitations under the License.
|
||||
#include <complex>
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_UNARY_KERNEL2(IsInf, f16, i1, DT_BOOL, bool, Eigen::half);
|
||||
REGISTER_KERNEL(IsInf, f16, i1, Eigen::half);
|
||||
GENERATE_UNARY_KERNEL2(IsInf, f32, i1, DT_BOOL, bool, float);
|
||||
REGISTER_KERNEL(IsInf, f32, i1, float);
|
||||
GENERATE_UNARY_KERNEL2(IsInf, f64, i1, DT_BOOL, bool, double);
|
||||
REGISTER_KERNEL(IsInf, f64, i1, double);
|
||||
GENERATE_UNARY_GPU_KERNEL2(IsInf, f16, i1, DT_BOOL, bool, Eigen::half);
|
||||
REGISTER_GPU_KERNEL(IsInf, f16, i1, Eigen::half);
|
||||
GENERATE_UNARY_GPU_KERNEL2(IsInf, f32, i1, DT_BOOL, bool, float);
|
||||
REGISTER_GPU_KERNEL(IsInf, f32, i1, float);
|
||||
GENERATE_UNARY_GPU_KERNEL2(IsInf, f64, i1, DT_BOOL, bool, double);
|
||||
REGISTER_GPU_KERNEL(IsInf, f64, i1, double);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -16,15 +16,15 @@ limitations under the License.
|
||||
#include <complex>
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_UNARY_KERNEL2(IsNan, f16, i1, DT_BOOL, bool, Eigen::half);
|
||||
REGISTER_KERNEL(IsNan, f16, i1, Eigen::half);
|
||||
GENERATE_UNARY_KERNEL2(IsNan, f32, i1, DT_BOOL, bool, float);
|
||||
REGISTER_KERNEL(IsNan, f32, i1, float);
|
||||
GENERATE_UNARY_KERNEL2(IsNan, f64, i1, DT_BOOL, bool, double);
|
||||
REGISTER_KERNEL(IsNan, f64, i1, double);
|
||||
GENERATE_UNARY_GPU_KERNEL2(IsNan, f16, i1, DT_BOOL, bool, Eigen::half);
|
||||
REGISTER_GPU_KERNEL(IsNan, f16, i1, Eigen::half);
|
||||
GENERATE_UNARY_GPU_KERNEL2(IsNan, f32, i1, DT_BOOL, bool, float);
|
||||
REGISTER_GPU_KERNEL(IsNan, f32, i1, float);
|
||||
GENERATE_UNARY_GPU_KERNEL2(IsNan, f64, i1, DT_BOOL, bool, double);
|
||||
REGISTER_GPU_KERNEL(IsNan, f64, i1, double);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -13,13 +13,13 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(LeftShift, i8, DT_INT8, int8);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(LeftShift, i16, DT_INT16, int16);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(LeftShift, i32, DT_INT32, int32);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(LeftShift, i64, DT_INT64, int64);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(LeftShift, i8, DT_INT8, int8);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(LeftShift, i16, DT_INT16, int16);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(LeftShift, i32, DT_INT32, int32);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(LeftShift, i64, DT_INT64, int64);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -16,16 +16,17 @@ limitations under the License.
|
||||
#include <complex>
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL2(Less, f16, i1, DT_BOOL, bool, Eigen::half);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL2(Less, f32, i1, DT_BOOL, bool, float);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL2(Less, f64, i1, DT_BOOL, bool, double);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL2(Less, i8, i1, DT_BOOL, bool, int8);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL2(Less, i16, i1, DT_BOOL, bool, int16);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Less, f16, i1, DT_BOOL, bool,
|
||||
Eigen::half);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Less, f32, i1, DT_BOOL, bool, float);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Less, f64, i1, DT_BOOL, bool, double);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Less, i8, i1, DT_BOOL, bool, int8);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Less, i16, i1, DT_BOOL, bool, int16);
|
||||
// TODO(b/25387198): We cannot use a regular GPU kernel for int32.
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL2(Less, i64, i1, DT_BOOL, bool, int64);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(Less, i64, i1, DT_BOOL, bool, int64);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -16,17 +16,22 @@ limitations under the License.
|
||||
#include <complex>
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL2(LessEqual, f16, i1, DT_BOOL, bool,
|
||||
Eigen::half);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL2(LessEqual, f32, i1, DT_BOOL, bool, float);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL2(LessEqual, f64, i1, DT_BOOL, bool, double);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL2(LessEqual, i8, i1, DT_BOOL, bool, int8);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL2(LessEqual, i16, i1, DT_BOOL, bool, int16);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(LessEqual, f16, i1, DT_BOOL, bool,
|
||||
Eigen::half);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(LessEqual, f32, i1, DT_BOOL, bool,
|
||||
float);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(LessEqual, f64, i1, DT_BOOL, bool,
|
||||
double);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(LessEqual, i8, i1, DT_BOOL, bool,
|
||||
int8);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(LessEqual, i16, i1, DT_BOOL, bool,
|
||||
int16);
|
||||
// TODO(b/25387198): We cannot use a regular GPU kernel for int32.
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL2(LessEqual, i64, i1, DT_BOOL, bool, int64);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(LessEqual, i64, i1, DT_BOOL, bool,
|
||||
int64);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -14,12 +14,12 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Lgamma, f16, DT_HALF, Eigen::half);
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Lgamma, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Lgamma, f64, DT_DOUBLE, double);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Lgamma, f16, DT_HALF, Eigen::half);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Lgamma, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Lgamma, f64, DT_DOUBLE, double);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -14,12 +14,12 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Log, f16, DT_HALF, Eigen::half);
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Log, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Log, f64, DT_DOUBLE, double);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Log, f16, DT_HALF, Eigen::half);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Log, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Log, f64, DT_DOUBLE, double);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -14,12 +14,12 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Log1p, f16, DT_HALF, Eigen::half);
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Log1p, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Log1p, f64, DT_DOUBLE, double);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Log1p, f16, DT_HALF, Eigen::half);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Log1p, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Log1p, f64, DT_DOUBLE, double);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -13,13 +13,13 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_BINARY_KERNEL(LogicalAnd, i1, DT_BOOL, bool);
|
||||
GENERATE_BINARY_GPU_KERNEL(LogicalAnd, i1, DT_BOOL, bool);
|
||||
// LogicalAnd does not have a "T" attribute because it only works with type
|
||||
// bool. So we need to register it without TypeConstraint<bool>("T").
|
||||
REGISTER_KERNEL_NO_TYPE_CONSTRAINT(LogicalAnd, i1, i1);
|
||||
REGISTER_GPU_KERNEL_NO_TYPE_CONSTRAINT(LogicalAnd, i1, i1);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -13,13 +13,13 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_UNARY_KERNEL(LogicalNot, i1, DT_BOOL, bool);
|
||||
GENERATE_UNARY_GPU_KERNEL(LogicalNot, i1, DT_BOOL, bool);
|
||||
// LogicalNot does not have a "T" attribute because it only works with type
|
||||
// bool. So we need to register it without TypeConstraint<bool>("T").
|
||||
REGISTER_KERNEL_NO_TYPE_CONSTRAINT(LogicalNot, i1, i1);
|
||||
REGISTER_GPU_KERNEL_NO_TYPE_CONSTRAINT(LogicalNot, i1, i1);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -13,13 +13,13 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_BINARY_KERNEL(LogicalOr, i1, DT_BOOL, bool);
|
||||
GENERATE_BINARY_GPU_KERNEL(LogicalOr, i1, DT_BOOL, bool);
|
||||
// LogicalOr does not have a "T" attribute because it only works with type
|
||||
// bool. So we need to register it without TypeConstraint<bool>("T").
|
||||
REGISTER_KERNEL_NO_TYPE_CONSTRAINT(LogicalOr, i1, i1);
|
||||
REGISTER_GPU_KERNEL_NO_TYPE_CONSTRAINT(LogicalOr, i1, i1);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -14,15 +14,15 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(Maximum, f16, DT_HALF, Eigen::half);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(Maximum, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(Maximum, f64, DT_DOUBLE, double);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(Maximum, i16, DT_INT16, int16);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Maximum, f16, DT_HALF, Eigen::half);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Maximum, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Maximum, f64, DT_DOUBLE, double);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Maximum, i16, DT_INT16, int16);
|
||||
// TODO(b/25387198): We cannot use a regular GPU kernel for int32.
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(Maximum, i64, DT_INT64, int64);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Maximum, i64, DT_INT64, int64);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -14,15 +14,15 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(Minimum, f16, DT_HALF, Eigen::half);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(Minimum, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(Minimum, f64, DT_DOUBLE, double);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(Minimum, i16, DT_INT16, int16);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Minimum, f16, DT_HALF, Eigen::half);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Minimum, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Minimum, f64, DT_DOUBLE, double);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Minimum, i16, DT_INT16, int16);
|
||||
// TODO(b/25387198): We cannot use a regular GPU kernel for int32.
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(Minimum, i64, DT_INT64, int64);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Minimum, i64, DT_INT64, int64);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -14,16 +14,16 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(Mul, f16, DT_HALF, Eigen::half);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(Mul, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(Mul, f64, DT_DOUBLE, double);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(Mul, i8, DT_INT8, int8);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Mul, f16, DT_HALF, Eigen::half);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Mul, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Mul, f64, DT_DOUBLE, double);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Mul, i8, DT_INT8, int8);
|
||||
// TODO(b/25387198): We cannot use a regular GPU kernel for int32.
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(Mul, i16, DT_INT16, int16);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(Mul, i64, DT_INT64, int64);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Mul, i16, DT_INT16, int16);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Mul, i64, DT_INT64, int64);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -14,16 +14,16 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Neg, f16, DT_HALF, Eigen::half);
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Neg, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Neg, f64, DT_DOUBLE, double);
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Neg, i8, DT_INT8, int8);
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Neg, i16, DT_INT16, int16);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Neg, f16, DT_HALF, Eigen::half);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Neg, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Neg, f64, DT_DOUBLE, double);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Neg, i8, DT_INT8, int8);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Neg, i16, DT_INT16, int16);
|
||||
// TODO(b/25387198): We cannot use a regular GPU kernel for int32.
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Neg, i64, DT_INT64, int64);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Neg, i64, DT_INT64, int64);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -16,18 +16,22 @@ limitations under the License.
|
||||
#include <complex>
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL2(NotEqual, f16, i1, DT_BOOL, bool,
|
||||
Eigen::half);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL2(NotEqual, f32, i1, DT_BOOL, bool, float);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL2(NotEqual, f64, i1, DT_BOOL, bool, double);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL2(NotEqual, i1, i1, DT_BOOL, bool, bool);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL2(NotEqual, i8, i1, DT_BOOL, bool, int8);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL2(NotEqual, i16, i1, DT_BOOL, bool, int16);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(NotEqual, f16, i1, DT_BOOL, bool,
|
||||
Eigen::half);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(NotEqual, f32, i1, DT_BOOL, bool,
|
||||
float);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(NotEqual, f64, i1, DT_BOOL, bool,
|
||||
double);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(NotEqual, i1, i1, DT_BOOL, bool, bool);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(NotEqual, i8, i1, DT_BOOL, bool, int8);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(NotEqual, i16, i1, DT_BOOL, bool,
|
||||
int16);
|
||||
// TODO(b/25387198): We cannot use a regular GPU kernel for int32.
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL2(NotEqual, i64, i1, DT_BOOL, bool, int64);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(NotEqual, i64, i1, DT_BOOL, bool,
|
||||
int64);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -13,13 +13,13 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(Pow, f16, DT_HALF, Eigen::half);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(Pow, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(Pow, f64, DT_DOUBLE, double);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(Pow, i64, DT_INT64, int64);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Pow, f16, DT_HALF, Eigen::half);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Pow, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Pow, f64, DT_DOUBLE, double);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Pow, i64, DT_INT64, int64);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -16,14 +16,15 @@ limitations under the License.
|
||||
#include <complex>
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_UNARY_KERNEL2(Real, c64, f32, DT_FLOAT, float, std::complex<float>);
|
||||
REGISTER_COMPLEX_KERNEL(Real, c64, f32, float, std::complex<float>);
|
||||
GENERATE_UNARY_KERNEL2(Real, c128, f64, DT_DOUBLE, double,
|
||||
std::complex<double>);
|
||||
REGISTER_COMPLEX_KERNEL(Real, c128, f64, double, std::complex<double>);
|
||||
GENERATE_UNARY_GPU_KERNEL2(Real, c64, f32, DT_FLOAT, float,
|
||||
std::complex<float>);
|
||||
REGISTER_COMPLEX_GPU_KERNEL(Real, c64, f32, float, std::complex<float>);
|
||||
GENERATE_UNARY_GPU_KERNEL2(Real, c128, f64, DT_DOUBLE, double,
|
||||
std::complex<double>);
|
||||
REGISTER_COMPLEX_GPU_KERNEL(Real, c128, f64, double, std::complex<double>);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -13,13 +13,13 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(RightShift, i8, DT_INT8, int8);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(RightShift, i16, DT_INT16, int16);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(RightShift, i32, DT_INT32, int32);
|
||||
GENERATE_AND_REGISTER_BINARY_KERNEL(RightShift, i64, DT_INT64, int64);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(RightShift, i8, DT_INT8, int8);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(RightShift, i16, DT_INT16, int16);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(RightShift, i32, DT_INT32, int32);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(RightShift, i64, DT_INT64, int64);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -14,12 +14,12 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Rsqrt, f16, DT_HALF, Eigen::half);
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Rsqrt, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Rsqrt, f64, DT_DOUBLE, double);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Rsqrt, f16, DT_HALF, Eigen::half);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Rsqrt, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Rsqrt, f64, DT_DOUBLE, double);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -14,15 +14,15 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Sign, f16, DT_HALF, Eigen::half);
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Sign, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Sign, f64, DT_DOUBLE, double);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Sign, f16, DT_HALF, Eigen::half);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Sign, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Sign, f64, DT_DOUBLE, double);
|
||||
// TODO(b/25387198): We cannot use a regular GPU kernel for int32.
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Sign, i64, DT_INT64, int64);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Sign, i64, DT_INT64, int64);
|
||||
// TODO(b/162577610): Register the kernel for complex types and bfloat.
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -14,12 +14,12 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Sin, f16, DT_HALF, Eigen::half);
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Sin, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Sin, f64, DT_DOUBLE, double);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Sin, f16, DT_HALF, Eigen::half);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Sin, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Sin, f64, DT_DOUBLE, double);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -14,11 +14,11 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Sinh, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_UNARY_KERNEL(Sinh, f64, DT_DOUBLE, double);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Sinh, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(Sinh, f64, DT_DOUBLE, double);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user