Fix legalization of SparseMatMul to MatMul

1) Previously, SparseMatMul legalization didn't work correctly for operands with different
element types which is fixed now.
2) Forbidden element types are now correctly checked.
3) Added more unit tests for SparseMatMul legalization.

PiperOrigin-RevId: 320696400
Change-Id: I7376b2cf06027ab6153da0831e2f7d39498ee210
This commit is contained in:
Michael Gester 2020-07-10 16:19:55 -07:00 committed by TensorFlower Gardener
parent 38f980d3d7
commit d0ae0f8c70
4 changed files with 87 additions and 19 deletions

View File

@ -524,8 +524,8 @@ are sparse matrices.
}];
let arguments = (ins
TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$a,
TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$b,
TensorOf<[BF16, F32]>:$a,
TensorOf<[BF16, F32]>:$b,
DefaultValuedAttr<BoolAttr, "true">:$a_is_sparse,
DefaultValuedAttr<BoolAttr, "false">:$b_is_sparse,
@ -535,7 +535,7 @@ are sparse matrices.
);
let results = (outs
TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$product
TensorOf<[F32]>:$product
);
TF_DerivedOperandTypeAttr Ta = TF_DerivedOperandTypeAttr<0>;

View File

@ -956,6 +956,41 @@ func @test_sparse_mat_mul(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xf32>) -> ten
return %0: tensor<3x5xf32>
}
// SparseMatMul where one operand needs to be transposed and the other one not.
//
// CHECK-LABEL: func @test_sparse_mat_mul_with_transpose
// CHECK-SAME: %[[ARG0:.*]]: tensor<3x4xf32>
// CHECK-SAME: %[[ARG1:.*]]: tensor<5x4xf32>
// CHECK-SAME: -> tensor<3x5xf32>
// CHECK: %[[TRANSPOSE:.*]] = "mhlo.transpose"(%[[ARG1]])
// CHECK-SAME: permutation = dense<[1, 0]>
// CHECK-SAME: -> tensor<4x5xf32>
// CHECK: %[[RESULT:.*]] = "mhlo.dot"(%[[ARG0]], %[[TRANSPOSE]])
// CHECK-SAME: -> tensor<3x5xf32>
// CHECK: return %[[RESULT]]
// CHECK: }
func @test_sparse_mat_mul_with_transpose(%arg0: tensor<3x4xf32>, %arg1: tensor<5x4xf32>) -> tensor<3x5xf32> {
%0 = "tf.SparseMatMul"(%arg0, %arg1) {a_is_sparse = true, b_is_sparse = false, transpose_a = false, transpose_b = true} : (tensor<3x4xf32>, tensor<5x4xf32>) -> tensor<3x5xf32>
return %0: tensor<3x5xf32>
}
// SparseMatMul where one operand needs to be casted and the other one not.
//
// CHECK-LABEL: func @test_sparse_mat_mul_with_cast
// CHECK-SAME: %[[ARG0:.*]]: tensor<3x4xf32>
// CHECK-SAME: %[[ARG1:.*]]: tensor<4x5xbf16>
// CHECK-SAME: -> tensor<3x5xf32>
// CHECK: %[[CAST:.*]] = "mhlo.convert"(%[[ARG1]])
// CHECK-SAME: -> tensor<4x5xf32>
// CHECK: %[[RESULT:.*]] = "mhlo.dot"(%[[ARG0]], %[[CAST]])
// CHECK-SAME: -> tensor<3x5xf32>
// CHECK: return %[[RESULT]]
// CHECK: }
func @test_sparse_mat_mul_with_cast(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xbf16>) -> tensor<3x5xf32> {
%0 = "tf.SparseMatMul"(%arg0, %arg1) {a_is_sparse = true, b_is_sparse = false, transpose_a = false, transpose_b = false} : (tensor<3x4xf32>, tensor<4x5xbf16>) -> tensor<3x5xf32>
return %0: tensor<3x5xf32>
}
//===----------------------------------------------------------------------===//
// MatrixBandPart op legalizations.
//===----------------------------------------------------------------------===//

View File

@ -5358,6 +5358,50 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
}
};
// Converts `TF::SparseMatMulOp` to `TF::MatMulOp`, ignoring the sparseness
// hints, since we currently don't have an implementation that can use this
// information. Adds appropriate casts where necessary to align element types
// of operands and result for `TF::MatMulOp`.
class ConvertSparseMatMulOp : public OpRewritePattern<TF::SparseMatMulOp> {
public:
using OpRewritePattern<TF::SparseMatMulOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TF::SparseMatMulOp op,
PatternRewriter &rewriter) const override {
// Result type must be f32 for applying the pattern (currently this is
// required by the op anyway but this might change).
if (!op.product().getType().cast<TensorType>().getElementType().isF32()) {
return failure();
}
MLIRContext *context = rewriter.getContext();
llvm::SmallVector<Value, 2> operands{op.a(), op.b()};
for (Value &operand : operands) {
TensorType tensor_type = operand.getType().cast<TensorType>();
Type element_type = tensor_type.getElementType();
if (element_type.isF32()) continue;
// Element type can either be f32 or bf16 for `SparseMatMulOp` so it
// must be bf16 here.
assert(element_type.isBF16());
Type tensor_type_f32;
if (tensor_type.hasRank()) {
tensor_type_f32 = RankedTensorType::get(tensor_type.getShape(),
FloatType::getF32(context));
} else {
tensor_type_f32 = UnrankedTensorType::get(FloatType::getF32(context));
}
// Add cast to f32 to conform with element type of result.
operand =
rewriter.create<TF::CastOp>(op.getLoc(), tensor_type_f32, operand);
}
Value result = rewriter.create<TF::MatMulOp>(
op.getLoc(), op.product().getType(), operands[0], operands[1],
op.transpose_a(), op.transpose_b());
rewriter.replaceOp(op, {result});
return success();
}
};
// Emits debug information which includes the number of ops of each type which
// failed to legalize.
void EmitLegalizationErrors(Operation *op,
@ -5446,10 +5490,11 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion,
ConvertDynamicRangeOp, ConvertRangeOp, ConvertSelectV2Op,
ConvertSigmoidOp, ConvertShapeOp, ConvertSizeOp,
ConvertSoftmaxOp<TF::LogSoftmaxOp, true>,
ConvertSoftmaxOp<TF::SoftmaxOp, false>, ConvertSplitOp, ConvertSplitVOp,
ConvertStridedSliceOp, ConvertStridedSliceGradOp, ConvertSumOp,
ConvertTensorScatterUpdateOp, ConvertTileOp, ConvertTopKV2Op,
ConvertUnpackOp, ConvertUnsortedSegmentMaxOp, ConvertUnsortedSegmentMinOp,
ConvertSoftmaxOp<TF::SoftmaxOp, false>, ConvertSparseMatMulOp,
ConvertSplitOp, ConvertSplitVOp, ConvertStridedSliceOp,
ConvertStridedSliceGradOp, ConvertSumOp, ConvertTensorScatterUpdateOp,
ConvertTileOp, ConvertTopKV2Op, ConvertUnpackOp,
ConvertUnsortedSegmentMaxOp, ConvertUnsortedSegmentMinOp,
ConvertUnsortedSegmentProdOp, ConvertUnsortedSegmentSumOp,
ConvertRandomShuffleOp, ConvertXlaShardingOp,
ConvertXlaDynamicUpdateSliceOp>(op->getContext());

View File

@ -339,18 +339,6 @@ def : Pat<(TF_MatMulOp $a, $b, $transpose_a, $transpose_b),
(TF_TransposeOp $b, (TF_ConstOp (Get2DTransposePerm $transpose_b))),
/*precision_config=*/(NullArrayAttr))>;
//===----------------------------------------------------------------------===//
// SparseMatMul op patterns.
//===----------------------------------------------------------------------===//
// Ignores the sparseness hints and translates tf.SparseMatMul to tf.MatMul
// until we will have an implementation that can use the information.
def SparseMatMulToMatMul : Pat<(TF_SparseMatMulOp $a, $b, $a_sparse, $b_sparse,
$transpose_a, $transpose_b),
(TF_MatMulOp $a, $b, $transpose_a,
$transpose_b)>;
//===----------------------------------------------------------------------===//
// MatrixBandPart op pattern.
//===----------------------------------------------------------------------===//