From 6072451f8b17d05336407aeeead6c5455a7dfc6a Mon Sep 17 00:00:00 2001 From: "T.J. Alumbaugh" Date: Mon, 16 Mar 2020 08:10:23 -0700 Subject: [PATCH] Initial support for Einsum op in TF Lite PiperOrigin-RevId: 301159606 Change-Id: Ie8d9ef631ea64061423f95c06050b6329c844847 --- .../mlir/lite/transforms/prepare_tf.cc | 5 +- tensorflow/compiler/mlir/tensorflow/BUILD | 3 + .../mlir/tensorflow/tests/einsum.mlir | 57 ++++ .../mlir/tensorflow/transforms/einsum.cc | 296 ++++++++++++++++++ .../mlir/tensorflow/transforms/einsum.h | 55 ++++ 5 files changed, 414 insertions(+), 2 deletions(-) create mode 100644 tensorflow/compiler/mlir/tensorflow/tests/einsum.mlir create mode 100644 tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc create mode 100644 tensorflow/compiler/mlir/tensorflow/transforms/einsum.h diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc index ef6fd1899d2..6d86286749f 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc @@ -56,6 +56,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h" #include "tensorflow/compiler/mlir/lite/utils/validators.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/einsum.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.h" #define DEBUG_TYPE "tf-tfl-legalization" @@ -655,8 +656,8 @@ void PrepareTFPass::runOnFunction() { patterns.insert, TF::ConvertTFBatchMatMulOp>(ctx); } - patterns.insert(ctx); + patterns.insert(ctx); applyPatternsGreedily(func, patterns); } diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index fb5a60d38ab..90370bbed61 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -209,6 +209,7 @@ cc_library( "ir/tf_traits.h", "ir/tf_verifiers.h", "transforms/bridge.h", + "transforms/einsum.h", "transforms/passes.h", "transforms/unroll_batch_matmul.h", "@llvm-project//mlir:include/mlir/Interfaces/CallInterfaces.h", @@ -310,6 +311,7 @@ cc_library( "transforms/cluster_outlining.cc", "transforms/collection_ops_util.cc", "transforms/decompose_resource_ops_pass.cc", + "transforms/einsum.cc", "transforms/executor_island_coarsening.cc", "transforms/executor_tpuv1_inline_tpu_island.cc", "transforms/executor_tpuv1_island_coarsening.cc", @@ -357,6 +359,7 @@ cc_library( "transforms/batchmatmul_to_einsum.h", "transforms/bridge.h", "transforms/collection_ops_util.h", + "transforms/einsum.h", "transforms/passes.h", "transforms/shape_inference.h", ], diff --git a/tensorflow/compiler/mlir/tensorflow/tests/einsum.mlir b/tensorflow/compiler/mlir/tensorflow/tests/einsum.mlir new file mode 100644 index 00000000000..3dec94a98df --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/einsum.mlir @@ -0,0 +1,57 @@ +// RUN: tf-opt -split-input-file -verify-diagnostics -tf-einsum %s | FileCheck %s + +func @einsum_basic(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32>) -> tensor<3x4x6xf32> { + %0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "ijk,ikm->ijm"}: (tensor<3x4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32> + return %0 : tensor<3x4x6xf32> + // CHECK-LABEL: einsum_basic + // CHECK: "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<3x4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32> +} + +func @einsum_4D(%arg0: tensor<2x5x7x3xf32>, %arg1: tensor<2x4x7x3xf32>) -> tensor<2x7x5x4xf32> { + %0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "bfnh,btnh->bnft"}: (tensor<2x5x7x3xf32>, tensor<2x4x7x3xf32>) -> tensor<2x7x5x4xf32> + return %0 : tensor<2x7x5x4xf32> + // CHECK-LABEL: einsum_4D + // CHECK: %[[cst:.*]] = constant dense<[0, 2, 1, 3]> : tensor<4xi32> + // CHECK: %[[cst_1:.*]] = constant dense<[0, 2, 3, 1]> : tensor<4xi32> + // CHECK: %[[v0:.*]] = "tf.Transpose"(%arg0, %[[cst]]) : (tensor<2x5x7x3xf32>, tensor<4xi32>) -> tensor<2x7x5x3xf32> + // CHECK: %[[v1:.*]] = "tf.Transpose"(%arg1, %[[cst_1]]) : (tensor<2x4x7x3xf32>, tensor<4xi32>) -> tensor<2x7x3x4xf32> + // CHECK: "tf.BatchMatMulV2"(%[[v0]], %[[v1]]) {adj_x = false, adj_y = false} : (tensor<2x7x5x3xf32>, tensor<2x7x3x4xf32>) -> tensor<2x7x5x4xf32> +} + +func @einsum_matrixdotprod(%arg0: tensor<2x5x7x3xf32>, %arg1: tensor<7x3x4xf32>) -> tensor<2x5x4xf32> { + %0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "bfnd,ndh->bfh"}: (tensor<2x5x7x3xf32>, tensor<7x3x4xf32>) -> tensor<2x5x4xf32> + return %0 : tensor<2x5x4xf32> + // CHECK-LABEL: einsum_matrixdotprod + // CHECK: %[[cst:.*]] = constant dense<[2, 5, 21]> : tensor<3xi64> + // CHECK: %[[cst_1:.*]] = constant dense<[21, 4]> : tensor<2xi64> + // CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<2x5x7x3xf32>, tensor<3xi64>) -> tensor<2x5x21xf32> + // CHECK: %[[v1:.*]] = "tf.Reshape"(%arg1, %[[cst_1]]) : (tensor<7x3x4xf32>, tensor<2xi64>) -> tensor<21x4xf32> + // CHECK: "tf.BatchMatMulV2"(%[[v0]], %[[v1]]) {adj_x = false, adj_y = false} : (tensor<2x5x21xf32>, tensor<21x4xf32>) -> tensor<2x5x4xf32> +} + +func @einsum_reshapetail(%arg0: tensor<3x4x5xf32>, %arg1: tensor<5x6x2xf32>) -> tensor<3x4x6x2xf32> { + %0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "bfd,dnh->bfnh"}: (tensor<3x4x5xf32>, tensor<5x6x2xf32>) -> tensor<3x4x6x2xf32> + return %0 : tensor<3x4x6x2xf32> + // CHECK-LABEL: einsum_reshapetail + // CHECK: %[[cst:.*]] = constant dense<[5, 12]> : tensor<2xi64> + // CHECK: %[[cst_1:.*]] = constant dense<[3, 4, 6, 2]> : tensor<4xi64> + // CHECK: %[[v0:.*]] = "tf.Reshape"(%arg1, %[[cst]]) : (tensor<5x6x2xf32>, tensor<2xi64>) -> tensor<5x12xf32> + // CHECK: %[[v1:.*]] = "tf.BatchMatMulV2"(%arg0, %[[v0]]) {adj_x = false, adj_y = false} : (tensor<3x4x5xf32>, tensor<5x12xf32>) -> tensor<3x4x12xf32> + // CHECK: %[[v2:.*]] = "tf.Reshape"(%[[v1]], %[[cst_1]]) : (tensor<3x4x12xf32>, tensor<4xi64>) -> tensor<3x4x6x2xf32> + // CHECK: return %[[v2]] : tensor<3x4x6x2xf32> +} + +func @einsum_no_match(%arg0: tensor<4x5xf32>, %arg1: tensor<5xf32>) -> tensor<4xf32> { + %0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "ij,j->i"}: (tensor<4x5xf32>, tensor<5xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> +// CHECK-LABEL: einsum_no_match +// CHECK: %[[v0:.*]] = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "ij,j->i"} : (tensor<4x5xf32>, tensor<5xf32>) -> tensor<4xf32> +// CHECK: return %[[v0]] +} +func @einsum_illegal_no_match(%arg0: tensor<4x5xf32>, %arg1: tensor<5xf32>) -> tensor<4xf32> { + %0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "ij,?zw->kq->i"}: (tensor<4x5xf32>, tensor<5xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> +// CHECK-LABEL: einsum_illegal_no_match +// CHECK: %[[v0:.*]] = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "ij,?zw->kq->i"} : (tensor<4x5xf32>, tensor<5xf32>) -> tensor<4xf32> +// CHECK: return %[[v0]] +} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc b/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc new file mode 100644 index 00000000000..d7fcb232dac --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc @@ -0,0 +1,296 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/transforms/einsum.h" + +#include +#include + +#include "absl/memory/memory.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringSwitch.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/Regex.h" +#include "mlir/Analysis/LoopAnalysis.h" // TF:llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/OpImplementation.h" // TF:llvm-project +#include "mlir/IR/PatternMatch.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Support/Functional.h" // TF:llvm-project +#include "mlir/Support/LLVM.h" // TF:llvm-project +#include "mlir/Support/LogicalResult.h" // TF:llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/core/util/matmul_bcast.h" + +namespace mlir { +namespace TF { + +namespace { + +// All supported Einsum equations. +enum EinsumEquation { + BatchMatMul, + FourDMatrixDotProd, + ThreeDReshapeTail, + FourDBatchMatMul, + UnsupportedEquation +}; + +// Tokens for parsing the given equation string. +enum EquationToken { + A, + B, + C, + D, + E, + COMMA, + ARROW, +}; +inline constexpr int kNumSupportedEquationVariables = 5; // A - E for now. + +bool tokenizeEquation(const llvm::StringRef& equation, + std::vector* tokens) { + std::map label_axis_mapping; + int index = 0; + int variable_count = 0; + llvm::Regex r("[[:alpha:]]"); + while (index < equation.size()) { + if (r.match(equation.substr(index, 1))) { + const char ltr = equation[index]; + auto itr = label_axis_mapping.find(ltr); + if (itr == label_axis_mapping.end() && + variable_count < kNumSupportedEquationVariables) { + label_axis_mapping[ltr] = EquationToken(variable_count); + tokens->push_back(EquationToken(variable_count)); + variable_count++; + } else if (itr != label_axis_mapping.end()) { + tokens->push_back(itr->second); + } else { + // Ran out of equation variables. + return false; + } + } else if (equation.substr(index, 1).contains(",")) { + tokens->push_back(COMMA); + } else if ((index < (equation.size() - 1)) && + (equation.substr(index, 2).contains("->"))) { + tokens->push_back(ARROW); + index++; + } else { + // Unallowed character encountered. + return false; + } + index++; + } + return true; +} + +EinsumEquation parseEquation(const std::vector& eqn) { + auto is_equal = [](const std::vector& eqn1, + const std::initializer_list& eqn2) { + return std::equal(eqn1.begin(), eqn1.end(), eqn2.begin(), eqn2.end()); + }; + // IJK,IKM->IJM + if (is_equal(eqn, {A, B, C, COMMA, A, C, D, ARROW, A, B, D})) { + return EinsumEquation::BatchMatMul; + } + // BFND,NDH->BFH + if (is_equal(eqn, {A, B, C, D, COMMA, C, D, E, ARROW, A, B, E})) { + return EinsumEquation::FourDMatrixDotProd; + } + // BFNH,BTNH->BNFT + if (is_equal(eqn, {A, B, C, D, COMMA, A, E, C, D, ARROW, A, C, B, E})) { + return EinsumEquation::FourDBatchMatMul; + } + // BFD,DNH->BFNH + if (is_equal(eqn, {A, B, C, COMMA, C, D, E, ARROW, A, B, D, E})) { + return EinsumEquation::ThreeDReshapeTail; + } + return EinsumEquation::UnsupportedEquation; +} + +EinsumEquation tokenizeAndParse(const llvm::StringRef& equation) { + std::vector tokens; + if (tokenizeEquation(equation, &tokens)) { + return parseEquation(tokens); + } + return EinsumEquation::UnsupportedEquation; +} + +TF::TransposeOp createTransposeOp(Value value, Location loc, + llvm::ArrayRef permutation, + PatternRewriter* rewriter) { + auto value_type = value.getType().cast(); + auto shape = value_type.getShape(); + auto perm_type = RankedTensorType::get( + {static_cast(permutation.size())}, rewriter->getIntegerType(32)); + auto perm_attr = DenseElementsAttr::get(perm_type, permutation); + auto perm_op = rewriter->create(loc, perm_type, perm_attr); + std::vector transposed_shape(shape.begin(), shape.end()); + for (int i = 0; i < shape.size(); ++i) { + transposed_shape[i] = shape[permutation[i]]; + } + auto transposed_type = + RankedTensorType::get(transposed_shape, value_type.getElementType()); + return rewriter->create(loc, transposed_type, value, + perm_op); +} + +TF::ReshapeOp createReshapeOp(Value value, ArrayRef shape, + Type element_type, Location loc, + PatternRewriter* rewriter) { + int64_t shape_rank = shape.size(); + auto shape_spec_type = + RankedTensorType::get({shape_rank}, rewriter->getIntegerType(64)); + Type resultType = RankedTensorType::get(shape, element_type); + auto constant_attr = DenseElementsAttr::get(shape_spec_type, shape); + auto shape_tensor = + rewriter->create(loc, shape_spec_type, constant_attr); + return rewriter->create(loc, resultType, /*tensor=*/value, + /*shape=*/shape_tensor); +} + +} // namespace + +PatternMatchResult ConvertTFEinsumOp::matchAndRewrite( + TF::EinsumOp op, PatternRewriter& rewriter) const { + Type output_type = op.getResult().getType(); + Value lhs = op.getOperand(0); + Value rhs = op.getOperand(1); + Location loc = op.getLoc(); + + if (!lhs.getType().isa()) { + // LHS must be a ranked tensor type + return matchFailure(); + } + if (!rhs.getType().isa()) { + // RHS must be a ranked tensor type + return matchFailure(); + } + + auto lhs_type = lhs.getType().cast(); + auto rhs_type = rhs.getType().cast(); + auto lhs_shape = lhs_type.getShape(); + auto rhs_shape = rhs_type.getShape(); + + // Currently only support static shapes. + if (!(lhs_type.hasStaticShape() && rhs_type.hasStaticShape())) { + return matchFailure(); + } + + // Currently support use cases of LHS, RHS dims = 3 or 4 + const int dims_lhs = lhs_shape.size(); + const int dims_rhs = rhs_shape.size(); + if (dims_rhs < 3 || dims_rhs > 4 || dims_lhs < 3 || dims_lhs > 4) { + return matchFailure(); + } + + EinsumEquation einsum_eqn = tokenizeAndParse(op.equation()); + if (einsum_eqn == EinsumEquation::BatchMatMul) { + // Case "IJK,IKM->IJM" + auto bmm_op = rewriter.create( + loc, ArrayRef{output_type}, lhs, rhs, rewriter.getBoolAttr(false), + rewriter.getBoolAttr(false)); + rewriter.replaceOp(op, bmm_op.getResult()); + return matchSuccess(); + } + if (einsum_eqn == EinsumEquation::ThreeDReshapeTail) { + // Case "BFD,DNH->BFNH" + auto lhs_type = lhs.getType().cast(); + auto lhs_shape = lhs_type.getShape(); + const int lhs_dim0 = lhs_shape[0]; + const int lhs_dim1 = lhs_shape[1]; + // Reshape RHS + auto rhs_type = rhs.getType().cast(); + auto rhs_shape = rhs_type.getShape(); + auto rhs_element_type = rhs_type.getElementType(); + const int rhs_dim0 = rhs_shape[0]; + const int rhs_dim1 = rhs_shape[1]; + const int rhs_dim2 = rhs_shape[2]; + auto reshaped_rhs = createReshapeOp(rhs, {rhs_dim0, rhs_dim1 * rhs_dim2}, + rhs_element_type, loc, &rewriter); + + std::vector bmm_shape = {lhs_dim0, lhs_dim1, rhs_dim1 * rhs_dim2}; + auto bmm_type = RankedTensorType::get(bmm_shape, rhs_type.getElementType()); + auto bmm_op = rewriter.create( + loc, ArrayRef{bmm_type}, lhs, reshaped_rhs, + rewriter.getBoolAttr(false), rewriter.getBoolAttr(false)); + auto bmm_element_type = bmm_type.getElementType(); + auto final_reshape = + createReshapeOp(bmm_op, {lhs_dim0, lhs_dim1, rhs_dim1, rhs_dim2}, + bmm_element_type, loc, &rewriter); + rewriter.replaceOp(op, {final_reshape.getResult()}); + return matchSuccess(); + } + if (einsum_eqn == EinsumEquation::FourDMatrixDotProd) { + // Case "BFND,NDH->BFH" + // Reshape LHS + auto lhs_element_type = lhs_type.getElementType(); + const int lhs_dim0 = lhs_shape[0]; + const int lhs_dim1 = lhs_shape[1]; + const int lhs_dim2 = lhs_shape[2]; + const int lhs_dim3 = lhs_shape[3]; + auto reshaped_lhs = + createReshapeOp(lhs, {lhs_dim0, lhs_dim1, lhs_dim2 * lhs_dim3}, + lhs_element_type, loc, &rewriter); + // Reshape RHS + auto rhs_element_type = rhs_type.getElementType(); + const int rhs_dim0 = rhs_shape[0]; + const int rhs_dim1 = rhs_shape[1]; + const int rhs_dim2 = rhs_shape[2]; + auto reshaped_rhs = createReshapeOp(rhs, {rhs_dim0 * rhs_dim1, rhs_dim2}, + rhs_element_type, loc, &rewriter); + auto bmm_op = rewriter.create( + loc, ArrayRef{output_type}, reshaped_lhs, reshaped_rhs, + rewriter.getBoolAttr(false), rewriter.getBoolAttr(false)); + rewriter.replaceOp(op, {bmm_op.getResult()}); + return matchSuccess(); + } + if (einsum_eqn == EinsumEquation::FourDBatchMatMul) { + // Case "BFNH,BTNH->BNFT" + // Transpose LHS + lhs = createTransposeOp(lhs, loc, {0, 2, 1, 3}, &rewriter); + // Transpose RHS + rhs = createTransposeOp(rhs, loc, {0, 2, 3, 1}, &rewriter); + auto bmm_op = rewriter.create( + loc, ArrayRef{output_type}, lhs, rhs, rewriter.getBoolAttr(false), + rewriter.getBoolAttr(false)); + rewriter.replaceOp(op, {bmm_op.getResult()}); + return matchSuccess(); + } + return matchFailure(); +} + +// Transform Einsum to other TF Ops for the supported variants. +struct TransformEinsumPass : public FunctionPass { + void runOnFunction() override; +}; + +void TransformEinsumPass::runOnFunction() { + OwningRewritePatternList patterns; + auto func = getFunction(); + + patterns.insert(&getContext()); + applyPatternsGreedily(func, patterns); +} + +static PassRegistration pass( + "tf-einsum", "Transform Einsum to other TF Ops for the supported variants"); + +} // namespace TF +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/einsum.h b/tensorflow/compiler/mlir/tensorflow/transforms/einsum.h new file mode 100644 index 00000000000..77b0c72aaef --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/einsum.h @@ -0,0 +1,55 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// This pass identifies patterns for certain Einsum Ops and replaces them +// with other equivalent TF Ops. + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_EINSUM_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_EINSUM_H_ + +#include +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/Support/Casting.h" +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Location.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Matchers.h" // TF:llvm-project +#include "mlir/IR/PatternMatch.h" // TF:llvm-project +#include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/TypeUtilities.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/core/util/matmul_bcast.h" + +namespace mlir { +namespace TF { + +// TF.Einsum provides fully general tensor contractions. For a few select +// cases, we can convert this op to other TF Ops, which in later passes +// properly convert to TF Lite ops. +struct ConvertTFEinsumOp : public OpRewritePattern { + public: + explicit ConvertTFEinsumOp(MLIRContext* context) + : OpRewritePattern(context) {} + + PatternMatchResult matchAndRewrite(TF::EinsumOp op, + PatternRewriter& rewriter) const override; +}; + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_EINSUM_H_