diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc index c0de6f557ab..d67739a739b 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc @@ -344,12 +344,56 @@ class LowerPackOp : public OpRewritePattern { } }; +// Lowers `TF::SparseMatMulOp` to `TF::MatMulOp`, ignoring the sparseness hints, +// since we currently don't have an implementation that can use this +// information. Adds appropriate casts where necessary to align element types +// of operands and result for `TF::MatMulOp`. +class LowerSparseMatMulOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::SparseMatMulOp op, + PatternRewriter &rewriter) const override { + // Result type must be f32 for applying the pattern (currently this is + // required by the op anyway but this might change). + if (!op.product().getType().cast().getElementType().isF32()) { + return failure(); + } + MLIRContext *context = rewriter.getContext(); + llvm::SmallVector operands{op.a(), op.b()}; + for (Value &operand : operands) { + TensorType tensor_type = operand.getType().cast(); + Type element_type = tensor_type.getElementType(); + if (element_type.isF32()) continue; + // Element type can either be f32 or bf16 for `SparseMatMulOp` so it + // must be bf16 here. + assert(element_type.isBF16()); + Type tensor_type_f32; + if (tensor_type.hasRank()) { + tensor_type_f32 = RankedTensorType::get(tensor_type.getShape(), + FloatType::getF32(context)); + } else { + tensor_type_f32 = UnrankedTensorType::get(FloatType::getF32(context)); + } + // Add cast to f32 to conform with element type of result. + operand = + rewriter.create(op.getLoc(), tensor_type_f32, operand); + } + Value result = rewriter.create( + op.getLoc(), op.product().getType(), operands[0], operands[1], + op.transpose_a(), op.transpose_b()); + + rewriter.replaceOp(op, {result}); + return success(); + } +}; + } // namespace void PopulateLoweringTFPatterns(MLIRContext *context, OwningRewritePatternList *patterns) { patterns->insert(context); + LowerPackOp, LowerSparseMatMulOp>(context); populateWithGenerated(context, patterns); } diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index b724c1b08e0..b1e74e354fe 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -5400,50 +5400,6 @@ class ConvertQrOp : public OpRewritePattern { } }; -// Converts `TF::SparseMatMulOp` to `TF::MatMulOp`, ignoring the sparseness -// hints, since we currently don't have an implementation that can use this -// information. Adds appropriate casts where necessary to align element types -// of operands and result for `TF::MatMulOp`. -class ConvertSparseMatMulOp : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TF::SparseMatMulOp op, - PatternRewriter &rewriter) const override { - // Result type must be f32 for applying the pattern (currently this is - // required by the op anyway but this might change). - if (!op.product().getType().cast().getElementType().isF32()) { - return failure(); - } - MLIRContext *context = rewriter.getContext(); - llvm::SmallVector operands{op.a(), op.b()}; - for (Value &operand : operands) { - TensorType tensor_type = operand.getType().cast(); - Type element_type = tensor_type.getElementType(); - if (element_type.isF32()) continue; - // Element type can either be f32 or bf16 for `SparseMatMulOp` so it - // must be bf16 here. - assert(element_type.isBF16()); - Type tensor_type_f32; - if (tensor_type.hasRank()) { - tensor_type_f32 = RankedTensorType::get(tensor_type.getShape(), - FloatType::getF32(context)); - } else { - tensor_type_f32 = UnrankedTensorType::get(FloatType::getF32(context)); - } - // Add cast to f32 to conform with element type of result. - operand = - rewriter.create(op.getLoc(), tensor_type_f32, operand); - } - Value result = rewriter.create( - op.getLoc(), op.product().getType(), operands[0], operands[1], - op.transpose_a(), op.transpose_b()); - - rewriter.replaceOp(op, {result}); - return success(); - } -}; - // Emits debug information which includes the number of ops of each type which // failed to legalize. void EmitLegalizationErrors(Operation *op, @@ -5533,11 +5489,10 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion, ConvertDynamicRangeOp, ConvertRangeOp, ConvertSelectV2Op, ConvertSigmoidOp, ConvertShapeOp, ConvertSizeOp, ConvertSoftmaxOp, - ConvertSoftmaxOp, ConvertSparseMatMulOp, - ConvertSplitOp, ConvertSplitVOp, ConvertStridedSliceOp, - ConvertStridedSliceGradOp, ConvertSumOp, ConvertTensorScatterUpdateOp, - ConvertTileOp, ConvertTopKV2Op, ConvertUnpackOp, - ConvertUnsortedSegmentMaxOp, ConvertUnsortedSegmentMinOp, + ConvertSoftmaxOp, ConvertSplitOp, ConvertSplitVOp, + ConvertStridedSliceOp, ConvertStridedSliceGradOp, ConvertSumOp, + ConvertTensorScatterUpdateOp, ConvertTileOp, ConvertTopKV2Op, + ConvertUnpackOp, ConvertUnsortedSegmentMaxOp, ConvertUnsortedSegmentMinOp, ConvertUnsortedSegmentProdOp, ConvertUnsortedSegmentSumOp, ConvertRandomShuffleOp, ConvertXlaShardingOp, ConvertXlaDynamicUpdateSliceOp>(op->getContext());