Add XLA lowering pattern for TensorFlow MatMul op

This only handles the simplest case of no transpose. MNIST model only uses this case.

PiperOrigin-RevId: 266990244
This commit is contained in:
Smit Hinsu 2019-09-03 12:40:16 -07:00 committed by TensorFlower Gardener
parent 99f8d44812
commit 17099181cd
2 changed files with 21 additions and 0 deletions

View File

@ -165,6 +165,18 @@ func @const() -> tensor<2xi32> {
return %0: tensor<2xi32> return %0: tensor<2xi32>
} }
//===----------------------------------------------------------------------===//
// Matmul op legalizations.
//===----------------------------------------------------------------------===//
// CHECK-LABEL: matmul_notranspose
func @matmul_notranspose(%arg0: tensor<5x7xf32>, %arg1: tensor<7x11xf32>) -> tensor<5x11xf32> {
// CHECK: "xla_hlo.dot"(%arg0, %arg1)
%0 = "tf.MatMul"(%arg0, %arg1) {transpose_a = false, transpose_b = false} : (tensor<5x7xf32>, tensor<7x11xf32>) -> tensor<5x11xf32>
return %0 : tensor<5x11xf32>
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Relu op legalizations. // Relu op legalizations.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -20,6 +20,7 @@ include "mlir/Dialect/StandardOps/Ops.td"
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td" include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td"
def NullArrayAttr : NativeCodeCall<"ArrayAttr()">;
def NullElementsAttr : NativeCodeCall<"ElementsAttr()">; def NullElementsAttr : NativeCodeCall<"ElementsAttr()">;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -92,6 +93,14 @@ foreach fromToBinPair = [[TF_AddOp, HLO_AddOp],
def : Pat<(TF_IdentityOp $op), (replaceWithValue $op)>; def : Pat<(TF_IdentityOp $op), (replaceWithValue $op)>;
//===----------------------------------------------------------------------===//
// Matmul op patterns.
//===----------------------------------------------------------------------===//
// TODO(hinsu): Lower matmul ops with transpose attributes.
def : Pat<(TF_MatMulOp $a, $b, ConstBoolAttrFalse, ConstBoolAttrFalse),
(HLO_DotOp $a, $b, (NullArrayAttr))>;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Nullary op patterns. // Nullary op patterns.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//