From b266b468252baa1b2a8348c86ffec071fc90fa95 Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Wed, 27 May 2020 11:26:31 -0700 Subject: [PATCH] [XLA:GPU] Use the generic implementation for elemental reduce The generic version used in fusions didn't support variadic reduction on GPU (it did on CPU), so tie up some loose ends and use the generic version. PiperOrigin-RevId: 313428251 Change-Id: Ide547280b0fcf04a99a51b721d8ca860c9da6305 --- .../xla/service/gpu/elemental_ir_emitter.h | 9 +- .../compiler/xla/service/gpu/ir_emitter.cc | 136 +++--------------- .../compiler/xla/service/gpu/ir_emitter.h | 3 +- 3 files changed, 25 insertions(+), 123 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h index 3c4e9f7c1e6..a3056b1ddad 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h @@ -40,7 +40,7 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { public: // A NestedComputer computes an element of the output of the given computation // given a Span of its input elements. - using NestedComputer = std::function( + using NestedComputer = std::function>( const HloComputation&, absl::Span)>; GpuElementalIrEmitter(const HloModuleConfig& hlo_module_config, @@ -91,12 +91,7 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { StatusOr> EmitThreadLocalCall( const HloComputation& callee, absl::Span parameters, absl::string_view) override { - // TODO(b/118332391): Supported variadic return values. - auto result = compute_nested_(callee, parameters); - if (!result.ok()) { - return result.status(); - } - return std::vector{result.ValueOrDie()}; + return compute_nested_(callee, parameters); } llvm::Value* EmitThreadId() override; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index 744cd7b56bf..aa8a6215cc7 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -698,115 +698,6 @@ Status IrEmitter::HandleParameter(HloInstruction* parameter) { return Status::OK(); } -Status IrEmitter::HandleReduce(HloInstruction* instr) { - const HloReduceInstruction* reduce = Cast(instr); - const Shape& out_shape = reduce->shape(); - bool returns_tuple = !out_shape.IsArray(); - int accumulators_count = 1; - if (returns_tuple) { - CHECK(out_shape.IsTuple()); - accumulators_count = out_shape.tuple_shapes_size(); - } - - auto arg = reduce->operand(0); - absl::Span dimensions(reduce->dimensions()); - HloComputation* function = reduce->to_apply(); - return EmitTargetElementLoop( - *reduce, - [=](const llvm_ir::IrArray::Index& index) -> StatusOr { - std::vector accumulator_addrs; - std::vector accumulator_types; - - // Initialize accumulators with initial values. - for (int i = 0; i < accumulators_count; i++) { - auto init_value = reduce->init_values()[i]; - const Shape& element_shape = - returns_tuple ? out_shape.tuple_shapes(i) : out_shape; - PrimitiveType accumulator_type = element_shape.element_type(); - llvm::Type* accumulator_llvm_type = - llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_); - llvm::AllocaInst* accumulator_addr = Alloca(accumulator_llvm_type); - Store(Load(GetBasePointer(*init_value)), accumulator_addr); - accumulator_addrs.push_back(accumulator_addr); - accumulator_types.push_back(accumulator_llvm_type); - } - - // The enclosing loops go over all the target elements. Now we have to - // compute the actual target element. For this, we build a new loop nest - // to iterate over all the reduction dimensions in the argument. - // AddLoopsForShapeOnDimensions will return an Index where induction - // Value*s are placed for each dimension in dimensions, and all the rest - // are nullptrs. - llvm_ir::ForLoopNest loops(IrName(reduce, "inner"), &b_); - std::vector input_multi_index = - loops.AddLoopsForShapeOnDimensions(arg->shape(), dimensions, - "reduction_dim"); - - SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_); - - // Build a full index for the input argument, using reduced_dims_index - // as the base. In reduced_dims_index only the reduction dimensions are - // filled in. We fill in the rest of the dimensions with induction - // Value*s taken from 'index' which iterates over the target array. - // See the high-level description in the XLA documentation for details. - llvm_ir::IrArray::Index::const_iterator it = index.begin(); - - for (auto& i : input_multi_index) { - if (i == nullptr) { - i = *it++; - } - } - CHECK(index.end() == it); - - // Apply the reduction function to the loaded value. - llvm_ir::IrArray::Index input_index(input_multi_index, arg->shape(), - b_.getInt64Ty()); - std::vector reduction_operands(accumulator_addrs.begin(), - accumulator_addrs.end()); - for (int i = 0; i < accumulators_count; i++) { - llvm::Value* input_address = - GetIrArray(*reduce->operand(i), *reduce) - .EmitArrayElementAddress(input_index, &b_); - reduction_operands.push_back(input_address); - } - - llvm::Value* ret_argument; - if (!returns_tuple) { - CHECK_EQ(accumulator_addrs.size(), 1); - ret_argument = accumulator_addrs[0]; - } else { - const Shape& return_shape = function->root_instruction()->shape(); - - llvm::Type* return_value_buffer_type = - llvm_ir::ShapeToIrType(return_shape, module_); - ret_argument = Alloca(return_value_buffer_type); - llvm_ir::IrArray tuple_array(ret_argument, return_shape); - EmitTuple(tuple_array, accumulator_addrs, &b_); - } - - TF_RETURN_IF_ERROR(EmitCallToNestedComputation( - *function, reduction_operands, ret_argument)); - - SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); - - if (!returns_tuple) { - CHECK_EQ(accumulator_addrs.size(), 1); - return Load(accumulator_addrs[0]); - } else { - // Emit a struct for the LoopEmitter dealing with multi-output - // fusion. - llvm::Value* returned_structure = llvm::UndefValue::get( - llvm::StructType::get(b_.getContext(), accumulator_types)); - for (int i = 0; i < accumulators_count; i++) { - llvm::Value* accumulator_value = Load(accumulator_addrs[i]); - returned_structure = - b_.CreateInsertValue(returned_structure, accumulator_value, i); - } - return returned_structure; - } - }); -} - Status IrEmitter::HandleFusion(HloInstruction* fusion) { // kFusion for library calls should be handled by // IrEmitterUnnested::HandleFusion. @@ -866,22 +757,39 @@ Status IrEmitter::HandleBatchNormGrad(HloInstruction*) { "to a cudnn CustomCall using CudnnBatchNormRewriter."); } -StatusOr IrEmitter::ComputeNestedElement( +StatusOr> IrEmitter::ComputeNestedElement( const HloComputation& computation, absl::Span parameter_elements) { + const Shape& return_shape = computation.root_instruction()->shape(); llvm::Value* return_buffer = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType( - computation.root_instruction()->shape().element_type(), module_), - "return_buffer", &b_); + llvm_ir::ShapeToIrType(return_shape, module_), "return_buffer", &b_); std::vector parameter_buffers; for (llvm::Value* parameter_element : parameter_elements) { parameter_buffers.push_back(llvm_ir::EmitAllocaAtFunctionEntry( parameter_element->getType(), "parameter_buffer", &b_)); Store(parameter_element, parameter_buffers.back()); } + + std::vector allocas_for_returned_scalars; + if (!return_shape.IsTuple()) { + allocas_for_returned_scalars.push_back(return_buffer); + } else { + allocas_for_returned_scalars = + llvm_ir::EmitTupleAllocasAtFunctionEntry(return_shape, &b_); + llvm_ir::IrArray tuple_array(return_buffer, return_shape); + + EmitTuple(tuple_array, allocas_for_returned_scalars, &b_); + } + TF_RETURN_IF_ERROR(EmitCallToNestedComputation(computation, parameter_buffers, return_buffer)); - return Load(return_buffer); + + std::vector returned_scalars; + returned_scalars.reserve(allocas_for_returned_scalars.size()); + for (llvm::Value* addr : allocas_for_returned_scalars) { + returned_scalars.push_back(Load(addr)); + } + return returned_scalars; } std::vector IrEmitter::ConstructIrArrayForOutputs( diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h index e0fe454dcfe..93712961ea2 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h @@ -89,7 +89,6 @@ class IrEmitter : public DfsHloVisitorWithDefault, Status HandleRecv(HloInstruction* recv) override; Status HandleRecvDone(HloInstruction* recv_done) override; Status HandleParameter(HloInstruction* parameter) override; - Status HandleReduce(HloInstruction* reduce) override; Status HandleTuple(HloInstruction* tuple) override; Status HandleScatter(HloInstruction* scatter) override; Status HandleSelect(HloInstruction* select) override; @@ -213,7 +212,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, const llvm_ir::IrArray::Index& compare_keys_index, const llvm_ir::IrArray& keys_array); - StatusOr ComputeNestedElement( + StatusOr> ComputeNestedElement( const HloComputation& computation, absl::Span parameter_elements);