[MLIR] Add e2e test for unranked unary TF op, lowered and run with CPU runner.

PiperOrigin-RevId: 325665428
Change-Id: I3e8a1a3a9551ba470e858fb775a31ca894f47359
This commit is contained in:
Alexander Belyaev 2020-08-09 02:36:32 -07:00 committed by TensorFlower Gardener
parent 4308605c89
commit 947b6c3a4b
2 changed files with 29 additions and 3 deletions

View File

@ -170,7 +170,7 @@ struct TransformUnrankedHloPass
PopulateTransformUnrankedHloPatterns(&ctx, &patterns);
// Apply transformation.
if (failed(applyFullConversion(getFunction(), target, patterns)))
if (failed(applyPartialConversion(getFunction(), target, patterns)))
return signalPassFailure();
}
};

View File

@ -44,6 +44,28 @@ namespace {
#define GEN_PASS_CLASSES
#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc"
// TODO(herhut) : This could become a real pattern in bufferize pass. What we
// would need to do is insert a copy to model the semantics correctly. The same
// is true for the TensorLoad pattern that is already in there. Then buffer
// assignment free insertion and copy removal should clean this up for us.
//
// This patten erases `tensor_store(src_unranked_tensor, dst_unranked_memref)`
// op and replaces the result of the defining op produced `dst_unranked_memref`
// with the rewritten `src_unranked_tensor`.
class UnrankedTensorStoreTestOnlyPattern
: public OpConversionPattern<mlir::TensorStoreOp> {
public:
using OpConversionPattern<mlir::TensorStoreOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
mlir::TensorStoreOp op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
rewriter.replaceOp(op.memref().getDefiningOp(), op.tensor());
rewriter.replaceOp(op, {});
return success();
}
};
struct BufferizePass : public BufferizePassBase<BufferizePass> {
public:
void runOnOperation() override {
@ -57,8 +79,11 @@ struct BufferizePass : public BufferizePassBase<BufferizePass> {
target.addLegalOp<ModuleTerminatorOp>();
target.addIllegalDialect<mhlo::MhloDialect>();
target.addIllegalOp<TensorFromElementsOp>();
target.addIllegalOp<TensorLoadOp>();
target.addIllegalOp<ExtractElementOp>();
target.addIllegalOp<TensorLoadOp>();
target.addDynamicallyLegalOp<TensorStoreOp>([&](TensorStoreOp op) {
return !op.tensor().getType().isa<UnrankedTensorType>();
});
BufferAssignmentTypeConverter converter;
auto typesAreLegal = [&converter](Operation* op) {
@ -86,8 +111,9 @@ struct BufferizePass : public BufferizePassBase<BufferizePass> {
&converter, &patterns);
populateStandardBufferizePattern(func.getContext(), &bufferAssignment,
&converter, &patterns);
patterns.insert<UnrankedTensorStoreTestOnlyPattern>(func.getContext());
return applyFullConversion(func, target, patterns);
return applyPartialConversion(func, target, patterns);
});
if (result.wasInterrupted()) {
signalPassFailure();