Variadic reduce implementation on GPU.
Implements slower, non-vectorized version. Faster version still remains to be done. PiperOrigin-RevId: 240849148
This commit is contained in:
parent
fb772b781b
commit
b651a2cb5a
@ -38,17 +38,21 @@ using absl::StrCat;
|
|||||||
void HloToIrBindings::EmitBasePointersForHlos(
|
void HloToIrBindings::EmitBasePointersForHlos(
|
||||||
absl::Span<const HloInstruction* const> io_hlos,
|
absl::Span<const HloInstruction* const> io_hlos,
|
||||||
absl::Span<const HloInstruction* const> non_io_hlos) {
|
absl::Span<const HloInstruction* const> non_io_hlos) {
|
||||||
// I/O HLOs are bound to the arguments of the current IR function. I.e.,
|
// I/O HLOs are bound to the arguments of the current IR function,
|
||||||
|
// *excluding* the output argument, which is added to non-I/O HLOs.
|
||||||
|
// I.e.,
|
||||||
//
|
//
|
||||||
// void IrFunction(io_0, io_1, ..., io_{m-1}, temp_buffer_base) {
|
// void IrFunction(io_0, io_1, ..., io_{m-1}, output_arg, temp_buffer_base) {
|
||||||
llvm::Function* function = b_->GetInsertBlock()->getParent();
|
llvm::Function* function = b_->GetInsertBlock()->getParent();
|
||||||
CHECK_EQ(io_hlos.size() + 1, function->arg_size());
|
CHECK_EQ(io_hlos.size() + 2, function->arg_size());
|
||||||
|
|
||||||
// An HLO can have duplicated operands. This data structure remembers which
|
// An HLO can have duplicated operands. This data structure remembers which
|
||||||
// operand HLOs are already bound to avoid rebinding the same HLO.
|
// operand HLOs are already bound to avoid rebinding the same HLO.
|
||||||
absl::flat_hash_set<const HloInstruction*> already_bound_for_this_function;
|
absl::flat_hash_set<const HloInstruction*> already_bound_for_this_function;
|
||||||
auto arg_iter = function->arg_begin();
|
auto arg_iter = function->arg_begin();
|
||||||
for (const HloInstruction* io_hlo : io_hlos) {
|
for (const HloInstruction* io_hlo : io_hlos) {
|
||||||
|
CHECK(!absl::c_count(non_io_hlos, io_hlo))
|
||||||
|
<< "IO HLOs and non-IO HLOs should be disjoint";
|
||||||
if (!already_bound_for_this_function.contains(io_hlo)) {
|
if (!already_bound_for_this_function.contains(io_hlo)) {
|
||||||
if (!is_nested_ && io_hlo->opcode() == HloOpcode::kGetTupleElement) {
|
if (!is_nested_ && io_hlo->opcode() == HloOpcode::kGetTupleElement) {
|
||||||
BindHloToIrValue(*io_hlo, EmitGetTupleElement(io_hlo, &*arg_iter));
|
BindHloToIrValue(*io_hlo, EmitGetTupleElement(io_hlo, &*arg_iter));
|
||||||
@ -60,6 +64,10 @@ void HloToIrBindings::EmitBasePointersForHlos(
|
|||||||
++arg_iter;
|
++arg_iter;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Name and skip the output parameter.
|
||||||
|
arg_iter->setName("output_arg");
|
||||||
|
++arg_iter;
|
||||||
|
|
||||||
temp_buffer_base_ = &*arg_iter;
|
temp_buffer_base_ = &*arg_iter;
|
||||||
temp_buffer_base_->setName("temp_buffer");
|
temp_buffer_base_->setName("temp_buffer");
|
||||||
|
|
||||||
|
@ -256,6 +256,11 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
auto producer = consumer->operand(operand_index);
|
auto producer = consumer->operand(operand_index);
|
||||||
|
|
||||||
|
// Don't fuse variadic reduce.
|
||||||
|
if (consumer->opcode() == HloOpcode::kReduce && consumer->shape().IsTuple()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
// The following checks are potentially expensive.
|
// The following checks are potentially expensive.
|
||||||
if (FusionWouldBeTooLarge(consumer, producer)) {
|
if (FusionWouldBeTooLarge(consumer, producer)) {
|
||||||
return false;
|
return false;
|
||||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
|||||||
#include "absl/algorithm/container.h"
|
#include "absl/algorithm/container.h"
|
||||||
#include "llvm/IR/BasicBlock.h"
|
#include "llvm/IR/BasicBlock.h"
|
||||||
#include "llvm/IR/Constants.h"
|
#include "llvm/IR/Constants.h"
|
||||||
|
#include "llvm/IR/DerivedTypes.h"
|
||||||
#include "llvm/IR/Instructions.h"
|
#include "llvm/IR/Instructions.h"
|
||||||
#include "llvm/IR/Module.h"
|
#include "llvm/IR/Module.h"
|
||||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||||
@ -32,7 +33,9 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h"
|
#include "tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h"
|
#include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h"
|
#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||||
#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
|
#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
|
||||||
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
|
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
|
||||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
|
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
|
||||||
@ -157,8 +160,7 @@ Status IrEmitter::EmitCallToNestedComputation(
|
|||||||
if (emitted_function == nullptr) {
|
if (emitted_function == nullptr) {
|
||||||
IrEmitterNested ir_emitter_nested(hlo_module_config_, nested_computation,
|
IrEmitterNested ir_emitter_nested(hlo_module_config_, nested_computation,
|
||||||
ir_emitter_context_);
|
ir_emitter_context_);
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(ir_emitter_nested.CodegenNestedComputation());
|
||||||
nested_computation.root_instruction()->Accept(&ir_emitter_nested));
|
|
||||||
emitted_function = ir_emitter_nested.GetEmittedFunction();
|
emitted_function = ir_emitter_nested.GetEmittedFunction();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -661,23 +663,38 @@ Status IrEmitter::HandleParameter(HloInstruction* parameter) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status IrEmitter::HandleReduce(HloInstruction* reduce) {
|
Status IrEmitter::HandleReduce(HloInstruction* instr) {
|
||||||
// TODO(b/118332391): Support variadic reduce.
|
const HloReduceInstruction* reduce = Cast<HloReduceInstruction>(instr);
|
||||||
if (!reduce->shape().IsArray()) {
|
const Shape& out_shape = reduce->shape();
|
||||||
return Unimplemented("Variadic reduce is not supported on GPU");
|
bool returns_tuple = !out_shape.IsArray();
|
||||||
|
int accumulators_count = 1;
|
||||||
|
if (returns_tuple) {
|
||||||
|
CHECK(out_shape.IsTuple());
|
||||||
|
accumulators_count = out_shape.tuple_shapes_size();
|
||||||
}
|
}
|
||||||
|
|
||||||
auto arg = reduce->operand(0);
|
auto arg = reduce->operand(0);
|
||||||
auto init_value = reduce->operand(1);
|
|
||||||
absl::Span<const int64> dimensions(reduce->dimensions());
|
absl::Span<const int64> dimensions(reduce->dimensions());
|
||||||
HloComputation* function = reduce->to_apply();
|
HloComputation* function = reduce->to_apply();
|
||||||
return EmitTargetElementLoop(
|
return EmitTargetElementLoop(
|
||||||
*reduce,
|
*reduce,
|
||||||
[=](const llvm_ir::IrArray::Index& index) -> StatusOr<llvm::Value*> {
|
[=](const llvm_ir::IrArray::Index& index) -> StatusOr<llvm::Value*> {
|
||||||
// Initialize an accumulator with init_value.
|
std::vector<llvm::Value*> accumulator_addrs;
|
||||||
llvm::AllocaInst* accumulator_addr =
|
std::vector<llvm::Type*> accumulator_types;
|
||||||
Alloca(llvm_ir::PrimitiveTypeToIrType(
|
|
||||||
reduce->shape().element_type(), module_));
|
// Initialize accumulators with initial values.
|
||||||
Store(Load(GetBasePointer(*init_value)), accumulator_addr);
|
for (int i = 0; i < accumulators_count; i++) {
|
||||||
|
auto init_value = reduce->init_values()[i];
|
||||||
|
const Shape& element_shape =
|
||||||
|
returns_tuple ? out_shape.tuple_shapes(i) : out_shape;
|
||||||
|
PrimitiveType accumulator_type = element_shape.element_type();
|
||||||
|
llvm::Type* accumulator_llvm_type =
|
||||||
|
llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_);
|
||||||
|
llvm::AllocaInst* accumulator_addr = Alloca(accumulator_llvm_type);
|
||||||
|
Store(Load(GetBasePointer(*init_value)), accumulator_addr);
|
||||||
|
accumulator_addrs.push_back(accumulator_addr);
|
||||||
|
accumulator_types.push_back(accumulator_llvm_type);
|
||||||
|
}
|
||||||
|
|
||||||
// The enclosing loops go over all the target elements. Now we have to
|
// The enclosing loops go over all the target elements. Now we have to
|
||||||
// compute the actual target element. For this, we build a new loop nest
|
// compute the actual target element. For this, we build a new loop nest
|
||||||
@ -709,13 +726,49 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) {
|
|||||||
// Apply the reduction function to the loaded value.
|
// Apply the reduction function to the loaded value.
|
||||||
llvm_ir::IrArray::Index input_index(input_multi_index, arg->shape(),
|
llvm_ir::IrArray::Index input_index(input_multi_index, arg->shape(),
|
||||||
b_.getInt64Ty());
|
b_.getInt64Ty());
|
||||||
llvm::Value* input_address =
|
std::vector<llvm::Value*> reduction_operands(accumulator_addrs.begin(),
|
||||||
GetIrArray(*arg, *reduce).EmitArrayElementAddress(input_index, &b_);
|
accumulator_addrs.end());
|
||||||
|
for (int i = 0; i < accumulators_count; i++) {
|
||||||
|
llvm::Value* input_address =
|
||||||
|
GetIrArray(*reduce->operand(i), *reduce)
|
||||||
|
.EmitArrayElementAddress(input_index, &b_);
|
||||||
|
reduction_operands.push_back(input_address);
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::Value* ret_argument;
|
||||||
|
if (!returns_tuple) {
|
||||||
|
CHECK_EQ(accumulator_addrs.size(), 1);
|
||||||
|
ret_argument = accumulator_addrs[0];
|
||||||
|
} else {
|
||||||
|
const Shape& return_shape = function->root_instruction()->shape();
|
||||||
|
|
||||||
|
llvm::Type* return_value_buffer_type =
|
||||||
|
llvm_ir::ShapeToIrType(return_shape, module_);
|
||||||
|
ret_argument = Alloca(return_value_buffer_type);
|
||||||
|
llvm_ir::IrArray tuple_array(ret_argument, return_shape);
|
||||||
|
EmitTuple(tuple_array, accumulator_addrs, &b_);
|
||||||
|
}
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
|
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
|
||||||
*function, {accumulator_addr, input_address}, accumulator_addr));
|
*function, reduction_operands, ret_argument));
|
||||||
|
|
||||||
SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
|
SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
|
||||||
return Load(accumulator_addr);
|
|
||||||
|
if (!returns_tuple) {
|
||||||
|
CHECK_EQ(accumulator_addrs.size(), 1);
|
||||||
|
return Load(accumulator_addrs[0]);
|
||||||
|
} else {
|
||||||
|
// Emit a struct for the LoopEmitter dealing with multi-output
|
||||||
|
// fusion.
|
||||||
|
llvm::Value* returned_structure = llvm::UndefValue::get(
|
||||||
|
llvm::StructType::get(b_.getContext(), accumulator_types));
|
||||||
|
for (int i = 0; i < accumulators_count; i++) {
|
||||||
|
llvm::Value* accumulator_value = Load(accumulator_addrs[i]);
|
||||||
|
returned_structure =
|
||||||
|
b_.CreateInsertValue(returned_structure, accumulator_value, i);
|
||||||
|
}
|
||||||
|
return returned_structure;
|
||||||
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -38,20 +38,18 @@ namespace gpu {
|
|||||||
IrEmitterNested::IrEmitterNested(const HloModuleConfig& hlo_module_config,
|
IrEmitterNested::IrEmitterNested(const HloModuleConfig& hlo_module_config,
|
||||||
const HloComputation& nested_computation,
|
const HloComputation& nested_computation,
|
||||||
IrEmitterContext* ir_emitter_context)
|
IrEmitterContext* ir_emitter_context)
|
||||||
: IrEmitter(hlo_module_config, ir_emitter_context, /*is_nested=*/true) {
|
: IrEmitter(hlo_module_config, ir_emitter_context, /*is_nested=*/true),
|
||||||
std::vector<const HloInstruction*> io_hlos;
|
nested_computation_(nested_computation) {}
|
||||||
emitted_function_ =
|
|
||||||
EmitBasePointersForNestedComputation(nested_computation, &io_hlos);
|
|
||||||
}
|
|
||||||
|
|
||||||
llvm::Function* IrEmitterNested::EmitBasePointersForNestedComputation(
|
// Nested function serves the same purpose on GPU as a thread-local function on
|
||||||
const HloComputation& nested_computation,
|
// a CPU.
|
||||||
std::vector<const HloInstruction*>* io_hlos) {
|
Status IrEmitterNested::CodegenNestedComputation() {
|
||||||
|
std::vector<const HloInstruction*> io_hlos;
|
||||||
std::vector<llvm::Type*> argument_types;
|
std::vector<llvm::Type*> argument_types;
|
||||||
std::vector<int64> argument_dereferenceable_bytes;
|
std::vector<int64> argument_dereferenceable_bytes;
|
||||||
for (const HloInstruction* param :
|
for (const HloInstruction* param :
|
||||||
nested_computation.parameter_instructions()) {
|
nested_computation_.parameter_instructions()) {
|
||||||
io_hlos->push_back(param);
|
io_hlos.push_back(param);
|
||||||
const Shape& param_shape = param->shape();
|
const Shape& param_shape = param->shape();
|
||||||
argument_types.push_back(
|
argument_types.push_back(
|
||||||
llvm_ir::ShapeToIrType(param_shape, module_)->getPointerTo());
|
llvm_ir::ShapeToIrType(param_shape, module_)->getPointerTo());
|
||||||
@ -59,9 +57,9 @@ llvm::Function* IrEmitterNested::EmitBasePointersForNestedComputation(
|
|||||||
llvm_ir::ByteSizeOf(param_shape, module_->getDataLayout());
|
llvm_ir::ByteSizeOf(param_shape, module_->getDataLayout());
|
||||||
argument_dereferenceable_bytes.push_back(param_size);
|
argument_dereferenceable_bytes.push_back(param_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const HloInstruction* root = nested_computation_.root_instruction();
|
||||||
{
|
{
|
||||||
const HloInstruction* root = nested_computation.root_instruction();
|
|
||||||
io_hlos->push_back(root);
|
|
||||||
const Shape& root_shape = root->shape();
|
const Shape& root_shape = root->shape();
|
||||||
argument_types.push_back(
|
argument_types.push_back(
|
||||||
llvm_ir::ShapeToIrType(root_shape, module_)->getPointerTo());
|
llvm_ir::ShapeToIrType(root_shape, module_)->getPointerTo());
|
||||||
@ -79,8 +77,8 @@ llvm::Function* IrEmitterNested::EmitBasePointersForNestedComputation(
|
|||||||
llvm::GlobalValue::InternalLinkage, // The linkage type.
|
llvm::GlobalValue::InternalLinkage, // The linkage type.
|
||||||
ir_emitter_context_->name_uniquer()->GetUniqueName(
|
ir_emitter_context_->name_uniquer()->GetUniqueName(
|
||||||
llvm_ir::SanitizeFunctionName(
|
llvm_ir::SanitizeFunctionName(
|
||||||
nested_computation.name())), // The name of the function.
|
nested_computation_.name())), // The name of the function.
|
||||||
ir_emitter_context_->llvm_module()); // The parent LLVM module.
|
ir_emitter_context_->llvm_module()); // The parent LLVM module.
|
||||||
for (size_t arg_no = 0; arg_no < argument_dereferenceable_bytes.size();
|
for (size_t arg_no = 0; arg_no < argument_dereferenceable_bytes.size();
|
||||||
++arg_no) {
|
++arg_no) {
|
||||||
int64 arg_size = argument_dereferenceable_bytes[arg_no];
|
int64 arg_size = argument_dereferenceable_bytes[arg_no];
|
||||||
@ -96,17 +94,62 @@ llvm::Function* IrEmitterNested::EmitBasePointersForNestedComputation(
|
|||||||
llvm::BasicBlock::Create(function->getContext(), "entry", function);
|
llvm::BasicBlock::Create(function->getContext(), "entry", function);
|
||||||
// Emit a "return void" at entry_bb's end, and sets the insert point before
|
// Emit a "return void" at entry_bb's end, and sets the insert point before
|
||||||
// that return instruction.
|
// that return instruction.
|
||||||
b_.SetInsertPoint(llvm::ReturnInst::Create(function->getContext(), entry_bb));
|
llvm::ReturnInst* ret_instr =
|
||||||
|
llvm::ReturnInst::Create(function->getContext(), entry_bb);
|
||||||
|
b_.SetInsertPoint(ret_instr);
|
||||||
|
|
||||||
std::vector<const HloInstruction*> non_io_hlos;
|
std::vector<const HloInstruction*> non_io_hlos;
|
||||||
for (const auto* hlo : nested_computation.instructions()) {
|
non_io_hlos.push_back(root);
|
||||||
|
for (const auto* hlo : nested_computation_.instructions()) {
|
||||||
if (hlo->opcode() != HloOpcode::kParameter &&
|
if (hlo->opcode() != HloOpcode::kParameter &&
|
||||||
hlo != nested_computation.root_instruction()) {
|
hlo != nested_computation_.root_instruction()) {
|
||||||
non_io_hlos.push_back(hlo);
|
non_io_hlos.push_back(hlo);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
bindings_.EmitBasePointersForHlos(*io_hlos, non_io_hlos);
|
bindings_.EmitBasePointersForHlos(io_hlos, non_io_hlos);
|
||||||
return function;
|
|
||||||
|
TF_RETURN_IF_ERROR(nested_computation_.root_instruction()->Accept(this));
|
||||||
|
b_.SetInsertPoint(ret_instr);
|
||||||
|
|
||||||
|
// Function epilogue: copy the output value back.
|
||||||
|
{
|
||||||
|
// TODO(cheshire) Duplication vs. EmitThreadLocalFunctionEpilogue
|
||||||
|
const HloInstruction* root_instruction =
|
||||||
|
nested_computation_.root_instruction();
|
||||||
|
llvm::Value* root_value = bindings_.GetBasePointer(*root_instruction);
|
||||||
|
const Shape& return_shape = root_instruction->shape();
|
||||||
|
|
||||||
|
// Second last argument is the out parameter.
|
||||||
|
llvm::Argument* out_parameter = std::prev(function->arg_end(), 2);
|
||||||
|
|
||||||
|
if (ShapeUtil::IsScalar(return_shape)) {
|
||||||
|
llvm::Value* ret_value = Load(root_value, "load_ret_value");
|
||||||
|
Store(ret_value,
|
||||||
|
BitCast(out_parameter, root_value->getType(), "bitcast_ret_value"),
|
||||||
|
"store_ret_value");
|
||||||
|
} else {
|
||||||
|
CHECK(return_shape.IsTuple());
|
||||||
|
llvm::Type* tuple_type = llvm_ir::ShapeToIrType(return_shape, module_);
|
||||||
|
llvm::Type* tuple_type_ptr = tuple_type->getPointerTo();
|
||||||
|
llvm::Value* tuple_ptr = BitCast(out_parameter, tuple_type_ptr);
|
||||||
|
|
||||||
|
for (int i = 0; i < return_shape.tuple_shapes_size(); i++) {
|
||||||
|
const Shape& element_shape = return_shape.tuple_shapes(i);
|
||||||
|
llvm::Value* destination =
|
||||||
|
llvm_ir::EmitGetTupleElement(element_shape,
|
||||||
|
/*index=*/i,
|
||||||
|
/*alignment=*/1, tuple_ptr, &b_);
|
||||||
|
llvm::Value* source =
|
||||||
|
llvm_ir::EmitGetTupleElement(element_shape,
|
||||||
|
/*index=*/i,
|
||||||
|
/*alignment=*/1, root_value, &b_);
|
||||||
|
Store(Load(source), destination);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
b_.SetInsertPoint(ret_instr);
|
||||||
|
emitted_function_ = function;
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status IrEmitterNested::HandleParameter(HloInstruction* parameter) {
|
Status IrEmitterNested::HandleParameter(HloInstruction* parameter) {
|
||||||
@ -118,7 +161,7 @@ Status IrEmitterNested::EmitTargetElementLoop(
|
|||||||
const llvm_ir::ElementGenerator& element_generator) {
|
const llvm_ir::ElementGenerator& element_generator) {
|
||||||
// For MOF we give the loop emitter an array for every output it should
|
// For MOF we give the loop emitter an array for every output it should
|
||||||
// generate.
|
// generate.
|
||||||
if (hlo.IsMultiOutputFusion()) {
|
if (hlo.shape().IsTuple()) {
|
||||||
std::vector<llvm_ir::IrArray> target_arrays =
|
std::vector<llvm_ir::IrArray> target_arrays =
|
||||||
ConstructIrArrayForOutputs(hlo);
|
ConstructIrArrayForOutputs(hlo);
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
|
@ -58,11 +58,11 @@ class IrEmitterNested : public IrEmitter {
|
|||||||
const HloInstruction& hlo,
|
const HloInstruction& hlo,
|
||||||
const llvm_ir::ElementGenerator& body_emitter) override;
|
const llvm_ir::ElementGenerator& body_emitter) override;
|
||||||
|
|
||||||
private:
|
// Generate the code for the computation passed in the constructor.
|
||||||
llvm::Function* EmitBasePointersForNestedComputation(
|
Status CodegenNestedComputation();
|
||||||
const HloComputation& nested_computation,
|
|
||||||
std::vector<const HloInstruction*>* io_hlos);
|
|
||||||
|
|
||||||
|
private:
|
||||||
|
const HloComputation& nested_computation_;
|
||||||
llvm::Function* emitted_function_;
|
llvm::Function* emitted_function_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -632,8 +632,9 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
|
|||||||
// a 1D array. The specialized version requires a initializer thunk that
|
// a 1D array. The specialized version requires a initializer thunk that
|
||||||
// initializes the output array to the initial value of the reduce.
|
// initializes the output array to the initial value of the reduce.
|
||||||
if (root->opcode() == HloOpcode::kReduce && root->shape().IsTuple()) {
|
if (root->opcode() == HloOpcode::kReduce && root->shape().IsTuple()) {
|
||||||
// TODO(b/118332391): Support variadic reduce.
|
// TODO(b/129089333): Support tiled vectorized variadic reduce.
|
||||||
return Unimplemented("Variadic reduce is not supported on GPU");
|
return Unimplemented(
|
||||||
|
"Vectorized variadic reduce is not supported on GPU");
|
||||||
}
|
}
|
||||||
return EmitReductionToVector(fusion);
|
return EmitReductionToVector(fusion);
|
||||||
}
|
}
|
||||||
@ -722,11 +723,7 @@ Status IrEmitterUnnested::EmitExtraOutputsForReduce(
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) {
|
Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) {
|
||||||
// TODO(b/118332391): Support multi-output reduce.
|
if (IsReductionToVector(*reduce) && reduce->shape().IsArray()) {
|
||||||
if (!reduce->shape().IsArray()) {
|
|
||||||
return Unimplemented("Multi-output reduce is not supported on GPU");
|
|
||||||
}
|
|
||||||
if (IsReductionToVector(*reduce)) {
|
|
||||||
return EmitReductionToVector(reduce);
|
return EmitReductionToVector(reduce);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2179,9 +2176,10 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk(
|
|||||||
int unroll_factor = thunk->unroll_factor();
|
int unroll_factor = thunk->unroll_factor();
|
||||||
VLOG(3) << bindings_.ToString();
|
VLOG(3) << bindings_.ToString();
|
||||||
|
|
||||||
const Shape& element_shape = hlo.IsMultiOutputFusion()
|
bool multi_output = hlo.shape().IsTuple();
|
||||||
? ShapeUtil::GetSubshape(hlo.shape(), {0})
|
|
||||||
: hlo.shape();
|
const Shape& element_shape =
|
||||||
|
multi_output ? ShapeUtil::GetSubshape(hlo.shape(), {0}) : hlo.shape();
|
||||||
VLOG(3) << "EmitTargetElementLoopInThunk "
|
VLOG(3) << "EmitTargetElementLoopInThunk "
|
||||||
<< ShapeUtil::HumanStringWithLayout(hlo.shape())
|
<< ShapeUtil::HumanStringWithLayout(hlo.shape())
|
||||||
<< " for unroll_factor " << unroll_factor;
|
<< " for unroll_factor " << unroll_factor;
|
||||||
@ -2189,7 +2187,7 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk(
|
|||||||
element_shape, ir_emitter_context_->device_description(), unroll_factor);
|
element_shape, ir_emitter_context_->device_description(), unroll_factor);
|
||||||
UpdateLaunchDimensions(launch_dimensions, thunk,
|
UpdateLaunchDimensions(launch_dimensions, thunk,
|
||||||
ir_emitter_context_->llvm_module());
|
ir_emitter_context_->llvm_module());
|
||||||
if (!hlo.IsMultiOutputFusion()) {
|
if (!multi_output) {
|
||||||
return ParallelLoopEmitter(element_generator, GetIrArray(hlo, hlo),
|
return ParallelLoopEmitter(element_generator, GetIrArray(hlo, hlo),
|
||||||
launch_dimensions, &b_, unroll_factor)
|
launch_dimensions, &b_, unroll_factor)
|
||||||
.EmitLoop(
|
.EmitLoop(
|
||||||
|
@ -1038,7 +1038,7 @@ XLA_TEST_F(ReduceHloTest, HandleReductionToVectorAndOtherReduction) {
|
|||||||
|
|
||||||
class VariadicReduceTest : public HloTestBase {};
|
class VariadicReduceTest : public HloTestBase {};
|
||||||
|
|
||||||
XLA_TEST_F(VariadicReduceTest, DISABLED_ON_GPU(Reduce_R3x2_to_R2x2_simple)) {
|
XLA_TEST_F(VariadicReduceTest, Reduce_R3x2_to_R2x2_simple) {
|
||||||
absl::string_view hlo_string = R"(
|
absl::string_view hlo_string = R"(
|
||||||
HloModule Reduce_R3x2_to_R1x2_simple
|
HloModule Reduce_R3x2_to_R1x2_simple
|
||||||
|
|
||||||
@ -1066,7 +1066,7 @@ XLA_TEST_F(VariadicReduceTest, DISABLED_ON_GPU(Reduce_R3x2_to_R2x2_simple)) {
|
|||||||
EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-5, 1e-5}));
|
EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-5, 1e-5}));
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(VariadicReduceTest, DISABLED_ON_GPU(Reduce_R3x2_to_R1x2_simple)) {
|
XLA_TEST_F(VariadicReduceTest, Reduce_R3x2_to_R1x2_simple) {
|
||||||
absl::string_view hlo_string = R"(
|
absl::string_view hlo_string = R"(
|
||||||
HloModule Reduce_R3x2_to_R1x2_simple
|
HloModule Reduce_R3x2_to_R1x2_simple
|
||||||
|
|
||||||
@ -1094,7 +1094,7 @@ XLA_TEST_F(VariadicReduceTest, DISABLED_ON_GPU(Reduce_R3x2_to_R1x2_simple)) {
|
|||||||
EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-5, 1e-5}));
|
EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-5, 1e-5}));
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(VariadicReduceTest, DISABLED_ON_GPU(Reduce_R1x2_to_R0x2_simple)) {
|
XLA_TEST_F(VariadicReduceTest, Reduce_R1x2_to_R0x2_simple) {
|
||||||
absl::string_view hlo_string = R"(
|
absl::string_view hlo_string = R"(
|
||||||
HloModule Reduce_R1x2_to_R0x2_simple
|
HloModule Reduce_R1x2_to_R0x2_simple
|
||||||
|
|
||||||
@ -1122,7 +1122,7 @@ XLA_TEST_F(VariadicReduceTest, DISABLED_ON_GPU(Reduce_R1x2_to_R0x2_simple)) {
|
|||||||
EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-5, 1e-5}));
|
EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-5, 1e-5}));
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(VariadicReduceTest, DISABLED_ON_GPU(Reduce_R1x2_to_R0x2_argmax)) {
|
XLA_TEST_F(VariadicReduceTest, Reduce_R1x2_to_R0x2_argmax) {
|
||||||
absl::string_view hlo_string = R"(
|
absl::string_view hlo_string = R"(
|
||||||
HloModule Reduce_R1x2_to_R0x2_argmax
|
HloModule Reduce_R1x2_to_R0x2_argmax
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user