Lower TensorFlow Einsum op to HLO
Specifically, * Verify TF Einsum op arity * Define Einsum and UnaryEinsum op in HLO. UnaryEinsum is defined so that the op is not variadic or requires conversion from unary to binary at the time of import. This way translations are simplified and also it is easier to operand on the op. * Convert TF Einsum op to HLO Einsum op or UnaryEinsum op. * Canonicalize UnaryEinsum to Einsum op with two Also, added support for the StringAttr in HLO exporter generator. PiperOrigin-RevId: 283441951 Change-Id: I0d86ae2723209b08eb2c243a9dcde843dbd5897c
This commit is contained in:
parent
a46fa0b405
commit
513f16d55d
@ -1358,6 +1358,10 @@ Comparison with `numpy.einsum`:
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>;
|
||||
|
||||
let verifier = [{
|
||||
return Verify(*this);
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_EluOp : TF_Op<"Elu", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
|
@ -674,6 +674,21 @@ void DivOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
results.insert<DivWithSqrtDivisor>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// EinsumOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Verifies that,
|
||||
// * Arity of the op is at most two.
|
||||
//
|
||||
// TODO(hinsu): Verify einsum equation attribute.
|
||||
static LogicalResult Verify(EinsumOp op) {
|
||||
if (op.N() > 2) {
|
||||
return op.emitOpError("supports at most two operands");
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// EmptyTensorListOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1650,3 +1650,11 @@ func @testSplitSmallSplitDim(%input: tensor<4x8xf32>) {
|
||||
%0:3 = "tf.Split"(%cst, %input) : (tensor<i32>, tensor<4x8xf32>) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testTernaryEinsum(%arg0: tensor<2x3xf32>){
|
||||
// expected-error @+1 {{supports at most two operands}}
|
||||
%0 = "tf.Einsum"(%arg0, %arg0, %arg0) {equation = "ab,cd,ef->"} : (tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>) -> (tensor<*xf32>)
|
||||
return
|
||||
}
|
||||
|
@ -404,6 +404,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/client/lib:matrix",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:Analysis",
|
||||
|
@ -47,10 +47,10 @@ limitations under the License.
|
||||
#include "mlir/Transforms/InliningUtils.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/xla/convert_op_folder.h"
|
||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h.inc"
|
||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_utils.h"
|
||||
|
||||
namespace mlir {
|
||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_structs.cc.inc"
|
||||
|
||||
namespace xla_hlo {
|
||||
|
||||
Operation* XlaHloDialect::materializeConstant(OpBuilder& builder,
|
||||
@ -936,6 +936,15 @@ void TupleOp::build(Builder* builder, OperationState& result,
|
||||
build(builder, result, builder->getTupleType(types), values);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// UnaryEinsumOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void UnaryEinsumOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList& results, MLIRContext* context) {
|
||||
results.insert<UnaryEinsumToEinsum>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CompareOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -730,6 +730,43 @@ def HLO_DotGeneralOp: HLO_Op<"dot_general", [NoSideEffect]>, BASE_HLO_DotGeneral
|
||||
let results = (outs HLO_Tensor);
|
||||
}
|
||||
|
||||
def BASE_EinsumOp {
|
||||
string summary = "Einsum operator";
|
||||
|
||||
string description = [{
|
||||
Returns a tensor whose elements are defined by equation, which is written
|
||||
in a shorthand form inspired by the Einstein summation convention.
|
||||
}];
|
||||
}
|
||||
|
||||
def HLO_EinsumOp: HLO_Op<"einsum", [NoSideEffect]> {
|
||||
let arguments = (ins
|
||||
HLO_Tensor:$lhs,
|
||||
HLO_Tensor:$rhs,
|
||||
StrAttr:$einsum_config
|
||||
);
|
||||
|
||||
let results = (outs HLO_Tensor);
|
||||
|
||||
// TODO(hinsu): Canonicalize to lower this client side HLO op to server
|
||||
// side HLO ops.
|
||||
}
|
||||
|
||||
def HLO_UnaryEinsumOp: HLO_Op<"unary_einsum", [NoSideEffect]> {
|
||||
let arguments = (ins
|
||||
HLO_Tensor:$operand,
|
||||
StrAttr:$einsum_config
|
||||
);
|
||||
|
||||
let results = (outs HLO_Tensor);
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
|
||||
// UnarayEinsumOp is unconditionally canonicalized to the binary EinsumOp so
|
||||
// the HLO converter shouldn't be invoked.
|
||||
let hasCustomHLOConverter = 1;
|
||||
}
|
||||
|
||||
def HLO_FftOp: HLO_Op<"fft", [NoSideEffect]>, BASE_HLO_FftOp {
|
||||
let arguments = (ins
|
||||
HLO_Tensor:$operand,
|
||||
|
@ -17,6 +17,8 @@ limitations under the License.
|
||||
|
||||
#include <numeric>
|
||||
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
|
||||
namespace mlir {
|
||||
namespace xla {
|
||||
|
||||
@ -51,5 +53,18 @@ DenseIntElementsAttr getBroadcastDimensionsAttr(Builder *b, Value *x,
|
||||
return DenseIntElementsAttr::get(type, broadcastDimensions);
|
||||
}
|
||||
|
||||
DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value) {
|
||||
RankedTensorType scalar_ty = RankedTensorType::get({}, ty);
|
||||
|
||||
DenseElementsAttr attr;
|
||||
if (auto float_ty = ty.dyn_cast<FloatType>()) {
|
||||
APFloat value(float_ty.getFloatSemantics(), raw_value);
|
||||
return DenseElementsAttr::get(scalar_ty, value);
|
||||
}
|
||||
auto int_ty = ty.cast<IntegerType>();
|
||||
APInt value(int_ty.getWidth(), static_cast<int64_t>(raw_value), true);
|
||||
return DenseElementsAttr::get(scalar_ty, value);
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
} // namespace mlir
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/xla/convert_op_folder.h"
|
||||
@ -48,6 +49,12 @@ static ElementsAttr getSplat(Builder* b, Value* val, T constant) {
|
||||
|
||||
return DenseElementsAttr::get(valType, elementAttr);
|
||||
}
|
||||
|
||||
// Returns DenseElementsAttr of rank zero with the given element type and the
|
||||
// value.
|
||||
// Requires `ty` to be either FloatType of IntegerType.
|
||||
DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value);
|
||||
|
||||
} // namespace xla
|
||||
} // namespace mlir
|
||||
|
||||
|
@ -18,9 +18,7 @@ limitations under the License.
|
||||
#ifndef HLO_UTILS
|
||||
#define HLO_UTILS
|
||||
|
||||
#ifndef OP_BASE
|
||||
include "mlir/IR/OpBase.td"
|
||||
#endif // OP_BASE
|
||||
|
||||
def NullArrayAttr : NativeCodeCall<"ArrayAttr()">;
|
||||
|
||||
@ -34,4 +32,9 @@ def NullDenseIntElementsAttr : NativeCodeCall<"DenseIntElementsAttr()">;
|
||||
def BinBroadcastDimensions : NativeCodeCall<
|
||||
"xla::getBroadcastDimensionsAttr(&$_builder, $0, $1)">;
|
||||
|
||||
// Here, the element type can be any integer or float type. But, note that only
|
||||
// 32 bit integers are supported for the value.
|
||||
class GetScalarOfType<int value> : NativeCodeCall<
|
||||
"xla::GetScalarOfType(getElementTypeOrSelf($0)," # value # ")">;
|
||||
|
||||
#endif // HLO_UTILS
|
||||
|
@ -33,6 +33,7 @@ limitations under the License.
|
||||
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/matrix.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/comparison_util.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
@ -77,6 +78,10 @@ static double ConvertAPFloat(llvm::APFloat value) {
|
||||
return value.convertToDouble();
|
||||
}
|
||||
|
||||
static absl::string_view ConvertStringRef(mlir::StringRef value) {
|
||||
return {value.data(), value.size()};
|
||||
}
|
||||
|
||||
static std::vector<int64> ConvertDenseIntAttr(mlir::DenseIntElementsAttr attr) {
|
||||
auto values = attr.getValues<int64>();
|
||||
return {values.begin(), values.end()};
|
||||
@ -632,6 +637,12 @@ LogicalResult ExportXlaOp(TupleOp op, OpLoweringContext ctx) {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ExportXlaOp(UnaryEinsumOp op, OpLoweringContext ctx) {
|
||||
// Intentional as UnaryEinsumOp is always lowered to the EinsumOp with two
|
||||
// operands.
|
||||
return failure();
|
||||
}
|
||||
|
||||
LogicalResult ExportXlaOp(WhileOp op, OpLoweringContext ctx) {
|
||||
xla::XlaComputation condition;
|
||||
xla::XlaComputation body;
|
||||
|
@ -32,17 +32,20 @@ using llvm::raw_ostream;
|
||||
using llvm::RecordKeeper;
|
||||
using llvm::StringRef;
|
||||
using mlir::interleaveComma;
|
||||
using mlir::tblgen::Attribute;
|
||||
using mlir::tblgen::NamedAttribute;
|
||||
using mlir::tblgen::NamedTypeConstraint;
|
||||
using mlir::tblgen::Operator;
|
||||
|
||||
static std::string GetDefaultAttrExport(
|
||||
const mlir::tblgen::NamedAttribute& named_attr) {
|
||||
auto storage_type = named_attr.attr.getStorageType();
|
||||
Attribute attr = named_attr.attr;
|
||||
StringRef storage_type = attr.getStorageType();
|
||||
// For some attribute types we have a general conversion, so use that.
|
||||
if (storage_type.endswith("IntegerAttr") ||
|
||||
storage_type.endswith("FloatAttr")) {
|
||||
return "Convert" + named_attr.attr.getReturnType().str();
|
||||
if (!attr.isEnumAttr() && (storage_type.endswith("IntegerAttr") ||
|
||||
storage_type.endswith("FloatAttr") ||
|
||||
storage_type.endswith("StringAttr"))) {
|
||||
return "Convert" + attr.getReturnType().str();
|
||||
}
|
||||
return "Convert_" + named_attr.name.str();
|
||||
}
|
||||
|
@ -48,3 +48,11 @@ func @complex_collapse_fold(%arg0: tensor<4xcomplex<f32>>) -> tensor<4xcomplex<f
|
||||
// CHECK: return %arg0
|
||||
return %2 : tensor<4xcomplex<f32>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @unary_einsum
|
||||
func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: %[[ONE:.*]] = xla_hlo.constant dense<1.000000e+00> : tensor<f32>
|
||||
// CHECK: "xla_hlo.einsum"(%[[ONE]], %arg0) {einsum_config = ",ab->aa"}
|
||||
%0 = "xla_hlo.unary_einsum"(%arg0) {einsum_config = "ab->aa"} : (tensor<2x3xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
|
@ -215,6 +215,20 @@ func @pow_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
return %0: tensor<?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @einsum
|
||||
func @einsum(%arg0: tensor<2x3xf32>, %arg1: tensor<3x4xf32>) -> tensor<2x4xf32> {
|
||||
// CHECK: xla_hlo.einsum
|
||||
%0 = "tf.Einsum"(%arg0, %arg1) {equation = "ab,bc->ac"} : (tensor<2x3xf32>, tensor<3x4xf32>) -> tensor<2x4xf32>
|
||||
return %0: tensor<2x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @unary_einsum
|
||||
func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: xla_hlo.unary_einsum
|
||||
%0 = "tf.Einsum"(%arg0) {equation = "ab->aa"} : (tensor<2x3xf32>) -> tensor<2x2xf32>
|
||||
return %0: tensor<2x2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @floordiv_broadcast_i32
|
||||
func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> {
|
||||
// CHECK-DAG: [[ZEROS1:%.+]] = xla_hlo.constant dense<0>
|
||||
|
9
tensorflow/compiler/mlir/xla/tests/translate/einsum.mlir
Normal file
9
tensorflow/compiler/mlir/xla/tests/translate/einsum.mlir
Normal file
@ -0,0 +1,9 @@
|
||||
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: ENTRY
|
||||
func @main(%arg0: tensor<3x4xi32>, %arg1: tensor<4x5xi32>) -> tensor<3x5xi32> {
|
||||
// Simple einsum is lowered to HLO dot op.
|
||||
// CHECK: dot(s32[3,4] %{{.*}}, s32[4,5] %{{.*}}), lhs_contracting_dims={1}, rhs_contracting_dims={0}
|
||||
%0 = "xla_hlo.einsum"(%arg0, %arg1) {einsum_config = "ab,bc->ac"} : (tensor<3x4xi32>, tensor<4x5xi32>) -> tensor<3x5xi32>
|
||||
return %0 : tensor<3x5xi32>
|
||||
}
|
@ -17,7 +17,7 @@ limitations under the License.
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td"
|
||||
|
||||
include "tensorflow/compiler/mlir/xla/ir/hlo_utils.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DynamicSlice op patterns.
|
||||
@ -37,3 +37,13 @@ def DynamicSliceToSlice: Pat<(HLO_DynamicSliceOp HLO_Tensor:$input,
|
||||
(HLO_SliceOp $input, (CastIntElementsAttr $starting_indices),
|
||||
(BuildSliceLimits $starting_indices, $slice_sizes),
|
||||
(BuildSliceStrides $input))>;
|
||||
|
||||
def UnaryToBianryEinsumEq : NativeCodeCall<
|
||||
"$_builder.getStringAttr(\",\" + $0.getValue().str())">;
|
||||
|
||||
// Convert UnaryEinsumOp to EinsumOp with two operands with redundant first
|
||||
// operand.
|
||||
def UnaryEinsumToEinsum : Pat<
|
||||
(HLO_UnaryEinsumOp $operand, $equation),
|
||||
(HLO_EinsumOp (HLO_ConstOp (GetScalarOfType<1> $operand)),
|
||||
$operand, (UnaryToBianryEinsumEq $equation))>;
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/Optional.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Diagnostics.h" // TF:local_config_mlir
|
||||
@ -168,20 +169,9 @@ static ConstOp GetMinValueForType(Type ty, Location loc,
|
||||
|
||||
// Returns int or float scalar DenseElementsAttr attribute with the given
|
||||
// element type and the value.
|
||||
static ConstOp GetScalarOfType(Type ty, Location loc, int64_t raw_value,
|
||||
PatternRewriter *rewriter) {
|
||||
RankedTensorType scalar_ty = RankedTensorType::get({}, ty);
|
||||
|
||||
DenseElementsAttr attr;
|
||||
if (auto float_ty = ty.dyn_cast_or_null<FloatType>()) {
|
||||
APFloat value(float_ty.getFloatSemantics(), raw_value);
|
||||
attr = DenseElementsAttr::get(scalar_ty, value);
|
||||
} else {
|
||||
auto int_ty = ty.cast<IntegerType>();
|
||||
APInt value(int_ty.getWidth(), static_cast<int64_t>(raw_value), true);
|
||||
attr = DenseElementsAttr::get(scalar_ty, value);
|
||||
}
|
||||
return rewriter->create<ConstOp>(loc, attr);
|
||||
static ConstOp GetScalarConstOfType(Type ty, Location loc, int64_t raw_value,
|
||||
PatternRewriter *rewriter) {
|
||||
return rewriter->create<ConstOp>(loc, xla::GetScalarOfType(ty, raw_value));
|
||||
}
|
||||
|
||||
// Builds body for reduce op by using the using the template binary op as the
|
||||
@ -639,6 +629,31 @@ class ConvertBF16FloorDivOp : public OpRewritePattern<TF::FloorDivOp> {
|
||||
}
|
||||
};
|
||||
|
||||
// Converts TensorFlow EinsumOp to either HLO EinsumOp or UnaryEinsumOp
|
||||
// depending on arity of the op.
|
||||
class ConvertEinsumOp : public OpRewritePattern<TF::EinsumOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(TF::EinsumOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
StringAttr equation = op.getAttrOfType<StringAttr>("equation");
|
||||
if (op.N() == 1) {
|
||||
rewriter.replaceOpWithNewOp<UnaryEinsumOp>(
|
||||
op, op.getType(), *op.inputs().begin(), equation);
|
||||
} else if (op.N() == 2) {
|
||||
auto inputs = llvm::to_vector<2>(op.inputs());
|
||||
rewriter.replaceOpWithNewOp<EinsumOp>(op, op.getType(), inputs[0],
|
||||
inputs[1], equation);
|
||||
} else {
|
||||
// TensorFlow EinsumOp verifies that the number of operands are at most
|
||||
// two.
|
||||
return Pattern::matchFailure();
|
||||
}
|
||||
return Pattern::matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
// Converts MaxPool op to HLO ReduceWindow op by setting appropriate window
|
||||
// dimensions with max as the reduction function.
|
||||
//
|
||||
@ -847,8 +862,8 @@ class ConvertSizeOp : public OpRewritePattern<TF::SizeOp> {
|
||||
const int64_t rank = input_ty.getRank();
|
||||
auto result_type = op.getResult()->getType();
|
||||
Operation *size =
|
||||
GetScalarOfType(result_type.cast<TensorType>().getElementType(),
|
||||
op.getLoc(), 1, &rewriter);
|
||||
GetScalarConstOfType(result_type.cast<TensorType>().getElementType(),
|
||||
op.getLoc(), 1, &rewriter);
|
||||
for (int64_t i = 0; i < rank; ++i) {
|
||||
auto dim = rewriter.create<GetDimensionSizeOp>(
|
||||
op.getLoc(), result_type, input,
|
||||
@ -1169,8 +1184,8 @@ class GenericConvertReductionOp : public OpRewritePattern<OpTy> {
|
||||
divisor_count *= input_shape[i];
|
||||
}
|
||||
}
|
||||
auto divisor =
|
||||
GetScalarOfType(reduce_element_type, loc, divisor_count, &rewriter);
|
||||
auto divisor = GetScalarConstOfType(reduce_element_type, loc,
|
||||
divisor_count, &rewriter);
|
||||
auto broadcast_dims = GetI64ElementsAttr({}, &rewriter);
|
||||
result = rewriter.create<DivOp>(loc, result, divisor.getResult(),
|
||||
broadcast_dims);
|
||||
@ -1203,7 +1218,7 @@ class ConvertMeanOp
|
||||
|
||||
static Value *GetInitialValue(Type reduce_element_type, Location loc,
|
||||
PatternRewriter &rewriter) {
|
||||
return GetScalarOfType(reduce_element_type, loc, 0, &rewriter);
|
||||
return GetScalarConstOfType(reduce_element_type, loc, 0, &rewriter);
|
||||
}
|
||||
};
|
||||
|
||||
@ -1219,7 +1234,7 @@ class ConvertSumOp
|
||||
|
||||
static Value *GetInitialValue(Type reduce_element_type, Location loc,
|
||||
PatternRewriter &rewriter) {
|
||||
return GetScalarOfType(reduce_element_type, loc, 0, &rewriter);
|
||||
return GetScalarConstOfType(reduce_element_type, loc, 0, &rewriter);
|
||||
}
|
||||
};
|
||||
|
||||
@ -1274,7 +1289,7 @@ class ConvertArgMinMaxOp : public OpRewritePattern<OpTy> {
|
||||
|
||||
Type index_element_type = output_type.getElementType();
|
||||
Value *index_init_value =
|
||||
GetScalarOfType(index_element_type, loc, 0, &rewriter);
|
||||
GetScalarConstOfType(index_element_type, loc, 0, &rewriter);
|
||||
|
||||
RankedTensorType index_type =
|
||||
RankedTensorType::get(input_type.getShape(), index_element_type);
|
||||
@ -1418,7 +1433,7 @@ class ConvertMaxPoolGradOp : public OpRewritePattern<TF::MaxPoolGradOp> {
|
||||
|
||||
auto result = rewriter.create<SelectAndScatterOp>(
|
||||
loc, op.getType(), op.orig_input(), op.grad(),
|
||||
GetScalarOfType(element_type, loc, 0, &rewriter),
|
||||
GetScalarConstOfType(element_type, loc, 0, &rewriter),
|
||||
GetI64ElementsAttr(op.ksize()), GetI64ElementsAttr(op.strides()),
|
||||
nullptr);
|
||||
|
||||
@ -1860,8 +1875,9 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) {
|
||||
TF::PopulateLoweringTFPatterns(context, &patterns);
|
||||
patterns
|
||||
.insert<ConvertArgMaxOp, ConvertBF16FloorDivOp, ConvertConv2D,
|
||||
ConvertMaxPoolOp, ConvertRangeOp, ConvertSigmoidOp, ConvertSizeOp,
|
||||
ConvertSoftmaxOp<TF::LogSoftmaxOp, true>,
|
||||
ConvertEinsumOp, ConvertMaxPoolOp, ConvertRangeOp,
|
||||
ConvertSigmoidOp, ConvertSizeOp, ConvertMaxPoolOp, ConvertRangeOp,
|
||||
ConvertSigmoidOp, ConvertSoftmaxOp<TF::LogSoftmaxOp, true>,
|
||||
ConvertSoftmaxOp<TF::SoftmaxOp, false>, ConvertSplitOp,
|
||||
ConvertStridedSliceOp, ConvertMeanOp, ConvertSumOp, ConvertMaxOp,
|
||||
ConvertTileOp, ConvertMaxPoolGradOp, ConvertOneHotOp,
|
||||
|
Loading…
x
Reference in New Issue
Block a user