From e8d84bd57a45dce49cebd5afe09875425813695f Mon Sep 17 00:00:00 2001 From: Alexander Belyaev <pifon@google.com> Date: Mon, 5 Oct 2020 03:54:51 -0700 Subject: [PATCH] [HLO] Clean-up dynamic allocation in hlo-legalize-to-lhlo pass. PiperOrigin-RevId: 335385243 Change-Id: I29682df1f6753900554fd018f8d40464a7ba41fc --- .../mhlo/transforms/hlo_legalize_to_lhlo.cc | 31 ++++++++----------- 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc index a808608888c..3485afffc9f 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc @@ -45,7 +45,7 @@ using BaseOpConversion = BufferAssignmentOpConversionPattern<T>; Value InsertDynamicAllocAndDealloc(Location loc, Value result, Value shape_operand, ConversionPatternRewriter* rewriter) { - auto result_type = result.getType().dyn_cast<ShapedType>(); + auto result_type = result.getType().dyn_cast<RankedTensorType>(); if (!result_type) { result.getDefiningOp()->emitOpError() << "tensor to buffer conversion expects ranked results"; @@ -53,17 +53,13 @@ Value InsertDynamicAllocAndDealloc(Location loc, Value result, auto memref_type = MemRefType::get(result_type.getShape(), result_type.getElementType()); - Operation* op = result.getDefiningOp(); - // Extract the required element out of the vector. SmallVector<Value, 4> dynamic_operands; for (auto shape_element : llvm::enumerate(result_type.getShape())) { if (shape_element.value() != ShapedType::kDynamicSize) continue; - Value index = rewriter->create<ConstantOp>( - loc, rewriter->getIntegerAttr(rewriter->getIndexType(), - shape_element.index())); - Value alloc_operand = rewriter->create<ExtractElementOp>(loc, shape_operand, - ValueRange{index}); + Value index = rewriter->create<ConstantIndexOp>(loc, shape_element.index()); + Value alloc_operand = + rewriter->create<ExtractElementOp>(loc, shape_operand, index); if (!alloc_operand.getType().isIndex()) { alloc_operand = rewriter->create<IndexCastOp>(loc, alloc_operand, rewriter->getIndexType()); @@ -71,15 +67,12 @@ Value InsertDynamicAllocAndDealloc(Location loc, Value result, dynamic_operands.push_back(alloc_operand); } - // Insert in front of op to ensure sizes are available. - OpBuilder allocBuilder(op); - auto alloc = allocBuilder.create<AllocOp>(loc, memref_type, dynamic_operands); - return alloc; + return rewriter->create<AllocOp>(loc, memref_type, dynamic_operands); } Value InsertAlloc(Location loc, OpResult result, ConversionPatternRewriter* rewriter) { - auto result_type = result.getType().dyn_cast<ShapedType>(); + auto result_type = result.getType().dyn_cast<RankedTensorType>(); if (!result_type || !result_type.hasStaticShape()) { result.getDefiningOp()->emitOpError() << "tensor to buffer conversion expects statically shaped results"; @@ -112,19 +105,21 @@ class HloToLhloOpConverter : public BaseOpConversion<HloOpTy> { buffer_args.push_back( InsertAlloc(op->getLoc(), result.value(), &rewriter)); } else { - SmallVector<Value, 1> results_shape; auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op); if (!shape_type_op) return failure(); - if (failed( - shape_type_op.reifyReturnTypeShapes(rewriter, results_shape))) - return failure(); + + SmallVector<Value, 1> results_shape; + auto status = + shape_type_op.reifyReturnTypeShapes(rewriter, results_shape); + if (failed(status)) return failure(); buffer_args.push_back(InsertDynamicAllocAndDealloc( op->getLoc(), result.value(), results_shape.front(), &rewriter)); } } rewriter.create<mhlo::HloToLhloOp<HloOpTy>>(op->getLoc(), llvm::None, buffer_args, op->getAttrs()); - rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size())); + rewriter.replaceOp( + op, llvm::makeArrayRef(buffer_args).drop_front(operands.size())); return success(); } };