Add XLA lowering pattern for TensorFlow MatMul op
This only handles the simplest case of no transpose. MNIST model only uses this case. PiperOrigin-RevId: 266990244
This commit is contained in:
parent
99f8d44812
commit
17099181cd
@ -165,6 +165,18 @@ func @const() -> tensor<2xi32> {
|
|||||||
return %0: tensor<2xi32>
|
return %0: tensor<2xi32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Matmul op legalizations.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// CHECK-LABEL: matmul_notranspose
|
||||||
|
func @matmul_notranspose(%arg0: tensor<5x7xf32>, %arg1: tensor<7x11xf32>) -> tensor<5x11xf32> {
|
||||||
|
// CHECK: "xla_hlo.dot"(%arg0, %arg1)
|
||||||
|
%0 = "tf.MatMul"(%arg0, %arg1) {transpose_a = false, transpose_b = false} : (tensor<5x7xf32>, tensor<7x11xf32>) -> tensor<5x11xf32>
|
||||||
|
|
||||||
|
return %0 : tensor<5x11xf32>
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Relu op legalizations.
|
// Relu op legalizations.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -20,6 +20,7 @@ include "mlir/Dialect/StandardOps/Ops.td"
|
|||||||
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
|
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
|
||||||
include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td"
|
include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td"
|
||||||
|
|
||||||
|
def NullArrayAttr : NativeCodeCall<"ArrayAttr()">;
|
||||||
def NullElementsAttr : NativeCodeCall<"ElementsAttr()">;
|
def NullElementsAttr : NativeCodeCall<"ElementsAttr()">;
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -92,6 +93,14 @@ foreach fromToBinPair = [[TF_AddOp, HLO_AddOp],
|
|||||||
|
|
||||||
def : Pat<(TF_IdentityOp $op), (replaceWithValue $op)>;
|
def : Pat<(TF_IdentityOp $op), (replaceWithValue $op)>;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Matmul op patterns.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// TODO(hinsu): Lower matmul ops with transpose attributes.
|
||||||
|
def : Pat<(TF_MatMulOp $a, $b, ConstBoolAttrFalse, ConstBoolAttrFalse),
|
||||||
|
(HLO_DotOp $a, $b, (NullArrayAttr))>;
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Nullary op patterns.
|
// Nullary op patterns.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
Loading…
Reference in New Issue
Block a user