[KERNEL_GEN] Clean up bufferize.cc (mostly includes and deps).
PiperOrigin-RevId: 340653792 Change-Id: I064652fa6339d5fe5388cacf147ca4d92af9d5d7
This commit is contained in:
		
							parent
							
								
									7c2b39e4bf
								
							
						
					
					
						commit
						05f93ec4c3
					
				| @ -35,11 +35,9 @@ cc_library( | ||||
|     srcs = ["bufferize.cc"], | ||||
|     hdrs = ["rewriters.h"], | ||||
|     deps = [ | ||||
|         "@llvm-project//llvm:Support", | ||||
|         "@llvm-project//mlir:IR", | ||||
|         "@llvm-project//mlir:Pass", | ||||
|         "@llvm-project//mlir:SCFDialect", | ||||
|         "@llvm-project//mlir:StandardOps", | ||||
|         "@llvm-project//mlir:Support", | ||||
|         "@llvm-project//mlir:Transforms", | ||||
|     ], | ||||
|  | ||||
| @ -17,27 +17,14 @@ limitations under the License. | ||||
| 
 | ||||
| #include "mlir/Transforms/Bufferize.h"  // from @llvm-project
 | ||||
| 
 | ||||
| #include <cstddef> | ||||
| #include <memory> | ||||
| 
 | ||||
| #include "llvm/ADT/STLExtras.h" | ||||
| #include "llvm/ADT/SmallVector.h" | ||||
| #include "mlir/Dialect/SCF/SCF.h"  // from @llvm-project
 | ||||
| #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
 | ||||
| #include "mlir/IR/Attributes.h"  // from @llvm-project
 | ||||
| #include "mlir/IR/BlockAndValueMapping.h"  // from @llvm-project
 | ||||
| #include "mlir/IR/Function.h"  // from @llvm-project
 | ||||
| #include "mlir/IR/MLIRContext.h"  // from @llvm-project
 | ||||
| #include "mlir/IR/Operation.h"  // from @llvm-project
 | ||||
| #include "mlir/IR/OperationSupport.h"  // from @llvm-project
 | ||||
| #include "mlir/IR/StandardTypes.h"  // from @llvm-project
 | ||||
| #include "mlir/Pass/Pass.h"  // from @llvm-project
 | ||||
| #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
 | ||||
| 
 | ||||
| namespace mlir { | ||||
| namespace kernel_gen { | ||||
| namespace transforms { | ||||
| 
 | ||||
| namespace { | ||||
| 
 | ||||
| class ConstantOpConverter : public OpConversionPattern<ConstantOp> { | ||||
| @ -50,22 +37,22 @@ class ConstantOpConverter : public OpConversionPattern<ConstantOp> { | ||||
|     // We only need to bufferize tensor constants.
 | ||||
|     Location loc = op.getLoc(); | ||||
|     auto result_type = op.getType().dyn_cast<RankedTensorType>(); | ||||
|     if (!result_type) return failure(); | ||||
|     if (result_type.getNumDynamicDims() != 0 || result_type.getRank() != 1) | ||||
|     if (!result_type || !result_type.hasStaticShape() || | ||||
|         result_type.getRank() != 1) | ||||
|       return failure(); | ||||
| 
 | ||||
|     auto elements_attr = op.getValue().dyn_cast<DenseElementsAttr>(); | ||||
|     auto memref_type = MemRefType::get({result_type.getNumElements()}, | ||||
|                                        result_type.getElementType()); | ||||
| 
 | ||||
|     Value buffer = rewriter.create<AllocaOp>(loc, memref_type); | ||||
| 
 | ||||
|     auto elements_attr = op.getValue().dyn_cast<DenseElementsAttr>(); | ||||
|     bool all_same_elems = elements_attr.isSplat(); | ||||
|     Value value; | ||||
|     if (elements_attr.isSplat()) | ||||
|     if (all_same_elems) | ||||
|       value = rewriter.create<ConstantOp>(loc, elements_attr.getSplatValue()); | ||||
|     for (auto pair : llvm::enumerate(elements_attr.getAttributeValues())) { | ||||
|       if (!elements_attr.isSplat()) | ||||
|         value = rewriter.create<ConstantOp>(loc, pair.value()); | ||||
|       Value index = rewriter.create<ConstantIndexOp>(loc, pair.index()); | ||||
|     for (auto en : llvm::enumerate(elements_attr.getAttributeValues())) { | ||||
|       if (!all_same_elems) value = rewriter.create<ConstantOp>(loc, en.value()); | ||||
|       Value index = rewriter.create<ConstantIndexOp>(loc, en.index()); | ||||
|       rewriter.create<StoreOp>(loc, value, buffer, index); | ||||
|     } | ||||
|     rewriter.replaceOp(op, {buffer}); | ||||
| @ -82,14 +69,14 @@ class TensorFromElementsOpConverter | ||||
|       TensorFromElementsOp op, ArrayRef<Value> operands, | ||||
|       ConversionPatternRewriter &rewriter) const final { | ||||
|     Location loc = op.getLoc(); | ||||
|     ShapedType result_type = op.getType().cast<ShapedType>(); | ||||
|     auto result_type = op.getType().cast<ShapedType>(); | ||||
|     int number_of_elements = op.elements().size(); | ||||
|     MemRefType memref_type = | ||||
|         MemRefType::get({number_of_elements}, result_type.getElementType()); | ||||
|     Value result = rewriter.create<AllocaOp>(loc, memref_type); | ||||
|     for (auto operand : llvm::enumerate(operands)) { | ||||
|       Value index = rewriter.create<ConstantIndexOp>(loc, operand.index()); | ||||
|       rewriter.create<StoreOp>(loc, operand.value(), result, index); | ||||
|     for (auto en : llvm::enumerate(operands)) { | ||||
|       Value index = rewriter.create<ConstantIndexOp>(loc, en.index()); | ||||
|       rewriter.create<StoreOp>(loc, en.value(), result, index); | ||||
|     } | ||||
|     rewriter.replaceOp(op, {result}); | ||||
|     return success(); | ||||
| @ -107,7 +94,7 @@ class DynamicTensorFromElementsOpConverter | ||||
|     // Allocate memory on stack.
 | ||||
|     Location loc = op.getLoc(); | ||||
|     DynamicTensorFromElementsOp::Adaptor transformed(operands); | ||||
|     RankedTensorType tensor_ty = op.getType().cast<RankedTensorType>(); | ||||
|     auto tensor_ty = op.getType().cast<RankedTensorType>(); | ||||
|     MemRefType memref_type = | ||||
|         MemRefType::get(tensor_ty.getShape(), tensor_ty.getElementType()); | ||||
|     Value result = rewriter.create<AllocaOp>(loc, memref_type, | ||||
| @ -121,7 +108,7 @@ class DynamicTensorFromElementsOpConverter | ||||
|     SmallVector<Value, 4> steps(rank, one); | ||||
|     SmallVector<Value, 4> upper_bounds; | ||||
|     int next_dynamic_index = 0; | ||||
|     for (int i = 0; i < rank; i++) { | ||||
|     for (int i = 0; i < rank; ++i) { | ||||
|       Value ub = tensor_ty.isDynamicDim(i) | ||||
|                      ? transformed.dynamicExtents()[next_dynamic_index++] | ||||
|                      : rewriter.create<ConstantIndexOp>( | ||||
| @ -173,7 +160,6 @@ class ExtractElementOpConversion | ||||
|     if (!adaptor.aggregate().getType().isa<BaseMemRefType>()) { | ||||
|       return failure(); | ||||
|     } | ||||
| 
 | ||||
|     rewriter.replaceOpWithNewOp<LoadOp>(op, adaptor.aggregate(), | ||||
|                                         adaptor.indices()); | ||||
|     return success(); | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user