[MLIR][KernelGen] Lower tf.Acos to LMHLO.

- Add ranked code generation for `mhlo.compare/select`
- Add bufferization for `tensor_cast`
- Add lowerings for `Atan2Op`

PiperOrigin-RevId: 332407734
Change-Id: Ib3e0b8e57cd8660ed8e8aa85fc9bd7047f31d152
This commit is contained in:
A. Unique TensorFlower 2020-09-18 01:39:48 -07:00 committed by TensorFlower Gardener
parent 6c102af171
commit 5a701ae4bb
5 changed files with 35 additions and 6 deletions

View File

@ -40,6 +40,7 @@ using HloToLhloOp = typename HloToLhloOpImpl<HloOpTy>::Type;
MAP_HLO_TO_LHLO(AbsOp);
MAP_HLO_TO_LHLO(AddOp);
MAP_HLO_TO_LHLO(AndOp);
MAP_HLO_TO_LHLO(Atan2Op);
MAP_HLO_TO_LHLO(BroadcastInDimOp);
MAP_HLO_TO_LHLO(CeilOp);
MAP_HLO_TO_LHLO(ConstOp);

View File

@ -488,6 +488,7 @@ void populateHLOToLHLOConversionPattern(
HloToLhloOpConverter<mhlo::AbsOp>,
HloToLhloOpConverter<mhlo::AddOp>,
HloToLhloOpConverter<mhlo::AndOp>,
HloToLhloOpConverter<mhlo::Atan2Op>,
HloToLhloOpConverter<mhlo::BroadcastInDimOp>,
HloToLhloOpConverter<mhlo::CeilOp>,
HloToLhloOpConverter<mhlo::CompareOp>,

View File

@ -145,6 +145,8 @@ struct TransformUnrankedHloPass
MAP_CHLO_OPERATION_CWISE_UNARY(ADD_LEGAL_CHLO, ;);
#undef ADD_LEGAL_MHLO
#undef ADD_LEGAL_CHLO
AddLegalOpOnRankedTensor<mhlo::CompareOp>(&target);
AddLegalOpOnRankedTensor<mhlo::SelectOp>(&target);
// Populate rewrite patterns.
OwningRewritePatternList patterns;
@ -168,7 +170,9 @@ void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
patterns->insert<
MAP_XLA_OPERATION_CWISE_UNARY(MAP_UNARY, COMMA),
MAP_XLA_OPERATION_CWISE_BINARY(MAP_BINARY, COMMA),
MAP_CHLO_OPERATION_CWISE_UNARY(MAP_CHLO_UNARY, COMMA)>(context);
MAP_CHLO_OPERATION_CWISE_UNARY(MAP_CHLO_UNARY, COMMA),
ElementwiseOpConversion<mhlo::CompareOp>,
ElementwiseOpConversion<mhlo::SelectOp>>(context);
// clang-format on
#undef MAP_UNARY
#undef MAP_BINARY

View File

@ -150,15 +150,37 @@ class ExtractElementOpConversion
}
};
class TensorCastOpConverter
: public BufferAssignmentOpConversionPattern<TensorCastOp> {
public:
using BufferAssignmentOpConversionPattern<
TensorCastOp>::BufferAssignmentOpConversionPattern;
LogicalResult matchAndRewrite(
TensorCastOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto tensor_ty = op.getType().dyn_cast<RankedTensorType>();
if (!tensor_ty) return failure();
Value arg = operands.front();
auto arg_ty = arg.getType().dyn_cast<MemRefType>();
if (!arg_ty) return failure();
auto result_ty = converter->convertType(tensor_ty);
rewriter.replaceOpWithNewOp<MemRefCastOp>(op, arg, result_ty);
return success();
}
};
} // namespace
void populateStandardBufferizePattern(MLIRContext *context,
BufferAssignmentTypeConverter *converter,
OwningRewritePatternList *patterns) {
patterns
->insert<ExtractElementOpConversion, TensorFromElementsOpConverter,
DynamicTensorFromElementsOpConverter, TensorLoadOpConversion>(
context, converter);
patterns->insert<ExtractElementOpConversion, TensorFromElementsOpConverter,
DynamicTensorFromElementsOpConverter, TensorLoadOpConversion,
TensorCastOpConverter>(context, converter);
}
} // namespace transforms

View File

@ -79,7 +79,8 @@ struct BufferizePass : public BufferizePassBase<BufferizePass> {
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
target.addIllegalDialect<mhlo::MhloDialect>();
target.addIllegalOp<DynamicTensorFromElementsOp, ExtractElementOp,
TensorFromElementsOp, TensorLoadOp, YieldOp>();
TensorFromElementsOp, TensorLoadOp, YieldOp,
TensorCastOp>();
target.addDynamicallyLegalOp<TensorStoreOp>([&](TensorStoreOp op) {
return !op.tensor().getType().isa<UnrankedTensorType>();
});