[KERNEL_GEN] Clean up bufferize.cc (mostly includes and deps).
PiperOrigin-RevId: 340653792 Change-Id: I064652fa6339d5fe5388cacf147ca4d92af9d5d7
This commit is contained in:
parent
7c2b39e4bf
commit
05f93ec4c3
@ -35,11 +35,9 @@ cc_library(
|
||||
srcs = ["bufferize.cc"],
|
||||
hdrs = ["rewriters.h"],
|
||||
deps = [
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:SCFDialect",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
],
|
||||
|
||||
@ -17,27 +17,14 @@ limitations under the License.
|
||||
|
||||
#include "mlir/Transforms/Bufferize.h" // from @llvm-project
|
||||
|
||||
#include <cstddef>
|
||||
#include <memory>
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/OperationSupport.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace kernel_gen {
|
||||
namespace transforms {
|
||||
|
||||
namespace {
|
||||
|
||||
class ConstantOpConverter : public OpConversionPattern<ConstantOp> {
|
||||
@ -50,22 +37,22 @@ class ConstantOpConverter : public OpConversionPattern<ConstantOp> {
|
||||
// We only need to bufferize tensor constants.
|
||||
Location loc = op.getLoc();
|
||||
auto result_type = op.getType().dyn_cast<RankedTensorType>();
|
||||
if (!result_type) return failure();
|
||||
if (result_type.getNumDynamicDims() != 0 || result_type.getRank() != 1)
|
||||
if (!result_type || !result_type.hasStaticShape() ||
|
||||
result_type.getRank() != 1)
|
||||
return failure();
|
||||
|
||||
auto elements_attr = op.getValue().dyn_cast<DenseElementsAttr>();
|
||||
auto memref_type = MemRefType::get({result_type.getNumElements()},
|
||||
result_type.getElementType());
|
||||
|
||||
Value buffer = rewriter.create<AllocaOp>(loc, memref_type);
|
||||
|
||||
auto elements_attr = op.getValue().dyn_cast<DenseElementsAttr>();
|
||||
bool all_same_elems = elements_attr.isSplat();
|
||||
Value value;
|
||||
if (elements_attr.isSplat())
|
||||
if (all_same_elems)
|
||||
value = rewriter.create<ConstantOp>(loc, elements_attr.getSplatValue());
|
||||
for (auto pair : llvm::enumerate(elements_attr.getAttributeValues())) {
|
||||
if (!elements_attr.isSplat())
|
||||
value = rewriter.create<ConstantOp>(loc, pair.value());
|
||||
Value index = rewriter.create<ConstantIndexOp>(loc, pair.index());
|
||||
for (auto en : llvm::enumerate(elements_attr.getAttributeValues())) {
|
||||
if (!all_same_elems) value = rewriter.create<ConstantOp>(loc, en.value());
|
||||
Value index = rewriter.create<ConstantIndexOp>(loc, en.index());
|
||||
rewriter.create<StoreOp>(loc, value, buffer, index);
|
||||
}
|
||||
rewriter.replaceOp(op, {buffer});
|
||||
@ -82,14 +69,14 @@ class TensorFromElementsOpConverter
|
||||
TensorFromElementsOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
Location loc = op.getLoc();
|
||||
ShapedType result_type = op.getType().cast<ShapedType>();
|
||||
auto result_type = op.getType().cast<ShapedType>();
|
||||
int number_of_elements = op.elements().size();
|
||||
MemRefType memref_type =
|
||||
MemRefType::get({number_of_elements}, result_type.getElementType());
|
||||
Value result = rewriter.create<AllocaOp>(loc, memref_type);
|
||||
for (auto operand : llvm::enumerate(operands)) {
|
||||
Value index = rewriter.create<ConstantIndexOp>(loc, operand.index());
|
||||
rewriter.create<StoreOp>(loc, operand.value(), result, index);
|
||||
for (auto en : llvm::enumerate(operands)) {
|
||||
Value index = rewriter.create<ConstantIndexOp>(loc, en.index());
|
||||
rewriter.create<StoreOp>(loc, en.value(), result, index);
|
||||
}
|
||||
rewriter.replaceOp(op, {result});
|
||||
return success();
|
||||
@ -107,7 +94,7 @@ class DynamicTensorFromElementsOpConverter
|
||||
// Allocate memory on stack.
|
||||
Location loc = op.getLoc();
|
||||
DynamicTensorFromElementsOp::Adaptor transformed(operands);
|
||||
RankedTensorType tensor_ty = op.getType().cast<RankedTensorType>();
|
||||
auto tensor_ty = op.getType().cast<RankedTensorType>();
|
||||
MemRefType memref_type =
|
||||
MemRefType::get(tensor_ty.getShape(), tensor_ty.getElementType());
|
||||
Value result = rewriter.create<AllocaOp>(loc, memref_type,
|
||||
@ -121,7 +108,7 @@ class DynamicTensorFromElementsOpConverter
|
||||
SmallVector<Value, 4> steps(rank, one);
|
||||
SmallVector<Value, 4> upper_bounds;
|
||||
int next_dynamic_index = 0;
|
||||
for (int i = 0; i < rank; i++) {
|
||||
for (int i = 0; i < rank; ++i) {
|
||||
Value ub = tensor_ty.isDynamicDim(i)
|
||||
? transformed.dynamicExtents()[next_dynamic_index++]
|
||||
: rewriter.create<ConstantIndexOp>(
|
||||
@ -173,7 +160,6 @@ class ExtractElementOpConversion
|
||||
if (!adaptor.aggregate().getType().isa<BaseMemRefType>()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<LoadOp>(op, adaptor.aggregate(),
|
||||
adaptor.indices());
|
||||
return success();
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user