Initial support for Einsum op in TF Lite

PiperOrigin-RevId: 301159606
Change-Id: Ie8d9ef631ea64061423f95c06050b6329c844847
This commit is contained in:
T.J. Alumbaugh 2020-03-16 08:10:23 -07:00 committed by TensorFlower Gardener
parent 32f5b7dd76
commit 6072451f8b
5 changed files with 414 additions and 2 deletions

View File

@ -56,6 +56,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h" #include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
#include "tensorflow/compiler/mlir/lite/utils/validators.h" #include "tensorflow/compiler/mlir/lite/utils/validators.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/einsum.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.h"
#define DEBUG_TYPE "tf-tfl-legalization" #define DEBUG_TYPE "tf-tfl-legalization"
@ -655,8 +656,8 @@ void PrepareTFPass::runOnFunction() {
patterns.insert<TF::ConvertTFBatchMatMulOp<TF::BatchMatMulOp>, patterns.insert<TF::ConvertTFBatchMatMulOp<TF::BatchMatMulOp>,
TF::ConvertTFBatchMatMulOp<TF::BatchMatMulV2Op>>(ctx); TF::ConvertTFBatchMatMulOp<TF::BatchMatMulV2Op>>(ctx);
} }
patterns.insert<ConvertTFConv2D, ConvertTFDepthwiseConv2dNative, patterns.insert<TF::ConvertTFEinsumOp, ConvertTFConv2D,
ConvertTFStridedSlice>(ctx); ConvertTFDepthwiseConv2dNative, ConvertTFStridedSlice>(ctx);
applyPatternsGreedily(func, patterns); applyPatternsGreedily(func, patterns);
} }

View File

@ -209,6 +209,7 @@ cc_library(
"ir/tf_traits.h", "ir/tf_traits.h",
"ir/tf_verifiers.h", "ir/tf_verifiers.h",
"transforms/bridge.h", "transforms/bridge.h",
"transforms/einsum.h",
"transforms/passes.h", "transforms/passes.h",
"transforms/unroll_batch_matmul.h", "transforms/unroll_batch_matmul.h",
"@llvm-project//mlir:include/mlir/Interfaces/CallInterfaces.h", "@llvm-project//mlir:include/mlir/Interfaces/CallInterfaces.h",
@ -310,6 +311,7 @@ cc_library(
"transforms/cluster_outlining.cc", "transforms/cluster_outlining.cc",
"transforms/collection_ops_util.cc", "transforms/collection_ops_util.cc",
"transforms/decompose_resource_ops_pass.cc", "transforms/decompose_resource_ops_pass.cc",
"transforms/einsum.cc",
"transforms/executor_island_coarsening.cc", "transforms/executor_island_coarsening.cc",
"transforms/executor_tpuv1_inline_tpu_island.cc", "transforms/executor_tpuv1_inline_tpu_island.cc",
"transforms/executor_tpuv1_island_coarsening.cc", "transforms/executor_tpuv1_island_coarsening.cc",
@ -357,6 +359,7 @@ cc_library(
"transforms/batchmatmul_to_einsum.h", "transforms/batchmatmul_to_einsum.h",
"transforms/bridge.h", "transforms/bridge.h",
"transforms/collection_ops_util.h", "transforms/collection_ops_util.h",
"transforms/einsum.h",
"transforms/passes.h", "transforms/passes.h",
"transforms/shape_inference.h", "transforms/shape_inference.h",
], ],

View File

@ -0,0 +1,57 @@
// RUN: tf-opt -split-input-file -verify-diagnostics -tf-einsum %s | FileCheck %s
func @einsum_basic(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32>) -> tensor<3x4x6xf32> {
%0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "ijk,ikm->ijm"}: (tensor<3x4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32>
return %0 : tensor<3x4x6xf32>
// CHECK-LABEL: einsum_basic
// CHECK: "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<3x4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32>
}
func @einsum_4D(%arg0: tensor<2x5x7x3xf32>, %arg1: tensor<2x4x7x3xf32>) -> tensor<2x7x5x4xf32> {
%0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "bfnh,btnh->bnft"}: (tensor<2x5x7x3xf32>, tensor<2x4x7x3xf32>) -> tensor<2x7x5x4xf32>
return %0 : tensor<2x7x5x4xf32>
// CHECK-LABEL: einsum_4D
// CHECK: %[[cst:.*]] = constant dense<[0, 2, 1, 3]> : tensor<4xi32>
// CHECK: %[[cst_1:.*]] = constant dense<[0, 2, 3, 1]> : tensor<4xi32>
// CHECK: %[[v0:.*]] = "tf.Transpose"(%arg0, %[[cst]]) : (tensor<2x5x7x3xf32>, tensor<4xi32>) -> tensor<2x7x5x3xf32>
// CHECK: %[[v1:.*]] = "tf.Transpose"(%arg1, %[[cst_1]]) : (tensor<2x4x7x3xf32>, tensor<4xi32>) -> tensor<2x7x3x4xf32>
// CHECK: "tf.BatchMatMulV2"(%[[v0]], %[[v1]]) {adj_x = false, adj_y = false} : (tensor<2x7x5x3xf32>, tensor<2x7x3x4xf32>) -> tensor<2x7x5x4xf32>
}
func @einsum_matrixdotprod(%arg0: tensor<2x5x7x3xf32>, %arg1: tensor<7x3x4xf32>) -> tensor<2x5x4xf32> {
%0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "bfnd,ndh->bfh"}: (tensor<2x5x7x3xf32>, tensor<7x3x4xf32>) -> tensor<2x5x4xf32>
return %0 : tensor<2x5x4xf32>
// CHECK-LABEL: einsum_matrixdotprod
// CHECK: %[[cst:.*]] = constant dense<[2, 5, 21]> : tensor<3xi64>
// CHECK: %[[cst_1:.*]] = constant dense<[21, 4]> : tensor<2xi64>
// CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<2x5x7x3xf32>, tensor<3xi64>) -> tensor<2x5x21xf32>
// CHECK: %[[v1:.*]] = "tf.Reshape"(%arg1, %[[cst_1]]) : (tensor<7x3x4xf32>, tensor<2xi64>) -> tensor<21x4xf32>
// CHECK: "tf.BatchMatMulV2"(%[[v0]], %[[v1]]) {adj_x = false, adj_y = false} : (tensor<2x5x21xf32>, tensor<21x4xf32>) -> tensor<2x5x4xf32>
}
func @einsum_reshapetail(%arg0: tensor<3x4x5xf32>, %arg1: tensor<5x6x2xf32>) -> tensor<3x4x6x2xf32> {
%0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "bfd,dnh->bfnh"}: (tensor<3x4x5xf32>, tensor<5x6x2xf32>) -> tensor<3x4x6x2xf32>
return %0 : tensor<3x4x6x2xf32>
// CHECK-LABEL: einsum_reshapetail
// CHECK: %[[cst:.*]] = constant dense<[5, 12]> : tensor<2xi64>
// CHECK: %[[cst_1:.*]] = constant dense<[3, 4, 6, 2]> : tensor<4xi64>
// CHECK: %[[v0:.*]] = "tf.Reshape"(%arg1, %[[cst]]) : (tensor<5x6x2xf32>, tensor<2xi64>) -> tensor<5x12xf32>
// CHECK: %[[v1:.*]] = "tf.BatchMatMulV2"(%arg0, %[[v0]]) {adj_x = false, adj_y = false} : (tensor<3x4x5xf32>, tensor<5x12xf32>) -> tensor<3x4x12xf32>
// CHECK: %[[v2:.*]] = "tf.Reshape"(%[[v1]], %[[cst_1]]) : (tensor<3x4x12xf32>, tensor<4xi64>) -> tensor<3x4x6x2xf32>
// CHECK: return %[[v2]] : tensor<3x4x6x2xf32>
}
func @einsum_no_match(%arg0: tensor<4x5xf32>, %arg1: tensor<5xf32>) -> tensor<4xf32> {
%0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "ij,j->i"}: (tensor<4x5xf32>, tensor<5xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
// CHECK-LABEL: einsum_no_match
// CHECK: %[[v0:.*]] = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "ij,j->i"} : (tensor<4x5xf32>, tensor<5xf32>) -> tensor<4xf32>
// CHECK: return %[[v0]]
}
func @einsum_illegal_no_match(%arg0: tensor<4x5xf32>, %arg1: tensor<5xf32>) -> tensor<4xf32> {
%0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "ij,?zw->kq->i"}: (tensor<4x5xf32>, tensor<5xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
// CHECK-LABEL: einsum_illegal_no_match
// CHECK: %[[v0:.*]] = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "ij,?zw->kq->i"} : (tensor<4x5xf32>, tensor<5xf32>) -> tensor<4xf32>
// CHECK: return %[[v0]]
}

View File

@ -0,0 +1,296 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/tensorflow/transforms/einsum.h"
#include <climits>
#include <cstdint>
#include "absl/memory/memory.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/Regex.h"
#include "mlir/Analysis/LoopAnalysis.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/OpImplementation.h" // TF:llvm-project
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Support/Functional.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/core/util/matmul_bcast.h"
namespace mlir {
namespace TF {
namespace {
// All supported Einsum equations.
enum EinsumEquation {
BatchMatMul,
FourDMatrixDotProd,
ThreeDReshapeTail,
FourDBatchMatMul,
UnsupportedEquation
};
// Tokens for parsing the given equation string.
enum EquationToken {
A,
B,
C,
D,
E,
COMMA,
ARROW,
};
inline constexpr int kNumSupportedEquationVariables = 5; // A - E for now.
bool tokenizeEquation(const llvm::StringRef& equation,
std::vector<EquationToken>* tokens) {
std::map<char, EquationToken> label_axis_mapping;
int index = 0;
int variable_count = 0;
llvm::Regex r("[[:alpha:]]");
while (index < equation.size()) {
if (r.match(equation.substr(index, 1))) {
const char ltr = equation[index];
auto itr = label_axis_mapping.find(ltr);
if (itr == label_axis_mapping.end() &&
variable_count < kNumSupportedEquationVariables) {
label_axis_mapping[ltr] = EquationToken(variable_count);
tokens->push_back(EquationToken(variable_count));
variable_count++;
} else if (itr != label_axis_mapping.end()) {
tokens->push_back(itr->second);
} else {
// Ran out of equation variables.
return false;
}
} else if (equation.substr(index, 1).contains(",")) {
tokens->push_back(COMMA);
} else if ((index < (equation.size() - 1)) &&
(equation.substr(index, 2).contains("->"))) {
tokens->push_back(ARROW);
index++;
} else {
// Unallowed character encountered.
return false;
}
index++;
}
return true;
}
EinsumEquation parseEquation(const std::vector<EquationToken>& eqn) {
auto is_equal = [](const std::vector<EquationToken>& eqn1,
const std::initializer_list<EquationToken>& eqn2) {
return std::equal(eqn1.begin(), eqn1.end(), eqn2.begin(), eqn2.end());
};
// IJK,IKM->IJM
if (is_equal(eqn, {A, B, C, COMMA, A, C, D, ARROW, A, B, D})) {
return EinsumEquation::BatchMatMul;
}
// BFND,NDH->BFH
if (is_equal(eqn, {A, B, C, D, COMMA, C, D, E, ARROW, A, B, E})) {
return EinsumEquation::FourDMatrixDotProd;
}
// BFNH,BTNH->BNFT
if (is_equal(eqn, {A, B, C, D, COMMA, A, E, C, D, ARROW, A, C, B, E})) {
return EinsumEquation::FourDBatchMatMul;
}
// BFD,DNH->BFNH
if (is_equal(eqn, {A, B, C, COMMA, C, D, E, ARROW, A, B, D, E})) {
return EinsumEquation::ThreeDReshapeTail;
}
return EinsumEquation::UnsupportedEquation;
}
EinsumEquation tokenizeAndParse(const llvm::StringRef& equation) {
std::vector<EquationToken> tokens;
if (tokenizeEquation(equation, &tokens)) {
return parseEquation(tokens);
}
return EinsumEquation::UnsupportedEquation;
}
TF::TransposeOp createTransposeOp(Value value, Location loc,
llvm::ArrayRef<int32_t> permutation,
PatternRewriter* rewriter) {
auto value_type = value.getType().cast<RankedTensorType>();
auto shape = value_type.getShape();
auto perm_type = RankedTensorType::get(
{static_cast<int32_t>(permutation.size())}, rewriter->getIntegerType(32));
auto perm_attr = DenseElementsAttr::get(perm_type, permutation);
auto perm_op = rewriter->create<ConstantOp>(loc, perm_type, perm_attr);
std::vector<int64_t> transposed_shape(shape.begin(), shape.end());
for (int i = 0; i < shape.size(); ++i) {
transposed_shape[i] = shape[permutation[i]];
}
auto transposed_type =
RankedTensorType::get(transposed_shape, value_type.getElementType());
return rewriter->create<TF::TransposeOp>(loc, transposed_type, value,
perm_op);
}
TF::ReshapeOp createReshapeOp(Value value, ArrayRef<int64_t> shape,
Type element_type, Location loc,
PatternRewriter* rewriter) {
int64_t shape_rank = shape.size();
auto shape_spec_type =
RankedTensorType::get({shape_rank}, rewriter->getIntegerType(64));
Type resultType = RankedTensorType::get(shape, element_type);
auto constant_attr = DenseElementsAttr::get(shape_spec_type, shape);
auto shape_tensor =
rewriter->create<ConstantOp>(loc, shape_spec_type, constant_attr);
return rewriter->create<TF::ReshapeOp>(loc, resultType, /*tensor=*/value,
/*shape=*/shape_tensor);
}
} // namespace
PatternMatchResult ConvertTFEinsumOp::matchAndRewrite(
TF::EinsumOp op, PatternRewriter& rewriter) const {
Type output_type = op.getResult().getType();
Value lhs = op.getOperand(0);
Value rhs = op.getOperand(1);
Location loc = op.getLoc();
if (!lhs.getType().isa<RankedTensorType>()) {
// LHS must be a ranked tensor type
return matchFailure();
}
if (!rhs.getType().isa<RankedTensorType>()) {
// RHS must be a ranked tensor type
return matchFailure();
}
auto lhs_type = lhs.getType().cast<RankedTensorType>();
auto rhs_type = rhs.getType().cast<RankedTensorType>();
auto lhs_shape = lhs_type.getShape();
auto rhs_shape = rhs_type.getShape();
// Currently only support static shapes.
if (!(lhs_type.hasStaticShape() && rhs_type.hasStaticShape())) {
return matchFailure();
}
// Currently support use cases of LHS, RHS dims = 3 or 4
const int dims_lhs = lhs_shape.size();
const int dims_rhs = rhs_shape.size();
if (dims_rhs < 3 || dims_rhs > 4 || dims_lhs < 3 || dims_lhs > 4) {
return matchFailure();
}
EinsumEquation einsum_eqn = tokenizeAndParse(op.equation());
if (einsum_eqn == EinsumEquation::BatchMatMul) {
// Case "IJK,IKM->IJM"
auto bmm_op = rewriter.create<TF::BatchMatMulV2Op>(
loc, ArrayRef<Type>{output_type}, lhs, rhs, rewriter.getBoolAttr(false),
rewriter.getBoolAttr(false));
rewriter.replaceOp(op, bmm_op.getResult());
return matchSuccess();
}
if (einsum_eqn == EinsumEquation::ThreeDReshapeTail) {
// Case "BFD,DNH->BFNH"
auto lhs_type = lhs.getType().cast<RankedTensorType>();
auto lhs_shape = lhs_type.getShape();
const int lhs_dim0 = lhs_shape[0];
const int lhs_dim1 = lhs_shape[1];
// Reshape RHS
auto rhs_type = rhs.getType().cast<RankedTensorType>();
auto rhs_shape = rhs_type.getShape();
auto rhs_element_type = rhs_type.getElementType();
const int rhs_dim0 = rhs_shape[0];
const int rhs_dim1 = rhs_shape[1];
const int rhs_dim2 = rhs_shape[2];
auto reshaped_rhs = createReshapeOp(rhs, {rhs_dim0, rhs_dim1 * rhs_dim2},
rhs_element_type, loc, &rewriter);
std::vector<int64_t> bmm_shape = {lhs_dim0, lhs_dim1, rhs_dim1 * rhs_dim2};
auto bmm_type = RankedTensorType::get(bmm_shape, rhs_type.getElementType());
auto bmm_op = rewriter.create<TF::BatchMatMulV2Op>(
loc, ArrayRef<Type>{bmm_type}, lhs, reshaped_rhs,
rewriter.getBoolAttr(false), rewriter.getBoolAttr(false));
auto bmm_element_type = bmm_type.getElementType();
auto final_reshape =
createReshapeOp(bmm_op, {lhs_dim0, lhs_dim1, rhs_dim1, rhs_dim2},
bmm_element_type, loc, &rewriter);
rewriter.replaceOp(op, {final_reshape.getResult()});
return matchSuccess();
}
if (einsum_eqn == EinsumEquation::FourDMatrixDotProd) {
// Case "BFND,NDH->BFH"
// Reshape LHS
auto lhs_element_type = lhs_type.getElementType();
const int lhs_dim0 = lhs_shape[0];
const int lhs_dim1 = lhs_shape[1];
const int lhs_dim2 = lhs_shape[2];
const int lhs_dim3 = lhs_shape[3];
auto reshaped_lhs =
createReshapeOp(lhs, {lhs_dim0, lhs_dim1, lhs_dim2 * lhs_dim3},
lhs_element_type, loc, &rewriter);
// Reshape RHS
auto rhs_element_type = rhs_type.getElementType();
const int rhs_dim0 = rhs_shape[0];
const int rhs_dim1 = rhs_shape[1];
const int rhs_dim2 = rhs_shape[2];
auto reshaped_rhs = createReshapeOp(rhs, {rhs_dim0 * rhs_dim1, rhs_dim2},
rhs_element_type, loc, &rewriter);
auto bmm_op = rewriter.create<TF::BatchMatMulV2Op>(
loc, ArrayRef<Type>{output_type}, reshaped_lhs, reshaped_rhs,
rewriter.getBoolAttr(false), rewriter.getBoolAttr(false));
rewriter.replaceOp(op, {bmm_op.getResult()});
return matchSuccess();
}
if (einsum_eqn == EinsumEquation::FourDBatchMatMul) {
// Case "BFNH,BTNH->BNFT"
// Transpose LHS
lhs = createTransposeOp(lhs, loc, {0, 2, 1, 3}, &rewriter);
// Transpose RHS
rhs = createTransposeOp(rhs, loc, {0, 2, 3, 1}, &rewriter);
auto bmm_op = rewriter.create<TF::BatchMatMulV2Op>(
loc, ArrayRef<Type>{output_type}, lhs, rhs, rewriter.getBoolAttr(false),
rewriter.getBoolAttr(false));
rewriter.replaceOp(op, {bmm_op.getResult()});
return matchSuccess();
}
return matchFailure();
}
// Transform Einsum to other TF Ops for the supported variants.
struct TransformEinsumPass : public FunctionPass<TransformEinsumPass> {
void runOnFunction() override;
};
void TransformEinsumPass::runOnFunction() {
OwningRewritePatternList patterns;
auto func = getFunction();
patterns.insert<ConvertTFEinsumOp>(&getContext());
applyPatternsGreedily(func, patterns);
}
static PassRegistration<TransformEinsumPass> pass(
"tf-einsum", "Transform Einsum to other TF Ops for the supported variants");
} // namespace TF
} // namespace mlir

View File

@ -0,0 +1,55 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This pass identifies patterns for certain Einsum Ops and replaces them
// with other equivalent TF Ops.
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_EINSUM_H_
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_EINSUM_H_
#include <cstdint>
#include <initializer_list>
#include "llvm/ADT/ArrayRef.h"
#include "llvm/Support/Casting.h"
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Location.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Matchers.h" // TF:llvm-project
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/IR/TypeUtilities.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/core/util/matmul_bcast.h"
namespace mlir {
namespace TF {
// TF.Einsum provides fully general tensor contractions. For a few select
// cases, we can convert this op to other TF Ops, which in later passes
// properly convert to TF Lite ops.
struct ConvertTFEinsumOp : public OpRewritePattern<TF::EinsumOp> {
public:
explicit ConvertTFEinsumOp(MLIRContext* context)
: OpRewritePattern<TF::EinsumOp>(context) {}
PatternMatchResult matchAndRewrite(TF::EinsumOp op,
PatternRewriter& rewriter) const override;
};
} // namespace TF
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_EINSUM_H_