Add XLA lowering pattern for TensorFlow MatMul op
This only handles the simplest case of no transpose. MNIST model only uses this case. PiperOrigin-RevId: 266990244
This commit is contained in:
parent
99f8d44812
commit
17099181cd
@ -165,6 +165,18 @@ func @const() -> tensor<2xi32> {
|
||||
return %0: tensor<2xi32>
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Matmul op legalizations.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK-LABEL: matmul_notranspose
|
||||
func @matmul_notranspose(%arg0: tensor<5x7xf32>, %arg1: tensor<7x11xf32>) -> tensor<5x11xf32> {
|
||||
// CHECK: "xla_hlo.dot"(%arg0, %arg1)
|
||||
%0 = "tf.MatMul"(%arg0, %arg1) {transpose_a = false, transpose_b = false} : (tensor<5x7xf32>, tensor<7x11xf32>) -> tensor<5x11xf32>
|
||||
|
||||
return %0 : tensor<5x11xf32>
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Relu op legalizations.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -20,6 +20,7 @@ include "mlir/Dialect/StandardOps/Ops.td"
|
||||
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
|
||||
include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td"
|
||||
|
||||
def NullArrayAttr : NativeCodeCall<"ArrayAttr()">;
|
||||
def NullElementsAttr : NativeCodeCall<"ElementsAttr()">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -92,6 +93,14 @@ foreach fromToBinPair = [[TF_AddOp, HLO_AddOp],
|
||||
|
||||
def : Pat<(TF_IdentityOp $op), (replaceWithValue $op)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Matmul op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// TODO(hinsu): Lower matmul ops with transpose attributes.
|
||||
def : Pat<(TF_MatMulOp $a, $b, ConstBoolAttrFalse, ConstBoolAttrFalse),
|
||||
(HLO_DotOp $a, $b, (NullArrayAttr))>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Nullary op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
Loading…
Reference in New Issue
Block a user