From 803b0f51519e48fa05331951cd90eb4773ac0bdd Mon Sep 17 00:00:00 2001 From: George Karpenkov Date: Mon, 18 Mar 2019 11:12:23 -0700 Subject: [PATCH] Variadic reduce implementation on CPU Implements variadic reduce on the CPU backend. Before this change, thread-local functions could only return scalars. The biggest part of this change is allowing thread-local functions to return tuples of scalars, which required changes to function generation, allocation of space for the returned values on the callers side, and generating the function epilogue. PiperOrigin-RevId: 239022395 --- .../xla/service/cpu/elemental_ir_emitter.cc | 16 +- .../compiler/xla/service/cpu/ir_emitter.cc | 233 +++++++++++++----- .../compiler/xla/service/cpu/ir_emitter.h | 24 +- .../xla/service/llvm_ir/loop_emitter.cc | 6 +- .../compiler/xla/service/llvm_ir/tuple_ops.cc | 26 ++ .../compiler/xla/service/llvm_ir/tuple_ops.h | 6 + tensorflow/compiler/xla/tests/BUILD | 1 + tensorflow/compiler/xla/tests/reduce_test.cc | 96 ++++++++ 8 files changed, 340 insertions(+), 68 deletions(-) diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc index fb021f277b0..e21ca01c803 100644 --- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc @@ -20,6 +20,7 @@ limitations under the License. #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" @@ -135,10 +136,19 @@ llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator( }; case HloOpcode::kReduce: return [this, hlo, &operand_to_generator](const IrArray::Index& index) { + auto reduce_instr = Cast(hlo); + std::vector input_generators; + for (const HloInstruction* instr : reduce_instr->inputs()) { + input_generators.push_back(operand_to_generator.at(instr)); + } + + std::vector initial_value_generators; + for (const HloInstruction* instr : reduce_instr->init_values()) { + initial_value_generators.push_back(operand_to_generator.at(instr)); + } return ir_emitter_->EmitElementalReduce( - Cast(hlo), - operand_to_generator.at(hlo->operand(0)), - operand_to_generator.at(hlo->operand(1)), index); + reduce_instr, std::move(input_generators), + std::move(initial_value_generators), index); }; default: return ElementalIrEmitter::MakeElementGenerator(hlo, diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 55b6a3558fc..ee1431203d2 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -58,8 +58,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" +#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -106,6 +108,42 @@ IrEmitter::IrEmitter( TF_CHECK_OK(s) << "Should have failed buffer assignment."; } +void IrEmitter::EmitThreadLocalFunctionEpilogue(HloComputation* computation) { + llvm::Argument* out_parameter = compute_function_->result_arg(); + llvm_ir::IrArray root_value = GetIrArrayFor(computation->root_instruction()); + const Shape& return_shape = computation->root_instruction()->shape(); + + if (ShapeUtil::IsScalar(return_shape)) { + llvm::Value* ret_value = + Load(root_value.GetBasePointer(), "load_ret_value"); + Store(ret_value, + BitCast(out_parameter, root_value.GetBasePointer()->getType())); + } else { + CHECK(return_shape.IsTuple()); + + llvm::Type* tuple_type = llvm_ir::ShapeToIrType(return_shape, module_); + llvm::Type* tuple_type_lvalue = tuple_type->getPointerTo(); + llvm::Value* tuple_lvalue = BitCast(out_parameter, tuple_type_lvalue); + + for (int i = 0; i < return_shape.tuple_shapes_size(); i++) { + const Shape& element_shape = return_shape.tuple_shapes(i); + llvm::Value* destination = llvm_ir::EmitGetTupleElement( + element_shape, + /*index=*/i, + /*alignment=*/MinimumAlignmentForShape(element_shape), tuple_lvalue, + &b_); + + llvm::Value* source = llvm_ir::EmitGetTupleElement( + element_shape, + /*index=*/i, + /*alignment=*/MinimumAlignmentForShape(element_shape), + root_value.GetBasePointer(), &b_); + + Store(Load(source), destination); + } + } +} + StatusOr IrEmitter::EmitComputation( HloComputation* computation, const string& function_name_prefix, bool is_top_level_computation, @@ -143,6 +181,16 @@ StatusOr IrEmitter::EmitComputation( InsertOrDie(&emitted_functions_, computation, ir_function); // Delete 'compute_function', finalizing 'ir_function' and restoring caller // IR insert point. + + // Function epilogue: copying the value over to either the return register, + // or values pointing from the return register. + const BufferAllocation* root_allocation = + computation_root_allocation_.allocation(); + if (root_allocation && root_allocation->is_thread_local()) { + EmitThreadLocalFunctionEpilogue(computation); + } + + // Destructor for compute_function_ emits the "ret void" instruction. compute_function_.reset(); computation_root_allocation_ = BufferAllocation::Slice(); computation_parameter_allocations_.clear(); @@ -634,7 +682,8 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) { llvm::Value* IrEmitter::EmitElementalMap( const HloMapInstruction& map_instr, absl::Span elemental_operands, absl::string_view name) { - return EmitThreadLocalCall(*map_instr.to_apply(), elemental_operands, name); + return EmitScalarReturningThreadLocalCall(*map_instr.to_apply(), + elemental_operands, name); } StatusOr IrEmitter::EmitElementalReduceWindow( @@ -716,7 +765,7 @@ StatusOr IrEmitter::EmitElementalReduceWindow( b_.getInt64Ty()); TF_ASSIGN_OR_RETURN(llvm::Value* const input_value, input_generator(input_index)); - llvm::Value* result = EmitThreadLocalCall( + llvm::Value* result = EmitScalarReturningThreadLocalCall( *reduce_window->to_apply(), {Load(accumulator_address), input_value}, "reducer_function"); Store(result, accumulator_address); @@ -868,7 +917,7 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { llvm::Value* operand_address = operand_array.EmitArrayElementAddress(operand_index, &b_); llvm::Value* operand_element = Load(operand_address); - llvm::Value* result = EmitThreadLocalCall( + llvm::Value* result = EmitScalarReturningThreadLocalCall( *select_and_scatter->select(), {Load(selected_value_address), operand_element}, "select_function"); @@ -903,9 +952,9 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { selected_multi_index, output_array.GetShape(), source_index.GetType()); llvm::Value* output_value = output_array.EmitReadArrayElement(selected_index, &b_); - llvm::Value* scatter_value = - EmitThreadLocalCall(*select_and_scatter->scatter(), - {output_value, source_value}, "scatter_function"); + llvm::Value* scatter_value = EmitScalarReturningThreadLocalCall( + *select_and_scatter->scatter(), {output_value, source_value}, + "scatter_function"); output_array.EmitWriteArrayElement(selected_index, scatter_value, &b_); SetToFirstInsertPoint(source_loops.GetOuterLoopExitBasicBlock(), &b_); @@ -1665,6 +1714,11 @@ StatusOr IrEmitter::EmitVectorizedReduce( HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value, absl::Span dimensions, HloComputation* function, string* failure_reason) { + if (!reduce->shape().IsArray()) { + *failure_reason = "vectorization of variadic reduce not implemented"; + return false; + } + if (!ReductionPreservesLayout(*reduce)) { return false; } @@ -1813,21 +1867,39 @@ StatusOr IrEmitter::EmitVectorizedReduce( StatusOr IrEmitter::EmitElementalReduce( const HloReduceInstruction* reduce, - const llvm_ir::ElementGenerator& input_generator, - const llvm_ir::ElementGenerator& initial_value_generator, + std::vector input_generators, + std::vector initial_value_generators, const llvm_ir::IrArray::Index& index) { - const HloInstruction* arg = reduce->operand(0); - absl::Span dimensions(reduce->dimensions()); + const Shape& out_shape = reduce->shape(); + bool is_variadic = !out_shape.IsArray(); + int accumulators_count = 1; + if (is_variadic) { + CHECK(out_shape.IsTuple()); + accumulators_count = out_shape.tuple_shapes_size(); + } - // Initialize an accumulator with init_value. - PrimitiveType accumulator_type = reduce->shape().element_type(); - llvm::AllocaInst* accumulator_addr = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_), "accumulator", - &b_, MinimumAlignmentForPrimitiveType(accumulator_type)); - TF_ASSIGN_OR_RETURN( - llvm::Value* const init_value, - initial_value_generator(llvm_ir::IrArray::Index(index.GetType()))); - Store(init_value, accumulator_addr); + absl::Span reduced_dimensions(reduce->dimensions()); + + std::vector accumulator_addrs; + std::vector accumulator_types; + for (int i = 0; i < accumulators_count; i++) { + const Shape& element_shape = + is_variadic ? out_shape.tuple_shapes(i) : out_shape; + PrimitiveType accumulator_type = element_shape.element_type(); + llvm::Type* accumulator_llvm_type = + llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_); + accumulator_types.push_back(accumulator_llvm_type); + + // Initialize an accumulator with init_value. + llvm::AllocaInst* accumulator_addr = llvm_ir::EmitAllocaAtFunctionEntry( + accumulator_llvm_type, "accumulator_" + std::to_string(i), &b_, + MinimumAlignmentForPrimitiveType(accumulator_type)); + TF_ASSIGN_OR_RETURN( + llvm::Value* const init_value, + initial_value_generators[i](llvm_ir::IrArray::Index(index.GetType()))); + Store(init_value, accumulator_addr); + accumulator_addrs.push_back(accumulator_addr); + } // The enclosing loops go over all the target elements. Now we have to compute // the actual target element. For this, we build a new loop nest to iterate @@ -1835,14 +1907,15 @@ StatusOr IrEmitter::EmitElementalReduce( // AddLoopsForShapeOnDimensions will return an Index where induction Value*s // are placed for each dimension in dimensions, and all the rest are nullptrs. llvm_ir::ForLoopNest loops(IrName(reduce, "inner"), &b_); + const HloInstruction* arg = reduce->operand(0); std::vector input_multi_index = - loops.AddLoopsForShapeOnDimensions(arg->shape(), dimensions, + loops.AddLoopsForShapeOnDimensions(arg->shape(), reduced_dimensions, "reduction_dim"); SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_); - // Build a full index for the input argument, using reduced_dims_index as the - // base. In reduced_dims_index only the reduction dimensions are filled in. We + // Build a full index for the input argument, using input_multi_index as the + // base. In input_multi_index only the reduction dimensions are filled in. We // fill in the rest of the dimensions with induction Value*s taken from // 'index' which iterates over the target array. See the high-level // description in the XLA documentation for details. @@ -1857,23 +1930,44 @@ StatusOr IrEmitter::EmitElementalReduce( llvm_ir::IrArray::Index input_index(input_multi_index, arg->shape(), b_.getInt64Ty()); - // Apply the reduction function to the loaded value. - TF_ASSIGN_OR_RETURN(llvm::Value* const input_element, - input_generator(input_index)); - llvm::Value* result = EmitThreadLocalCall( - *reduce->to_apply(), {Load(accumulator_addr), input_element}, - "reduce_function"); - Store(result, accumulator_addr); + std::vector reduction_operands; + for (llvm::Value* accum : accumulator_addrs) { + llvm::Value* accum_value = Load(accum); + reduction_operands.push_back(accum_value); + } + for (int i = 0; i < accumulators_count; i++) { + TF_ASSIGN_OR_RETURN(llvm::Value* const input_element, + input_generators[i](input_index)); + reduction_operands.push_back(input_element); + } + + std::vector results = EmitThreadLocalCall( + *reduce->to_apply(), reduction_operands, "reduce_function"); + + CHECK(results.size() == accumulators_count); + for (int i = 0; i < accumulators_count; i++) { + Store(results[i], accumulator_addrs[i]); + } SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); - return Load(accumulator_addr); + + if (is_variadic) { + // Emit a structure, as that what the LoopEmitter expects. + llvm::Value* returned_structure = llvm::UndefValue::get( + llvm::StructType::get(b_.getContext(), accumulator_types)); + for (int i = 0; i < accumulators_count; i++) { + llvm::Value* accumulator_value = Load(accumulator_addrs[i]); + returned_structure = + b_.CreateInsertValue(returned_structure, accumulator_value, i); + } + return returned_structure; + } else { + CHECK_EQ(accumulator_addrs.size(), 1); + return Load(accumulator_addrs[0]); + } } Status IrEmitter::HandleReduce(HloInstruction* reduce) { - // TODO(b/118333695): Support variadic reduce. - if (!reduce->shape().IsArray()) { - return Unimplemented("Variadic reduce is not supported on CPU"); - } auto arg = reduce->mutable_operand(0); auto init_value = reduce->mutable_operand(1); absl::Span dimensions(reduce->dimensions()); @@ -2848,15 +2942,6 @@ llvm::Value* IrEmitter::EmitThreadLocalBufferPointer( const BufferAllocation::Slice& slice, const Shape& target_shape) { const BufferAllocation& allocation = *slice.allocation(); llvm::Value* tempbuf_address = [&]() -> llvm::Value* { - if (slice == computation_root_allocation_) { - llvm::Argument* retval = compute_function_->result_arg(); - llvm::AttrBuilder attr_builder; - attr_builder.addAlignmentAttr(MinimumAlignmentForShape(target_shape)); - attr_builder.addDereferenceableAttr(ByteSizeOf(target_shape)); - retval->addAttrs(attr_builder); - return retval; - } - auto param_it = computation_parameter_allocations_.find(slice.allocation()->index()); if (param_it != computation_parameter_allocations_.end()) { @@ -2966,7 +3051,8 @@ Status IrEmitter::EmitTargetElementLoop( TF_RETURN_IF_ERROR(EmitTargetAddressForOp(target_op)); llvm_ir::IrArray target_array = GetIrArrayFor(target_op); - if (target_op->IsMultiOutputFusion()) { + if (target_shape.IsTuple() && (target_op->IsMultiOutputFusion() || + target_op->opcode() == HloOpcode::kReduce)) { // For multiple outputs fusion, we need to emit each operand and the root. TF_RET_CHECK(num_dynamic_loop_bounds_ == 0); std::vector output_arrays; @@ -3048,19 +3134,27 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) { hlo, elemental_emitter.MakeElementGenerator(hlo, operand_to_generator)); } -llvm::Value* IrEmitter::EmitThreadLocalCall( +llvm::Value* IrEmitter::EmitScalarReturningThreadLocalCall( + const HloComputation& callee, absl::Span parameters, + absl::string_view name) { + std::vector return_value = + EmitThreadLocalCall(callee, parameters, name); + CHECK_EQ(return_value.size(), 1); + return return_value[0]; +} + +std::vector IrEmitter::EmitThreadLocalCall( const HloComputation& callee, absl::Span parameters, absl::string_view name) { CHECK(absl::c_binary_search(thread_local_computations_, &callee)); - const Shape& return_shape = callee.root_instruction()->shape(); - - // Lifting this restriction to allow "small" arrays should be easy. Allowing - // larger arrays is difficult because we allocate the buffer for this return - // value on the stack. - CHECK(ShapeUtil::IsScalar(return_shape)); - - PrimitiveType return_type = return_shape.element_type(); + bool is_scalar_return = ShapeUtil::IsScalar(return_shape); + bool is_tuple_of_scalars_return = + return_shape.IsTuple() && + absl::c_all_of(return_shape.tuple_shapes(), [&](const Shape& shape) { + return ShapeUtil::IsScalar(shape); + }); + CHECK(is_scalar_return || is_tuple_of_scalars_return); std::vector parameter_addrs; for (llvm::Value* parameter : parameters) { @@ -3071,10 +3165,30 @@ llvm::Value* IrEmitter::EmitThreadLocalCall( parameter_addrs.push_back(parameter_addr); } + llvm::Type* return_value_buffer_type = + llvm_ir::ShapeToIrType(return_shape, module_); + std::string retval_alloca_name = absl::StrCat(name, "_return_value_addr"); + int retval_alignment = + is_scalar_return + ? MinimumAlignmentForPrimitiveType(return_shape.element_type()) + : 0; llvm::Value* return_value_buffer = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(return_type, module_), - absl::StrCat(name, "_retval_addr"), &b_, - MinimumAlignmentForPrimitiveType(return_type)); + return_value_buffer_type, retval_alloca_name, &b_, retval_alignment); + + std::vector allocas_for_returned_scalars; + if (is_scalar_return) { + allocas_for_returned_scalars.push_back(return_value_buffer); + } else { + constexpr int max_tuple_size = 1000; + CHECK_LT(return_shape.tuple_shapes_size(), max_tuple_size) + << "Multivalue function can not return more than 1000 elements to avoid" + << " stack smashing"; + allocas_for_returned_scalars = + llvm_ir::EmitTupleAllocasAtFunctionEntry(return_shape, &b_); + llvm_ir::IrArray tuple_array(return_value_buffer, return_shape); + + EmitTuple(tuple_array, allocas_for_returned_scalars, &b_); + } Call(FindOrDie(emitted_functions_, &callee), GetArrayFunctionCallArguments( @@ -3085,7 +3199,12 @@ llvm::Value* IrEmitter::EmitThreadLocalCall( llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo()), /*profile_counters_arg=*/GetProfileCountersArgument())); - return Load(return_value_buffer); + std::vector returned_scalars; + returned_scalars.reserve(allocas_for_returned_scalars.size()); + for (llvm::Value* addr : allocas_for_returned_scalars) { + returned_scalars.push_back(Load(addr)); + } + return returned_scalars; } void IrEmitter::EmitGlobalCall(const HloComputation& callee, diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 478e0be8149..2cd624d70d5 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -132,8 +132,8 @@ class IrEmitter : public DfsHloVisitorWithDefault, // Emit code to emit the element at `index` for a reduce instruction. StatusOr EmitElementalReduce( const HloReduceInstruction* reduce, - const llvm_ir::ElementGenerator& input_generator, - const llvm_ir::ElementGenerator& initial_value_generator, + std::vector input_generators, + std::vector initial_value_generator, const llvm_ir::IrArray::Index& index); protected: @@ -197,6 +197,14 @@ class IrEmitter : public DfsHloVisitorWithDefault, // Private helper to initialize an IR function for the computation. void InitializeIrFunction(const string& function_name); + // Emits the copying epilogue for the function, + // where it copies the returned value to the reserved alloca. + // This is only necessary for thread-local functions. + // Note that since the call graph is flattened, if the same function is + // called in both thread-local and non-thread-local it would be codegen'd + // twice, and we would know whether it's thread-local at codegen time. + void EmitThreadLocalFunctionEpilogue(HloComputation* computation); + // Convenience functions to generate a GEP into the profile counter parameter // which would correspond to the index for a given HLO instruction or // computation. @@ -267,12 +275,18 @@ class IrEmitter : public DfsHloVisitorWithDefault, // Emits a call to a thread local function (e.g. to the computation nested // within a reduce or a map). Thread local callees (by definition) only write // to and read from thread local allocations. + // Supports only functions returning scalars or tuples of scalars. // // `parameters` holds the *scalar values* that need to be passed to the // callee. The return value is the scalar returned by the callee. - llvm::Value* EmitThreadLocalCall(const HloComputation& callee, - absl::Span parameters, - absl::string_view name); + std::vector EmitThreadLocalCall( + const HloComputation& callee, absl::Span parameters, + absl::string_view name); + + // Similar to EmitThreadLocal, yet assumes that the function returns a scalar. + llvm::Value* EmitScalarReturningThreadLocalCall( + const HloComputation& callee, absl::Span parameters, + absl::string_view name); // Emits a call to a "global" function (e.g. to the computation nested within // a kWhile or a kCall). Buffer assignment unabiguously assignes buffers to diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc index 638ff2c6e0b..83be4334269 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc @@ -47,14 +47,14 @@ LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator, shape_(target_array.GetShape()), b_(b) {} -static LoopEmitter::BodyEmitter MakeBodyEmitterForMultiOutputFusion( +static LoopEmitter::BodyEmitter MakeBodyEmitterForMultiOutput( const ElementGenerator& target_element_generator, const std::vector& target_arrays, llvm::IRBuilder<>* b) { return [=](const llvm_ir::IrArray::Index array_index) { TF_ASSIGN_OR_RETURN(llvm::Value * target_element, target_element_generator(array_index)); CHECK(target_element->getType()->isStructTy()) - << "This BodyEmitter is for multi-output fusion, but target element " + << "This BodyEmitter is for multi-output, but target element " "generator does not produce values of struct type."; CHECK_EQ(target_element->getType()->getStructNumElements(), target_arrays.size()); @@ -70,7 +70,7 @@ static LoopEmitter::BodyEmitter MakeBodyEmitterForMultiOutputFusion( LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator, absl::Span target_arrays, llvm::IRBuilder<>* b) - : body_emitter_(MakeBodyEmitterForMultiOutputFusion( + : body_emitter_(MakeBodyEmitterForMultiOutput( target_element_generator, std::vector(target_arrays.begin(), target_arrays.end()), b)), shape_(target_arrays[0].GetShape()), diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc index 3a4814b1857..e00b93e973c 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc @@ -91,6 +91,32 @@ void EmitTuple(const IrArray& tuple, absl::Span buffers, llvm_ir::EmitTuple(tuple, buffer_ptrs, b); } +std::vector EmitTupleAllocasAtFunctionEntry( + const Shape& tuple_shape, llvm::IRBuilder<>* b) { + llvm::Module* module = b->GetInsertBlock()->getModule(); + + llvm::IRBuilder<>::InsertPointGuard guard(*b); + llvm::Function* function = b->GetInsertBlock()->getParent(); + b->SetInsertPoint(&function->getEntryBlock(), + function->getEntryBlock().getFirstInsertionPt()); + CHECK(tuple_shape.IsTuple()); + int tuple_size = tuple_shape.tuple_shapes_size(); + + std::vector generated_allocas; + for (int i = 0; i < tuple_size; i++) { + const Shape& element_shape = tuple_shape.tuple_shapes(i); + CHECK(ShapeUtil::IsScalar(element_shape)); + llvm::Type* type = + llvm_ir::PrimitiveTypeToIrType(element_shape.element_type(), module); + llvm::AllocaInst* alloca = b->CreateAlloca( + type, + /*ArraySize=*/nullptr, AsStringRef(absl::StrCat("tuple_element_", i))); + generated_allocas.push_back(alloca); + } + + return generated_allocas; +} + llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index, int alignment, llvm::Value* operand, llvm::IRBuilder<>* b) { diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h index 67d6323aba2..1e173801139 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h +++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h @@ -68,6 +68,12 @@ void EmitTupleSelect(const IrArray& select, const IrArray& pred, void EmitTuple(const IrArray& tuple, absl::Span operands, llvm::IRBuilder<>* b); +// Emits one alloca for each element in the tuple of shape tuple_shape, +// returns the emitted allocas. +// Precondition: tuple_shape should be a tuple of scalars. +std::vector EmitTupleAllocasAtFunctionEntry( + const Shape& tuple_shape, llvm::IRBuilder<>* b); + // Similar to EmitTuple above, except that the output buffers are provided in // the form of IrArray. void EmitTuple(const IrArray& tuple, absl::Span buffers, diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index ab875c1c4bc..8c54d6d83ed 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -1168,6 +1168,7 @@ xla_test( "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc index 6d2c2fc79ce..88cdb7e5928 100644 --- a/tensorflow/compiler/xla/tests/reduce_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_test.cc @@ -51,6 +51,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/util.h" @@ -1002,5 +1003,100 @@ XLA_TEST_F(ReduceTest, R0ReduceInDisguise) { ErrorSpec(0.001)); } +class VariadicReduceTest : public HloTestBase {}; + +XLA_TEST_F(VariadicReduceTest, DISABLED_ON_GPU(Reduce_R3x2_to_R1x2_simple)) { + absl::string_view hlo_string = R"( + HloModule Reduce_R3x2_to_R1x2_simple + + add { + op1 = f32[] parameter(0) + op2 = f32[] parameter(1) + acc1 = f32[] parameter(2) + acc2 = f32[] parameter(3) + out1 = f32[] add(acc1, op1) + out2 = f32[] add(acc2, op2) + ROOT result = (f32[], f32[]) tuple(out1, out2) + } + + ENTRY main { + inp1 = f32[10,20,3] parameter(0) + inp2 = f32[10,20,3] parameter(1) + zero = f32[] constant(0) + + ROOT out = (f32[10], f32[10]) reduce(inp1, inp2, zero, zero), + dimensions={1,2}, + to_apply=add + } +)"; + + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-5, 1e-5})); +} + +XLA_TEST_F(VariadicReduceTest, DISABLED_ON_GPU(Reduce_R1x2_to_R0x2_simple)) { + absl::string_view hlo_string = R"( + HloModule Reduce_R1x2_to_R0x2_simple + + add { + op1 = f32[] parameter(0) + op2 = f32[] parameter(1) + acc1 = f32[] parameter(2) + acc2 = f32[] parameter(3) + out1 = f32[] add(acc1, op1) + out2 = f32[] add(acc2, op2) + ROOT result = (f32[], f32[]) tuple(out1, out2) + } + + ENTRY main { + inp1 = f32[100] parameter(0) + inp2 = f32[100] parameter(1) + zero = f32[] constant(0) + + ROOT out = (f32[], f32[]) reduce(inp1, inp2, zero, zero), + dimensions={0}, + to_apply=add + } +)"; + + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-5, 1e-5})); +} + +XLA_TEST_F(VariadicReduceTest, DISABLED_ON_GPU(Reduce_R1x2_to_R0x2_argmax)) { + absl::string_view hlo_string = R"( + HloModule Reduce_R1x2_to_R0x2_argmax + + argmax { + running_max = u32[] parameter(0) + running_max_idx = u32[] parameter(1) + current_value = u32[] parameter(2) + current_value_idx = u32[] parameter(3) + + current = (u32[], u32[]) tuple(running_max, running_max_idx) + potential = (u32[], u32[]) tuple(current_value, current_value_idx) + + cmp_code = pred[] compare(current_value, running_max), direction=GT + + new_max = u32[] select(cmp_code, current_value, running_max) + new_idx = u32[] select(cmp_code, current_value_idx, running_max_idx) + + ROOT out = (u32[], u32[]) tuple(new_max, new_idx) + } + + ENTRY main { + input = u32[100] parameter(0) + idxs = u32[100]{0} iota(), iota_dimension=0 + zero = u32[] constant(0) + zero_idx = u32[] constant(0) + + ROOT out = (u32[], u32[]) reduce( + input, idxs, zero, zero_idx), + dimensions={0}, + to_apply=%argmax + } +)"; + + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-5, 1e-5})); +} + } // namespace } // namespace xla