[MLIR][KernelGen] Lower tf.Acos
to LMHLO.
- Add ranked code generation for `mhlo.compare/select` - Add bufferization for `tensor_cast` - Add lowerings for `Atan2Op` PiperOrigin-RevId: 332407734 Change-Id: Ib3e0b8e57cd8660ed8e8aa85fc9bd7047f31d152
This commit is contained in:
parent
6c102af171
commit
5a701ae4bb
@ -40,6 +40,7 @@ using HloToLhloOp = typename HloToLhloOpImpl<HloOpTy>::Type;
|
||||
MAP_HLO_TO_LHLO(AbsOp);
|
||||
MAP_HLO_TO_LHLO(AddOp);
|
||||
MAP_HLO_TO_LHLO(AndOp);
|
||||
MAP_HLO_TO_LHLO(Atan2Op);
|
||||
MAP_HLO_TO_LHLO(BroadcastInDimOp);
|
||||
MAP_HLO_TO_LHLO(CeilOp);
|
||||
MAP_HLO_TO_LHLO(ConstOp);
|
||||
|
@ -488,6 +488,7 @@ void populateHLOToLHLOConversionPattern(
|
||||
HloToLhloOpConverter<mhlo::AbsOp>,
|
||||
HloToLhloOpConverter<mhlo::AddOp>,
|
||||
HloToLhloOpConverter<mhlo::AndOp>,
|
||||
HloToLhloOpConverter<mhlo::Atan2Op>,
|
||||
HloToLhloOpConverter<mhlo::BroadcastInDimOp>,
|
||||
HloToLhloOpConverter<mhlo::CeilOp>,
|
||||
HloToLhloOpConverter<mhlo::CompareOp>,
|
||||
|
@ -145,6 +145,8 @@ struct TransformUnrankedHloPass
|
||||
MAP_CHLO_OPERATION_CWISE_UNARY(ADD_LEGAL_CHLO, ;);
|
||||
#undef ADD_LEGAL_MHLO
|
||||
#undef ADD_LEGAL_CHLO
|
||||
AddLegalOpOnRankedTensor<mhlo::CompareOp>(&target);
|
||||
AddLegalOpOnRankedTensor<mhlo::SelectOp>(&target);
|
||||
|
||||
// Populate rewrite patterns.
|
||||
OwningRewritePatternList patterns;
|
||||
@ -168,7 +170,9 @@ void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
|
||||
patterns->insert<
|
||||
MAP_XLA_OPERATION_CWISE_UNARY(MAP_UNARY, COMMA),
|
||||
MAP_XLA_OPERATION_CWISE_BINARY(MAP_BINARY, COMMA),
|
||||
MAP_CHLO_OPERATION_CWISE_UNARY(MAP_CHLO_UNARY, COMMA)>(context);
|
||||
MAP_CHLO_OPERATION_CWISE_UNARY(MAP_CHLO_UNARY, COMMA),
|
||||
ElementwiseOpConversion<mhlo::CompareOp>,
|
||||
ElementwiseOpConversion<mhlo::SelectOp>>(context);
|
||||
// clang-format on
|
||||
#undef MAP_UNARY
|
||||
#undef MAP_BINARY
|
||||
|
@ -150,15 +150,37 @@ class ExtractElementOpConversion
|
||||
}
|
||||
};
|
||||
|
||||
class TensorCastOpConverter
|
||||
: public BufferAssignmentOpConversionPattern<TensorCastOp> {
|
||||
public:
|
||||
using BufferAssignmentOpConversionPattern<
|
||||
TensorCastOp>::BufferAssignmentOpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
TensorCastOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
auto tensor_ty = op.getType().dyn_cast<RankedTensorType>();
|
||||
if (!tensor_ty) return failure();
|
||||
|
||||
Value arg = operands.front();
|
||||
auto arg_ty = arg.getType().dyn_cast<MemRefType>();
|
||||
if (!arg_ty) return failure();
|
||||
|
||||
auto result_ty = converter->convertType(tensor_ty);
|
||||
rewriter.replaceOpWithNewOp<MemRefCastOp>(op, arg, result_ty);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void populateStandardBufferizePattern(MLIRContext *context,
|
||||
BufferAssignmentTypeConverter *converter,
|
||||
OwningRewritePatternList *patterns) {
|
||||
patterns
|
||||
->insert<ExtractElementOpConversion, TensorFromElementsOpConverter,
|
||||
DynamicTensorFromElementsOpConverter, TensorLoadOpConversion>(
|
||||
context, converter);
|
||||
patterns->insert<ExtractElementOpConversion, TensorFromElementsOpConverter,
|
||||
DynamicTensorFromElementsOpConverter, TensorLoadOpConversion,
|
||||
TensorCastOpConverter>(context, converter);
|
||||
}
|
||||
|
||||
} // namespace transforms
|
||||
|
@ -79,7 +79,8 @@ struct BufferizePass : public BufferizePassBase<BufferizePass> {
|
||||
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
|
||||
target.addIllegalDialect<mhlo::MhloDialect>();
|
||||
target.addIllegalOp<DynamicTensorFromElementsOp, ExtractElementOp,
|
||||
TensorFromElementsOp, TensorLoadOp, YieldOp>();
|
||||
TensorFromElementsOp, TensorLoadOp, YieldOp,
|
||||
TensorCastOp>();
|
||||
target.addDynamicallyLegalOp<TensorStoreOp>([&](TensorStoreOp op) {
|
||||
return !op.tensor().getType().isa<UnrankedTensorType>();
|
||||
});
|
||||
|
Loading…
x
Reference in New Issue
Block a user