Convert additional Einsum formulas

PiperOrigin-RevId: 303742280
Change-Id: I8965a5d534ab730058b65cad94a6cf634f466da4
This commit is contained in:
T.J. Alumbaugh 2020-03-30 08:09:14 -07:00 committed by TensorFlower Gardener
parent 82ff6702ab
commit b43f1e0072
2 changed files with 111 additions and 3 deletions

View File

@ -7,6 +7,37 @@ func @einsum_basic(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32>) -> tensor
// CHECK: "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<3x4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32>
}
func @einsum_broadcast(%arg0: tensor<3x4x5xf32>, %arg1: tensor<5x6xf32>) -> tensor<3x4x6xf32> {
%0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "ijk,km->ijm"}: (tensor<3x4x5xf32>, tensor<5x6xf32>) -> tensor<3x4x6xf32>
return %0 : tensor<3x4x6xf32>
// CHECK-LABEL: einsum_broadcast
// CHECK: "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<3x4x5xf32>, tensor<5x6xf32>) -> tensor<3x4x6xf32>
}
func @einsum_reducesum(%arg0: tensor<2x5x7xf32>, %arg1: tensor<5x2xf32>) -> tensor<5x7xf32> {
%0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "lbh,bl->bh"}: (tensor<2x5x7xf32>, tensor<5x2xf32>) -> tensor<5x7xf32>
return %0 : tensor<5x7xf32>
// CHECK-LABEL: einsum_reducesum
// CHECK: %[[cst:.*]] = constant dense<[1, 2, 0]> : tensor<3xi32>
// CHECK: %[[cst_1:.*]] = constant dense<[5, 1, 2]> : tensor<3xi64>
// CHECK: %[[cst_2:.*]] = constant dense<2> : tensor<1xi32>
// CHECK: %[[v0:.*]] = "tf.Transpose"(%arg0, %[[cst]]) : (tensor<2x5x7xf32>, tensor<3xi32>) -> tensor<5x7x2xf32>
// CHECK: %[[v1:.*]] = "tf.Reshape"(%arg1, %[[cst_1]]) : (tensor<5x2xf32>, tensor<3xi64>) -> tensor<5x1x2xf32>
// CHECK: %[[v2:.*]] = "tf.Mul"(%[[v0]], %[[v1]]) : (tensor<5x7x2xf32>, tensor<5x1x2xf32>) -> tensor<5x7x2xf32>
// CHECK: "tf.Sum"(%[[v2]], %[[cst_2]]) {keep_dims = false} : (tensor<5x7x2xf32>, tensor<1xi32>) -> tensor<5x7xf32>
}
func @einsum_transpose_matmul(%arg0: tensor<2x5x7xf32>, %arg1: tensor<5x3x2xf32>) -> tensor<5x3x7xf32> {
%0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "lbh,bkl->bkh"}: (tensor<2x5x7xf32>, tensor<5x3x2xf32>) -> tensor<5x3x7xf32>
return %0 : tensor<5x3x7xf32>
// CHECK-LABEL: einsum_transpose_matmul
// CHECK: %[[cst:.*]] = constant dense<[1, 2, 0]> : tensor<3xi32>
// CHECK: %[[cst_0:.*]] = constant dense<[0, 2, 1]> : tensor<3xi32>
// CHECK: %[[v0:.*]] = "tf.Transpose"(%arg0, %[[cst]]) : (tensor<2x5x7xf32>, tensor<3xi32>) -> tensor<5x7x2xf32>
// CHECK: %[[v1:.*]] = "tf.Transpose"(%arg1, %[[cst_0]]) : (tensor<5x3x2xf32>, tensor<3xi32>) -> tensor<5x2x3xf32>
// CHECK: %[[v2:.*]] = "tf.BatchMatMulV2"(%[[v0]], %[[v1]]) {adj_x = false, adj_y = false} : (tensor<5x7x2xf32>, tensor<5x2x3xf32>) -> tensor<5x7x3xf32>
// CHECK: %[[v3:.*]] = "tf.Transpose"(%[[v2]], %[[cst_0]]) : (tensor<5x7x3xf32>, tensor<3xi32>) -> tensor<5x3x7xf32>
}
func @einsum_4D(%arg0: tensor<2x5x7x3xf32>, %arg1: tensor<2x4x7x3xf32>) -> tensor<2x7x5x4xf32> {
%0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "bfnh,btnh->bnft"}: (tensor<2x5x7x3xf32>, tensor<2x4x7x3xf32>) -> tensor<2x7x5x4xf32>
return %0 : tensor<2x7x5x4xf32>

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/transforms/einsum.h"
#include <algorithm>
#include <climits>
#include <cstdint>
@ -49,6 +50,9 @@ enum EinsumEquation {
FourDMatrixDotProd,
ThreeDReshapeTail,
FourDBatchMatMul,
BroadcastMatMul,
ReduceSum,
TransposeMatMul,
UnsupportedEquation
};
@ -121,6 +125,18 @@ EinsumEquation parseEquation(const std::vector<EquationToken>& eqn) {
if (is_equal(eqn, {A, B, C, COMMA, C, D, E, ARROW, A, B, D, E})) {
return EinsumEquation::ThreeDReshapeTail;
}
// BFH,HO->BFO
if (is_equal(eqn, {A, B, C, COMMA, C, D, ARROW, A, B, D})) {
return EinsumEquation::BroadcastMatMul;
}
// LBH,BL->BH
if (is_equal(eqn, {A, B, C, COMMA, B, A, ARROW, B, C})) {
return EinsumEquation::ReduceSum;
}
// LBH,BKL->BKH
if (is_equal(eqn, {A, B, C, COMMA, B, D, A, ARROW, B, D, C})) {
return EinsumEquation::TransposeMatMul;
}
return EinsumEquation::UnsupportedEquation;
}
@ -151,6 +167,28 @@ TF::TransposeOp createTransposeOp(Value value, Location loc,
perm_op);
}
TF::SumOp createSumOp(Value value, Location loc,
llvm::ArrayRef<int32_t> redux_axes,
PatternRewriter* rewriter) {
auto value_type = value.getType().cast<RankedTensorType>();
auto shape = value_type.getShape();
auto redux_type = RankedTensorType::get(
{static_cast<int32_t>(redux_axes.size())}, rewriter->getIntegerType(32));
auto redux_attr = DenseElementsAttr::get(redux_type, redux_axes);
auto redux_op = rewriter->create<ConstantOp>(loc, redux_type, redux_attr);
std::vector<int64_t> sum_shape(shape.size() - redux_axes.size());
int count = 0;
for (int i = 0; i < shape.size(); ++i) {
if (std::find(redux_axes.begin(), redux_axes.end(), i) ==
redux_axes.end()) {
sum_shape[count] = shape[i];
count++;
}
}
auto sum_type = RankedTensorType::get(sum_shape, value_type.getElementType());
return rewriter->create<TF::SumOp>(loc, sum_type, value, redux_op);
}
TF::ReshapeOp createReshapeOp(Value value, ArrayRef<int64_t> shape,
Type element_type, Location loc,
PatternRewriter* rewriter) {
@ -173,7 +211,6 @@ LogicalResult ConvertTFEinsumOp::matchAndRewrite(
Value lhs = op.getOperand(0);
Value rhs = op.getOperand(1);
Location loc = op.getLoc();
if (!lhs.getType().isa<RankedTensorType>()) {
// LHS must be a ranked tensor type
return failure();
@ -193,10 +230,10 @@ LogicalResult ConvertTFEinsumOp::matchAndRewrite(
return failure();
}
// Currently support use cases of LHS, RHS dims = 3 or 4
// Currently support use cases of LHS dims \in {3,4} RHS dims \in {2, 3, 4}
const int dims_lhs = lhs_shape.size();
const int dims_rhs = rhs_shape.size();
if (dims_rhs < 3 || dims_rhs > 4 || dims_lhs < 3 || dims_lhs > 4) {
if (dims_lhs < 3 || dims_lhs > 4 || dims_rhs < 2 || dims_lhs > 4) {
return failure();
}
@ -209,6 +246,46 @@ LogicalResult ConvertTFEinsumOp::matchAndRewrite(
rewriter.replaceOp(op, bmm_op.getResult());
return success();
}
if (einsum_eqn == EinsumEquation::BroadcastMatMul) {
// Case "BFH,HO->BFO"
auto bmm_op = rewriter.create<TF::BatchMatMulV2Op>(
loc, ArrayRef<Type>{output_type}, lhs, rhs, rewriter.getBoolAttr(false),
rewriter.getBoolAttr(false));
rewriter.replaceOp(op, bmm_op.getResult());
return success();
}
if (einsum_eqn == EinsumEquation::ReduceSum) {
// Case "LBH,BL->BH"
// Transpose LHS
lhs = createTransposeOp(lhs, loc, {1, 2, 0}, &rewriter);
// Reshape RHS
auto rhs_element_type = rhs_type.getElementType();
const int rhs_dim0 = rhs_shape[0];
const int rhs_dim1 = rhs_shape[1];
auto reshaped_rhs = createReshapeOp(rhs, {rhs_dim0, 1, rhs_dim1},
rhs_element_type, loc, &rewriter);
auto mul_op = rewriter.create<TF::MulOp>(loc, lhs, reshaped_rhs);
auto sum_op = createSumOp(mul_op, loc, {2}, &rewriter);
rewriter.replaceOp(op, {sum_op.getResult()});
return success();
}
if (einsum_eqn == EinsumEquation::TransposeMatMul) {
// Case "LBH,BKL->BKH"
// Transpose LHS
lhs = createTransposeOp(lhs, loc, {1, 2, 0}, &rewriter);
// Transpose RHS
rhs = createTransposeOp(rhs, loc, {0, 2, 1}, &rewriter);
std::vector<int64_t> bmm_shape = {lhs_shape[1], lhs_shape[2], rhs_shape[1]};
auto bmm_type = RankedTensorType::get(bmm_shape, rhs_type.getElementType());
auto bmm_op = rewriter.create<TF::BatchMatMulV2Op>(
loc, ArrayRef<Type>{bmm_type}, lhs, rhs, rewriter.getBoolAttr(false),
rewriter.getBoolAttr(false));
auto trans_bmm = createTransposeOp(bmm_op, loc, {0, 2, 1}, &rewriter);
rewriter.replaceOp(op, {trans_bmm.getResult()});
return success();
}
if (einsum_eqn == EinsumEquation::ThreeDReshapeTail) {
// Case "BFD,DNH->BFNH"
auto lhs_type = lhs.getType().cast<RankedTensorType>();