[MLIR] Add e2e test for unranked unary TF op, lowered and run with CPU runner.
PiperOrigin-RevId: 325665428 Change-Id: I3e8a1a3a9551ba470e858fb775a31ca894f47359
This commit is contained in:
parent
4308605c89
commit
947b6c3a4b
@ -170,7 +170,7 @@ struct TransformUnrankedHloPass
|
||||
PopulateTransformUnrankedHloPatterns(&ctx, &patterns);
|
||||
|
||||
// Apply transformation.
|
||||
if (failed(applyFullConversion(getFunction(), target, patterns)))
|
||||
if (failed(applyPartialConversion(getFunction(), target, patterns)))
|
||||
return signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
@ -44,6 +44,28 @@ namespace {
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc"
|
||||
|
||||
// TODO(herhut) : This could become a real pattern in bufferize pass. What we
|
||||
// would need to do is insert a copy to model the semantics correctly. The same
|
||||
// is true for the TensorLoad pattern that is already in there. Then buffer
|
||||
// assignment free insertion and copy removal should clean this up for us.
|
||||
//
|
||||
// This patten erases `tensor_store(src_unranked_tensor, dst_unranked_memref)`
|
||||
// op and replaces the result of the defining op produced `dst_unranked_memref`
|
||||
// with the rewritten `src_unranked_tensor`.
|
||||
class UnrankedTensorStoreTestOnlyPattern
|
||||
: public OpConversionPattern<mlir::TensorStoreOp> {
|
||||
public:
|
||||
using OpConversionPattern<mlir::TensorStoreOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
mlir::TensorStoreOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
rewriter.replaceOp(op.memref().getDefiningOp(), op.tensor());
|
||||
rewriter.replaceOp(op, {});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct BufferizePass : public BufferizePassBase<BufferizePass> {
|
||||
public:
|
||||
void runOnOperation() override {
|
||||
@ -57,8 +79,11 @@ struct BufferizePass : public BufferizePassBase<BufferizePass> {
|
||||
target.addLegalOp<ModuleTerminatorOp>();
|
||||
target.addIllegalDialect<mhlo::MhloDialect>();
|
||||
target.addIllegalOp<TensorFromElementsOp>();
|
||||
target.addIllegalOp<TensorLoadOp>();
|
||||
target.addIllegalOp<ExtractElementOp>();
|
||||
target.addIllegalOp<TensorLoadOp>();
|
||||
target.addDynamicallyLegalOp<TensorStoreOp>([&](TensorStoreOp op) {
|
||||
return !op.tensor().getType().isa<UnrankedTensorType>();
|
||||
});
|
||||
|
||||
BufferAssignmentTypeConverter converter;
|
||||
auto typesAreLegal = [&converter](Operation* op) {
|
||||
@ -86,8 +111,9 @@ struct BufferizePass : public BufferizePassBase<BufferizePass> {
|
||||
&converter, &patterns);
|
||||
populateStandardBufferizePattern(func.getContext(), &bufferAssignment,
|
||||
&converter, &patterns);
|
||||
patterns.insert<UnrankedTensorStoreTestOnlyPattern>(func.getContext());
|
||||
|
||||
return applyFullConversion(func, target, patterns);
|
||||
return applyPartialConversion(func, target, patterns);
|
||||
});
|
||||
if (result.wasInterrupted()) {
|
||||
signalPassFailure();
|
||||
|
Loading…
Reference in New Issue
Block a user