Initial support for Einsum op in TF Lite
PiperOrigin-RevId: 301159606 Change-Id: Ie8d9ef631ea64061423f95c06050b6329c844847
This commit is contained in:
parent
32f5b7dd76
commit
6072451f8b
@ -56,6 +56,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/einsum.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.h"
|
||||
|
||||
#define DEBUG_TYPE "tf-tfl-legalization"
|
||||
@ -655,8 +656,8 @@ void PrepareTFPass::runOnFunction() {
|
||||
patterns.insert<TF::ConvertTFBatchMatMulOp<TF::BatchMatMulOp>,
|
||||
TF::ConvertTFBatchMatMulOp<TF::BatchMatMulV2Op>>(ctx);
|
||||
}
|
||||
patterns.insert<ConvertTFConv2D, ConvertTFDepthwiseConv2dNative,
|
||||
ConvertTFStridedSlice>(ctx);
|
||||
patterns.insert<TF::ConvertTFEinsumOp, ConvertTFConv2D,
|
||||
ConvertTFDepthwiseConv2dNative, ConvertTFStridedSlice>(ctx);
|
||||
applyPatternsGreedily(func, patterns);
|
||||
}
|
||||
|
||||
|
@ -209,6 +209,7 @@ cc_library(
|
||||
"ir/tf_traits.h",
|
||||
"ir/tf_verifiers.h",
|
||||
"transforms/bridge.h",
|
||||
"transforms/einsum.h",
|
||||
"transforms/passes.h",
|
||||
"transforms/unroll_batch_matmul.h",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/CallInterfaces.h",
|
||||
@ -310,6 +311,7 @@ cc_library(
|
||||
"transforms/cluster_outlining.cc",
|
||||
"transforms/collection_ops_util.cc",
|
||||
"transforms/decompose_resource_ops_pass.cc",
|
||||
"transforms/einsum.cc",
|
||||
"transforms/executor_island_coarsening.cc",
|
||||
"transforms/executor_tpuv1_inline_tpu_island.cc",
|
||||
"transforms/executor_tpuv1_island_coarsening.cc",
|
||||
@ -357,6 +359,7 @@ cc_library(
|
||||
"transforms/batchmatmul_to_einsum.h",
|
||||
"transforms/bridge.h",
|
||||
"transforms/collection_ops_util.h",
|
||||
"transforms/einsum.h",
|
||||
"transforms/passes.h",
|
||||
"transforms/shape_inference.h",
|
||||
],
|
||||
|
57
tensorflow/compiler/mlir/tensorflow/tests/einsum.mlir
Normal file
57
tensorflow/compiler/mlir/tensorflow/tests/einsum.mlir
Normal file
@ -0,0 +1,57 @@
|
||||
// RUN: tf-opt -split-input-file -verify-diagnostics -tf-einsum %s | FileCheck %s
|
||||
|
||||
func @einsum_basic(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32>) -> tensor<3x4x6xf32> {
|
||||
%0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "ijk,ikm->ijm"}: (tensor<3x4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32>
|
||||
return %0 : tensor<3x4x6xf32>
|
||||
// CHECK-LABEL: einsum_basic
|
||||
// CHECK: "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<3x4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32>
|
||||
}
|
||||
|
||||
func @einsum_4D(%arg0: tensor<2x5x7x3xf32>, %arg1: tensor<2x4x7x3xf32>) -> tensor<2x7x5x4xf32> {
|
||||
%0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "bfnh,btnh->bnft"}: (tensor<2x5x7x3xf32>, tensor<2x4x7x3xf32>) -> tensor<2x7x5x4xf32>
|
||||
return %0 : tensor<2x7x5x4xf32>
|
||||
// CHECK-LABEL: einsum_4D
|
||||
// CHECK: %[[cst:.*]] = constant dense<[0, 2, 1, 3]> : tensor<4xi32>
|
||||
// CHECK: %[[cst_1:.*]] = constant dense<[0, 2, 3, 1]> : tensor<4xi32>
|
||||
// CHECK: %[[v0:.*]] = "tf.Transpose"(%arg0, %[[cst]]) : (tensor<2x5x7x3xf32>, tensor<4xi32>) -> tensor<2x7x5x3xf32>
|
||||
// CHECK: %[[v1:.*]] = "tf.Transpose"(%arg1, %[[cst_1]]) : (tensor<2x4x7x3xf32>, tensor<4xi32>) -> tensor<2x7x3x4xf32>
|
||||
// CHECK: "tf.BatchMatMulV2"(%[[v0]], %[[v1]]) {adj_x = false, adj_y = false} : (tensor<2x7x5x3xf32>, tensor<2x7x3x4xf32>) -> tensor<2x7x5x4xf32>
|
||||
}
|
||||
|
||||
func @einsum_matrixdotprod(%arg0: tensor<2x5x7x3xf32>, %arg1: tensor<7x3x4xf32>) -> tensor<2x5x4xf32> {
|
||||
%0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "bfnd,ndh->bfh"}: (tensor<2x5x7x3xf32>, tensor<7x3x4xf32>) -> tensor<2x5x4xf32>
|
||||
return %0 : tensor<2x5x4xf32>
|
||||
// CHECK-LABEL: einsum_matrixdotprod
|
||||
// CHECK: %[[cst:.*]] = constant dense<[2, 5, 21]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_1:.*]] = constant dense<[21, 4]> : tensor<2xi64>
|
||||
// CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<2x5x7x3xf32>, tensor<3xi64>) -> tensor<2x5x21xf32>
|
||||
// CHECK: %[[v1:.*]] = "tf.Reshape"(%arg1, %[[cst_1]]) : (tensor<7x3x4xf32>, tensor<2xi64>) -> tensor<21x4xf32>
|
||||
// CHECK: "tf.BatchMatMulV2"(%[[v0]], %[[v1]]) {adj_x = false, adj_y = false} : (tensor<2x5x21xf32>, tensor<21x4xf32>) -> tensor<2x5x4xf32>
|
||||
}
|
||||
|
||||
func @einsum_reshapetail(%arg0: tensor<3x4x5xf32>, %arg1: tensor<5x6x2xf32>) -> tensor<3x4x6x2xf32> {
|
||||
%0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "bfd,dnh->bfnh"}: (tensor<3x4x5xf32>, tensor<5x6x2xf32>) -> tensor<3x4x6x2xf32>
|
||||
return %0 : tensor<3x4x6x2xf32>
|
||||
// CHECK-LABEL: einsum_reshapetail
|
||||
// CHECK: %[[cst:.*]] = constant dense<[5, 12]> : tensor<2xi64>
|
||||
// CHECK: %[[cst_1:.*]] = constant dense<[3, 4, 6, 2]> : tensor<4xi64>
|
||||
// CHECK: %[[v0:.*]] = "tf.Reshape"(%arg1, %[[cst]]) : (tensor<5x6x2xf32>, tensor<2xi64>) -> tensor<5x12xf32>
|
||||
// CHECK: %[[v1:.*]] = "tf.BatchMatMulV2"(%arg0, %[[v0]]) {adj_x = false, adj_y = false} : (tensor<3x4x5xf32>, tensor<5x12xf32>) -> tensor<3x4x12xf32>
|
||||
// CHECK: %[[v2:.*]] = "tf.Reshape"(%[[v1]], %[[cst_1]]) : (tensor<3x4x12xf32>, tensor<4xi64>) -> tensor<3x4x6x2xf32>
|
||||
// CHECK: return %[[v2]] : tensor<3x4x6x2xf32>
|
||||
}
|
||||
|
||||
func @einsum_no_match(%arg0: tensor<4x5xf32>, %arg1: tensor<5xf32>) -> tensor<4xf32> {
|
||||
%0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "ij,j->i"}: (tensor<4x5xf32>, tensor<5xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
// CHECK-LABEL: einsum_no_match
|
||||
// CHECK: %[[v0:.*]] = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "ij,j->i"} : (tensor<4x5xf32>, tensor<5xf32>) -> tensor<4xf32>
|
||||
// CHECK: return %[[v0]]
|
||||
}
|
||||
func @einsum_illegal_no_match(%arg0: tensor<4x5xf32>, %arg1: tensor<5xf32>) -> tensor<4xf32> {
|
||||
%0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "ij,?zw->kq->i"}: (tensor<4x5xf32>, tensor<5xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
// CHECK-LABEL: einsum_illegal_no_match
|
||||
// CHECK: %[[v0:.*]] = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "ij,?zw->kq->i"} : (tensor<4x5xf32>, tensor<5xf32>) -> tensor<4xf32>
|
||||
// CHECK: return %[[v0]]
|
||||
}
|
296
tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc
Normal file
296
tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc
Normal file
@ -0,0 +1,296 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/einsum.h"
|
||||
|
||||
#include <climits>
|
||||
#include <cstdint>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/Regex.h"
|
||||
#include "mlir/Analysis/LoopAnalysis.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/OpImplementation.h" // TF:llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Support/Functional.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/core/util/matmul_bcast.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
|
||||
namespace {
|
||||
|
||||
// All supported Einsum equations.
|
||||
enum EinsumEquation {
|
||||
BatchMatMul,
|
||||
FourDMatrixDotProd,
|
||||
ThreeDReshapeTail,
|
||||
FourDBatchMatMul,
|
||||
UnsupportedEquation
|
||||
};
|
||||
|
||||
// Tokens for parsing the given equation string.
|
||||
enum EquationToken {
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D,
|
||||
E,
|
||||
COMMA,
|
||||
ARROW,
|
||||
};
|
||||
inline constexpr int kNumSupportedEquationVariables = 5; // A - E for now.
|
||||
|
||||
bool tokenizeEquation(const llvm::StringRef& equation,
|
||||
std::vector<EquationToken>* tokens) {
|
||||
std::map<char, EquationToken> label_axis_mapping;
|
||||
int index = 0;
|
||||
int variable_count = 0;
|
||||
llvm::Regex r("[[:alpha:]]");
|
||||
while (index < equation.size()) {
|
||||
if (r.match(equation.substr(index, 1))) {
|
||||
const char ltr = equation[index];
|
||||
auto itr = label_axis_mapping.find(ltr);
|
||||
if (itr == label_axis_mapping.end() &&
|
||||
variable_count < kNumSupportedEquationVariables) {
|
||||
label_axis_mapping[ltr] = EquationToken(variable_count);
|
||||
tokens->push_back(EquationToken(variable_count));
|
||||
variable_count++;
|
||||
} else if (itr != label_axis_mapping.end()) {
|
||||
tokens->push_back(itr->second);
|
||||
} else {
|
||||
// Ran out of equation variables.
|
||||
return false;
|
||||
}
|
||||
} else if (equation.substr(index, 1).contains(",")) {
|
||||
tokens->push_back(COMMA);
|
||||
} else if ((index < (equation.size() - 1)) &&
|
||||
(equation.substr(index, 2).contains("->"))) {
|
||||
tokens->push_back(ARROW);
|
||||
index++;
|
||||
} else {
|
||||
// Unallowed character encountered.
|
||||
return false;
|
||||
}
|
||||
index++;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
EinsumEquation parseEquation(const std::vector<EquationToken>& eqn) {
|
||||
auto is_equal = [](const std::vector<EquationToken>& eqn1,
|
||||
const std::initializer_list<EquationToken>& eqn2) {
|
||||
return std::equal(eqn1.begin(), eqn1.end(), eqn2.begin(), eqn2.end());
|
||||
};
|
||||
// IJK,IKM->IJM
|
||||
if (is_equal(eqn, {A, B, C, COMMA, A, C, D, ARROW, A, B, D})) {
|
||||
return EinsumEquation::BatchMatMul;
|
||||
}
|
||||
// BFND,NDH->BFH
|
||||
if (is_equal(eqn, {A, B, C, D, COMMA, C, D, E, ARROW, A, B, E})) {
|
||||
return EinsumEquation::FourDMatrixDotProd;
|
||||
}
|
||||
// BFNH,BTNH->BNFT
|
||||
if (is_equal(eqn, {A, B, C, D, COMMA, A, E, C, D, ARROW, A, C, B, E})) {
|
||||
return EinsumEquation::FourDBatchMatMul;
|
||||
}
|
||||
// BFD,DNH->BFNH
|
||||
if (is_equal(eqn, {A, B, C, COMMA, C, D, E, ARROW, A, B, D, E})) {
|
||||
return EinsumEquation::ThreeDReshapeTail;
|
||||
}
|
||||
return EinsumEquation::UnsupportedEquation;
|
||||
}
|
||||
|
||||
EinsumEquation tokenizeAndParse(const llvm::StringRef& equation) {
|
||||
std::vector<EquationToken> tokens;
|
||||
if (tokenizeEquation(equation, &tokens)) {
|
||||
return parseEquation(tokens);
|
||||
}
|
||||
return EinsumEquation::UnsupportedEquation;
|
||||
}
|
||||
|
||||
TF::TransposeOp createTransposeOp(Value value, Location loc,
|
||||
llvm::ArrayRef<int32_t> permutation,
|
||||
PatternRewriter* rewriter) {
|
||||
auto value_type = value.getType().cast<RankedTensorType>();
|
||||
auto shape = value_type.getShape();
|
||||
auto perm_type = RankedTensorType::get(
|
||||
{static_cast<int32_t>(permutation.size())}, rewriter->getIntegerType(32));
|
||||
auto perm_attr = DenseElementsAttr::get(perm_type, permutation);
|
||||
auto perm_op = rewriter->create<ConstantOp>(loc, perm_type, perm_attr);
|
||||
std::vector<int64_t> transposed_shape(shape.begin(), shape.end());
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
transposed_shape[i] = shape[permutation[i]];
|
||||
}
|
||||
auto transposed_type =
|
||||
RankedTensorType::get(transposed_shape, value_type.getElementType());
|
||||
return rewriter->create<TF::TransposeOp>(loc, transposed_type, value,
|
||||
perm_op);
|
||||
}
|
||||
|
||||
TF::ReshapeOp createReshapeOp(Value value, ArrayRef<int64_t> shape,
|
||||
Type element_type, Location loc,
|
||||
PatternRewriter* rewriter) {
|
||||
int64_t shape_rank = shape.size();
|
||||
auto shape_spec_type =
|
||||
RankedTensorType::get({shape_rank}, rewriter->getIntegerType(64));
|
||||
Type resultType = RankedTensorType::get(shape, element_type);
|
||||
auto constant_attr = DenseElementsAttr::get(shape_spec_type, shape);
|
||||
auto shape_tensor =
|
||||
rewriter->create<ConstantOp>(loc, shape_spec_type, constant_attr);
|
||||
return rewriter->create<TF::ReshapeOp>(loc, resultType, /*tensor=*/value,
|
||||
/*shape=*/shape_tensor);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
PatternMatchResult ConvertTFEinsumOp::matchAndRewrite(
|
||||
TF::EinsumOp op, PatternRewriter& rewriter) const {
|
||||
Type output_type = op.getResult().getType();
|
||||
Value lhs = op.getOperand(0);
|
||||
Value rhs = op.getOperand(1);
|
||||
Location loc = op.getLoc();
|
||||
|
||||
if (!lhs.getType().isa<RankedTensorType>()) {
|
||||
// LHS must be a ranked tensor type
|
||||
return matchFailure();
|
||||
}
|
||||
if (!rhs.getType().isa<RankedTensorType>()) {
|
||||
// RHS must be a ranked tensor type
|
||||
return matchFailure();
|
||||
}
|
||||
|
||||
auto lhs_type = lhs.getType().cast<RankedTensorType>();
|
||||
auto rhs_type = rhs.getType().cast<RankedTensorType>();
|
||||
auto lhs_shape = lhs_type.getShape();
|
||||
auto rhs_shape = rhs_type.getShape();
|
||||
|
||||
// Currently only support static shapes.
|
||||
if (!(lhs_type.hasStaticShape() && rhs_type.hasStaticShape())) {
|
||||
return matchFailure();
|
||||
}
|
||||
|
||||
// Currently support use cases of LHS, RHS dims = 3 or 4
|
||||
const int dims_lhs = lhs_shape.size();
|
||||
const int dims_rhs = rhs_shape.size();
|
||||
if (dims_rhs < 3 || dims_rhs > 4 || dims_lhs < 3 || dims_lhs > 4) {
|
||||
return matchFailure();
|
||||
}
|
||||
|
||||
EinsumEquation einsum_eqn = tokenizeAndParse(op.equation());
|
||||
if (einsum_eqn == EinsumEquation::BatchMatMul) {
|
||||
// Case "IJK,IKM->IJM"
|
||||
auto bmm_op = rewriter.create<TF::BatchMatMulV2Op>(
|
||||
loc, ArrayRef<Type>{output_type}, lhs, rhs, rewriter.getBoolAttr(false),
|
||||
rewriter.getBoolAttr(false));
|
||||
rewriter.replaceOp(op, bmm_op.getResult());
|
||||
return matchSuccess();
|
||||
}
|
||||
if (einsum_eqn == EinsumEquation::ThreeDReshapeTail) {
|
||||
// Case "BFD,DNH->BFNH"
|
||||
auto lhs_type = lhs.getType().cast<RankedTensorType>();
|
||||
auto lhs_shape = lhs_type.getShape();
|
||||
const int lhs_dim0 = lhs_shape[0];
|
||||
const int lhs_dim1 = lhs_shape[1];
|
||||
// Reshape RHS
|
||||
auto rhs_type = rhs.getType().cast<RankedTensorType>();
|
||||
auto rhs_shape = rhs_type.getShape();
|
||||
auto rhs_element_type = rhs_type.getElementType();
|
||||
const int rhs_dim0 = rhs_shape[0];
|
||||
const int rhs_dim1 = rhs_shape[1];
|
||||
const int rhs_dim2 = rhs_shape[2];
|
||||
auto reshaped_rhs = createReshapeOp(rhs, {rhs_dim0, rhs_dim1 * rhs_dim2},
|
||||
rhs_element_type, loc, &rewriter);
|
||||
|
||||
std::vector<int64_t> bmm_shape = {lhs_dim0, lhs_dim1, rhs_dim1 * rhs_dim2};
|
||||
auto bmm_type = RankedTensorType::get(bmm_shape, rhs_type.getElementType());
|
||||
auto bmm_op = rewriter.create<TF::BatchMatMulV2Op>(
|
||||
loc, ArrayRef<Type>{bmm_type}, lhs, reshaped_rhs,
|
||||
rewriter.getBoolAttr(false), rewriter.getBoolAttr(false));
|
||||
auto bmm_element_type = bmm_type.getElementType();
|
||||
auto final_reshape =
|
||||
createReshapeOp(bmm_op, {lhs_dim0, lhs_dim1, rhs_dim1, rhs_dim2},
|
||||
bmm_element_type, loc, &rewriter);
|
||||
rewriter.replaceOp(op, {final_reshape.getResult()});
|
||||
return matchSuccess();
|
||||
}
|
||||
if (einsum_eqn == EinsumEquation::FourDMatrixDotProd) {
|
||||
// Case "BFND,NDH->BFH"
|
||||
// Reshape LHS
|
||||
auto lhs_element_type = lhs_type.getElementType();
|
||||
const int lhs_dim0 = lhs_shape[0];
|
||||
const int lhs_dim1 = lhs_shape[1];
|
||||
const int lhs_dim2 = lhs_shape[2];
|
||||
const int lhs_dim3 = lhs_shape[3];
|
||||
auto reshaped_lhs =
|
||||
createReshapeOp(lhs, {lhs_dim0, lhs_dim1, lhs_dim2 * lhs_dim3},
|
||||
lhs_element_type, loc, &rewriter);
|
||||
// Reshape RHS
|
||||
auto rhs_element_type = rhs_type.getElementType();
|
||||
const int rhs_dim0 = rhs_shape[0];
|
||||
const int rhs_dim1 = rhs_shape[1];
|
||||
const int rhs_dim2 = rhs_shape[2];
|
||||
auto reshaped_rhs = createReshapeOp(rhs, {rhs_dim0 * rhs_dim1, rhs_dim2},
|
||||
rhs_element_type, loc, &rewriter);
|
||||
auto bmm_op = rewriter.create<TF::BatchMatMulV2Op>(
|
||||
loc, ArrayRef<Type>{output_type}, reshaped_lhs, reshaped_rhs,
|
||||
rewriter.getBoolAttr(false), rewriter.getBoolAttr(false));
|
||||
rewriter.replaceOp(op, {bmm_op.getResult()});
|
||||
return matchSuccess();
|
||||
}
|
||||
if (einsum_eqn == EinsumEquation::FourDBatchMatMul) {
|
||||
// Case "BFNH,BTNH->BNFT"
|
||||
// Transpose LHS
|
||||
lhs = createTransposeOp(lhs, loc, {0, 2, 1, 3}, &rewriter);
|
||||
// Transpose RHS
|
||||
rhs = createTransposeOp(rhs, loc, {0, 2, 3, 1}, &rewriter);
|
||||
auto bmm_op = rewriter.create<TF::BatchMatMulV2Op>(
|
||||
loc, ArrayRef<Type>{output_type}, lhs, rhs, rewriter.getBoolAttr(false),
|
||||
rewriter.getBoolAttr(false));
|
||||
rewriter.replaceOp(op, {bmm_op.getResult()});
|
||||
return matchSuccess();
|
||||
}
|
||||
return matchFailure();
|
||||
}
|
||||
|
||||
// Transform Einsum to other TF Ops for the supported variants.
|
||||
struct TransformEinsumPass : public FunctionPass<TransformEinsumPass> {
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
void TransformEinsumPass::runOnFunction() {
|
||||
OwningRewritePatternList patterns;
|
||||
auto func = getFunction();
|
||||
|
||||
patterns.insert<ConvertTFEinsumOp>(&getContext());
|
||||
applyPatternsGreedily(func, patterns);
|
||||
}
|
||||
|
||||
static PassRegistration<TransformEinsumPass> pass(
|
||||
"tf-einsum", "Transform Einsum to other TF Ops for the supported variants");
|
||||
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
55
tensorflow/compiler/mlir/tensorflow/transforms/einsum.h
Normal file
55
tensorflow/compiler/mlir/tensorflow/transforms/einsum.h
Normal file
@ -0,0 +1,55 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
// This pass identifies patterns for certain Einsum Ops and replaces them
|
||||
// with other equivalent TF Ops.
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_EINSUM_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_EINSUM_H_
|
||||
|
||||
#include <cstdint>
|
||||
#include <initializer_list>
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Location.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Matchers.h" // TF:llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/IR/TypeUtilities.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/core/util/matmul_bcast.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
|
||||
// TF.Einsum provides fully general tensor contractions. For a few select
|
||||
// cases, we can convert this op to other TF Ops, which in later passes
|
||||
// properly convert to TF Lite ops.
|
||||
struct ConvertTFEinsumOp : public OpRewritePattern<TF::EinsumOp> {
|
||||
public:
|
||||
explicit ConvertTFEinsumOp(MLIRContext* context)
|
||||
: OpRewritePattern<TF::EinsumOp>(context) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(TF::EinsumOp op,
|
||||
PatternRewriter& rewriter) const override;
|
||||
};
|
||||
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_EINSUM_H_
|
Loading…
Reference in New Issue
Block a user