Fix legalization of SparseMatMul to MatMul
1) Previously, SparseMatMul legalization didn't work correctly for operands with different element types which is fixed now. 2) Forbidden element types are now correctly checked. 3) Added more unit tests for SparseMatMul legalization. PiperOrigin-RevId: 320696400 Change-Id: I7376b2cf06027ab6153da0831e2f7d39498ee210
This commit is contained in:
parent
38f980d3d7
commit
d0ae0f8c70
@ -524,8 +524,8 @@ are sparse matrices.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$a,
|
||||
TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$b,
|
||||
TensorOf<[BF16, F32]>:$a,
|
||||
TensorOf<[BF16, F32]>:$b,
|
||||
|
||||
DefaultValuedAttr<BoolAttr, "true">:$a_is_sparse,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$b_is_sparse,
|
||||
@ -535,7 +535,7 @@ are sparse matrices.
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$product
|
||||
TensorOf<[F32]>:$product
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr Ta = TF_DerivedOperandTypeAttr<0>;
|
||||
|
@ -956,6 +956,41 @@ func @test_sparse_mat_mul(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xf32>) -> ten
|
||||
return %0: tensor<3x5xf32>
|
||||
}
|
||||
|
||||
// SparseMatMul where one operand needs to be transposed and the other one not.
|
||||
//
|
||||
// CHECK-LABEL: func @test_sparse_mat_mul_with_transpose
|
||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<3x4xf32>
|
||||
// CHECK-SAME: %[[ARG1:.*]]: tensor<5x4xf32>
|
||||
// CHECK-SAME: -> tensor<3x5xf32>
|
||||
// CHECK: %[[TRANSPOSE:.*]] = "mhlo.transpose"(%[[ARG1]])
|
||||
// CHECK-SAME: permutation = dense<[1, 0]>
|
||||
// CHECK-SAME: -> tensor<4x5xf32>
|
||||
// CHECK: %[[RESULT:.*]] = "mhlo.dot"(%[[ARG0]], %[[TRANSPOSE]])
|
||||
// CHECK-SAME: -> tensor<3x5xf32>
|
||||
// CHECK: return %[[RESULT]]
|
||||
// CHECK: }
|
||||
func @test_sparse_mat_mul_with_transpose(%arg0: tensor<3x4xf32>, %arg1: tensor<5x4xf32>) -> tensor<3x5xf32> {
|
||||
%0 = "tf.SparseMatMul"(%arg0, %arg1) {a_is_sparse = true, b_is_sparse = false, transpose_a = false, transpose_b = true} : (tensor<3x4xf32>, tensor<5x4xf32>) -> tensor<3x5xf32>
|
||||
return %0: tensor<3x5xf32>
|
||||
}
|
||||
|
||||
// SparseMatMul where one operand needs to be casted and the other one not.
|
||||
//
|
||||
// CHECK-LABEL: func @test_sparse_mat_mul_with_cast
|
||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<3x4xf32>
|
||||
// CHECK-SAME: %[[ARG1:.*]]: tensor<4x5xbf16>
|
||||
// CHECK-SAME: -> tensor<3x5xf32>
|
||||
// CHECK: %[[CAST:.*]] = "mhlo.convert"(%[[ARG1]])
|
||||
// CHECK-SAME: -> tensor<4x5xf32>
|
||||
// CHECK: %[[RESULT:.*]] = "mhlo.dot"(%[[ARG0]], %[[CAST]])
|
||||
// CHECK-SAME: -> tensor<3x5xf32>
|
||||
// CHECK: return %[[RESULT]]
|
||||
// CHECK: }
|
||||
func @test_sparse_mat_mul_with_cast(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xbf16>) -> tensor<3x5xf32> {
|
||||
%0 = "tf.SparseMatMul"(%arg0, %arg1) {a_is_sparse = true, b_is_sparse = false, transpose_a = false, transpose_b = false} : (tensor<3x4xf32>, tensor<4x5xbf16>) -> tensor<3x5xf32>
|
||||
return %0: tensor<3x5xf32>
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MatrixBandPart op legalizations.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -5358,6 +5358,50 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
|
||||
}
|
||||
};
|
||||
|
||||
// Converts `TF::SparseMatMulOp` to `TF::MatMulOp`, ignoring the sparseness
|
||||
// hints, since we currently don't have an implementation that can use this
|
||||
// information. Adds appropriate casts where necessary to align element types
|
||||
// of operands and result for `TF::MatMulOp`.
|
||||
class ConvertSparseMatMulOp : public OpRewritePattern<TF::SparseMatMulOp> {
|
||||
public:
|
||||
using OpRewritePattern<TF::SparseMatMulOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(TF::SparseMatMulOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Result type must be f32 for applying the pattern (currently this is
|
||||
// required by the op anyway but this might change).
|
||||
if (!op.product().getType().cast<TensorType>().getElementType().isF32()) {
|
||||
return failure();
|
||||
}
|
||||
MLIRContext *context = rewriter.getContext();
|
||||
llvm::SmallVector<Value, 2> operands{op.a(), op.b()};
|
||||
for (Value &operand : operands) {
|
||||
TensorType tensor_type = operand.getType().cast<TensorType>();
|
||||
Type element_type = tensor_type.getElementType();
|
||||
if (element_type.isF32()) continue;
|
||||
// Element type can either be f32 or bf16 for `SparseMatMulOp` so it
|
||||
// must be bf16 here.
|
||||
assert(element_type.isBF16());
|
||||
Type tensor_type_f32;
|
||||
if (tensor_type.hasRank()) {
|
||||
tensor_type_f32 = RankedTensorType::get(tensor_type.getShape(),
|
||||
FloatType::getF32(context));
|
||||
} else {
|
||||
tensor_type_f32 = UnrankedTensorType::get(FloatType::getF32(context));
|
||||
}
|
||||
// Add cast to f32 to conform with element type of result.
|
||||
operand =
|
||||
rewriter.create<TF::CastOp>(op.getLoc(), tensor_type_f32, operand);
|
||||
}
|
||||
Value result = rewriter.create<TF::MatMulOp>(
|
||||
op.getLoc(), op.product().getType(), operands[0], operands[1],
|
||||
op.transpose_a(), op.transpose_b());
|
||||
|
||||
rewriter.replaceOp(op, {result});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Emits debug information which includes the number of ops of each type which
|
||||
// failed to legalize.
|
||||
void EmitLegalizationErrors(Operation *op,
|
||||
@ -5446,10 +5490,11 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion,
|
||||
ConvertDynamicRangeOp, ConvertRangeOp, ConvertSelectV2Op,
|
||||
ConvertSigmoidOp, ConvertShapeOp, ConvertSizeOp,
|
||||
ConvertSoftmaxOp<TF::LogSoftmaxOp, true>,
|
||||
ConvertSoftmaxOp<TF::SoftmaxOp, false>, ConvertSplitOp, ConvertSplitVOp,
|
||||
ConvertStridedSliceOp, ConvertStridedSliceGradOp, ConvertSumOp,
|
||||
ConvertTensorScatterUpdateOp, ConvertTileOp, ConvertTopKV2Op,
|
||||
ConvertUnpackOp, ConvertUnsortedSegmentMaxOp, ConvertUnsortedSegmentMinOp,
|
||||
ConvertSoftmaxOp<TF::SoftmaxOp, false>, ConvertSparseMatMulOp,
|
||||
ConvertSplitOp, ConvertSplitVOp, ConvertStridedSliceOp,
|
||||
ConvertStridedSliceGradOp, ConvertSumOp, ConvertTensorScatterUpdateOp,
|
||||
ConvertTileOp, ConvertTopKV2Op, ConvertUnpackOp,
|
||||
ConvertUnsortedSegmentMaxOp, ConvertUnsortedSegmentMinOp,
|
||||
ConvertUnsortedSegmentProdOp, ConvertUnsortedSegmentSumOp,
|
||||
ConvertRandomShuffleOp, ConvertXlaShardingOp,
|
||||
ConvertXlaDynamicUpdateSliceOp>(op->getContext());
|
||||
|
@ -339,18 +339,6 @@ def : Pat<(TF_MatMulOp $a, $b, $transpose_a, $transpose_b),
|
||||
(TF_TransposeOp $b, (TF_ConstOp (Get2DTransposePerm $transpose_b))),
|
||||
/*precision_config=*/(NullArrayAttr))>;
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SparseMatMul op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Ignores the sparseness hints and translates tf.SparseMatMul to tf.MatMul
|
||||
// until we will have an implementation that can use the information.
|
||||
def SparseMatMulToMatMul : Pat<(TF_SparseMatMulOp $a, $b, $a_sparse, $b_sparse,
|
||||
$transpose_a, $transpose_b),
|
||||
(TF_MatMulOp $a, $b, $transpose_a,
|
||||
$transpose_b)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MatrixBandPart op pattern.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
Loading…
Reference in New Issue
Block a user