[KERNEL_GEN] Clean up bufferize.cc (mostly includes and deps).

PiperOrigin-RevId: 340653792
Change-Id: I064652fa6339d5fe5388cacf147ca4d92af9d5d7
This commit is contained in:
Alexander Belyaev 2020-11-04 08:31:09 -08:00 committed by TensorFlower Gardener
parent 7c2b39e4bf
commit 05f93ec4c3
2 changed files with 15 additions and 31 deletions

View File

@ -35,11 +35,9 @@ cc_library(
srcs = ["bufferize.cc"], srcs = ["bufferize.cc"],
hdrs = ["rewriters.h"], hdrs = ["rewriters.h"],
deps = [ deps = [
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support", "@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms", "@llvm-project//mlir:Transforms",
], ],

View File

@ -17,27 +17,14 @@ limitations under the License.
#include "mlir/Transforms/Bufferize.h" // from @llvm-project #include "mlir/Transforms/Bufferize.h" // from @llvm-project
#include <cstddef>
#include <memory>
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project #include "mlir/Dialect/SCF/SCF.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/OperationSupport.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project
namespace mlir { namespace mlir {
namespace kernel_gen { namespace kernel_gen {
namespace transforms { namespace transforms {
namespace { namespace {
class ConstantOpConverter : public OpConversionPattern<ConstantOp> { class ConstantOpConverter : public OpConversionPattern<ConstantOp> {
@ -50,22 +37,22 @@ class ConstantOpConverter : public OpConversionPattern<ConstantOp> {
// We only need to bufferize tensor constants. // We only need to bufferize tensor constants.
Location loc = op.getLoc(); Location loc = op.getLoc();
auto result_type = op.getType().dyn_cast<RankedTensorType>(); auto result_type = op.getType().dyn_cast<RankedTensorType>();
if (!result_type) return failure(); if (!result_type || !result_type.hasStaticShape() ||
if (result_type.getNumDynamicDims() != 0 || result_type.getRank() != 1) result_type.getRank() != 1)
return failure(); return failure();
auto elements_attr = op.getValue().dyn_cast<DenseElementsAttr>();
auto memref_type = MemRefType::get({result_type.getNumElements()}, auto memref_type = MemRefType::get({result_type.getNumElements()},
result_type.getElementType()); result_type.getElementType());
Value buffer = rewriter.create<AllocaOp>(loc, memref_type); Value buffer = rewriter.create<AllocaOp>(loc, memref_type);
auto elements_attr = op.getValue().dyn_cast<DenseElementsAttr>();
bool all_same_elems = elements_attr.isSplat();
Value value; Value value;
if (elements_attr.isSplat()) if (all_same_elems)
value = rewriter.create<ConstantOp>(loc, elements_attr.getSplatValue()); value = rewriter.create<ConstantOp>(loc, elements_attr.getSplatValue());
for (auto pair : llvm::enumerate(elements_attr.getAttributeValues())) { for (auto en : llvm::enumerate(elements_attr.getAttributeValues())) {
if (!elements_attr.isSplat()) if (!all_same_elems) value = rewriter.create<ConstantOp>(loc, en.value());
value = rewriter.create<ConstantOp>(loc, pair.value()); Value index = rewriter.create<ConstantIndexOp>(loc, en.index());
Value index = rewriter.create<ConstantIndexOp>(loc, pair.index());
rewriter.create<StoreOp>(loc, value, buffer, index); rewriter.create<StoreOp>(loc, value, buffer, index);
} }
rewriter.replaceOp(op, {buffer}); rewriter.replaceOp(op, {buffer});
@ -82,14 +69,14 @@ class TensorFromElementsOpConverter
TensorFromElementsOp op, ArrayRef<Value> operands, TensorFromElementsOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final { ConversionPatternRewriter &rewriter) const final {
Location loc = op.getLoc(); Location loc = op.getLoc();
ShapedType result_type = op.getType().cast<ShapedType>(); auto result_type = op.getType().cast<ShapedType>();
int number_of_elements = op.elements().size(); int number_of_elements = op.elements().size();
MemRefType memref_type = MemRefType memref_type =
MemRefType::get({number_of_elements}, result_type.getElementType()); MemRefType::get({number_of_elements}, result_type.getElementType());
Value result = rewriter.create<AllocaOp>(loc, memref_type); Value result = rewriter.create<AllocaOp>(loc, memref_type);
for (auto operand : llvm::enumerate(operands)) { for (auto en : llvm::enumerate(operands)) {
Value index = rewriter.create<ConstantIndexOp>(loc, operand.index()); Value index = rewriter.create<ConstantIndexOp>(loc, en.index());
rewriter.create<StoreOp>(loc, operand.value(), result, index); rewriter.create<StoreOp>(loc, en.value(), result, index);
} }
rewriter.replaceOp(op, {result}); rewriter.replaceOp(op, {result});
return success(); return success();
@ -107,7 +94,7 @@ class DynamicTensorFromElementsOpConverter
// Allocate memory on stack. // Allocate memory on stack.
Location loc = op.getLoc(); Location loc = op.getLoc();
DynamicTensorFromElementsOp::Adaptor transformed(operands); DynamicTensorFromElementsOp::Adaptor transformed(operands);
RankedTensorType tensor_ty = op.getType().cast<RankedTensorType>(); auto tensor_ty = op.getType().cast<RankedTensorType>();
MemRefType memref_type = MemRefType memref_type =
MemRefType::get(tensor_ty.getShape(), tensor_ty.getElementType()); MemRefType::get(tensor_ty.getShape(), tensor_ty.getElementType());
Value result = rewriter.create<AllocaOp>(loc, memref_type, Value result = rewriter.create<AllocaOp>(loc, memref_type,
@ -121,7 +108,7 @@ class DynamicTensorFromElementsOpConverter
SmallVector<Value, 4> steps(rank, one); SmallVector<Value, 4> steps(rank, one);
SmallVector<Value, 4> upper_bounds; SmallVector<Value, 4> upper_bounds;
int next_dynamic_index = 0; int next_dynamic_index = 0;
for (int i = 0; i < rank; i++) { for (int i = 0; i < rank; ++i) {
Value ub = tensor_ty.isDynamicDim(i) Value ub = tensor_ty.isDynamicDim(i)
? transformed.dynamicExtents()[next_dynamic_index++] ? transformed.dynamicExtents()[next_dynamic_index++]
: rewriter.create<ConstantIndexOp>( : rewriter.create<ConstantIndexOp>(
@ -173,7 +160,6 @@ class ExtractElementOpConversion
if (!adaptor.aggregate().getType().isa<BaseMemRefType>()) { if (!adaptor.aggregate().getType().isa<BaseMemRefType>()) {
return failure(); return failure();
} }
rewriter.replaceOpWithNewOp<LoadOp>(op, adaptor.aggregate(), rewriter.replaceOpWithNewOp<LoadOp>(op, adaptor.aggregate(),
adaptor.indices()); adaptor.indices());
return success(); return success();