Support all tf.MatMul in the TFL legalization pass

PiperOrigin-RevId: 328911250
Change-Id: I70bcad27c107c069ca7c1daac96d33cc962b8be1
This commit is contained in:
A. Unique TensorFlower 2020-08-28 03:28:47 -07:00 committed by TensorFlower Gardener
parent 92cf015f8a
commit 0f17988fd9
2 changed files with 70 additions and 15 deletions

View File

@ -1029,14 +1029,49 @@ func @splitv(%arg0: tensor<1x4x3x3xf32>, %arg1: tensor<2xi32>, %arg2: tensor<i32
// CHECK: "tfl.split_v"(%arg0, %arg1, %arg2) {num_splits = 2 : i32} : (tensor<1x4x3x3xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<1x4x2x3xf32>, tensor<1x4x1x3xf32>)
}
func @matmul_transposed(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
func @matmul(%arg0: tensor<40x37xf32>, %arg1: tensor<37x40xf32>) -> tensor<40x40xf32> {
%0 = "tf.MatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", transpose_a = false, transpose_b = false} :
(tensor<40x37xf32>, tensor<37x40xf32>) -> tensor<40x40xf32>
return %0 : tensor<40x40xf32>
// CHECK-LABEL: matmul
// CHECK: %[[CST:.*]] = constant dense<[1, 0]> : tensor<2xi32>
// CHECK: %[[ARG:.*]] = "tfl.transpose"(%arg1, %[[CST]]) : (tensor<37x40xf32>, tensor<2xi32>) -> tensor<40x37xf32>
// CHECK: %[[CST_0:.*]] = constant unit
// CHECK: "tfl.fully_connected"(%arg0, %[[ARG]], %[[CST_0]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> tensor<40x40xf32>
}
func @matmul_transposed_a(%arg0: tensor<37x40xf32>, %arg1: tensor<37x40xf32>) -> tensor<40x40xf32> {
%0 = "tf.MatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", transpose_a = true, transpose_b = false} :
(tensor<37x40xf32>, tensor<37x40xf32>) -> tensor<40x40xf32>
return %0 : tensor<40x40xf32>
// CHECK-LABEL: matmul_transposed_a
// CHECK: %[[CST_0:.*]] = constant dense<[1, 0]> : tensor<2xi32>
// CHECK: %[[ARG_0:.*]] = "tfl.transpose"(%arg0, %[[CST_0]]) : (tensor<37x40xf32>, tensor<2xi32>) -> tensor<40x37xf32>
// CHECK: %[[CST_1:.*]] = constant dense<[1, 0]> : tensor<2xi32>
// CHECK: %[[ARG_1:.*]] = "tfl.transpose"(%arg1, %[[CST_1]]) : (tensor<37x40xf32>, tensor<2xi32>) -> tensor<40x37xf32>
// CHECK: %[[CST_2:.*]] = constant unit
// CHECK: "tfl.fully_connected"(%[[ARG_0]], %[[ARG_1]], %[[CST_2]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> tensor<40x40xf32>
}
func @matmul_transposed_b(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
%0 = "tf.MatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", transpose_a = false, transpose_b = true} :
(tensor<40x37xf32>, tensor<40x37xf32>) -> tensor<40x40xf32>
return %0 : tensor<40x40xf32>
// CHECK-LABEL: matmul_transposed
// CHECK-LABEL: matmul_transposed_b
// CHECK: "tfl.fully_connected"(%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> tensor<40x40xf32>
}
func @matmul_transposed_ab(%arg0: tensor<37x40xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
%0 = "tf.MatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", transpose_a = true, transpose_b = true} :
(tensor<37x40xf32>, tensor<40x37xf32>) -> tensor<40x40xf32>
return %0 : tensor<40x40xf32>
// CHECK-LABEL: matmul_transposed_ab
// CHECK: %[[CST_0:.*]] = constant dense<[1, 0]> : tensor<2xi32>
// CHECK: %[[ARG_0:.*]] = "tfl.transpose"(%arg0, %[[CST_0]]) : (tensor<37x40xf32>, tensor<2xi32>) -> tensor<40x37xf32>
// CHECK: %[[CST_1:.*]] = constant unit
// CHECK: "tfl.fully_connected"(%[[ARG_0]], %arg1, %[[CST_1]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> tensor<40x40xf32>
}
func @concatv2With3Tensors(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2: tensor<2x1xi32>) -> tensor<2x3xi32> {
%0 = "tf.Const"() { value = dense<-1> : tensor<i32> } : () -> tensor<i32>
%1 = "tf.ConcatV2"(%arg0, %arg1, %arg2, %0) : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>, tensor<i32>) -> tensor<2x3xi32>

View File

@ -66,7 +66,6 @@ namespace TFL {
// The actual LegalizeTF Pass.
namespace {
using xla::Status;
using xla::StatusOr;
constexpr char kUnidirectionalSequenceLstm[] = "tf.UnidirectionalSequenceLstm";
@ -232,26 +231,47 @@ LogicalResult ConvertTFConcatV2Op::matchAndRewrite(
return success();
}
// The following is effectively:
// def : Pat<
// (TF_MatMulOp $a, $b, ConstBoolAttrFalse:$transpose_a,
// ConstBoolAttrTrue:$transpose_b),
// (TFL_FullyConnectedOp:$__0 $a, $b,
// NoInput.pattern, TFL_AF_None, TFL_FCWO_Default, ConstBoolAttrFalse)>;
LogicalResult ConvertTFMatMulOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tf_matmul_op = cast<TF::MatMulOp>(op);
if (tf_matmul_op.transpose_a()) return failure();
if (!tf_matmul_op.transpose_b()) return failure();
auto lhs = op->getOperand(0);
auto rhs = op->getOperand(1);
auto transpose = [&](Value input) -> std::pair<LogicalResult, Value> {
RankedTensorType type =
input.getType().dyn_cast_or_null<RankedTensorType>();
if (!type || type.getRank() != 2) return {failure(), nullptr};
auto permute_attr = DenseIntElementsAttr::get(
RankedTensorType::get({2}, rewriter.getI32Type()), {1, 0});
auto permute = rewriter.create<ConstantOp>(
op->getLoc(), permute_attr.getType(), permute_attr);
llvm::SmallVector<int64_t, 2> new_shape{type.getShape()[1],
type.getShape()[0]};
auto output = rewriter.create<TFL::TransposeOp>(
op->getLoc(), RankedTensorType::get(new_shape, type.getElementType()),
input, permute);
return {success(), output};
};
// TODO(jpienaar): Remove once handled via dailect conversion.
if (tf_matmul_op.transpose_a()) {
LogicalResult result = success();
std::tie(result, lhs) = transpose(lhs);
if (failed(result)) return failure();
}
if (!tf_matmul_op.transpose_b()) {
LogicalResult result = success();
std::tie(result, rhs) = transpose(rhs);
if (failed(result)) return failure();
}
Type output_type = tf_matmul_op.getResult().getType();
// TODO(jpienaar): Follow up post shuffle discussion.
auto no_input = rewriter.create<ConstantOp>(
op->getLoc(), rewriter.getNoneType(), rewriter.getUnitAttr());
auto fc_op = rewriter.create<FullyConnectedOp>(
op->getLoc(), ArrayRef<Type>{output_type}, op->getOperand(0),
op->getOperand(1), no_input, rewriter.getStringAttr("NONE"),
rewriter.getStringAttr("DEFAULT"), rewriter.getBoolAttr(false));
op->getLoc(), ArrayRef<Type>{output_type}, lhs, rhs, no_input,
rewriter.getStringAttr("NONE"), rewriter.getStringAttr("DEFAULT"),
rewriter.getBoolAttr(false));
rewriter.replaceOp(op, {fc_op.getResult(0)});
return success();
}