[KERNEL_GEN] Clean up bufferize.cc (mostly includes and deps).

PiperOrigin-RevId: 340653792
Change-Id: I064652fa6339d5fe5388cacf147ca4d92af9d5d7
This commit is contained in:
Alexander Belyaev 2020-11-04 08:31:09 -08:00 committed by TensorFlower Gardener
parent 7c2b39e4bf
commit 05f93ec4c3
2 changed files with 15 additions and 31 deletions

View File

@ -35,11 +35,9 @@ cc_library(
srcs = ["bufferize.cc"],
hdrs = ["rewriters.h"],
deps = [
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
],

View File

@ -17,27 +17,14 @@ limitations under the License.
#include "mlir/Transforms/Bufferize.h" // from @llvm-project
#include <cstddef>
#include <memory>
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/OperationSupport.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
namespace mlir {
namespace kernel_gen {
namespace transforms {
namespace {
class ConstantOpConverter : public OpConversionPattern<ConstantOp> {
@ -50,22 +37,22 @@ class ConstantOpConverter : public OpConversionPattern<ConstantOp> {
// We only need to bufferize tensor constants.
Location loc = op.getLoc();
auto result_type = op.getType().dyn_cast<RankedTensorType>();
if (!result_type) return failure();
if (result_type.getNumDynamicDims() != 0 || result_type.getRank() != 1)
if (!result_type || !result_type.hasStaticShape() ||
result_type.getRank() != 1)
return failure();
auto elements_attr = op.getValue().dyn_cast<DenseElementsAttr>();
auto memref_type = MemRefType::get({result_type.getNumElements()},
result_type.getElementType());
Value buffer = rewriter.create<AllocaOp>(loc, memref_type);
auto elements_attr = op.getValue().dyn_cast<DenseElementsAttr>();
bool all_same_elems = elements_attr.isSplat();
Value value;
if (elements_attr.isSplat())
if (all_same_elems)
value = rewriter.create<ConstantOp>(loc, elements_attr.getSplatValue());
for (auto pair : llvm::enumerate(elements_attr.getAttributeValues())) {
if (!elements_attr.isSplat())
value = rewriter.create<ConstantOp>(loc, pair.value());
Value index = rewriter.create<ConstantIndexOp>(loc, pair.index());
for (auto en : llvm::enumerate(elements_attr.getAttributeValues())) {
if (!all_same_elems) value = rewriter.create<ConstantOp>(loc, en.value());
Value index = rewriter.create<ConstantIndexOp>(loc, en.index());
rewriter.create<StoreOp>(loc, value, buffer, index);
}
rewriter.replaceOp(op, {buffer});
@ -82,14 +69,14 @@ class TensorFromElementsOpConverter
TensorFromElementsOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
Location loc = op.getLoc();
ShapedType result_type = op.getType().cast<ShapedType>();
auto result_type = op.getType().cast<ShapedType>();
int number_of_elements = op.elements().size();
MemRefType memref_type =
MemRefType::get({number_of_elements}, result_type.getElementType());
Value result = rewriter.create<AllocaOp>(loc, memref_type);
for (auto operand : llvm::enumerate(operands)) {
Value index = rewriter.create<ConstantIndexOp>(loc, operand.index());
rewriter.create<StoreOp>(loc, operand.value(), result, index);
for (auto en : llvm::enumerate(operands)) {
Value index = rewriter.create<ConstantIndexOp>(loc, en.index());
rewriter.create<StoreOp>(loc, en.value(), result, index);
}
rewriter.replaceOp(op, {result});
return success();
@ -107,7 +94,7 @@ class DynamicTensorFromElementsOpConverter
// Allocate memory on stack.
Location loc = op.getLoc();
DynamicTensorFromElementsOp::Adaptor transformed(operands);
RankedTensorType tensor_ty = op.getType().cast<RankedTensorType>();
auto tensor_ty = op.getType().cast<RankedTensorType>();
MemRefType memref_type =
MemRefType::get(tensor_ty.getShape(), tensor_ty.getElementType());
Value result = rewriter.create<AllocaOp>(loc, memref_type,
@ -121,7 +108,7 @@ class DynamicTensorFromElementsOpConverter
SmallVector<Value, 4> steps(rank, one);
SmallVector<Value, 4> upper_bounds;
int next_dynamic_index = 0;
for (int i = 0; i < rank; i++) {
for (int i = 0; i < rank; ++i) {
Value ub = tensor_ty.isDynamicDim(i)
? transformed.dynamicExtents()[next_dynamic_index++]
: rewriter.create<ConstantIndexOp>(
@ -173,7 +160,6 @@ class ExtractElementOpConversion
if (!adaptor.aggregate().getType().isa<BaseMemRefType>()) {
return failure();
}
rewriter.replaceOpWithNewOp<LoadOp>(op, adaptor.aggregate(),
adaptor.indices());
return success();