Moved tf.SparseMatMul lowering code
The code should be in lower_tf.cc since it lowers from TF to TF and not to HLO. PiperOrigin-RevId: 323017854 Change-Id: Id93e99b65bc38266f5612fb1cacff1e250699c81
This commit is contained in:
parent
06eb028030
commit
40d01ae6fd
@ -344,12 +344,56 @@ class LowerPackOp : public OpRewritePattern<TF::PackOp> {
|
||||
}
|
||||
};
|
||||
|
||||
// Lowers `TF::SparseMatMulOp` to `TF::MatMulOp`, ignoring the sparseness hints,
|
||||
// since we currently don't have an implementation that can use this
|
||||
// information. Adds appropriate casts where necessary to align element types
|
||||
// of operands and result for `TF::MatMulOp`.
|
||||
class LowerSparseMatMulOp : public OpRewritePattern<TF::SparseMatMulOp> {
|
||||
public:
|
||||
using OpRewritePattern<TF::SparseMatMulOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(TF::SparseMatMulOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Result type must be f32 for applying the pattern (currently this is
|
||||
// required by the op anyway but this might change).
|
||||
if (!op.product().getType().cast<TensorType>().getElementType().isF32()) {
|
||||
return failure();
|
||||
}
|
||||
MLIRContext *context = rewriter.getContext();
|
||||
llvm::SmallVector<Value, 2> operands{op.a(), op.b()};
|
||||
for (Value &operand : operands) {
|
||||
TensorType tensor_type = operand.getType().cast<TensorType>();
|
||||
Type element_type = tensor_type.getElementType();
|
||||
if (element_type.isF32()) continue;
|
||||
// Element type can either be f32 or bf16 for `SparseMatMulOp` so it
|
||||
// must be bf16 here.
|
||||
assert(element_type.isBF16());
|
||||
Type tensor_type_f32;
|
||||
if (tensor_type.hasRank()) {
|
||||
tensor_type_f32 = RankedTensorType::get(tensor_type.getShape(),
|
||||
FloatType::getF32(context));
|
||||
} else {
|
||||
tensor_type_f32 = UnrankedTensorType::get(FloatType::getF32(context));
|
||||
}
|
||||
// Add cast to f32 to conform with element type of result.
|
||||
operand =
|
||||
rewriter.create<TF::CastOp>(op.getLoc(), tensor_type_f32, operand);
|
||||
}
|
||||
Value result = rewriter.create<TF::MatMulOp>(
|
||||
op.getLoc(), op.product().getType(), operands[0], operands[1],
|
||||
op.transpose_a(), op.transpose_b());
|
||||
|
||||
rewriter.replaceOp(op, {result});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void PopulateLoweringTFPatterns(MLIRContext *context,
|
||||
OwningRewritePatternList *patterns) {
|
||||
patterns->insert<LowerAddNOp, LowerDynamicStitchOp, LowerInvertPermutationOp,
|
||||
LowerPackOp>(context);
|
||||
LowerPackOp, LowerSparseMatMulOp>(context);
|
||||
populateWithGenerated(context, patterns);
|
||||
}
|
||||
|
||||
|
@ -5400,50 +5400,6 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
|
||||
}
|
||||
};
|
||||
|
||||
// Converts `TF::SparseMatMulOp` to `TF::MatMulOp`, ignoring the sparseness
|
||||
// hints, since we currently don't have an implementation that can use this
|
||||
// information. Adds appropriate casts where necessary to align element types
|
||||
// of operands and result for `TF::MatMulOp`.
|
||||
class ConvertSparseMatMulOp : public OpRewritePattern<TF::SparseMatMulOp> {
|
||||
public:
|
||||
using OpRewritePattern<TF::SparseMatMulOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(TF::SparseMatMulOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Result type must be f32 for applying the pattern (currently this is
|
||||
// required by the op anyway but this might change).
|
||||
if (!op.product().getType().cast<TensorType>().getElementType().isF32()) {
|
||||
return failure();
|
||||
}
|
||||
MLIRContext *context = rewriter.getContext();
|
||||
llvm::SmallVector<Value, 2> operands{op.a(), op.b()};
|
||||
for (Value &operand : operands) {
|
||||
TensorType tensor_type = operand.getType().cast<TensorType>();
|
||||
Type element_type = tensor_type.getElementType();
|
||||
if (element_type.isF32()) continue;
|
||||
// Element type can either be f32 or bf16 for `SparseMatMulOp` so it
|
||||
// must be bf16 here.
|
||||
assert(element_type.isBF16());
|
||||
Type tensor_type_f32;
|
||||
if (tensor_type.hasRank()) {
|
||||
tensor_type_f32 = RankedTensorType::get(tensor_type.getShape(),
|
||||
FloatType::getF32(context));
|
||||
} else {
|
||||
tensor_type_f32 = UnrankedTensorType::get(FloatType::getF32(context));
|
||||
}
|
||||
// Add cast to f32 to conform with element type of result.
|
||||
operand =
|
||||
rewriter.create<TF::CastOp>(op.getLoc(), tensor_type_f32, operand);
|
||||
}
|
||||
Value result = rewriter.create<TF::MatMulOp>(
|
||||
op.getLoc(), op.product().getType(), operands[0], operands[1],
|
||||
op.transpose_a(), op.transpose_b());
|
||||
|
||||
rewriter.replaceOp(op, {result});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Emits debug information which includes the number of ops of each type which
|
||||
// failed to legalize.
|
||||
void EmitLegalizationErrors(Operation *op,
|
||||
@ -5533,11 +5489,10 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion,
|
||||
ConvertDynamicRangeOp, ConvertRangeOp, ConvertSelectV2Op,
|
||||
ConvertSigmoidOp, ConvertShapeOp, ConvertSizeOp,
|
||||
ConvertSoftmaxOp<TF::LogSoftmaxOp, true>,
|
||||
ConvertSoftmaxOp<TF::SoftmaxOp, false>, ConvertSparseMatMulOp,
|
||||
ConvertSplitOp, ConvertSplitVOp, ConvertStridedSliceOp,
|
||||
ConvertStridedSliceGradOp, ConvertSumOp, ConvertTensorScatterUpdateOp,
|
||||
ConvertTileOp, ConvertTopKV2Op, ConvertUnpackOp,
|
||||
ConvertUnsortedSegmentMaxOp, ConvertUnsortedSegmentMinOp,
|
||||
ConvertSoftmaxOp<TF::SoftmaxOp, false>, ConvertSplitOp, ConvertSplitVOp,
|
||||
ConvertStridedSliceOp, ConvertStridedSliceGradOp, ConvertSumOp,
|
||||
ConvertTensorScatterUpdateOp, ConvertTileOp, ConvertTopKV2Op,
|
||||
ConvertUnpackOp, ConvertUnsortedSegmentMaxOp, ConvertUnsortedSegmentMinOp,
|
||||
ConvertUnsortedSegmentProdOp, ConvertUnsortedSegmentSumOp,
|
||||
ConvertRandomShuffleOp, ConvertXlaShardingOp,
|
||||
ConvertXlaDynamicUpdateSliceOp>(op->getContext());
|
||||
|
Loading…
Reference in New Issue
Block a user