From 0406aa2fcc0d1d25972e2af9ead15e3b5ab56d11 Mon Sep 17 00:00:00 2001
From: Tim Shen <timshen@google.com>
Date: Wed, 7 Oct 2020 15:09:35 -0700
Subject: [PATCH] [XLA/GPU] Factor out the logic of MLIR op -> kernel param
 slices to share with all emitters.

PiperOrigin-RevId: 335962025
Change-Id: Ie7d462bcf533d62bfdc5ddfb91666f20e26dd03b
---
 .../xla/service/gpu/ir_emitter_unnested.cc    | 88 ++++++++++---------
 .../xla/service/gpu/ir_emitter_unnested.h     |  2 +-
 2 files changed, 49 insertions(+), 41 deletions(-)

diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index b0c221dc954..51bee21df4e 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -202,6 +202,31 @@ StatusOr<BufferAllocation::Slice> GetAllocationSliceForMlir(
       "StaticMemRefCastOp(ViewOp(arg))");
 }
 
+StatusOr<std::vector<MlirBufferSlice>> GetMlirBufferSlices(
+    mlir::Operation* op, mlir::OperandRange operands,
+    absl::Span<const BufferAllocation> allocations) {
+  const auto buffer_is_written = [op](mlir::Value operand) {
+    llvm::SmallVector<mlir::MemoryEffects::EffectInstance, 2> effects;
+    mlir::cast<mlir::MemoryEffectOpInterface>(op).getEffectsOnValue(operand,
+                                                                    effects);
+    return absl::c_any_of(
+        effects, [](const mlir::MemoryEffects::EffectInstance& instance) {
+          return mlir::isa<mlir::MemoryEffects::Write>(instance.getEffect());
+        });
+  };
+
+  std::vector<MlirBufferSlice> slices;
+  for (mlir::Value operand : operands) {
+    slices.emplace_back();
+    auto& slice = slices.back();
+    TF_ASSIGN_OR_RETURN(slice.buffer_slice,
+                        GetAllocationSliceForMlir(operand, allocations));
+    slice.written = buffer_is_written(operand);
+    slice.shape = TypeToShape(operand.getType());
+  }
+  return slices;
+}
+
 }  // namespace
 
 IrEmitterUnnested::IrEmitterUnnested(const HloModuleConfig& hlo_module_config,
@@ -1371,47 +1396,30 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) {
   return EmitSortFromMlir(result);
 }
 
-Status IrEmitterUnnested::EmitSortFromMlir(MlirEmitterInput input) {
+Status IrEmitterUnnested::EmitSortFromMlir(MlirEmitterInput mlir_input) {
   absl::Span<const BufferAllocation> allocations(
       ir_emitter_context_->buffer_assignment().Allocations());
-  auto sort_op = mlir::cast<mlir::lmhlo::SortOp>(input.op);
+  auto sort_op = mlir::cast<mlir::lmhlo::SortOp>(mlir_input.op);
   std::string name = mlir::GetNameFromLoc(sort_op.getLoc());
-
-  int operand_count = sort_op.operands().size();
-  std::vector<xla::Shape> operand_shapes(operand_count);
-  std::vector<MlirBufferSlice> slices;
-  std::vector<xla::Shape> output_shapes(sort_op.output().size());
-
-  for (int i = 0; i < operand_count; i++) {
-    operand_shapes[i] = TypeToShape(sort_op.operands()[i].getType());
-  }
-
-  // Craft n + 1 slices, where the first n are output parameters, and the last
-  // is the on-device tuple storage. We don't need n operands because sorting
-  // kernels are always in-place.
-  for (int i = 0; i < operand_count; i++) {
-    output_shapes[i] = TypeToShape(sort_op.output()[i].getType());
-    MlirBufferSlice slice;
-    TF_ASSIGN_OR_RETURN(
-        slice.buffer_slice,
-        GetAllocationSliceForMlir(sort_op.output()[i], allocations));
-    slice.written = true;
-    slice.shape = operand_shapes[i];
-    slices.push_back(slice);
-  }
-  slices.push_back(input.extra_slice);
+  TF_ASSIGN_OR_RETURN(
+      std::vector<MlirBufferSlice> operands,
+      GetMlirBufferSlices(sort_op, sort_op.operands(), allocations));
+  TF_ASSIGN_OR_RETURN(
+      std::vector<MlirBufferSlice> outputs,
+      GetMlirBufferSlices(sort_op, sort_op.output(), allocations));
+  outputs.push_back(mlir_input.extra_slice);
 
   std::vector<std::unique_ptr<Thunk>> thunks;
 
-  Shape keys_shape = operand_shapes[0];
+  Shape keys_shape = operands[0].shape;
   int64 dimension_to_sort = sort_op.dimension();
-  for (int64 i = 0; i < operand_count; ++i) {
+  for (int64 i = 0; i < operands.size(); ++i) {
     // We assume that the layout of all involved operands and outputs is the
     // same.
     TF_RET_CHECK(
-        LayoutUtil::LayoutsInShapesEqual(keys_shape, operand_shapes[i]));
+        LayoutUtil::LayoutsInShapesEqual(keys_shape, operands[i].shape));
     TF_RET_CHECK(
-        LayoutUtil::LayoutsInShapesEqual(keys_shape, output_shapes[i]));
+        LayoutUtil::LayoutsInShapesEqual(keys_shape, outputs[i].shape));
 
     // If possible, we share buffers. If that is not possible, we need to copy
     // the values, because the emitter does the sorting in-place.
@@ -1429,7 +1437,7 @@ Status IrEmitterUnnested::EmitSortFromMlir(MlirEmitterInput input) {
           Thunk::ThunkInfo(),
           /*source_address=*/source_address,
           /*destination_buffer=*/destination_buffer,
-          /*mem_size=*/ShapeUtil::ByteSizeOf(operand_shapes[i])));
+          /*mem_size=*/ShapeUtil::ByteSizeOf(operands[i].shape)));
     }
   }
 
@@ -1499,10 +1507,10 @@ Status IrEmitterUnnested::EmitSortFromMlir(MlirEmitterInput input) {
   // we have not enough threads, or not enough shared memory. Also it does not
   // give a speedup if the tile size is < 128.
   int64 total_shared_memory_needed = 0;
-  for (int64 i = 0; i < operand_count; ++i) {
+  for (int64 i = 0; i < operands.size(); ++i) {
     total_shared_memory_needed +=
         kTileSize *
-        ShapeUtil::ByteSizeOfPrimitiveType(operand_shapes[i].element_type());
+        ShapeUtil::ByteSizeOfPrimitiveType(operands[i].shape.element_type());
   }
   bool no_tiling =
       kTileSize < 128 ||
@@ -1533,15 +1541,15 @@ Status IrEmitterUnnested::EmitSortFromMlir(MlirEmitterInput input) {
           absl::StrAppendFormat(out, "0x%x", xor_mask);
         }));
     thunks.push_back(
-        BuildKernelThunkForMlir(name, Thunk::ThunkInfo(), slices, &ir_arrays));
+        BuildKernelThunkForMlir(name, Thunk::ThunkInfo(), outputs, &ir_arrays));
     LaunchDimensions launch_dimensions = xor_masks.size() > 1
                                              ? tiled_launch_dimensions
                                              : standard_launch_dimensions;
     UpdateLaunchDimensions(launch_dimensions, thunks.back().get(),
                            ir_emitter_context_->llvm_module());
     std::vector<IrArray> values_arrays;
-    values_arrays.reserve(operand_count);
-    for (int64 i = 0; i < operand_count; ++i) {
+    values_arrays.reserve(operands.size());
+    for (int64 i = 0; i < operands.size(); ++i) {
       values_arrays.push_back(ir_arrays[i]);
     }
     TF_ASSIGN_OR_RETURN(
@@ -1583,14 +1591,14 @@ Status IrEmitterUnnested::EmitSortFromMlir(MlirEmitterInput input) {
   VLOG(2) << absl::StreamFormat(
       "%s requires %d thunks (including any D2D copies)", name, thunks.size());
 
-  AddThunkToThunkSequence(
-      absl::make_unique<SequentialThunk>(input.thunk_info, std::move(thunks)));
-  if (operand_count > 1) {
+  AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
+      mlir_input.thunk_info, std::move(thunks)));
+  if (operands.size() > 1) {
     // Emit the tuple as part of the last stage of sorting.
     // We are currently in the block sorted.in_bounds.after.
     b_.SetInsertPoint(b_.GetInsertBlock()->getTerminator());
     llvm_ir::EmitTuple(
-        ir_arrays[operand_count],
+        ir_arrays.back(),
         absl::MakeSpan(ir_arrays).subspan(0, ir_arrays.size() - 1), &b_);
   }
   return Status::OK();
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
index a317aac16ec..5cc5e206167 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
@@ -160,7 +160,7 @@ class IrEmitterUnnested : public IrEmitter,
   Status HandleScatter(HloInstruction* scatter) override;
   Status HandleSelect(HloInstruction* select) override;
   Status HandleSort(HloInstruction* sort) override;
-  Status EmitSortFromMlir(MlirEmitterInput input);
+  Status EmitSortFromMlir(MlirEmitterInput mlir_input);
   Status HandleTriangularSolve(HloInstruction* hlo) override;
   Status HandleTupleSelect(HloInstruction* tuple_select) override;
   Status HandleAllReduce(HloInstruction* crs) override;