Moved tf.SparseMatMul lowering code

The code should be in lower_tf.cc since it lowers from TF to TF and not to HLO.

PiperOrigin-RevId: 323017854
Change-Id: Id93e99b65bc38266f5612fb1cacff1e250699c81
This commit is contained in:
Michael Gester 2020-07-24 10:08:36 -07:00 committed by TensorFlower Gardener
parent 06eb028030
commit 40d01ae6fd
2 changed files with 49 additions and 50 deletions

View File

@ -344,12 +344,56 @@ class LowerPackOp : public OpRewritePattern<TF::PackOp> {
}
};
// Lowers `TF::SparseMatMulOp` to `TF::MatMulOp`, ignoring the sparseness hints,
// since we currently don't have an implementation that can use this
// information. Adds appropriate casts where necessary to align element types
// of operands and result for `TF::MatMulOp`.
class LowerSparseMatMulOp : public OpRewritePattern<TF::SparseMatMulOp> {
public:
using OpRewritePattern<TF::SparseMatMulOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TF::SparseMatMulOp op,
PatternRewriter &rewriter) const override {
// Result type must be f32 for applying the pattern (currently this is
// required by the op anyway but this might change).
if (!op.product().getType().cast<TensorType>().getElementType().isF32()) {
return failure();
}
MLIRContext *context = rewriter.getContext();
llvm::SmallVector<Value, 2> operands{op.a(), op.b()};
for (Value &operand : operands) {
TensorType tensor_type = operand.getType().cast<TensorType>();
Type element_type = tensor_type.getElementType();
if (element_type.isF32()) continue;
// Element type can either be f32 or bf16 for `SparseMatMulOp` so it
// must be bf16 here.
assert(element_type.isBF16());
Type tensor_type_f32;
if (tensor_type.hasRank()) {
tensor_type_f32 = RankedTensorType::get(tensor_type.getShape(),
FloatType::getF32(context));
} else {
tensor_type_f32 = UnrankedTensorType::get(FloatType::getF32(context));
}
// Add cast to f32 to conform with element type of result.
operand =
rewriter.create<TF::CastOp>(op.getLoc(), tensor_type_f32, operand);
}
Value result = rewriter.create<TF::MatMulOp>(
op.getLoc(), op.product().getType(), operands[0], operands[1],
op.transpose_a(), op.transpose_b());
rewriter.replaceOp(op, {result});
return success();
}
};
} // namespace
void PopulateLoweringTFPatterns(MLIRContext *context,
OwningRewritePatternList *patterns) {
patterns->insert<LowerAddNOp, LowerDynamicStitchOp, LowerInvertPermutationOp,
LowerPackOp>(context);
LowerPackOp, LowerSparseMatMulOp>(context);
populateWithGenerated(context, patterns);
}

View File

@ -5400,50 +5400,6 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
}
};
// Converts `TF::SparseMatMulOp` to `TF::MatMulOp`, ignoring the sparseness
// hints, since we currently don't have an implementation that can use this
// information. Adds appropriate casts where necessary to align element types
// of operands and result for `TF::MatMulOp`.
class ConvertSparseMatMulOp : public OpRewritePattern<TF::SparseMatMulOp> {
public:
using OpRewritePattern<TF::SparseMatMulOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TF::SparseMatMulOp op,
PatternRewriter &rewriter) const override {
// Result type must be f32 for applying the pattern (currently this is
// required by the op anyway but this might change).
if (!op.product().getType().cast<TensorType>().getElementType().isF32()) {
return failure();
}
MLIRContext *context = rewriter.getContext();
llvm::SmallVector<Value, 2> operands{op.a(), op.b()};
for (Value &operand : operands) {
TensorType tensor_type = operand.getType().cast<TensorType>();
Type element_type = tensor_type.getElementType();
if (element_type.isF32()) continue;
// Element type can either be f32 or bf16 for `SparseMatMulOp` so it
// must be bf16 here.
assert(element_type.isBF16());
Type tensor_type_f32;
if (tensor_type.hasRank()) {
tensor_type_f32 = RankedTensorType::get(tensor_type.getShape(),
FloatType::getF32(context));
} else {
tensor_type_f32 = UnrankedTensorType::get(FloatType::getF32(context));
}
// Add cast to f32 to conform with element type of result.
operand =
rewriter.create<TF::CastOp>(op.getLoc(), tensor_type_f32, operand);
}
Value result = rewriter.create<TF::MatMulOp>(
op.getLoc(), op.product().getType(), operands[0], operands[1],
op.transpose_a(), op.transpose_b());
rewriter.replaceOp(op, {result});
return success();
}
};
// Emits debug information which includes the number of ops of each type which
// failed to legalize.
void EmitLegalizationErrors(Operation *op,
@ -5533,11 +5489,10 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion,
ConvertDynamicRangeOp, ConvertRangeOp, ConvertSelectV2Op,
ConvertSigmoidOp, ConvertShapeOp, ConvertSizeOp,
ConvertSoftmaxOp<TF::LogSoftmaxOp, true>,
ConvertSoftmaxOp<TF::SoftmaxOp, false>, ConvertSparseMatMulOp,
ConvertSplitOp, ConvertSplitVOp, ConvertStridedSliceOp,
ConvertStridedSliceGradOp, ConvertSumOp, ConvertTensorScatterUpdateOp,
ConvertTileOp, ConvertTopKV2Op, ConvertUnpackOp,
ConvertUnsortedSegmentMaxOp, ConvertUnsortedSegmentMinOp,
ConvertSoftmaxOp<TF::SoftmaxOp, false>, ConvertSplitOp, ConvertSplitVOp,
ConvertStridedSliceOp, ConvertStridedSliceGradOp, ConvertSumOp,
ConvertTensorScatterUpdateOp, ConvertTileOp, ConvertTopKV2Op,
ConvertUnpackOp, ConvertUnsortedSegmentMaxOp, ConvertUnsortedSegmentMinOp,
ConvertUnsortedSegmentProdOp, ConvertUnsortedSegmentSumOp,
ConvertRandomShuffleOp, ConvertXlaShardingOp,
ConvertXlaDynamicUpdateSliceOp>(op->getContext());