[XLA:GPU] Use the generic implementation for elemental reduce

The generic version used in fusions didn't support variadic reduction
on GPU (it did on CPU), so tie up some loose ends and use the generic version.

PiperOrigin-RevId: 313428251
Change-Id: Ide547280b0fcf04a99a51b721d8ca860c9da6305
This commit is contained in:
Benjamin Kramer 2020-05-27 11:26:31 -07:00 committed by TensorFlower Gardener
parent dc18758c27
commit b266b46825
3 changed files with 25 additions and 123 deletions

View File

@ -40,7 +40,7 @@ class GpuElementalIrEmitter : public ElementalIrEmitter {
public:
// A NestedComputer computes an element of the output of the given computation
// given a Span of its input elements.
using NestedComputer = std::function<StatusOr<llvm::Value*>(
using NestedComputer = std::function<StatusOr<std::vector<llvm::Value*>>(
const HloComputation&, absl::Span<llvm::Value* const>)>;
GpuElementalIrEmitter(const HloModuleConfig& hlo_module_config,
@ -91,12 +91,7 @@ class GpuElementalIrEmitter : public ElementalIrEmitter {
StatusOr<std::vector<llvm::Value*>> EmitThreadLocalCall(
const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
absl::string_view) override {
// TODO(b/118332391): Supported variadic return values.
auto result = compute_nested_(callee, parameters);
if (!result.ok()) {
return result.status();
}
return std::vector<llvm::Value*>{result.ValueOrDie()};
return compute_nested_(callee, parameters);
}
llvm::Value* EmitThreadId() override;

View File

@ -698,115 +698,6 @@ Status IrEmitter::HandleParameter(HloInstruction* parameter) {
return Status::OK();
}
Status IrEmitter::HandleReduce(HloInstruction* instr) {
const HloReduceInstruction* reduce = Cast<HloReduceInstruction>(instr);
const Shape& out_shape = reduce->shape();
bool returns_tuple = !out_shape.IsArray();
int accumulators_count = 1;
if (returns_tuple) {
CHECK(out_shape.IsTuple());
accumulators_count = out_shape.tuple_shapes_size();
}
auto arg = reduce->operand(0);
absl::Span<const int64> dimensions(reduce->dimensions());
HloComputation* function = reduce->to_apply();
return EmitTargetElementLoop(
*reduce,
[=](const llvm_ir::IrArray::Index& index) -> StatusOr<llvm::Value*> {
std::vector<llvm::Value*> accumulator_addrs;
std::vector<llvm::Type*> accumulator_types;
// Initialize accumulators with initial values.
for (int i = 0; i < accumulators_count; i++) {
auto init_value = reduce->init_values()[i];
const Shape& element_shape =
returns_tuple ? out_shape.tuple_shapes(i) : out_shape;
PrimitiveType accumulator_type = element_shape.element_type();
llvm::Type* accumulator_llvm_type =
llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_);
llvm::AllocaInst* accumulator_addr = Alloca(accumulator_llvm_type);
Store(Load(GetBasePointer(*init_value)), accumulator_addr);
accumulator_addrs.push_back(accumulator_addr);
accumulator_types.push_back(accumulator_llvm_type);
}
// The enclosing loops go over all the target elements. Now we have to
// compute the actual target element. For this, we build a new loop nest
// to iterate over all the reduction dimensions in the argument.
// AddLoopsForShapeOnDimensions will return an Index where induction
// Value*s are placed for each dimension in dimensions, and all the rest
// are nullptrs.
llvm_ir::ForLoopNest loops(IrName(reduce, "inner"), &b_);
std::vector<llvm::Value*> input_multi_index =
loops.AddLoopsForShapeOnDimensions(arg->shape(), dimensions,
"reduction_dim");
SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_);
// Build a full index for the input argument, using reduced_dims_index
// as the base. In reduced_dims_index only the reduction dimensions are
// filled in. We fill in the rest of the dimensions with induction
// Value*s taken from 'index' which iterates over the target array.
// See the high-level description in the XLA documentation for details.
llvm_ir::IrArray::Index::const_iterator it = index.begin();
for (auto& i : input_multi_index) {
if (i == nullptr) {
i = *it++;
}
}
CHECK(index.end() == it);
// Apply the reduction function to the loaded value.
llvm_ir::IrArray::Index input_index(input_multi_index, arg->shape(),
b_.getInt64Ty());
std::vector<llvm::Value*> reduction_operands(accumulator_addrs.begin(),
accumulator_addrs.end());
for (int i = 0; i < accumulators_count; i++) {
llvm::Value* input_address =
GetIrArray(*reduce->operand(i), *reduce)
.EmitArrayElementAddress(input_index, &b_);
reduction_operands.push_back(input_address);
}
llvm::Value* ret_argument;
if (!returns_tuple) {
CHECK_EQ(accumulator_addrs.size(), 1);
ret_argument = accumulator_addrs[0];
} else {
const Shape& return_shape = function->root_instruction()->shape();
llvm::Type* return_value_buffer_type =
llvm_ir::ShapeToIrType(return_shape, module_);
ret_argument = Alloca(return_value_buffer_type);
llvm_ir::IrArray tuple_array(ret_argument, return_shape);
EmitTuple(tuple_array, accumulator_addrs, &b_);
}
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
*function, reduction_operands, ret_argument));
SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
if (!returns_tuple) {
CHECK_EQ(accumulator_addrs.size(), 1);
return Load(accumulator_addrs[0]);
} else {
// Emit a struct for the LoopEmitter dealing with multi-output
// fusion.
llvm::Value* returned_structure = llvm::UndefValue::get(
llvm::StructType::get(b_.getContext(), accumulator_types));
for (int i = 0; i < accumulators_count; i++) {
llvm::Value* accumulator_value = Load(accumulator_addrs[i]);
returned_structure =
b_.CreateInsertValue(returned_structure, accumulator_value, i);
}
return returned_structure;
}
});
}
Status IrEmitter::HandleFusion(HloInstruction* fusion) {
// kFusion for library calls should be handled by
// IrEmitterUnnested::HandleFusion.
@ -866,22 +757,39 @@ Status IrEmitter::HandleBatchNormGrad(HloInstruction*) {
"to a cudnn CustomCall using CudnnBatchNormRewriter.");
}
StatusOr<llvm::Value*> IrEmitter::ComputeNestedElement(
StatusOr<std::vector<llvm::Value*>> IrEmitter::ComputeNestedElement(
const HloComputation& computation,
absl::Span<llvm::Value* const> parameter_elements) {
const Shape& return_shape = computation.root_instruction()->shape();
llvm::Value* return_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
llvm_ir::PrimitiveTypeToIrType(
computation.root_instruction()->shape().element_type(), module_),
"return_buffer", &b_);
llvm_ir::ShapeToIrType(return_shape, module_), "return_buffer", &b_);
std::vector<llvm::Value*> parameter_buffers;
for (llvm::Value* parameter_element : parameter_elements) {
parameter_buffers.push_back(llvm_ir::EmitAllocaAtFunctionEntry(
parameter_element->getType(), "parameter_buffer", &b_));
Store(parameter_element, parameter_buffers.back());
}
std::vector<llvm::Value*> allocas_for_returned_scalars;
if (!return_shape.IsTuple()) {
allocas_for_returned_scalars.push_back(return_buffer);
} else {
allocas_for_returned_scalars =
llvm_ir::EmitTupleAllocasAtFunctionEntry(return_shape, &b_);
llvm_ir::IrArray tuple_array(return_buffer, return_shape);
EmitTuple(tuple_array, allocas_for_returned_scalars, &b_);
}
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(computation, parameter_buffers,
return_buffer));
return Load(return_buffer);
std::vector<llvm::Value*> returned_scalars;
returned_scalars.reserve(allocas_for_returned_scalars.size());
for (llvm::Value* addr : allocas_for_returned_scalars) {
returned_scalars.push_back(Load(addr));
}
return returned_scalars;
}
std::vector<llvm_ir::IrArray> IrEmitter::ConstructIrArrayForOutputs(

View File

@ -89,7 +89,6 @@ class IrEmitter : public DfsHloVisitorWithDefault,
Status HandleRecv(HloInstruction* recv) override;
Status HandleRecvDone(HloInstruction* recv_done) override;
Status HandleParameter(HloInstruction* parameter) override;
Status HandleReduce(HloInstruction* reduce) override;
Status HandleTuple(HloInstruction* tuple) override;
Status HandleScatter(HloInstruction* scatter) override;
Status HandleSelect(HloInstruction* select) override;
@ -213,7 +212,7 @@ class IrEmitter : public DfsHloVisitorWithDefault,
const llvm_ir::IrArray::Index& compare_keys_index,
const llvm_ir::IrArray& keys_array);
StatusOr<llvm::Value*> ComputeNestedElement(
StatusOr<std::vector<llvm::Value*>> ComputeNestedElement(
const HloComputation& computation,
absl::Span<llvm::Value* const> parameter_elements);