[XLA:GPU] Use the generic implementation for elemental reduce
The generic version used in fusions didn't support variadic reduction on GPU (it did on CPU), so tie up some loose ends and use the generic version. PiperOrigin-RevId: 313428251 Change-Id: Ide547280b0fcf04a99a51b721d8ca860c9da6305
This commit is contained in:
parent
dc18758c27
commit
b266b46825
@ -40,7 +40,7 @@ class GpuElementalIrEmitter : public ElementalIrEmitter {
|
|||||||
public:
|
public:
|
||||||
// A NestedComputer computes an element of the output of the given computation
|
// A NestedComputer computes an element of the output of the given computation
|
||||||
// given a Span of its input elements.
|
// given a Span of its input elements.
|
||||||
using NestedComputer = std::function<StatusOr<llvm::Value*>(
|
using NestedComputer = std::function<StatusOr<std::vector<llvm::Value*>>(
|
||||||
const HloComputation&, absl::Span<llvm::Value* const>)>;
|
const HloComputation&, absl::Span<llvm::Value* const>)>;
|
||||||
|
|
||||||
GpuElementalIrEmitter(const HloModuleConfig& hlo_module_config,
|
GpuElementalIrEmitter(const HloModuleConfig& hlo_module_config,
|
||||||
@ -91,12 +91,7 @@ class GpuElementalIrEmitter : public ElementalIrEmitter {
|
|||||||
StatusOr<std::vector<llvm::Value*>> EmitThreadLocalCall(
|
StatusOr<std::vector<llvm::Value*>> EmitThreadLocalCall(
|
||||||
const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
|
const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
|
||||||
absl::string_view) override {
|
absl::string_view) override {
|
||||||
// TODO(b/118332391): Supported variadic return values.
|
return compute_nested_(callee, parameters);
|
||||||
auto result = compute_nested_(callee, parameters);
|
|
||||||
if (!result.ok()) {
|
|
||||||
return result.status();
|
|
||||||
}
|
|
||||||
return std::vector<llvm::Value*>{result.ValueOrDie()};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
llvm::Value* EmitThreadId() override;
|
llvm::Value* EmitThreadId() override;
|
||||||
|
@ -698,115 +698,6 @@ Status IrEmitter::HandleParameter(HloInstruction* parameter) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status IrEmitter::HandleReduce(HloInstruction* instr) {
|
|
||||||
const HloReduceInstruction* reduce = Cast<HloReduceInstruction>(instr);
|
|
||||||
const Shape& out_shape = reduce->shape();
|
|
||||||
bool returns_tuple = !out_shape.IsArray();
|
|
||||||
int accumulators_count = 1;
|
|
||||||
if (returns_tuple) {
|
|
||||||
CHECK(out_shape.IsTuple());
|
|
||||||
accumulators_count = out_shape.tuple_shapes_size();
|
|
||||||
}
|
|
||||||
|
|
||||||
auto arg = reduce->operand(0);
|
|
||||||
absl::Span<const int64> dimensions(reduce->dimensions());
|
|
||||||
HloComputation* function = reduce->to_apply();
|
|
||||||
return EmitTargetElementLoop(
|
|
||||||
*reduce,
|
|
||||||
[=](const llvm_ir::IrArray::Index& index) -> StatusOr<llvm::Value*> {
|
|
||||||
std::vector<llvm::Value*> accumulator_addrs;
|
|
||||||
std::vector<llvm::Type*> accumulator_types;
|
|
||||||
|
|
||||||
// Initialize accumulators with initial values.
|
|
||||||
for (int i = 0; i < accumulators_count; i++) {
|
|
||||||
auto init_value = reduce->init_values()[i];
|
|
||||||
const Shape& element_shape =
|
|
||||||
returns_tuple ? out_shape.tuple_shapes(i) : out_shape;
|
|
||||||
PrimitiveType accumulator_type = element_shape.element_type();
|
|
||||||
llvm::Type* accumulator_llvm_type =
|
|
||||||
llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_);
|
|
||||||
llvm::AllocaInst* accumulator_addr = Alloca(accumulator_llvm_type);
|
|
||||||
Store(Load(GetBasePointer(*init_value)), accumulator_addr);
|
|
||||||
accumulator_addrs.push_back(accumulator_addr);
|
|
||||||
accumulator_types.push_back(accumulator_llvm_type);
|
|
||||||
}
|
|
||||||
|
|
||||||
// The enclosing loops go over all the target elements. Now we have to
|
|
||||||
// compute the actual target element. For this, we build a new loop nest
|
|
||||||
// to iterate over all the reduction dimensions in the argument.
|
|
||||||
// AddLoopsForShapeOnDimensions will return an Index where induction
|
|
||||||
// Value*s are placed for each dimension in dimensions, and all the rest
|
|
||||||
// are nullptrs.
|
|
||||||
llvm_ir::ForLoopNest loops(IrName(reduce, "inner"), &b_);
|
|
||||||
std::vector<llvm::Value*> input_multi_index =
|
|
||||||
loops.AddLoopsForShapeOnDimensions(arg->shape(), dimensions,
|
|
||||||
"reduction_dim");
|
|
||||||
|
|
||||||
SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_);
|
|
||||||
|
|
||||||
// Build a full index for the input argument, using reduced_dims_index
|
|
||||||
// as the base. In reduced_dims_index only the reduction dimensions are
|
|
||||||
// filled in. We fill in the rest of the dimensions with induction
|
|
||||||
// Value*s taken from 'index' which iterates over the target array.
|
|
||||||
// See the high-level description in the XLA documentation for details.
|
|
||||||
llvm_ir::IrArray::Index::const_iterator it = index.begin();
|
|
||||||
|
|
||||||
for (auto& i : input_multi_index) {
|
|
||||||
if (i == nullptr) {
|
|
||||||
i = *it++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
CHECK(index.end() == it);
|
|
||||||
|
|
||||||
// Apply the reduction function to the loaded value.
|
|
||||||
llvm_ir::IrArray::Index input_index(input_multi_index, arg->shape(),
|
|
||||||
b_.getInt64Ty());
|
|
||||||
std::vector<llvm::Value*> reduction_operands(accumulator_addrs.begin(),
|
|
||||||
accumulator_addrs.end());
|
|
||||||
for (int i = 0; i < accumulators_count; i++) {
|
|
||||||
llvm::Value* input_address =
|
|
||||||
GetIrArray(*reduce->operand(i), *reduce)
|
|
||||||
.EmitArrayElementAddress(input_index, &b_);
|
|
||||||
reduction_operands.push_back(input_address);
|
|
||||||
}
|
|
||||||
|
|
||||||
llvm::Value* ret_argument;
|
|
||||||
if (!returns_tuple) {
|
|
||||||
CHECK_EQ(accumulator_addrs.size(), 1);
|
|
||||||
ret_argument = accumulator_addrs[0];
|
|
||||||
} else {
|
|
||||||
const Shape& return_shape = function->root_instruction()->shape();
|
|
||||||
|
|
||||||
llvm::Type* return_value_buffer_type =
|
|
||||||
llvm_ir::ShapeToIrType(return_shape, module_);
|
|
||||||
ret_argument = Alloca(return_value_buffer_type);
|
|
||||||
llvm_ir::IrArray tuple_array(ret_argument, return_shape);
|
|
||||||
EmitTuple(tuple_array, accumulator_addrs, &b_);
|
|
||||||
}
|
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
|
|
||||||
*function, reduction_operands, ret_argument));
|
|
||||||
|
|
||||||
SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
|
|
||||||
|
|
||||||
if (!returns_tuple) {
|
|
||||||
CHECK_EQ(accumulator_addrs.size(), 1);
|
|
||||||
return Load(accumulator_addrs[0]);
|
|
||||||
} else {
|
|
||||||
// Emit a struct for the LoopEmitter dealing with multi-output
|
|
||||||
// fusion.
|
|
||||||
llvm::Value* returned_structure = llvm::UndefValue::get(
|
|
||||||
llvm::StructType::get(b_.getContext(), accumulator_types));
|
|
||||||
for (int i = 0; i < accumulators_count; i++) {
|
|
||||||
llvm::Value* accumulator_value = Load(accumulator_addrs[i]);
|
|
||||||
returned_structure =
|
|
||||||
b_.CreateInsertValue(returned_structure, accumulator_value, i);
|
|
||||||
}
|
|
||||||
return returned_structure;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
Status IrEmitter::HandleFusion(HloInstruction* fusion) {
|
Status IrEmitter::HandleFusion(HloInstruction* fusion) {
|
||||||
// kFusion for library calls should be handled by
|
// kFusion for library calls should be handled by
|
||||||
// IrEmitterUnnested::HandleFusion.
|
// IrEmitterUnnested::HandleFusion.
|
||||||
@ -866,22 +757,39 @@ Status IrEmitter::HandleBatchNormGrad(HloInstruction*) {
|
|||||||
"to a cudnn CustomCall using CudnnBatchNormRewriter.");
|
"to a cudnn CustomCall using CudnnBatchNormRewriter.");
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<llvm::Value*> IrEmitter::ComputeNestedElement(
|
StatusOr<std::vector<llvm::Value*>> IrEmitter::ComputeNestedElement(
|
||||||
const HloComputation& computation,
|
const HloComputation& computation,
|
||||||
absl::Span<llvm::Value* const> parameter_elements) {
|
absl::Span<llvm::Value* const> parameter_elements) {
|
||||||
|
const Shape& return_shape = computation.root_instruction()->shape();
|
||||||
llvm::Value* return_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
|
llvm::Value* return_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
|
||||||
llvm_ir::PrimitiveTypeToIrType(
|
llvm_ir::ShapeToIrType(return_shape, module_), "return_buffer", &b_);
|
||||||
computation.root_instruction()->shape().element_type(), module_),
|
|
||||||
"return_buffer", &b_);
|
|
||||||
std::vector<llvm::Value*> parameter_buffers;
|
std::vector<llvm::Value*> parameter_buffers;
|
||||||
for (llvm::Value* parameter_element : parameter_elements) {
|
for (llvm::Value* parameter_element : parameter_elements) {
|
||||||
parameter_buffers.push_back(llvm_ir::EmitAllocaAtFunctionEntry(
|
parameter_buffers.push_back(llvm_ir::EmitAllocaAtFunctionEntry(
|
||||||
parameter_element->getType(), "parameter_buffer", &b_));
|
parameter_element->getType(), "parameter_buffer", &b_));
|
||||||
Store(parameter_element, parameter_buffers.back());
|
Store(parameter_element, parameter_buffers.back());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<llvm::Value*> allocas_for_returned_scalars;
|
||||||
|
if (!return_shape.IsTuple()) {
|
||||||
|
allocas_for_returned_scalars.push_back(return_buffer);
|
||||||
|
} else {
|
||||||
|
allocas_for_returned_scalars =
|
||||||
|
llvm_ir::EmitTupleAllocasAtFunctionEntry(return_shape, &b_);
|
||||||
|
llvm_ir::IrArray tuple_array(return_buffer, return_shape);
|
||||||
|
|
||||||
|
EmitTuple(tuple_array, allocas_for_returned_scalars, &b_);
|
||||||
|
}
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(computation, parameter_buffers,
|
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(computation, parameter_buffers,
|
||||||
return_buffer));
|
return_buffer));
|
||||||
return Load(return_buffer);
|
|
||||||
|
std::vector<llvm::Value*> returned_scalars;
|
||||||
|
returned_scalars.reserve(allocas_for_returned_scalars.size());
|
||||||
|
for (llvm::Value* addr : allocas_for_returned_scalars) {
|
||||||
|
returned_scalars.push_back(Load(addr));
|
||||||
|
}
|
||||||
|
return returned_scalars;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<llvm_ir::IrArray> IrEmitter::ConstructIrArrayForOutputs(
|
std::vector<llvm_ir::IrArray> IrEmitter::ConstructIrArrayForOutputs(
|
||||||
|
@ -89,7 +89,6 @@ class IrEmitter : public DfsHloVisitorWithDefault,
|
|||||||
Status HandleRecv(HloInstruction* recv) override;
|
Status HandleRecv(HloInstruction* recv) override;
|
||||||
Status HandleRecvDone(HloInstruction* recv_done) override;
|
Status HandleRecvDone(HloInstruction* recv_done) override;
|
||||||
Status HandleParameter(HloInstruction* parameter) override;
|
Status HandleParameter(HloInstruction* parameter) override;
|
||||||
Status HandleReduce(HloInstruction* reduce) override;
|
|
||||||
Status HandleTuple(HloInstruction* tuple) override;
|
Status HandleTuple(HloInstruction* tuple) override;
|
||||||
Status HandleScatter(HloInstruction* scatter) override;
|
Status HandleScatter(HloInstruction* scatter) override;
|
||||||
Status HandleSelect(HloInstruction* select) override;
|
Status HandleSelect(HloInstruction* select) override;
|
||||||
@ -213,7 +212,7 @@ class IrEmitter : public DfsHloVisitorWithDefault,
|
|||||||
const llvm_ir::IrArray::Index& compare_keys_index,
|
const llvm_ir::IrArray::Index& compare_keys_index,
|
||||||
const llvm_ir::IrArray& keys_array);
|
const llvm_ir::IrArray& keys_array);
|
||||||
|
|
||||||
StatusOr<llvm::Value*> ComputeNestedElement(
|
StatusOr<std::vector<llvm::Value*>> ComputeNestedElement(
|
||||||
const HloComputation& computation,
|
const HloComputation& computation,
|
||||||
absl::Span<llvm::Value* const> parameter_elements);
|
absl::Span<llvm::Value* const> parameter_elements);
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user