diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index 5b45862a2b3..0328761becc 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -165,6 +165,18 @@ func @const() -> tensor<2xi32> { return %0: tensor<2xi32> } +//===----------------------------------------------------------------------===// +// Matmul op legalizations. +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: matmul_notranspose +func @matmul_notranspose(%arg0: tensor<5x7xf32>, %arg1: tensor<7x11xf32>) -> tensor<5x11xf32> { + // CHECK: "xla_hlo.dot"(%arg0, %arg1) + %0 = "tf.MatMul"(%arg0, %arg1) {transpose_a = false, transpose_b = false} : (tensor<5x7xf32>, tensor<7x11xf32>) -> tensor<5x11xf32> + + return %0 : tensor<5x11xf32> +} + //===----------------------------------------------------------------------===// // Relu op legalizations. //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td index 1730e5374a4..d67f7b0c5fd 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td @@ -20,6 +20,7 @@ include "mlir/Dialect/StandardOps/Ops.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td" +def NullArrayAttr : NativeCodeCall<"ArrayAttr()">; def NullElementsAttr : NativeCodeCall<"ElementsAttr()">; //===----------------------------------------------------------------------===// @@ -92,6 +93,14 @@ foreach fromToBinPair = [[TF_AddOp, HLO_AddOp], def : Pat<(TF_IdentityOp $op), (replaceWithValue $op)>; +//===----------------------------------------------------------------------===// +// Matmul op patterns. +//===----------------------------------------------------------------------===// + +// TODO(hinsu): Lower matmul ops with transpose attributes. +def : Pat<(TF_MatMulOp $a, $b, ConstBoolAttrFalse, ConstBoolAttrFalse), + (HLO_DotOp $a, $b, (NullArrayAttr))>; + //===----------------------------------------------------------------------===// // Nullary op patterns. //===----------------------------------------------------------------------===//