Convert additional Einsum formulas
PiperOrigin-RevId: 303742280 Change-Id: I8965a5d534ab730058b65cad94a6cf634f466da4
This commit is contained in:
parent
82ff6702ab
commit
b43f1e0072
@ -7,6 +7,37 @@ func @einsum_basic(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32>) -> tensor
|
||||
// CHECK: "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<3x4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32>
|
||||
}
|
||||
|
||||
func @einsum_broadcast(%arg0: tensor<3x4x5xf32>, %arg1: tensor<5x6xf32>) -> tensor<3x4x6xf32> {
|
||||
%0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "ijk,km->ijm"}: (tensor<3x4x5xf32>, tensor<5x6xf32>) -> tensor<3x4x6xf32>
|
||||
return %0 : tensor<3x4x6xf32>
|
||||
// CHECK-LABEL: einsum_broadcast
|
||||
// CHECK: "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<3x4x5xf32>, tensor<5x6xf32>) -> tensor<3x4x6xf32>
|
||||
}
|
||||
|
||||
func @einsum_reducesum(%arg0: tensor<2x5x7xf32>, %arg1: tensor<5x2xf32>) -> tensor<5x7xf32> {
|
||||
%0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "lbh,bl->bh"}: (tensor<2x5x7xf32>, tensor<5x2xf32>) -> tensor<5x7xf32>
|
||||
return %0 : tensor<5x7xf32>
|
||||
// CHECK-LABEL: einsum_reducesum
|
||||
// CHECK: %[[cst:.*]] = constant dense<[1, 2, 0]> : tensor<3xi32>
|
||||
// CHECK: %[[cst_1:.*]] = constant dense<[5, 1, 2]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_2:.*]] = constant dense<2> : tensor<1xi32>
|
||||
// CHECK: %[[v0:.*]] = "tf.Transpose"(%arg0, %[[cst]]) : (tensor<2x5x7xf32>, tensor<3xi32>) -> tensor<5x7x2xf32>
|
||||
// CHECK: %[[v1:.*]] = "tf.Reshape"(%arg1, %[[cst_1]]) : (tensor<5x2xf32>, tensor<3xi64>) -> tensor<5x1x2xf32>
|
||||
// CHECK: %[[v2:.*]] = "tf.Mul"(%[[v0]], %[[v1]]) : (tensor<5x7x2xf32>, tensor<5x1x2xf32>) -> tensor<5x7x2xf32>
|
||||
// CHECK: "tf.Sum"(%[[v2]], %[[cst_2]]) {keep_dims = false} : (tensor<5x7x2xf32>, tensor<1xi32>) -> tensor<5x7xf32>
|
||||
}
|
||||
func @einsum_transpose_matmul(%arg0: tensor<2x5x7xf32>, %arg1: tensor<5x3x2xf32>) -> tensor<5x3x7xf32> {
|
||||
%0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "lbh,bkl->bkh"}: (tensor<2x5x7xf32>, tensor<5x3x2xf32>) -> tensor<5x3x7xf32>
|
||||
return %0 : tensor<5x3x7xf32>
|
||||
// CHECK-LABEL: einsum_transpose_matmul
|
||||
// CHECK: %[[cst:.*]] = constant dense<[1, 2, 0]> : tensor<3xi32>
|
||||
// CHECK: %[[cst_0:.*]] = constant dense<[0, 2, 1]> : tensor<3xi32>
|
||||
// CHECK: %[[v0:.*]] = "tf.Transpose"(%arg0, %[[cst]]) : (tensor<2x5x7xf32>, tensor<3xi32>) -> tensor<5x7x2xf32>
|
||||
// CHECK: %[[v1:.*]] = "tf.Transpose"(%arg1, %[[cst_0]]) : (tensor<5x3x2xf32>, tensor<3xi32>) -> tensor<5x2x3xf32>
|
||||
// CHECK: %[[v2:.*]] = "tf.BatchMatMulV2"(%[[v0]], %[[v1]]) {adj_x = false, adj_y = false} : (tensor<5x7x2xf32>, tensor<5x2x3xf32>) -> tensor<5x7x3xf32>
|
||||
// CHECK: %[[v3:.*]] = "tf.Transpose"(%[[v2]], %[[cst_0]]) : (tensor<5x7x3xf32>, tensor<3xi32>) -> tensor<5x3x7xf32>
|
||||
}
|
||||
|
||||
func @einsum_4D(%arg0: tensor<2x5x7x3xf32>, %arg1: tensor<2x4x7x3xf32>) -> tensor<2x7x5x4xf32> {
|
||||
%0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "bfnh,btnh->bnft"}: (tensor<2x5x7x3xf32>, tensor<2x4x7x3xf32>) -> tensor<2x7x5x4xf32>
|
||||
return %0 : tensor<2x7x5x4xf32>
|
||||
|
||||
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/einsum.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <climits>
|
||||
#include <cstdint>
|
||||
|
||||
@ -49,6 +50,9 @@ enum EinsumEquation {
|
||||
FourDMatrixDotProd,
|
||||
ThreeDReshapeTail,
|
||||
FourDBatchMatMul,
|
||||
BroadcastMatMul,
|
||||
ReduceSum,
|
||||
TransposeMatMul,
|
||||
UnsupportedEquation
|
||||
};
|
||||
|
||||
@ -121,6 +125,18 @@ EinsumEquation parseEquation(const std::vector<EquationToken>& eqn) {
|
||||
if (is_equal(eqn, {A, B, C, COMMA, C, D, E, ARROW, A, B, D, E})) {
|
||||
return EinsumEquation::ThreeDReshapeTail;
|
||||
}
|
||||
// BFH,HO->BFO
|
||||
if (is_equal(eqn, {A, B, C, COMMA, C, D, ARROW, A, B, D})) {
|
||||
return EinsumEquation::BroadcastMatMul;
|
||||
}
|
||||
// LBH,BL->BH
|
||||
if (is_equal(eqn, {A, B, C, COMMA, B, A, ARROW, B, C})) {
|
||||
return EinsumEquation::ReduceSum;
|
||||
}
|
||||
// LBH,BKL->BKH
|
||||
if (is_equal(eqn, {A, B, C, COMMA, B, D, A, ARROW, B, D, C})) {
|
||||
return EinsumEquation::TransposeMatMul;
|
||||
}
|
||||
return EinsumEquation::UnsupportedEquation;
|
||||
}
|
||||
|
||||
@ -151,6 +167,28 @@ TF::TransposeOp createTransposeOp(Value value, Location loc,
|
||||
perm_op);
|
||||
}
|
||||
|
||||
TF::SumOp createSumOp(Value value, Location loc,
|
||||
llvm::ArrayRef<int32_t> redux_axes,
|
||||
PatternRewriter* rewriter) {
|
||||
auto value_type = value.getType().cast<RankedTensorType>();
|
||||
auto shape = value_type.getShape();
|
||||
auto redux_type = RankedTensorType::get(
|
||||
{static_cast<int32_t>(redux_axes.size())}, rewriter->getIntegerType(32));
|
||||
auto redux_attr = DenseElementsAttr::get(redux_type, redux_axes);
|
||||
auto redux_op = rewriter->create<ConstantOp>(loc, redux_type, redux_attr);
|
||||
std::vector<int64_t> sum_shape(shape.size() - redux_axes.size());
|
||||
int count = 0;
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
if (std::find(redux_axes.begin(), redux_axes.end(), i) ==
|
||||
redux_axes.end()) {
|
||||
sum_shape[count] = shape[i];
|
||||
count++;
|
||||
}
|
||||
}
|
||||
auto sum_type = RankedTensorType::get(sum_shape, value_type.getElementType());
|
||||
return rewriter->create<TF::SumOp>(loc, sum_type, value, redux_op);
|
||||
}
|
||||
|
||||
TF::ReshapeOp createReshapeOp(Value value, ArrayRef<int64_t> shape,
|
||||
Type element_type, Location loc,
|
||||
PatternRewriter* rewriter) {
|
||||
@ -173,7 +211,6 @@ LogicalResult ConvertTFEinsumOp::matchAndRewrite(
|
||||
Value lhs = op.getOperand(0);
|
||||
Value rhs = op.getOperand(1);
|
||||
Location loc = op.getLoc();
|
||||
|
||||
if (!lhs.getType().isa<RankedTensorType>()) {
|
||||
// LHS must be a ranked tensor type
|
||||
return failure();
|
||||
@ -193,10 +230,10 @@ LogicalResult ConvertTFEinsumOp::matchAndRewrite(
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Currently support use cases of LHS, RHS dims = 3 or 4
|
||||
// Currently support use cases of LHS dims \in {3,4} RHS dims \in {2, 3, 4}
|
||||
const int dims_lhs = lhs_shape.size();
|
||||
const int dims_rhs = rhs_shape.size();
|
||||
if (dims_rhs < 3 || dims_rhs > 4 || dims_lhs < 3 || dims_lhs > 4) {
|
||||
if (dims_lhs < 3 || dims_lhs > 4 || dims_rhs < 2 || dims_lhs > 4) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
@ -209,6 +246,46 @@ LogicalResult ConvertTFEinsumOp::matchAndRewrite(
|
||||
rewriter.replaceOp(op, bmm_op.getResult());
|
||||
return success();
|
||||
}
|
||||
if (einsum_eqn == EinsumEquation::BroadcastMatMul) {
|
||||
// Case "BFH,HO->BFO"
|
||||
auto bmm_op = rewriter.create<TF::BatchMatMulV2Op>(
|
||||
loc, ArrayRef<Type>{output_type}, lhs, rhs, rewriter.getBoolAttr(false),
|
||||
rewriter.getBoolAttr(false));
|
||||
rewriter.replaceOp(op, bmm_op.getResult());
|
||||
return success();
|
||||
}
|
||||
if (einsum_eqn == EinsumEquation::ReduceSum) {
|
||||
// Case "LBH,BL->BH"
|
||||
// Transpose LHS
|
||||
lhs = createTransposeOp(lhs, loc, {1, 2, 0}, &rewriter);
|
||||
// Reshape RHS
|
||||
auto rhs_element_type = rhs_type.getElementType();
|
||||
const int rhs_dim0 = rhs_shape[0];
|
||||
const int rhs_dim1 = rhs_shape[1];
|
||||
auto reshaped_rhs = createReshapeOp(rhs, {rhs_dim0, 1, rhs_dim1},
|
||||
rhs_element_type, loc, &rewriter);
|
||||
auto mul_op = rewriter.create<TF::MulOp>(loc, lhs, reshaped_rhs);
|
||||
|
||||
auto sum_op = createSumOp(mul_op, loc, {2}, &rewriter);
|
||||
rewriter.replaceOp(op, {sum_op.getResult()});
|
||||
return success();
|
||||
}
|
||||
if (einsum_eqn == EinsumEquation::TransposeMatMul) {
|
||||
// Case "LBH,BKL->BKH"
|
||||
// Transpose LHS
|
||||
lhs = createTransposeOp(lhs, loc, {1, 2, 0}, &rewriter);
|
||||
// Transpose RHS
|
||||
rhs = createTransposeOp(rhs, loc, {0, 2, 1}, &rewriter);
|
||||
std::vector<int64_t> bmm_shape = {lhs_shape[1], lhs_shape[2], rhs_shape[1]};
|
||||
auto bmm_type = RankedTensorType::get(bmm_shape, rhs_type.getElementType());
|
||||
auto bmm_op = rewriter.create<TF::BatchMatMulV2Op>(
|
||||
loc, ArrayRef<Type>{bmm_type}, lhs, rhs, rewriter.getBoolAttr(false),
|
||||
rewriter.getBoolAttr(false));
|
||||
|
||||
auto trans_bmm = createTransposeOp(bmm_op, loc, {0, 2, 1}, &rewriter);
|
||||
rewriter.replaceOp(op, {trans_bmm.getResult()});
|
||||
return success();
|
||||
}
|
||||
if (einsum_eqn == EinsumEquation::ThreeDReshapeTail) {
|
||||
// Case "BFD,DNH->BFNH"
|
||||
auto lhs_type = lhs.getType().cast<RankedTensorType>();
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user