Support all tf.MatMul in the TFL legalization pass
PiperOrigin-RevId: 328911250 Change-Id: I70bcad27c107c069ca7c1daac96d33cc962b8be1
This commit is contained in:
parent
92cf015f8a
commit
0f17988fd9
@ -1029,14 +1029,49 @@ func @splitv(%arg0: tensor<1x4x3x3xf32>, %arg1: tensor<2xi32>, %arg2: tensor<i32
|
||||
// CHECK: "tfl.split_v"(%arg0, %arg1, %arg2) {num_splits = 2 : i32} : (tensor<1x4x3x3xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<1x4x2x3xf32>, tensor<1x4x1x3xf32>)
|
||||
}
|
||||
|
||||
func @matmul_transposed(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
|
||||
func @matmul(%arg0: tensor<40x37xf32>, %arg1: tensor<37x40xf32>) -> tensor<40x40xf32> {
|
||||
%0 = "tf.MatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", transpose_a = false, transpose_b = false} :
|
||||
(tensor<40x37xf32>, tensor<37x40xf32>) -> tensor<40x40xf32>
|
||||
return %0 : tensor<40x40xf32>
|
||||
// CHECK-LABEL: matmul
|
||||
// CHECK: %[[CST:.*]] = constant dense<[1, 0]> : tensor<2xi32>
|
||||
// CHECK: %[[ARG:.*]] = "tfl.transpose"(%arg1, %[[CST]]) : (tensor<37x40xf32>, tensor<2xi32>) -> tensor<40x37xf32>
|
||||
// CHECK: %[[CST_0:.*]] = constant unit
|
||||
// CHECK: "tfl.fully_connected"(%arg0, %[[ARG]], %[[CST_0]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> tensor<40x40xf32>
|
||||
}
|
||||
|
||||
func @matmul_transposed_a(%arg0: tensor<37x40xf32>, %arg1: tensor<37x40xf32>) -> tensor<40x40xf32> {
|
||||
%0 = "tf.MatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", transpose_a = true, transpose_b = false} :
|
||||
(tensor<37x40xf32>, tensor<37x40xf32>) -> tensor<40x40xf32>
|
||||
return %0 : tensor<40x40xf32>
|
||||
// CHECK-LABEL: matmul_transposed_a
|
||||
// CHECK: %[[CST_0:.*]] = constant dense<[1, 0]> : tensor<2xi32>
|
||||
// CHECK: %[[ARG_0:.*]] = "tfl.transpose"(%arg0, %[[CST_0]]) : (tensor<37x40xf32>, tensor<2xi32>) -> tensor<40x37xf32>
|
||||
// CHECK: %[[CST_1:.*]] = constant dense<[1, 0]> : tensor<2xi32>
|
||||
// CHECK: %[[ARG_1:.*]] = "tfl.transpose"(%arg1, %[[CST_1]]) : (tensor<37x40xf32>, tensor<2xi32>) -> tensor<40x37xf32>
|
||||
// CHECK: %[[CST_2:.*]] = constant unit
|
||||
// CHECK: "tfl.fully_connected"(%[[ARG_0]], %[[ARG_1]], %[[CST_2]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> tensor<40x40xf32>
|
||||
}
|
||||
|
||||
func @matmul_transposed_b(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
|
||||
%0 = "tf.MatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", transpose_a = false, transpose_b = true} :
|
||||
(tensor<40x37xf32>, tensor<40x37xf32>) -> tensor<40x40xf32>
|
||||
return %0 : tensor<40x40xf32>
|
||||
// CHECK-LABEL: matmul_transposed
|
||||
// CHECK-LABEL: matmul_transposed_b
|
||||
// CHECK: "tfl.fully_connected"(%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> tensor<40x40xf32>
|
||||
}
|
||||
|
||||
func @matmul_transposed_ab(%arg0: tensor<37x40xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
|
||||
%0 = "tf.MatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", transpose_a = true, transpose_b = true} :
|
||||
(tensor<37x40xf32>, tensor<40x37xf32>) -> tensor<40x40xf32>
|
||||
return %0 : tensor<40x40xf32>
|
||||
// CHECK-LABEL: matmul_transposed_ab
|
||||
// CHECK: %[[CST_0:.*]] = constant dense<[1, 0]> : tensor<2xi32>
|
||||
// CHECK: %[[ARG_0:.*]] = "tfl.transpose"(%arg0, %[[CST_0]]) : (tensor<37x40xf32>, tensor<2xi32>) -> tensor<40x37xf32>
|
||||
// CHECK: %[[CST_1:.*]] = constant unit
|
||||
// CHECK: "tfl.fully_connected"(%[[ARG_0]], %arg1, %[[CST_1]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> tensor<40x40xf32>
|
||||
}
|
||||
|
||||
func @concatv2With3Tensors(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2: tensor<2x1xi32>) -> tensor<2x3xi32> {
|
||||
%0 = "tf.Const"() { value = dense<-1> : tensor<i32> } : () -> tensor<i32>
|
||||
%1 = "tf.ConcatV2"(%arg0, %arg1, %arg2, %0) : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>, tensor<i32>) -> tensor<2x3xi32>
|
||||
|
@ -66,7 +66,6 @@ namespace TFL {
|
||||
// The actual LegalizeTF Pass.
|
||||
namespace {
|
||||
|
||||
using xla::Status;
|
||||
using xla::StatusOr;
|
||||
|
||||
constexpr char kUnidirectionalSequenceLstm[] = "tf.UnidirectionalSequenceLstm";
|
||||
@ -232,26 +231,47 @@ LogicalResult ConvertTFConcatV2Op::matchAndRewrite(
|
||||
return success();
|
||||
}
|
||||
|
||||
// The following is effectively:
|
||||
// def : Pat<
|
||||
// (TF_MatMulOp $a, $b, ConstBoolAttrFalse:$transpose_a,
|
||||
// ConstBoolAttrTrue:$transpose_b),
|
||||
// (TFL_FullyConnectedOp:$__0 $a, $b,
|
||||
// NoInput.pattern, TFL_AF_None, TFL_FCWO_Default, ConstBoolAttrFalse)>;
|
||||
LogicalResult ConvertTFMatMulOp::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
auto tf_matmul_op = cast<TF::MatMulOp>(op);
|
||||
if (tf_matmul_op.transpose_a()) return failure();
|
||||
if (!tf_matmul_op.transpose_b()) return failure();
|
||||
auto lhs = op->getOperand(0);
|
||||
auto rhs = op->getOperand(1);
|
||||
auto transpose = [&](Value input) -> std::pair<LogicalResult, Value> {
|
||||
RankedTensorType type =
|
||||
input.getType().dyn_cast_or_null<RankedTensorType>();
|
||||
if (!type || type.getRank() != 2) return {failure(), nullptr};
|
||||
|
||||
auto permute_attr = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({2}, rewriter.getI32Type()), {1, 0});
|
||||
auto permute = rewriter.create<ConstantOp>(
|
||||
op->getLoc(), permute_attr.getType(), permute_attr);
|
||||
llvm::SmallVector<int64_t, 2> new_shape{type.getShape()[1],
|
||||
type.getShape()[0]};
|
||||
auto output = rewriter.create<TFL::TransposeOp>(
|
||||
op->getLoc(), RankedTensorType::get(new_shape, type.getElementType()),
|
||||
input, permute);
|
||||
return {success(), output};
|
||||
};
|
||||
|
||||
// TODO(jpienaar): Remove once handled via dailect conversion.
|
||||
if (tf_matmul_op.transpose_a()) {
|
||||
LogicalResult result = success();
|
||||
std::tie(result, lhs) = transpose(lhs);
|
||||
if (failed(result)) return failure();
|
||||
}
|
||||
if (!tf_matmul_op.transpose_b()) {
|
||||
LogicalResult result = success();
|
||||
std::tie(result, rhs) = transpose(rhs);
|
||||
if (failed(result)) return failure();
|
||||
}
|
||||
|
||||
Type output_type = tf_matmul_op.getResult().getType();
|
||||
// TODO(jpienaar): Follow up post shuffle discussion.
|
||||
auto no_input = rewriter.create<ConstantOp>(
|
||||
op->getLoc(), rewriter.getNoneType(), rewriter.getUnitAttr());
|
||||
auto fc_op = rewriter.create<FullyConnectedOp>(
|
||||
op->getLoc(), ArrayRef<Type>{output_type}, op->getOperand(0),
|
||||
op->getOperand(1), no_input, rewriter.getStringAttr("NONE"),
|
||||
rewriter.getStringAttr("DEFAULT"), rewriter.getBoolAttr(false));
|
||||
op->getLoc(), ArrayRef<Type>{output_type}, lhs, rhs, no_input,
|
||||
rewriter.getStringAttr("NONE"), rewriter.getStringAttr("DEFAULT"),
|
||||
rewriter.getBoolAttr(false));
|
||||
rewriter.replaceOp(op, {fc_op.getResult(0)});
|
||||
return success();
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user