From b651a2cb5aa46b64422ecf2078403a2436e3868b Mon Sep 17 00:00:00 2001 From: George Karpenkov Date: Thu, 28 Mar 2019 14:09:22 -0700 Subject: [PATCH] Variadic reduce implementation on GPU. Implements slower, non-vectorized version. Faster version still remains to be done. PiperOrigin-RevId: 240849148 --- .../xla/service/gpu/hlo_to_ir_bindings.cc | 14 ++- .../xla/service/gpu/instruction_fusion.cc | 5 ++ .../compiler/xla/service/gpu/ir_emitter.cc | 85 +++++++++++++++---- .../xla/service/gpu/ir_emitter_nested.cc | 83 +++++++++++++----- .../xla/service/gpu/ir_emitter_nested.h | 8 +- .../xla/service/gpu/ir_emitter_unnested.cc | 20 ++--- tensorflow/compiler/xla/tests/reduce_test.cc | 8 +- 7 files changed, 165 insertions(+), 58 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc index 3c50c2b1d8e..6e414bd7a4d 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc @@ -38,17 +38,21 @@ using absl::StrCat; void HloToIrBindings::EmitBasePointersForHlos( absl::Span io_hlos, absl::Span non_io_hlos) { - // I/O HLOs are bound to the arguments of the current IR function. I.e., + // I/O HLOs are bound to the arguments of the current IR function, + // *excluding* the output argument, which is added to non-I/O HLOs. + // I.e., // - // void IrFunction(io_0, io_1, ..., io_{m-1}, temp_buffer_base) { + // void IrFunction(io_0, io_1, ..., io_{m-1}, output_arg, temp_buffer_base) { llvm::Function* function = b_->GetInsertBlock()->getParent(); - CHECK_EQ(io_hlos.size() + 1, function->arg_size()); + CHECK_EQ(io_hlos.size() + 2, function->arg_size()); // An HLO can have duplicated operands. This data structure remembers which // operand HLOs are already bound to avoid rebinding the same HLO. absl::flat_hash_set already_bound_for_this_function; auto arg_iter = function->arg_begin(); for (const HloInstruction* io_hlo : io_hlos) { + CHECK(!absl::c_count(non_io_hlos, io_hlo)) + << "IO HLOs and non-IO HLOs should be disjoint"; if (!already_bound_for_this_function.contains(io_hlo)) { if (!is_nested_ && io_hlo->opcode() == HloOpcode::kGetTupleElement) { BindHloToIrValue(*io_hlo, EmitGetTupleElement(io_hlo, &*arg_iter)); @@ -60,6 +64,10 @@ void HloToIrBindings::EmitBasePointersForHlos( ++arg_iter; } + // Name and skip the output parameter. + arg_iter->setName("output_arg"); + ++arg_iter; + temp_buffer_base_ = &*arg_iter; temp_buffer_base_->setName("temp_buffer"); diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc index 38fd2f64ef4..1340340124e 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc @@ -256,6 +256,11 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, return false; } auto producer = consumer->operand(operand_index); + + // Don't fuse variadic reduce. + if (consumer->opcode() == HloOpcode::kReduce && consumer->shape().IsTuple()) { + return false; + } // The following checks are potentially expensive. if (FusionWouldBeTooLarge(consumer, producer)) { return false; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index f04e8241e57..a3fb1ce7307 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" +#include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" #include "tensorflow/compiler/xla/primitive_util.h" @@ -32,7 +33,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" @@ -157,8 +160,7 @@ Status IrEmitter::EmitCallToNestedComputation( if (emitted_function == nullptr) { IrEmitterNested ir_emitter_nested(hlo_module_config_, nested_computation, ir_emitter_context_); - TF_RETURN_IF_ERROR( - nested_computation.root_instruction()->Accept(&ir_emitter_nested)); + TF_RETURN_IF_ERROR(ir_emitter_nested.CodegenNestedComputation()); emitted_function = ir_emitter_nested.GetEmittedFunction(); } @@ -661,23 +663,38 @@ Status IrEmitter::HandleParameter(HloInstruction* parameter) { return Status::OK(); } -Status IrEmitter::HandleReduce(HloInstruction* reduce) { - // TODO(b/118332391): Support variadic reduce. - if (!reduce->shape().IsArray()) { - return Unimplemented("Variadic reduce is not supported on GPU"); +Status IrEmitter::HandleReduce(HloInstruction* instr) { + const HloReduceInstruction* reduce = Cast(instr); + const Shape& out_shape = reduce->shape(); + bool returns_tuple = !out_shape.IsArray(); + int accumulators_count = 1; + if (returns_tuple) { + CHECK(out_shape.IsTuple()); + accumulators_count = out_shape.tuple_shapes_size(); } + auto arg = reduce->operand(0); - auto init_value = reduce->operand(1); absl::Span dimensions(reduce->dimensions()); HloComputation* function = reduce->to_apply(); return EmitTargetElementLoop( *reduce, [=](const llvm_ir::IrArray::Index& index) -> StatusOr { - // Initialize an accumulator with init_value. - llvm::AllocaInst* accumulator_addr = - Alloca(llvm_ir::PrimitiveTypeToIrType( - reduce->shape().element_type(), module_)); - Store(Load(GetBasePointer(*init_value)), accumulator_addr); + std::vector accumulator_addrs; + std::vector accumulator_types; + + // Initialize accumulators with initial values. + for (int i = 0; i < accumulators_count; i++) { + auto init_value = reduce->init_values()[i]; + const Shape& element_shape = + returns_tuple ? out_shape.tuple_shapes(i) : out_shape; + PrimitiveType accumulator_type = element_shape.element_type(); + llvm::Type* accumulator_llvm_type = + llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_); + llvm::AllocaInst* accumulator_addr = Alloca(accumulator_llvm_type); + Store(Load(GetBasePointer(*init_value)), accumulator_addr); + accumulator_addrs.push_back(accumulator_addr); + accumulator_types.push_back(accumulator_llvm_type); + } // The enclosing loops go over all the target elements. Now we have to // compute the actual target element. For this, we build a new loop nest @@ -709,13 +726,49 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) { // Apply the reduction function to the loaded value. llvm_ir::IrArray::Index input_index(input_multi_index, arg->shape(), b_.getInt64Ty()); - llvm::Value* input_address = - GetIrArray(*arg, *reduce).EmitArrayElementAddress(input_index, &b_); + std::vector reduction_operands(accumulator_addrs.begin(), + accumulator_addrs.end()); + for (int i = 0; i < accumulators_count; i++) { + llvm::Value* input_address = + GetIrArray(*reduce->operand(i), *reduce) + .EmitArrayElementAddress(input_index, &b_); + reduction_operands.push_back(input_address); + } + + llvm::Value* ret_argument; + if (!returns_tuple) { + CHECK_EQ(accumulator_addrs.size(), 1); + ret_argument = accumulator_addrs[0]; + } else { + const Shape& return_shape = function->root_instruction()->shape(); + + llvm::Type* return_value_buffer_type = + llvm_ir::ShapeToIrType(return_shape, module_); + ret_argument = Alloca(return_value_buffer_type); + llvm_ir::IrArray tuple_array(ret_argument, return_shape); + EmitTuple(tuple_array, accumulator_addrs, &b_); + } + TF_RETURN_IF_ERROR(EmitCallToNestedComputation( - *function, {accumulator_addr, input_address}, accumulator_addr)); + *function, reduction_operands, ret_argument)); SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); - return Load(accumulator_addr); + + if (!returns_tuple) { + CHECK_EQ(accumulator_addrs.size(), 1); + return Load(accumulator_addrs[0]); + } else { + // Emit a struct for the LoopEmitter dealing with multi-output + // fusion. + llvm::Value* returned_structure = llvm::UndefValue::get( + llvm::StructType::get(b_.getContext(), accumulator_types)); + for (int i = 0; i < accumulators_count; i++) { + llvm::Value* accumulator_value = Load(accumulator_addrs[i]); + returned_structure = + b_.CreateInsertValue(returned_structure, accumulator_value, i); + } + return returned_structure; + } }); } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc index 8c02416eef4..bd8d72af581 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc @@ -38,20 +38,18 @@ namespace gpu { IrEmitterNested::IrEmitterNested(const HloModuleConfig& hlo_module_config, const HloComputation& nested_computation, IrEmitterContext* ir_emitter_context) - : IrEmitter(hlo_module_config, ir_emitter_context, /*is_nested=*/true) { - std::vector io_hlos; - emitted_function_ = - EmitBasePointersForNestedComputation(nested_computation, &io_hlos); -} + : IrEmitter(hlo_module_config, ir_emitter_context, /*is_nested=*/true), + nested_computation_(nested_computation) {} -llvm::Function* IrEmitterNested::EmitBasePointersForNestedComputation( - const HloComputation& nested_computation, - std::vector* io_hlos) { +// Nested function serves the same purpose on GPU as a thread-local function on +// a CPU. +Status IrEmitterNested::CodegenNestedComputation() { + std::vector io_hlos; std::vector argument_types; std::vector argument_dereferenceable_bytes; for (const HloInstruction* param : - nested_computation.parameter_instructions()) { - io_hlos->push_back(param); + nested_computation_.parameter_instructions()) { + io_hlos.push_back(param); const Shape& param_shape = param->shape(); argument_types.push_back( llvm_ir::ShapeToIrType(param_shape, module_)->getPointerTo()); @@ -59,9 +57,9 @@ llvm::Function* IrEmitterNested::EmitBasePointersForNestedComputation( llvm_ir::ByteSizeOf(param_shape, module_->getDataLayout()); argument_dereferenceable_bytes.push_back(param_size); } + + const HloInstruction* root = nested_computation_.root_instruction(); { - const HloInstruction* root = nested_computation.root_instruction(); - io_hlos->push_back(root); const Shape& root_shape = root->shape(); argument_types.push_back( llvm_ir::ShapeToIrType(root_shape, module_)->getPointerTo()); @@ -79,8 +77,8 @@ llvm::Function* IrEmitterNested::EmitBasePointersForNestedComputation( llvm::GlobalValue::InternalLinkage, // The linkage type. ir_emitter_context_->name_uniquer()->GetUniqueName( llvm_ir::SanitizeFunctionName( - nested_computation.name())), // The name of the function. - ir_emitter_context_->llvm_module()); // The parent LLVM module. + nested_computation_.name())), // The name of the function. + ir_emitter_context_->llvm_module()); // The parent LLVM module. for (size_t arg_no = 0; arg_no < argument_dereferenceable_bytes.size(); ++arg_no) { int64 arg_size = argument_dereferenceable_bytes[arg_no]; @@ -96,17 +94,62 @@ llvm::Function* IrEmitterNested::EmitBasePointersForNestedComputation( llvm::BasicBlock::Create(function->getContext(), "entry", function); // Emit a "return void" at entry_bb's end, and sets the insert point before // that return instruction. - b_.SetInsertPoint(llvm::ReturnInst::Create(function->getContext(), entry_bb)); + llvm::ReturnInst* ret_instr = + llvm::ReturnInst::Create(function->getContext(), entry_bb); + b_.SetInsertPoint(ret_instr); std::vector non_io_hlos; - for (const auto* hlo : nested_computation.instructions()) { + non_io_hlos.push_back(root); + for (const auto* hlo : nested_computation_.instructions()) { if (hlo->opcode() != HloOpcode::kParameter && - hlo != nested_computation.root_instruction()) { + hlo != nested_computation_.root_instruction()) { non_io_hlos.push_back(hlo); } } - bindings_.EmitBasePointersForHlos(*io_hlos, non_io_hlos); - return function; + bindings_.EmitBasePointersForHlos(io_hlos, non_io_hlos); + + TF_RETURN_IF_ERROR(nested_computation_.root_instruction()->Accept(this)); + b_.SetInsertPoint(ret_instr); + + // Function epilogue: copy the output value back. + { + // TODO(cheshire) Duplication vs. EmitThreadLocalFunctionEpilogue + const HloInstruction* root_instruction = + nested_computation_.root_instruction(); + llvm::Value* root_value = bindings_.GetBasePointer(*root_instruction); + const Shape& return_shape = root_instruction->shape(); + + // Second last argument is the out parameter. + llvm::Argument* out_parameter = std::prev(function->arg_end(), 2); + + if (ShapeUtil::IsScalar(return_shape)) { + llvm::Value* ret_value = Load(root_value, "load_ret_value"); + Store(ret_value, + BitCast(out_parameter, root_value->getType(), "bitcast_ret_value"), + "store_ret_value"); + } else { + CHECK(return_shape.IsTuple()); + llvm::Type* tuple_type = llvm_ir::ShapeToIrType(return_shape, module_); + llvm::Type* tuple_type_ptr = tuple_type->getPointerTo(); + llvm::Value* tuple_ptr = BitCast(out_parameter, tuple_type_ptr); + + for (int i = 0; i < return_shape.tuple_shapes_size(); i++) { + const Shape& element_shape = return_shape.tuple_shapes(i); + llvm::Value* destination = + llvm_ir::EmitGetTupleElement(element_shape, + /*index=*/i, + /*alignment=*/1, tuple_ptr, &b_); + llvm::Value* source = + llvm_ir::EmitGetTupleElement(element_shape, + /*index=*/i, + /*alignment=*/1, root_value, &b_); + Store(Load(source), destination); + } + } + } + b_.SetInsertPoint(ret_instr); + emitted_function_ = function; + return Status::OK(); } Status IrEmitterNested::HandleParameter(HloInstruction* parameter) { @@ -118,7 +161,7 @@ Status IrEmitterNested::EmitTargetElementLoop( const llvm_ir::ElementGenerator& element_generator) { // For MOF we give the loop emitter an array for every output it should // generate. - if (hlo.IsMultiOutputFusion()) { + if (hlo.shape().IsTuple()) { std::vector target_arrays = ConstructIrArrayForOutputs(hlo); TF_RETURN_IF_ERROR( diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h index ca11cf2c182..ce825851bcc 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h @@ -58,11 +58,11 @@ class IrEmitterNested : public IrEmitter { const HloInstruction& hlo, const llvm_ir::ElementGenerator& body_emitter) override; - private: - llvm::Function* EmitBasePointersForNestedComputation( - const HloComputation& nested_computation, - std::vector* io_hlos); + // Generate the code for the computation passed in the constructor. + Status CodegenNestedComputation(); + private: + const HloComputation& nested_computation_; llvm::Function* emitted_function_; }; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index f1ab8fe5d97..317ca2184ff 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -632,8 +632,9 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { // a 1D array. The specialized version requires a initializer thunk that // initializes the output array to the initial value of the reduce. if (root->opcode() == HloOpcode::kReduce && root->shape().IsTuple()) { - // TODO(b/118332391): Support variadic reduce. - return Unimplemented("Variadic reduce is not supported on GPU"); + // TODO(b/129089333): Support tiled vectorized variadic reduce. + return Unimplemented( + "Vectorized variadic reduce is not supported on GPU"); } return EmitReductionToVector(fusion); } @@ -722,11 +723,7 @@ Status IrEmitterUnnested::EmitExtraOutputsForReduce( } Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) { - // TODO(b/118332391): Support multi-output reduce. - if (!reduce->shape().IsArray()) { - return Unimplemented("Multi-output reduce is not supported on GPU"); - } - if (IsReductionToVector(*reduce)) { + if (IsReductionToVector(*reduce) && reduce->shape().IsArray()) { return EmitReductionToVector(reduce); } @@ -2179,9 +2176,10 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( int unroll_factor = thunk->unroll_factor(); VLOG(3) << bindings_.ToString(); - const Shape& element_shape = hlo.IsMultiOutputFusion() - ? ShapeUtil::GetSubshape(hlo.shape(), {0}) - : hlo.shape(); + bool multi_output = hlo.shape().IsTuple(); + + const Shape& element_shape = + multi_output ? ShapeUtil::GetSubshape(hlo.shape(), {0}) : hlo.shape(); VLOG(3) << "EmitTargetElementLoopInThunk " << ShapeUtil::HumanStringWithLayout(hlo.shape()) << " for unroll_factor " << unroll_factor; @@ -2189,7 +2187,7 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( element_shape, ir_emitter_context_->device_description(), unroll_factor); UpdateLaunchDimensions(launch_dimensions, thunk, ir_emitter_context_->llvm_module()); - if (!hlo.IsMultiOutputFusion()) { + if (!multi_output) { return ParallelLoopEmitter(element_generator, GetIrArray(hlo, hlo), launch_dimensions, &b_, unroll_factor) .EmitLoop( diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc index 18c4e315033..1fcb2120072 100644 --- a/tensorflow/compiler/xla/tests/reduce_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_test.cc @@ -1038,7 +1038,7 @@ XLA_TEST_F(ReduceHloTest, HandleReductionToVectorAndOtherReduction) { class VariadicReduceTest : public HloTestBase {}; -XLA_TEST_F(VariadicReduceTest, DISABLED_ON_GPU(Reduce_R3x2_to_R2x2_simple)) { +XLA_TEST_F(VariadicReduceTest, Reduce_R3x2_to_R2x2_simple) { absl::string_view hlo_string = R"( HloModule Reduce_R3x2_to_R1x2_simple @@ -1066,7 +1066,7 @@ XLA_TEST_F(VariadicReduceTest, DISABLED_ON_GPU(Reduce_R3x2_to_R2x2_simple)) { EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-5, 1e-5})); } -XLA_TEST_F(VariadicReduceTest, DISABLED_ON_GPU(Reduce_R3x2_to_R1x2_simple)) { +XLA_TEST_F(VariadicReduceTest, Reduce_R3x2_to_R1x2_simple) { absl::string_view hlo_string = R"( HloModule Reduce_R3x2_to_R1x2_simple @@ -1094,7 +1094,7 @@ XLA_TEST_F(VariadicReduceTest, DISABLED_ON_GPU(Reduce_R3x2_to_R1x2_simple)) { EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-5, 1e-5})); } -XLA_TEST_F(VariadicReduceTest, DISABLED_ON_GPU(Reduce_R1x2_to_R0x2_simple)) { +XLA_TEST_F(VariadicReduceTest, Reduce_R1x2_to_R0x2_simple) { absl::string_view hlo_string = R"( HloModule Reduce_R1x2_to_R0x2_simple @@ -1122,7 +1122,7 @@ XLA_TEST_F(VariadicReduceTest, DISABLED_ON_GPU(Reduce_R1x2_to_R0x2_simple)) { EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-5, 1e-5})); } -XLA_TEST_F(VariadicReduceTest, DISABLED_ON_GPU(Reduce_R1x2_to_R0x2_argmax)) { +XLA_TEST_F(VariadicReduceTest, Reduce_R1x2_to_R0x2_argmax) { absl::string_view hlo_string = R"( HloModule Reduce_R1x2_to_R0x2_argmax