Variadic reduce implementation on GPU.

Implements slower, non-vectorized version.
Faster version still remains to be done.

PiperOrigin-RevId: 240849148
This commit is contained in:
George Karpenkov 2019-03-28 14:09:22 -07:00 committed by TensorFlower Gardener
parent fb772b781b
commit b651a2cb5a
7 changed files with 165 additions and 58 deletions

View File

@ -38,17 +38,21 @@ using absl::StrCat;
void HloToIrBindings::EmitBasePointersForHlos(
absl::Span<const HloInstruction* const> io_hlos,
absl::Span<const HloInstruction* const> non_io_hlos) {
// I/O HLOs are bound to the arguments of the current IR function. I.e.,
// I/O HLOs are bound to the arguments of the current IR function,
// *excluding* the output argument, which is added to non-I/O HLOs.
// I.e.,
//
// void IrFunction(io_0, io_1, ..., io_{m-1}, temp_buffer_base) {
// void IrFunction(io_0, io_1, ..., io_{m-1}, output_arg, temp_buffer_base) {
llvm::Function* function = b_->GetInsertBlock()->getParent();
CHECK_EQ(io_hlos.size() + 1, function->arg_size());
CHECK_EQ(io_hlos.size() + 2, function->arg_size());
// An HLO can have duplicated operands. This data structure remembers which
// operand HLOs are already bound to avoid rebinding the same HLO.
absl::flat_hash_set<const HloInstruction*> already_bound_for_this_function;
auto arg_iter = function->arg_begin();
for (const HloInstruction* io_hlo : io_hlos) {
CHECK(!absl::c_count(non_io_hlos, io_hlo))
<< "IO HLOs and non-IO HLOs should be disjoint";
if (!already_bound_for_this_function.contains(io_hlo)) {
if (!is_nested_ && io_hlo->opcode() == HloOpcode::kGetTupleElement) {
BindHloToIrValue(*io_hlo, EmitGetTupleElement(io_hlo, &*arg_iter));
@ -60,6 +64,10 @@ void HloToIrBindings::EmitBasePointersForHlos(
++arg_iter;
}
// Name and skip the output parameter.
arg_iter->setName("output_arg");
++arg_iter;
temp_buffer_base_ = &*arg_iter;
temp_buffer_base_->setName("temp_buffer");

View File

@ -256,6 +256,11 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
return false;
}
auto producer = consumer->operand(operand_index);
// Don't fuse variadic reduce.
if (consumer->opcode() == HloOpcode::kReduce && consumer->shape().IsTuple()) {
return false;
}
// The following checks are potentially expensive.
if (FusionWouldBeTooLarge(consumer, producer)) {
return false;

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "absl/algorithm/container.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Module.h"
#include "tensorflow/compiler/xla/primitive_util.h"
@ -32,7 +33,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h"
#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
@ -157,8 +160,7 @@ Status IrEmitter::EmitCallToNestedComputation(
if (emitted_function == nullptr) {
IrEmitterNested ir_emitter_nested(hlo_module_config_, nested_computation,
ir_emitter_context_);
TF_RETURN_IF_ERROR(
nested_computation.root_instruction()->Accept(&ir_emitter_nested));
TF_RETURN_IF_ERROR(ir_emitter_nested.CodegenNestedComputation());
emitted_function = ir_emitter_nested.GetEmittedFunction();
}
@ -661,23 +663,38 @@ Status IrEmitter::HandleParameter(HloInstruction* parameter) {
return Status::OK();
}
Status IrEmitter::HandleReduce(HloInstruction* reduce) {
// TODO(b/118332391): Support variadic reduce.
if (!reduce->shape().IsArray()) {
return Unimplemented("Variadic reduce is not supported on GPU");
Status IrEmitter::HandleReduce(HloInstruction* instr) {
const HloReduceInstruction* reduce = Cast<HloReduceInstruction>(instr);
const Shape& out_shape = reduce->shape();
bool returns_tuple = !out_shape.IsArray();
int accumulators_count = 1;
if (returns_tuple) {
CHECK(out_shape.IsTuple());
accumulators_count = out_shape.tuple_shapes_size();
}
auto arg = reduce->operand(0);
auto init_value = reduce->operand(1);
absl::Span<const int64> dimensions(reduce->dimensions());
HloComputation* function = reduce->to_apply();
return EmitTargetElementLoop(
*reduce,
[=](const llvm_ir::IrArray::Index& index) -> StatusOr<llvm::Value*> {
// Initialize an accumulator with init_value.
llvm::AllocaInst* accumulator_addr =
Alloca(llvm_ir::PrimitiveTypeToIrType(
reduce->shape().element_type(), module_));
Store(Load(GetBasePointer(*init_value)), accumulator_addr);
std::vector<llvm::Value*> accumulator_addrs;
std::vector<llvm::Type*> accumulator_types;
// Initialize accumulators with initial values.
for (int i = 0; i < accumulators_count; i++) {
auto init_value = reduce->init_values()[i];
const Shape& element_shape =
returns_tuple ? out_shape.tuple_shapes(i) : out_shape;
PrimitiveType accumulator_type = element_shape.element_type();
llvm::Type* accumulator_llvm_type =
llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_);
llvm::AllocaInst* accumulator_addr = Alloca(accumulator_llvm_type);
Store(Load(GetBasePointer(*init_value)), accumulator_addr);
accumulator_addrs.push_back(accumulator_addr);
accumulator_types.push_back(accumulator_llvm_type);
}
// The enclosing loops go over all the target elements. Now we have to
// compute the actual target element. For this, we build a new loop nest
@ -709,13 +726,49 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) {
// Apply the reduction function to the loaded value.
llvm_ir::IrArray::Index input_index(input_multi_index, arg->shape(),
b_.getInt64Ty());
llvm::Value* input_address =
GetIrArray(*arg, *reduce).EmitArrayElementAddress(input_index, &b_);
std::vector<llvm::Value*> reduction_operands(accumulator_addrs.begin(),
accumulator_addrs.end());
for (int i = 0; i < accumulators_count; i++) {
llvm::Value* input_address =
GetIrArray(*reduce->operand(i), *reduce)
.EmitArrayElementAddress(input_index, &b_);
reduction_operands.push_back(input_address);
}
llvm::Value* ret_argument;
if (!returns_tuple) {
CHECK_EQ(accumulator_addrs.size(), 1);
ret_argument = accumulator_addrs[0];
} else {
const Shape& return_shape = function->root_instruction()->shape();
llvm::Type* return_value_buffer_type =
llvm_ir::ShapeToIrType(return_shape, module_);
ret_argument = Alloca(return_value_buffer_type);
llvm_ir::IrArray tuple_array(ret_argument, return_shape);
EmitTuple(tuple_array, accumulator_addrs, &b_);
}
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
*function, {accumulator_addr, input_address}, accumulator_addr));
*function, reduction_operands, ret_argument));
SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
return Load(accumulator_addr);
if (!returns_tuple) {
CHECK_EQ(accumulator_addrs.size(), 1);
return Load(accumulator_addrs[0]);
} else {
// Emit a struct for the LoopEmitter dealing with multi-output
// fusion.
llvm::Value* returned_structure = llvm::UndefValue::get(
llvm::StructType::get(b_.getContext(), accumulator_types));
for (int i = 0; i < accumulators_count; i++) {
llvm::Value* accumulator_value = Load(accumulator_addrs[i]);
returned_structure =
b_.CreateInsertValue(returned_structure, accumulator_value, i);
}
return returned_structure;
}
});
}

View File

@ -38,20 +38,18 @@ namespace gpu {
IrEmitterNested::IrEmitterNested(const HloModuleConfig& hlo_module_config,
const HloComputation& nested_computation,
IrEmitterContext* ir_emitter_context)
: IrEmitter(hlo_module_config, ir_emitter_context, /*is_nested=*/true) {
std::vector<const HloInstruction*> io_hlos;
emitted_function_ =
EmitBasePointersForNestedComputation(nested_computation, &io_hlos);
}
: IrEmitter(hlo_module_config, ir_emitter_context, /*is_nested=*/true),
nested_computation_(nested_computation) {}
llvm::Function* IrEmitterNested::EmitBasePointersForNestedComputation(
const HloComputation& nested_computation,
std::vector<const HloInstruction*>* io_hlos) {
// Nested function serves the same purpose on GPU as a thread-local function on
// a CPU.
Status IrEmitterNested::CodegenNestedComputation() {
std::vector<const HloInstruction*> io_hlos;
std::vector<llvm::Type*> argument_types;
std::vector<int64> argument_dereferenceable_bytes;
for (const HloInstruction* param :
nested_computation.parameter_instructions()) {
io_hlos->push_back(param);
nested_computation_.parameter_instructions()) {
io_hlos.push_back(param);
const Shape& param_shape = param->shape();
argument_types.push_back(
llvm_ir::ShapeToIrType(param_shape, module_)->getPointerTo());
@ -59,9 +57,9 @@ llvm::Function* IrEmitterNested::EmitBasePointersForNestedComputation(
llvm_ir::ByteSizeOf(param_shape, module_->getDataLayout());
argument_dereferenceable_bytes.push_back(param_size);
}
const HloInstruction* root = nested_computation_.root_instruction();
{
const HloInstruction* root = nested_computation.root_instruction();
io_hlos->push_back(root);
const Shape& root_shape = root->shape();
argument_types.push_back(
llvm_ir::ShapeToIrType(root_shape, module_)->getPointerTo());
@ -79,8 +77,8 @@ llvm::Function* IrEmitterNested::EmitBasePointersForNestedComputation(
llvm::GlobalValue::InternalLinkage, // The linkage type.
ir_emitter_context_->name_uniquer()->GetUniqueName(
llvm_ir::SanitizeFunctionName(
nested_computation.name())), // The name of the function.
ir_emitter_context_->llvm_module()); // The parent LLVM module.
nested_computation_.name())), // The name of the function.
ir_emitter_context_->llvm_module()); // The parent LLVM module.
for (size_t arg_no = 0; arg_no < argument_dereferenceable_bytes.size();
++arg_no) {
int64 arg_size = argument_dereferenceable_bytes[arg_no];
@ -96,17 +94,62 @@ llvm::Function* IrEmitterNested::EmitBasePointersForNestedComputation(
llvm::BasicBlock::Create(function->getContext(), "entry", function);
// Emit a "return void" at entry_bb's end, and sets the insert point before
// that return instruction.
b_.SetInsertPoint(llvm::ReturnInst::Create(function->getContext(), entry_bb));
llvm::ReturnInst* ret_instr =
llvm::ReturnInst::Create(function->getContext(), entry_bb);
b_.SetInsertPoint(ret_instr);
std::vector<const HloInstruction*> non_io_hlos;
for (const auto* hlo : nested_computation.instructions()) {
non_io_hlos.push_back(root);
for (const auto* hlo : nested_computation_.instructions()) {
if (hlo->opcode() != HloOpcode::kParameter &&
hlo != nested_computation.root_instruction()) {
hlo != nested_computation_.root_instruction()) {
non_io_hlos.push_back(hlo);
}
}
bindings_.EmitBasePointersForHlos(*io_hlos, non_io_hlos);
return function;
bindings_.EmitBasePointersForHlos(io_hlos, non_io_hlos);
TF_RETURN_IF_ERROR(nested_computation_.root_instruction()->Accept(this));
b_.SetInsertPoint(ret_instr);
// Function epilogue: copy the output value back.
{
// TODO(cheshire) Duplication vs. EmitThreadLocalFunctionEpilogue
const HloInstruction* root_instruction =
nested_computation_.root_instruction();
llvm::Value* root_value = bindings_.GetBasePointer(*root_instruction);
const Shape& return_shape = root_instruction->shape();
// Second last argument is the out parameter.
llvm::Argument* out_parameter = std::prev(function->arg_end(), 2);
if (ShapeUtil::IsScalar(return_shape)) {
llvm::Value* ret_value = Load(root_value, "load_ret_value");
Store(ret_value,
BitCast(out_parameter, root_value->getType(), "bitcast_ret_value"),
"store_ret_value");
} else {
CHECK(return_shape.IsTuple());
llvm::Type* tuple_type = llvm_ir::ShapeToIrType(return_shape, module_);
llvm::Type* tuple_type_ptr = tuple_type->getPointerTo();
llvm::Value* tuple_ptr = BitCast(out_parameter, tuple_type_ptr);
for (int i = 0; i < return_shape.tuple_shapes_size(); i++) {
const Shape& element_shape = return_shape.tuple_shapes(i);
llvm::Value* destination =
llvm_ir::EmitGetTupleElement(element_shape,
/*index=*/i,
/*alignment=*/1, tuple_ptr, &b_);
llvm::Value* source =
llvm_ir::EmitGetTupleElement(element_shape,
/*index=*/i,
/*alignment=*/1, root_value, &b_);
Store(Load(source), destination);
}
}
}
b_.SetInsertPoint(ret_instr);
emitted_function_ = function;
return Status::OK();
}
Status IrEmitterNested::HandleParameter(HloInstruction* parameter) {
@ -118,7 +161,7 @@ Status IrEmitterNested::EmitTargetElementLoop(
const llvm_ir::ElementGenerator& element_generator) {
// For MOF we give the loop emitter an array for every output it should
// generate.
if (hlo.IsMultiOutputFusion()) {
if (hlo.shape().IsTuple()) {
std::vector<llvm_ir::IrArray> target_arrays =
ConstructIrArrayForOutputs(hlo);
TF_RETURN_IF_ERROR(

View File

@ -58,11 +58,11 @@ class IrEmitterNested : public IrEmitter {
const HloInstruction& hlo,
const llvm_ir::ElementGenerator& body_emitter) override;
private:
llvm::Function* EmitBasePointersForNestedComputation(
const HloComputation& nested_computation,
std::vector<const HloInstruction*>* io_hlos);
// Generate the code for the computation passed in the constructor.
Status CodegenNestedComputation();
private:
const HloComputation& nested_computation_;
llvm::Function* emitted_function_;
};

View File

@ -632,8 +632,9 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
// a 1D array. The specialized version requires a initializer thunk that
// initializes the output array to the initial value of the reduce.
if (root->opcode() == HloOpcode::kReduce && root->shape().IsTuple()) {
// TODO(b/118332391): Support variadic reduce.
return Unimplemented("Variadic reduce is not supported on GPU");
// TODO(b/129089333): Support tiled vectorized variadic reduce.
return Unimplemented(
"Vectorized variadic reduce is not supported on GPU");
}
return EmitReductionToVector(fusion);
}
@ -722,11 +723,7 @@ Status IrEmitterUnnested::EmitExtraOutputsForReduce(
}
Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) {
// TODO(b/118332391): Support multi-output reduce.
if (!reduce->shape().IsArray()) {
return Unimplemented("Multi-output reduce is not supported on GPU");
}
if (IsReductionToVector(*reduce)) {
if (IsReductionToVector(*reduce) && reduce->shape().IsArray()) {
return EmitReductionToVector(reduce);
}
@ -2179,9 +2176,10 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk(
int unroll_factor = thunk->unroll_factor();
VLOG(3) << bindings_.ToString();
const Shape& element_shape = hlo.IsMultiOutputFusion()
? ShapeUtil::GetSubshape(hlo.shape(), {0})
: hlo.shape();
bool multi_output = hlo.shape().IsTuple();
const Shape& element_shape =
multi_output ? ShapeUtil::GetSubshape(hlo.shape(), {0}) : hlo.shape();
VLOG(3) << "EmitTargetElementLoopInThunk "
<< ShapeUtil::HumanStringWithLayout(hlo.shape())
<< " for unroll_factor " << unroll_factor;
@ -2189,7 +2187,7 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk(
element_shape, ir_emitter_context_->device_description(), unroll_factor);
UpdateLaunchDimensions(launch_dimensions, thunk,
ir_emitter_context_->llvm_module());
if (!hlo.IsMultiOutputFusion()) {
if (!multi_output) {
return ParallelLoopEmitter(element_generator, GetIrArray(hlo, hlo),
launch_dimensions, &b_, unroll_factor)
.EmitLoop(

View File

@ -1038,7 +1038,7 @@ XLA_TEST_F(ReduceHloTest, HandleReductionToVectorAndOtherReduction) {
class VariadicReduceTest : public HloTestBase {};
XLA_TEST_F(VariadicReduceTest, DISABLED_ON_GPU(Reduce_R3x2_to_R2x2_simple)) {
XLA_TEST_F(VariadicReduceTest, Reduce_R3x2_to_R2x2_simple) {
absl::string_view hlo_string = R"(
HloModule Reduce_R3x2_to_R1x2_simple
@ -1066,7 +1066,7 @@ XLA_TEST_F(VariadicReduceTest, DISABLED_ON_GPU(Reduce_R3x2_to_R2x2_simple)) {
EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-5, 1e-5}));
}
XLA_TEST_F(VariadicReduceTest, DISABLED_ON_GPU(Reduce_R3x2_to_R1x2_simple)) {
XLA_TEST_F(VariadicReduceTest, Reduce_R3x2_to_R1x2_simple) {
absl::string_view hlo_string = R"(
HloModule Reduce_R3x2_to_R1x2_simple
@ -1094,7 +1094,7 @@ XLA_TEST_F(VariadicReduceTest, DISABLED_ON_GPU(Reduce_R3x2_to_R1x2_simple)) {
EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-5, 1e-5}));
}
XLA_TEST_F(VariadicReduceTest, DISABLED_ON_GPU(Reduce_R1x2_to_R0x2_simple)) {
XLA_TEST_F(VariadicReduceTest, Reduce_R1x2_to_R0x2_simple) {
absl::string_view hlo_string = R"(
HloModule Reduce_R1x2_to_R0x2_simple
@ -1122,7 +1122,7 @@ XLA_TEST_F(VariadicReduceTest, DISABLED_ON_GPU(Reduce_R1x2_to_R0x2_simple)) {
EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-5, 1e-5}));
}
XLA_TEST_F(VariadicReduceTest, DISABLED_ON_GPU(Reduce_R1x2_to_R0x2_argmax)) {
XLA_TEST_F(VariadicReduceTest, Reduce_R1x2_to_R0x2_argmax) {
absl::string_view hlo_string = R"(
HloModule Reduce_R1x2_to_R0x2_argmax