From e8d84bd57a45dce49cebd5afe09875425813695f Mon Sep 17 00:00:00 2001
From: Alexander Belyaev <pifon@google.com>
Date: Mon, 5 Oct 2020 03:54:51 -0700
Subject: [PATCH] [HLO] Clean-up dynamic allocation in hlo-legalize-to-lhlo
 pass.

PiperOrigin-RevId: 335385243
Change-Id: I29682df1f6753900554fd018f8d40464a7ba41fc
---
 .../mhlo/transforms/hlo_legalize_to_lhlo.cc   | 31 ++++++++-----------
 1 file changed, 13 insertions(+), 18 deletions(-)

diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc
index a808608888c..3485afffc9f 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc
@@ -45,7 +45,7 @@ using BaseOpConversion = BufferAssignmentOpConversionPattern<T>;
 Value InsertDynamicAllocAndDealloc(Location loc, Value result,
                                    Value shape_operand,
                                    ConversionPatternRewriter* rewriter) {
-  auto result_type = result.getType().dyn_cast<ShapedType>();
+  auto result_type = result.getType().dyn_cast<RankedTensorType>();
   if (!result_type) {
     result.getDefiningOp()->emitOpError()
         << "tensor to buffer conversion expects ranked results";
@@ -53,17 +53,13 @@ Value InsertDynamicAllocAndDealloc(Location loc, Value result,
   auto memref_type =
       MemRefType::get(result_type.getShape(), result_type.getElementType());
 
-  Operation* op = result.getDefiningOp();
-
   // Extract the required element out of the vector.
   SmallVector<Value, 4> dynamic_operands;
   for (auto shape_element : llvm::enumerate(result_type.getShape())) {
     if (shape_element.value() != ShapedType::kDynamicSize) continue;
-    Value index = rewriter->create<ConstantOp>(
-        loc, rewriter->getIntegerAttr(rewriter->getIndexType(),
-                                      shape_element.index()));
-    Value alloc_operand = rewriter->create<ExtractElementOp>(loc, shape_operand,
-                                                             ValueRange{index});
+    Value index = rewriter->create<ConstantIndexOp>(loc, shape_element.index());
+    Value alloc_operand =
+        rewriter->create<ExtractElementOp>(loc, shape_operand, index);
     if (!alloc_operand.getType().isIndex()) {
       alloc_operand = rewriter->create<IndexCastOp>(loc, alloc_operand,
                                                     rewriter->getIndexType());
@@ -71,15 +67,12 @@ Value InsertDynamicAllocAndDealloc(Location loc, Value result,
     dynamic_operands.push_back(alloc_operand);
   }
 
-  // Insert in front of op to ensure sizes are available.
-  OpBuilder allocBuilder(op);
-  auto alloc = allocBuilder.create<AllocOp>(loc, memref_type, dynamic_operands);
-  return alloc;
+  return rewriter->create<AllocOp>(loc, memref_type, dynamic_operands);
 }
 
 Value InsertAlloc(Location loc, OpResult result,
                   ConversionPatternRewriter* rewriter) {
-  auto result_type = result.getType().dyn_cast<ShapedType>();
+  auto result_type = result.getType().dyn_cast<RankedTensorType>();
   if (!result_type || !result_type.hasStaticShape()) {
     result.getDefiningOp()->emitOpError()
         << "tensor to buffer conversion expects statically shaped results";
@@ -112,19 +105,21 @@ class HloToLhloOpConverter : public BaseOpConversion<HloOpTy> {
         buffer_args.push_back(
             InsertAlloc(op->getLoc(), result.value(), &rewriter));
       } else {
-        SmallVector<Value, 1> results_shape;
         auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
         if (!shape_type_op) return failure();
-        if (failed(
-                shape_type_op.reifyReturnTypeShapes(rewriter, results_shape)))
-          return failure();
+
+        SmallVector<Value, 1> results_shape;
+        auto status =
+            shape_type_op.reifyReturnTypeShapes(rewriter, results_shape);
+        if (failed(status)) return failure();
         buffer_args.push_back(InsertDynamicAllocAndDealloc(
             op->getLoc(), result.value(), results_shape.front(), &rewriter));
       }
     }
     rewriter.create<mhlo::HloToLhloOp<HloOpTy>>(op->getLoc(), llvm::None,
                                                 buffer_args, op->getAttrs());
-    rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size()));
+    rewriter.replaceOp(
+        op, llvm::makeArrayRef(buffer_args).drop_front(operands.size()));
     return success();
   }
 };