From 513f16d55d31f2daf21972f910b0438fad3f151e Mon Sep 17 00:00:00 2001 From: Smit Hinsu Date: Mon, 2 Dec 2019 16:24:53 -0800 Subject: [PATCH] Lower TensorFlow Einsum op to HLO Specifically, * Verify TF Einsum op arity * Define Einsum and UnaryEinsum op in HLO. UnaryEinsum is defined so that the op is not variadic or requires conversion from unary to binary at the time of import. This way translations are simplified and also it is easier to operand on the op. * Convert TF Einsum op to HLO Einsum op or UnaryEinsum op. * Canonicalize UnaryEinsum to Einsum op with two Also, added support for the StringAttr in HLO exporter generator. PiperOrigin-RevId: 283441951 Change-Id: I0d86ae2723209b08eb2c243a9dcde843dbd5897c --- .../mlir/tensorflow/ir/tf_generated_ops.td | 4 ++ .../compiler/mlir/tensorflow/ir/tf_ops.cc | 15 +++++ .../mlir/tensorflow/tests/tf-ops.mlir | 8 +++ tensorflow/compiler/mlir/xla/BUILD | 1 + tensorflow/compiler/mlir/xla/ir/hlo_ops.cc | 11 +++- tensorflow/compiler/mlir/xla/ir/hlo_ops.td | 37 +++++++++++ tensorflow/compiler/mlir/xla/ir/hlo_utils.cc | 15 +++++ tensorflow/compiler/mlir/xla/ir/hlo_utils.h | 7 ++ tensorflow/compiler/mlir/xla/ir/hlo_utils.td | 7 +- .../compiler/mlir/xla/mlir_hlo_to_hlo.cc | 11 ++++ .../compiler/mlir/xla/operator_writer_gen.cc | 11 ++-- .../compiler/mlir/xla/tests/canonicalize.mlir | 8 +++ .../compiler/mlir/xla/tests/legalize-tf.mlir | 14 ++++ .../mlir/xla/tests/translate/einsum.mlir | 9 +++ .../mlir/xla/transforms/canonicalize.td | 12 +++- .../mlir/xla/transforms/legalize_tf.cc | 64 ++++++++++++------- 16 files changed, 202 insertions(+), 32 deletions(-) create mode 100644 tensorflow/compiler/mlir/xla/tests/translate/einsum.mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index b68634ba704..57b61461d02 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -1358,6 +1358,10 @@ Comparison with `numpy.einsum`: TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>; + + let verifier = [{ + return Verify(*this); + }]; } def TF_EluOp : TF_Op<"Elu", [NoSideEffect, SameOperandsAndResultType]> { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index a58e20a9952..3b836a6188d 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -674,6 +674,21 @@ void DivOp::getCanonicalizationPatterns(OwningRewritePatternList &results, results.insert(context); } +//===----------------------------------------------------------------------===// +// EinsumOp +//===----------------------------------------------------------------------===// + +// Verifies that, +// * Arity of the op is at most two. +// +// TODO(hinsu): Verify einsum equation attribute. +static LogicalResult Verify(EinsumOp op) { + if (op.N() > 2) { + return op.emitOpError("supports at most two operands"); + } + return success(); +} + //===----------------------------------------------------------------------===// // EmptyTensorListOp //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir index b8e7ba71198..1914ca177cc 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir @@ -1650,3 +1650,11 @@ func @testSplitSmallSplitDim(%input: tensor<4x8xf32>) { %0:3 = "tf.Split"(%cst, %input) : (tensor, tensor<4x8xf32>) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) return } + +// ----- + +func @testTernaryEinsum(%arg0: tensor<2x3xf32>){ + // expected-error @+1 {{supports at most two operands}} + %0 = "tf.Einsum"(%arg0, %arg0, %arg0) {equation = "ab,cd,ef->"} : (tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>) -> (tensor<*xf32>) + return +} diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index 3ed3fb6fc40..ac3475cebc4 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -404,6 +404,7 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client/lib:matrix", "//tensorflow/compiler/xla/service:hlo", "@llvm//:support", "@local_config_mlir//:Analysis", diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc index 639c85c48b5..8fa33d19363 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc @@ -47,10 +47,10 @@ limitations under the License. #include "mlir/Transforms/InliningUtils.h" // TF:local_config_mlir #include "tensorflow/compiler/mlir/xla/convert_op_folder.h" #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h.inc" +#include "tensorflow/compiler/mlir/xla/ir/hlo_utils.h" namespace mlir { #include "tensorflow/compiler/mlir/xla/ir/hlo_structs.cc.inc" - namespace xla_hlo { Operation* XlaHloDialect::materializeConstant(OpBuilder& builder, @@ -936,6 +936,15 @@ void TupleOp::build(Builder* builder, OperationState& result, build(builder, result, builder->getTupleType(types), values); } +//===----------------------------------------------------------------------===// +// UnaryEinsumOp +//===----------------------------------------------------------------------===// + +void UnaryEinsumOp::getCanonicalizationPatterns( + OwningRewritePatternList& results, MLIRContext* context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // CompareOp //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td index f036dec92b9..4fb85f9f6b3 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td @@ -730,6 +730,43 @@ def HLO_DotGeneralOp: HLO_Op<"dot_general", [NoSideEffect]>, BASE_HLO_DotGeneral let results = (outs HLO_Tensor); } +def BASE_EinsumOp { + string summary = "Einsum operator"; + + string description = [{ + Returns a tensor whose elements are defined by equation, which is written + in a shorthand form inspired by the Einstein summation convention. + }]; +} + +def HLO_EinsumOp: HLO_Op<"einsum", [NoSideEffect]> { + let arguments = (ins + HLO_Tensor:$lhs, + HLO_Tensor:$rhs, + StrAttr:$einsum_config + ); + + let results = (outs HLO_Tensor); + + // TODO(hinsu): Canonicalize to lower this client side HLO op to server + // side HLO ops. +} + +def HLO_UnaryEinsumOp: HLO_Op<"unary_einsum", [NoSideEffect]> { + let arguments = (ins + HLO_Tensor:$operand, + StrAttr:$einsum_config + ); + + let results = (outs HLO_Tensor); + + let hasCanonicalizer = 1; + + // UnarayEinsumOp is unconditionally canonicalized to the binary EinsumOp so + // the HLO converter shouldn't be invoked. + let hasCustomHLOConverter = 1; +} + def HLO_FftOp: HLO_Op<"fft", [NoSideEffect]>, BASE_HLO_FftOp { let arguments = (ins HLO_Tensor:$operand, diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_utils.cc b/tensorflow/compiler/mlir/xla/ir/hlo_utils.cc index 82b7032d542..7d3e2ca2384 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_utils.cc +++ b/tensorflow/compiler/mlir/xla/ir/hlo_utils.cc @@ -17,6 +17,8 @@ limitations under the License. #include +#include "mlir/IR/Attributes.h" // TF:local_config_mlir + namespace mlir { namespace xla { @@ -51,5 +53,18 @@ DenseIntElementsAttr getBroadcastDimensionsAttr(Builder *b, Value *x, return DenseIntElementsAttr::get(type, broadcastDimensions); } +DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value) { + RankedTensorType scalar_ty = RankedTensorType::get({}, ty); + + DenseElementsAttr attr; + if (auto float_ty = ty.dyn_cast()) { + APFloat value(float_ty.getFloatSemantics(), raw_value); + return DenseElementsAttr::get(scalar_ty, value); + } + auto int_ty = ty.cast(); + APInt value(int_ty.getWidth(), static_cast(raw_value), true); + return DenseElementsAttr::get(scalar_ty, value); +} + } // namespace xla } // namespace mlir diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_utils.h b/tensorflow/compiler/mlir/xla/ir/hlo_utils.h index d81abf6a0be..86c90b49f16 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_utils.h +++ b/tensorflow/compiler/mlir/xla/ir/hlo_utils.h @@ -18,6 +18,7 @@ limitations under the License. #include "mlir/IR/Attributes.h" // TF:local_config_mlir #include "mlir/IR/Builders.h" // TF:local_config_mlir +#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir #include "mlir/IR/StandardTypes.h" // TF:local_config_mlir #include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir #include "tensorflow/compiler/mlir/xla/convert_op_folder.h" @@ -48,6 +49,12 @@ static ElementsAttr getSplat(Builder* b, Value* val, T constant) { return DenseElementsAttr::get(valType, elementAttr); } + +// Returns DenseElementsAttr of rank zero with the given element type and the +// value. +// Requires `ty` to be either FloatType of IntegerType. +DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value); + } // namespace xla } // namespace mlir diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_utils.td b/tensorflow/compiler/mlir/xla/ir/hlo_utils.td index 1a56d230d0d..97b29bf0851 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_utils.td +++ b/tensorflow/compiler/mlir/xla/ir/hlo_utils.td @@ -18,9 +18,7 @@ limitations under the License. #ifndef HLO_UTILS #define HLO_UTILS -#ifndef OP_BASE include "mlir/IR/OpBase.td" -#endif // OP_BASE def NullArrayAttr : NativeCodeCall<"ArrayAttr()">; @@ -34,4 +32,9 @@ def NullDenseIntElementsAttr : NativeCodeCall<"DenseIntElementsAttr()">; def BinBroadcastDimensions : NativeCodeCall< "xla::getBroadcastDimensionsAttr(&$_builder, $0, $1)">; +// Here, the element type can be any integer or float type. But, note that only +// 32 bit integers are supported for the value. +class GetScalarOfType : NativeCodeCall< + "xla::GetScalarOfType(getElementTypeOrSelf($0)," # value # ")">; + #endif // HLO_UTILS diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc index 267fd3b21b4..f717c8199fd 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc @@ -33,6 +33,7 @@ limitations under the License. #include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/mlir/xla/type_to_shape.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/comparison_util.h" #include "tensorflow/compiler/xla/literal_util.h" @@ -77,6 +78,10 @@ static double ConvertAPFloat(llvm::APFloat value) { return value.convertToDouble(); } +static absl::string_view ConvertStringRef(mlir::StringRef value) { + return {value.data(), value.size()}; +} + static std::vector ConvertDenseIntAttr(mlir::DenseIntElementsAttr attr) { auto values = attr.getValues(); return {values.begin(), values.end()}; @@ -632,6 +637,12 @@ LogicalResult ExportXlaOp(TupleOp op, OpLoweringContext ctx) { return success(); } +LogicalResult ExportXlaOp(UnaryEinsumOp op, OpLoweringContext ctx) { + // Intentional as UnaryEinsumOp is always lowered to the EinsumOp with two + // operands. + return failure(); +} + LogicalResult ExportXlaOp(WhileOp op, OpLoweringContext ctx) { xla::XlaComputation condition; xla::XlaComputation body; diff --git a/tensorflow/compiler/mlir/xla/operator_writer_gen.cc b/tensorflow/compiler/mlir/xla/operator_writer_gen.cc index 4a9555a256a..acc3c17baf5 100644 --- a/tensorflow/compiler/mlir/xla/operator_writer_gen.cc +++ b/tensorflow/compiler/mlir/xla/operator_writer_gen.cc @@ -32,17 +32,20 @@ using llvm::raw_ostream; using llvm::RecordKeeper; using llvm::StringRef; using mlir::interleaveComma; +using mlir::tblgen::Attribute; using mlir::tblgen::NamedAttribute; using mlir::tblgen::NamedTypeConstraint; using mlir::tblgen::Operator; static std::string GetDefaultAttrExport( const mlir::tblgen::NamedAttribute& named_attr) { - auto storage_type = named_attr.attr.getStorageType(); + Attribute attr = named_attr.attr; + StringRef storage_type = attr.getStorageType(); // For some attribute types we have a general conversion, so use that. - if (storage_type.endswith("IntegerAttr") || - storage_type.endswith("FloatAttr")) { - return "Convert" + named_attr.attr.getReturnType().str(); + if (!attr.isEnumAttr() && (storage_type.endswith("IntegerAttr") || + storage_type.endswith("FloatAttr") || + storage_type.endswith("StringAttr"))) { + return "Convert" + attr.getReturnType().str(); } return "Convert_" + named_attr.name.str(); } diff --git a/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir b/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir index e6d99b9e7d8..fa39b77918a 100644 --- a/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir @@ -48,3 +48,11 @@ func @complex_collapse_fold(%arg0: tensor<4xcomplex>) -> tensor<4xcomplex> } + +// CHECK-LABEL: @unary_einsum +func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> { + // CHECK: %[[ONE:.*]] = xla_hlo.constant dense<1.000000e+00> : tensor + // CHECK: "xla_hlo.einsum"(%[[ONE]], %arg0) {einsum_config = ",ab->aa"} + %0 = "xla_hlo.unary_einsum"(%arg0) {einsum_config = "ab->aa"} : (tensor<2x3xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index 3004f2276fe..d8093f1a39a 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -215,6 +215,20 @@ func @pow_dynamic(%arg0: tensor) -> tensor { return %0: tensor } +// CHECK-LABEL: func @einsum +func @einsum(%arg0: tensor<2x3xf32>, %arg1: tensor<3x4xf32>) -> tensor<2x4xf32> { + // CHECK: xla_hlo.einsum + %0 = "tf.Einsum"(%arg0, %arg1) {equation = "ab,bc->ac"} : (tensor<2x3xf32>, tensor<3x4xf32>) -> tensor<2x4xf32> + return %0: tensor<2x4xf32> +} + +// CHECK-LABEL: func @unary_einsum +func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> { + // CHECK: xla_hlo.unary_einsum + %0 = "tf.Einsum"(%arg0) {equation = "ab->aa"} : (tensor<2x3xf32>) -> tensor<2x2xf32> + return %0: tensor<2x2xf32> +} + // CHECK-LABEL: func @floordiv_broadcast_i32 func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> { // CHECK-DAG: [[ZEROS1:%.+]] = xla_hlo.constant dense<0> diff --git a/tensorflow/compiler/mlir/xla/tests/translate/einsum.mlir b/tensorflow/compiler/mlir/xla/tests/translate/einsum.mlir new file mode 100644 index 00000000000..e703a5cb872 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/einsum.mlir @@ -0,0 +1,9 @@ +// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s + +// CHECK-LABEL: ENTRY +func @main(%arg0: tensor<3x4xi32>, %arg1: tensor<4x5xi32>) -> tensor<3x5xi32> { + // Simple einsum is lowered to HLO dot op. + // CHECK: dot(s32[3,4] %{{.*}}, s32[4,5] %{{.*}}), lhs_contracting_dims={1}, rhs_contracting_dims={0} + %0 = "xla_hlo.einsum"(%arg0, %arg1) {einsum_config = "ab,bc->ac"} : (tensor<3x4xi32>, tensor<4x5xi32>) -> tensor<3x5xi32> + return %0 : tensor<3x5xi32> +} diff --git a/tensorflow/compiler/mlir/xla/transforms/canonicalize.td b/tensorflow/compiler/mlir/xla/transforms/canonicalize.td index bc44117910b..37f6d7deaa3 100644 --- a/tensorflow/compiler/mlir/xla/transforms/canonicalize.td +++ b/tensorflow/compiler/mlir/xla/transforms/canonicalize.td @@ -17,7 +17,7 @@ limitations under the License. include "mlir/IR/OpBase.td" include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td" - +include "tensorflow/compiler/mlir/xla/ir/hlo_utils.td" //===----------------------------------------------------------------------===// // DynamicSlice op patterns. @@ -37,3 +37,13 @@ def DynamicSliceToSlice: Pat<(HLO_DynamicSliceOp HLO_Tensor:$input, (HLO_SliceOp $input, (CastIntElementsAttr $starting_indices), (BuildSliceLimits $starting_indices, $slice_sizes), (BuildSliceStrides $input))>; + +def UnaryToBianryEinsumEq : NativeCodeCall< + "$_builder.getStringAttr(\",\" + $0.getValue().str())">; + +// Convert UnaryEinsumOp to EinsumOp with two operands with redundant first +// operand. +def UnaryEinsumToEinsum : Pat< + (HLO_UnaryEinsumOp $operand, $equation), + (HLO_EinsumOp (HLO_ConstOp (GetScalarOfType<1> $operand)), + $operand, (UnaryToBianryEinsumEq $equation))>; diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index a156685f005..94e0ce35cb0 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -23,6 +23,7 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir #include "mlir/IR/Attributes.h" // TF:local_config_mlir #include "mlir/IR/Diagnostics.h" // TF:local_config_mlir @@ -168,20 +169,9 @@ static ConstOp GetMinValueForType(Type ty, Location loc, // Returns int or float scalar DenseElementsAttr attribute with the given // element type and the value. -static ConstOp GetScalarOfType(Type ty, Location loc, int64_t raw_value, - PatternRewriter *rewriter) { - RankedTensorType scalar_ty = RankedTensorType::get({}, ty); - - DenseElementsAttr attr; - if (auto float_ty = ty.dyn_cast_or_null()) { - APFloat value(float_ty.getFloatSemantics(), raw_value); - attr = DenseElementsAttr::get(scalar_ty, value); - } else { - auto int_ty = ty.cast(); - APInt value(int_ty.getWidth(), static_cast(raw_value), true); - attr = DenseElementsAttr::get(scalar_ty, value); - } - return rewriter->create(loc, attr); +static ConstOp GetScalarConstOfType(Type ty, Location loc, int64_t raw_value, + PatternRewriter *rewriter) { + return rewriter->create(loc, xla::GetScalarOfType(ty, raw_value)); } // Builds body for reduce op by using the using the template binary op as the @@ -639,6 +629,31 @@ class ConvertBF16FloorDivOp : public OpRewritePattern { } }; +// Converts TensorFlow EinsumOp to either HLO EinsumOp or UnaryEinsumOp +// depending on arity of the op. +class ConvertEinsumOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(TF::EinsumOp op, + PatternRewriter &rewriter) const override { + StringAttr equation = op.getAttrOfType("equation"); + if (op.N() == 1) { + rewriter.replaceOpWithNewOp( + op, op.getType(), *op.inputs().begin(), equation); + } else if (op.N() == 2) { + auto inputs = llvm::to_vector<2>(op.inputs()); + rewriter.replaceOpWithNewOp(op, op.getType(), inputs[0], + inputs[1], equation); + } else { + // TensorFlow EinsumOp verifies that the number of operands are at most + // two. + return Pattern::matchFailure(); + } + return Pattern::matchSuccess(); + } +}; + // Converts MaxPool op to HLO ReduceWindow op by setting appropriate window // dimensions with max as the reduction function. // @@ -847,8 +862,8 @@ class ConvertSizeOp : public OpRewritePattern { const int64_t rank = input_ty.getRank(); auto result_type = op.getResult()->getType(); Operation *size = - GetScalarOfType(result_type.cast().getElementType(), - op.getLoc(), 1, &rewriter); + GetScalarConstOfType(result_type.cast().getElementType(), + op.getLoc(), 1, &rewriter); for (int64_t i = 0; i < rank; ++i) { auto dim = rewriter.create( op.getLoc(), result_type, input, @@ -1169,8 +1184,8 @@ class GenericConvertReductionOp : public OpRewritePattern { divisor_count *= input_shape[i]; } } - auto divisor = - GetScalarOfType(reduce_element_type, loc, divisor_count, &rewriter); + auto divisor = GetScalarConstOfType(reduce_element_type, loc, + divisor_count, &rewriter); auto broadcast_dims = GetI64ElementsAttr({}, &rewriter); result = rewriter.create(loc, result, divisor.getResult(), broadcast_dims); @@ -1203,7 +1218,7 @@ class ConvertMeanOp static Value *GetInitialValue(Type reduce_element_type, Location loc, PatternRewriter &rewriter) { - return GetScalarOfType(reduce_element_type, loc, 0, &rewriter); + return GetScalarConstOfType(reduce_element_type, loc, 0, &rewriter); } }; @@ -1219,7 +1234,7 @@ class ConvertSumOp static Value *GetInitialValue(Type reduce_element_type, Location loc, PatternRewriter &rewriter) { - return GetScalarOfType(reduce_element_type, loc, 0, &rewriter); + return GetScalarConstOfType(reduce_element_type, loc, 0, &rewriter); } }; @@ -1274,7 +1289,7 @@ class ConvertArgMinMaxOp : public OpRewritePattern { Type index_element_type = output_type.getElementType(); Value *index_init_value = - GetScalarOfType(index_element_type, loc, 0, &rewriter); + GetScalarConstOfType(index_element_type, loc, 0, &rewriter); RankedTensorType index_type = RankedTensorType::get(input_type.getShape(), index_element_type); @@ -1418,7 +1433,7 @@ class ConvertMaxPoolGradOp : public OpRewritePattern { auto result = rewriter.create( loc, op.getType(), op.orig_input(), op.grad(), - GetScalarOfType(element_type, loc, 0, &rewriter), + GetScalarConstOfType(element_type, loc, 0, &rewriter), GetI64ElementsAttr(op.ksize()), GetI64ElementsAttr(op.strides()), nullptr); @@ -1860,8 +1875,9 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) { TF::PopulateLoweringTFPatterns(context, &patterns); patterns .insert, + ConvertEinsumOp, ConvertMaxPoolOp, ConvertRangeOp, + ConvertSigmoidOp, ConvertSizeOp, ConvertMaxPoolOp, ConvertRangeOp, + ConvertSigmoidOp, ConvertSoftmaxOp, ConvertSoftmaxOp, ConvertSplitOp, ConvertStridedSliceOp, ConvertMeanOp, ConvertSumOp, ConvertMaxOp, ConvertTileOp, ConvertMaxPoolGradOp, ConvertOneHotOp,