Variadic reduce implementation on CPU
Implements variadic reduce on the CPU backend. Before this change, thread-local functions could only return scalars. The biggest part of this change is allowing thread-local functions to return tuples of scalars, which required changes to function generation, allocation of space for the returned values on the callers side, and generating the function epilogue. PiperOrigin-RevId: 239022395
This commit is contained in:
parent
1a690e6c75
commit
803b0f5151
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "llvm/IR/Instructions.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
|
||||
@ -135,10 +136,19 @@ llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator(
|
||||
};
|
||||
case HloOpcode::kReduce:
|
||||
return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
|
||||
auto reduce_instr = Cast<HloReduceInstruction>(hlo);
|
||||
std::vector<llvm_ir::ElementGenerator> input_generators;
|
||||
for (const HloInstruction* instr : reduce_instr->inputs()) {
|
||||
input_generators.push_back(operand_to_generator.at(instr));
|
||||
}
|
||||
|
||||
std::vector<llvm_ir::ElementGenerator> initial_value_generators;
|
||||
for (const HloInstruction* instr : reduce_instr->init_values()) {
|
||||
initial_value_generators.push_back(operand_to_generator.at(instr));
|
||||
}
|
||||
return ir_emitter_->EmitElementalReduce(
|
||||
Cast<HloReduceInstruction>(hlo),
|
||||
operand_to_generator.at(hlo->operand(0)),
|
||||
operand_to_generator.at(hlo->operand(1)), index);
|
||||
reduce_instr, std::move(input_generators),
|
||||
std::move(initial_value_generators), index);
|
||||
};
|
||||
default:
|
||||
return ElementalIrEmitter::MakeElementGenerator(hlo,
|
||||
|
@ -58,8 +58,10 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
@ -106,6 +108,42 @@ IrEmitter::IrEmitter(
|
||||
TF_CHECK_OK(s) << "Should have failed buffer assignment.";
|
||||
}
|
||||
|
||||
void IrEmitter::EmitThreadLocalFunctionEpilogue(HloComputation* computation) {
|
||||
llvm::Argument* out_parameter = compute_function_->result_arg();
|
||||
llvm_ir::IrArray root_value = GetIrArrayFor(computation->root_instruction());
|
||||
const Shape& return_shape = computation->root_instruction()->shape();
|
||||
|
||||
if (ShapeUtil::IsScalar(return_shape)) {
|
||||
llvm::Value* ret_value =
|
||||
Load(root_value.GetBasePointer(), "load_ret_value");
|
||||
Store(ret_value,
|
||||
BitCast(out_parameter, root_value.GetBasePointer()->getType()));
|
||||
} else {
|
||||
CHECK(return_shape.IsTuple());
|
||||
|
||||
llvm::Type* tuple_type = llvm_ir::ShapeToIrType(return_shape, module_);
|
||||
llvm::Type* tuple_type_lvalue = tuple_type->getPointerTo();
|
||||
llvm::Value* tuple_lvalue = BitCast(out_parameter, tuple_type_lvalue);
|
||||
|
||||
for (int i = 0; i < return_shape.tuple_shapes_size(); i++) {
|
||||
const Shape& element_shape = return_shape.tuple_shapes(i);
|
||||
llvm::Value* destination = llvm_ir::EmitGetTupleElement(
|
||||
element_shape,
|
||||
/*index=*/i,
|
||||
/*alignment=*/MinimumAlignmentForShape(element_shape), tuple_lvalue,
|
||||
&b_);
|
||||
|
||||
llvm::Value* source = llvm_ir::EmitGetTupleElement(
|
||||
element_shape,
|
||||
/*index=*/i,
|
||||
/*alignment=*/MinimumAlignmentForShape(element_shape),
|
||||
root_value.GetBasePointer(), &b_);
|
||||
|
||||
Store(Load(source), destination);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
StatusOr<llvm::Function*> IrEmitter::EmitComputation(
|
||||
HloComputation* computation, const string& function_name_prefix,
|
||||
bool is_top_level_computation,
|
||||
@ -143,6 +181,16 @@ StatusOr<llvm::Function*> IrEmitter::EmitComputation(
|
||||
InsertOrDie(&emitted_functions_, computation, ir_function);
|
||||
// Delete 'compute_function', finalizing 'ir_function' and restoring caller
|
||||
// IR insert point.
|
||||
|
||||
// Function epilogue: copying the value over to either the return register,
|
||||
// or values pointing from the return register.
|
||||
const BufferAllocation* root_allocation =
|
||||
computation_root_allocation_.allocation();
|
||||
if (root_allocation && root_allocation->is_thread_local()) {
|
||||
EmitThreadLocalFunctionEpilogue(computation);
|
||||
}
|
||||
|
||||
// Destructor for compute_function_ emits the "ret void" instruction.
|
||||
compute_function_.reset();
|
||||
computation_root_allocation_ = BufferAllocation::Slice();
|
||||
computation_parameter_allocations_.clear();
|
||||
@ -634,7 +682,8 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) {
|
||||
llvm::Value* IrEmitter::EmitElementalMap(
|
||||
const HloMapInstruction& map_instr,
|
||||
absl::Span<llvm::Value* const> elemental_operands, absl::string_view name) {
|
||||
return EmitThreadLocalCall(*map_instr.to_apply(), elemental_operands, name);
|
||||
return EmitScalarReturningThreadLocalCall(*map_instr.to_apply(),
|
||||
elemental_operands, name);
|
||||
}
|
||||
|
||||
StatusOr<llvm::Value*> IrEmitter::EmitElementalReduceWindow(
|
||||
@ -716,7 +765,7 @@ StatusOr<llvm::Value*> IrEmitter::EmitElementalReduceWindow(
|
||||
b_.getInt64Ty());
|
||||
TF_ASSIGN_OR_RETURN(llvm::Value* const input_value,
|
||||
input_generator(input_index));
|
||||
llvm::Value* result = EmitThreadLocalCall(
|
||||
llvm::Value* result = EmitScalarReturningThreadLocalCall(
|
||||
*reduce_window->to_apply(), {Load(accumulator_address), input_value},
|
||||
"reducer_function");
|
||||
Store(result, accumulator_address);
|
||||
@ -868,7 +917,7 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
|
||||
llvm::Value* operand_address =
|
||||
operand_array.EmitArrayElementAddress(operand_index, &b_);
|
||||
llvm::Value* operand_element = Load(operand_address);
|
||||
llvm::Value* result = EmitThreadLocalCall(
|
||||
llvm::Value* result = EmitScalarReturningThreadLocalCall(
|
||||
*select_and_scatter->select(),
|
||||
{Load(selected_value_address), operand_element}, "select_function");
|
||||
|
||||
@ -903,9 +952,9 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
|
||||
selected_multi_index, output_array.GetShape(), source_index.GetType());
|
||||
llvm::Value* output_value =
|
||||
output_array.EmitReadArrayElement(selected_index, &b_);
|
||||
llvm::Value* scatter_value =
|
||||
EmitThreadLocalCall(*select_and_scatter->scatter(),
|
||||
{output_value, source_value}, "scatter_function");
|
||||
llvm::Value* scatter_value = EmitScalarReturningThreadLocalCall(
|
||||
*select_and_scatter->scatter(), {output_value, source_value},
|
||||
"scatter_function");
|
||||
output_array.EmitWriteArrayElement(selected_index, scatter_value, &b_);
|
||||
|
||||
SetToFirstInsertPoint(source_loops.GetOuterLoopExitBasicBlock(), &b_);
|
||||
@ -1665,6 +1714,11 @@ StatusOr<bool> IrEmitter::EmitVectorizedReduce(
|
||||
HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value,
|
||||
absl::Span<const int64> dimensions, HloComputation* function,
|
||||
string* failure_reason) {
|
||||
if (!reduce->shape().IsArray()) {
|
||||
*failure_reason = "vectorization of variadic reduce not implemented";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!ReductionPreservesLayout(*reduce)) {
|
||||
return false;
|
||||
}
|
||||
@ -1813,21 +1867,39 @@ StatusOr<bool> IrEmitter::EmitVectorizedReduce(
|
||||
|
||||
StatusOr<llvm::Value*> IrEmitter::EmitElementalReduce(
|
||||
const HloReduceInstruction* reduce,
|
||||
const llvm_ir::ElementGenerator& input_generator,
|
||||
const llvm_ir::ElementGenerator& initial_value_generator,
|
||||
std::vector<llvm_ir::ElementGenerator> input_generators,
|
||||
std::vector<llvm_ir::ElementGenerator> initial_value_generators,
|
||||
const llvm_ir::IrArray::Index& index) {
|
||||
const HloInstruction* arg = reduce->operand(0);
|
||||
absl::Span<const int64> dimensions(reduce->dimensions());
|
||||
const Shape& out_shape = reduce->shape();
|
||||
bool is_variadic = !out_shape.IsArray();
|
||||
int accumulators_count = 1;
|
||||
if (is_variadic) {
|
||||
CHECK(out_shape.IsTuple());
|
||||
accumulators_count = out_shape.tuple_shapes_size();
|
||||
}
|
||||
|
||||
absl::Span<const int64> reduced_dimensions(reduce->dimensions());
|
||||
|
||||
std::vector<llvm::Value*> accumulator_addrs;
|
||||
std::vector<llvm::Type*> accumulator_types;
|
||||
for (int i = 0; i < accumulators_count; i++) {
|
||||
const Shape& element_shape =
|
||||
is_variadic ? out_shape.tuple_shapes(i) : out_shape;
|
||||
PrimitiveType accumulator_type = element_shape.element_type();
|
||||
llvm::Type* accumulator_llvm_type =
|
||||
llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_);
|
||||
accumulator_types.push_back(accumulator_llvm_type);
|
||||
|
||||
// Initialize an accumulator with init_value.
|
||||
PrimitiveType accumulator_type = reduce->shape().element_type();
|
||||
llvm::AllocaInst* accumulator_addr = llvm_ir::EmitAllocaAtFunctionEntry(
|
||||
llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_), "accumulator",
|
||||
&b_, MinimumAlignmentForPrimitiveType(accumulator_type));
|
||||
accumulator_llvm_type, "accumulator_" + std::to_string(i), &b_,
|
||||
MinimumAlignmentForPrimitiveType(accumulator_type));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
llvm::Value* const init_value,
|
||||
initial_value_generator(llvm_ir::IrArray::Index(index.GetType())));
|
||||
initial_value_generators[i](llvm_ir::IrArray::Index(index.GetType())));
|
||||
Store(init_value, accumulator_addr);
|
||||
accumulator_addrs.push_back(accumulator_addr);
|
||||
}
|
||||
|
||||
// The enclosing loops go over all the target elements. Now we have to compute
|
||||
// the actual target element. For this, we build a new loop nest to iterate
|
||||
@ -1835,14 +1907,15 @@ StatusOr<llvm::Value*> IrEmitter::EmitElementalReduce(
|
||||
// AddLoopsForShapeOnDimensions will return an Index where induction Value*s
|
||||
// are placed for each dimension in dimensions, and all the rest are nullptrs.
|
||||
llvm_ir::ForLoopNest loops(IrName(reduce, "inner"), &b_);
|
||||
const HloInstruction* arg = reduce->operand(0);
|
||||
std::vector<llvm::Value*> input_multi_index =
|
||||
loops.AddLoopsForShapeOnDimensions(arg->shape(), dimensions,
|
||||
loops.AddLoopsForShapeOnDimensions(arg->shape(), reduced_dimensions,
|
||||
"reduction_dim");
|
||||
|
||||
SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_);
|
||||
|
||||
// Build a full index for the input argument, using reduced_dims_index as the
|
||||
// base. In reduced_dims_index only the reduction dimensions are filled in. We
|
||||
// Build a full index for the input argument, using input_multi_index as the
|
||||
// base. In input_multi_index only the reduction dimensions are filled in. We
|
||||
// fill in the rest of the dimensions with induction Value*s taken from
|
||||
// 'index' which iterates over the target array. See the high-level
|
||||
// description in the XLA documentation for details.
|
||||
@ -1857,23 +1930,44 @@ StatusOr<llvm::Value*> IrEmitter::EmitElementalReduce(
|
||||
llvm_ir::IrArray::Index input_index(input_multi_index, arg->shape(),
|
||||
b_.getInt64Ty());
|
||||
|
||||
// Apply the reduction function to the loaded value.
|
||||
TF_ASSIGN_OR_RETURN(llvm::Value* const input_element,
|
||||
input_generator(input_index));
|
||||
llvm::Value* result = EmitThreadLocalCall(
|
||||
*reduce->to_apply(), {Load(accumulator_addr), input_element},
|
||||
"reduce_function");
|
||||
Store(result, accumulator_addr);
|
||||
std::vector<llvm::Value*> reduction_operands;
|
||||
for (llvm::Value* accum : accumulator_addrs) {
|
||||
llvm::Value* accum_value = Load(accum);
|
||||
reduction_operands.push_back(accum_value);
|
||||
}
|
||||
|
||||
for (int i = 0; i < accumulators_count; i++) {
|
||||
TF_ASSIGN_OR_RETURN(llvm::Value* const input_element,
|
||||
input_generators[i](input_index));
|
||||
reduction_operands.push_back(input_element);
|
||||
}
|
||||
|
||||
std::vector<llvm::Value*> results = EmitThreadLocalCall(
|
||||
*reduce->to_apply(), reduction_operands, "reduce_function");
|
||||
|
||||
CHECK(results.size() == accumulators_count);
|
||||
for (int i = 0; i < accumulators_count; i++) {
|
||||
Store(results[i], accumulator_addrs[i]);
|
||||
}
|
||||
SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
|
||||
return Load(accumulator_addr);
|
||||
|
||||
if (is_variadic) {
|
||||
// Emit a structure, as that what the LoopEmitter expects.
|
||||
llvm::Value* returned_structure = llvm::UndefValue::get(
|
||||
llvm::StructType::get(b_.getContext(), accumulator_types));
|
||||
for (int i = 0; i < accumulators_count; i++) {
|
||||
llvm::Value* accumulator_value = Load(accumulator_addrs[i]);
|
||||
returned_structure =
|
||||
b_.CreateInsertValue(returned_structure, accumulator_value, i);
|
||||
}
|
||||
return returned_structure;
|
||||
} else {
|
||||
CHECK_EQ(accumulator_addrs.size(), 1);
|
||||
return Load(accumulator_addrs[0]);
|
||||
}
|
||||
}
|
||||
|
||||
Status IrEmitter::HandleReduce(HloInstruction* reduce) {
|
||||
// TODO(b/118333695): Support variadic reduce.
|
||||
if (!reduce->shape().IsArray()) {
|
||||
return Unimplemented("Variadic reduce is not supported on CPU");
|
||||
}
|
||||
auto arg = reduce->mutable_operand(0);
|
||||
auto init_value = reduce->mutable_operand(1);
|
||||
absl::Span<const int64> dimensions(reduce->dimensions());
|
||||
@ -2848,15 +2942,6 @@ llvm::Value* IrEmitter::EmitThreadLocalBufferPointer(
|
||||
const BufferAllocation::Slice& slice, const Shape& target_shape) {
|
||||
const BufferAllocation& allocation = *slice.allocation();
|
||||
llvm::Value* tempbuf_address = [&]() -> llvm::Value* {
|
||||
if (slice == computation_root_allocation_) {
|
||||
llvm::Argument* retval = compute_function_->result_arg();
|
||||
llvm::AttrBuilder attr_builder;
|
||||
attr_builder.addAlignmentAttr(MinimumAlignmentForShape(target_shape));
|
||||
attr_builder.addDereferenceableAttr(ByteSizeOf(target_shape));
|
||||
retval->addAttrs(attr_builder);
|
||||
return retval;
|
||||
}
|
||||
|
||||
auto param_it =
|
||||
computation_parameter_allocations_.find(slice.allocation()->index());
|
||||
if (param_it != computation_parameter_allocations_.end()) {
|
||||
@ -2966,7 +3051,8 @@ Status IrEmitter::EmitTargetElementLoop(
|
||||
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(target_op));
|
||||
llvm_ir::IrArray target_array = GetIrArrayFor(target_op);
|
||||
|
||||
if (target_op->IsMultiOutputFusion()) {
|
||||
if (target_shape.IsTuple() && (target_op->IsMultiOutputFusion() ||
|
||||
target_op->opcode() == HloOpcode::kReduce)) {
|
||||
// For multiple outputs fusion, we need to emit each operand and the root.
|
||||
TF_RET_CHECK(num_dynamic_loop_bounds_ == 0);
|
||||
std::vector<llvm_ir::IrArray> output_arrays;
|
||||
@ -3048,19 +3134,27 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) {
|
||||
hlo, elemental_emitter.MakeElementGenerator(hlo, operand_to_generator));
|
||||
}
|
||||
|
||||
llvm::Value* IrEmitter::EmitThreadLocalCall(
|
||||
llvm::Value* IrEmitter::EmitScalarReturningThreadLocalCall(
|
||||
const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
|
||||
absl::string_view name) {
|
||||
std::vector<llvm::Value*> return_value =
|
||||
EmitThreadLocalCall(callee, parameters, name);
|
||||
CHECK_EQ(return_value.size(), 1);
|
||||
return return_value[0];
|
||||
}
|
||||
|
||||
std::vector<llvm::Value*> IrEmitter::EmitThreadLocalCall(
|
||||
const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
|
||||
absl::string_view name) {
|
||||
CHECK(absl::c_binary_search(thread_local_computations_, &callee));
|
||||
|
||||
const Shape& return_shape = callee.root_instruction()->shape();
|
||||
|
||||
// Lifting this restriction to allow "small" arrays should be easy. Allowing
|
||||
// larger arrays is difficult because we allocate the buffer for this return
|
||||
// value on the stack.
|
||||
CHECK(ShapeUtil::IsScalar(return_shape));
|
||||
|
||||
PrimitiveType return_type = return_shape.element_type();
|
||||
bool is_scalar_return = ShapeUtil::IsScalar(return_shape);
|
||||
bool is_tuple_of_scalars_return =
|
||||
return_shape.IsTuple() &&
|
||||
absl::c_all_of(return_shape.tuple_shapes(), [&](const Shape& shape) {
|
||||
return ShapeUtil::IsScalar(shape);
|
||||
});
|
||||
CHECK(is_scalar_return || is_tuple_of_scalars_return);
|
||||
|
||||
std::vector<llvm::Value*> parameter_addrs;
|
||||
for (llvm::Value* parameter : parameters) {
|
||||
@ -3071,10 +3165,30 @@ llvm::Value* IrEmitter::EmitThreadLocalCall(
|
||||
parameter_addrs.push_back(parameter_addr);
|
||||
}
|
||||
|
||||
llvm::Type* return_value_buffer_type =
|
||||
llvm_ir::ShapeToIrType(return_shape, module_);
|
||||
std::string retval_alloca_name = absl::StrCat(name, "_return_value_addr");
|
||||
int retval_alignment =
|
||||
is_scalar_return
|
||||
? MinimumAlignmentForPrimitiveType(return_shape.element_type())
|
||||
: 0;
|
||||
llvm::Value* return_value_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
|
||||
llvm_ir::PrimitiveTypeToIrType(return_type, module_),
|
||||
absl::StrCat(name, "_retval_addr"), &b_,
|
||||
MinimumAlignmentForPrimitiveType(return_type));
|
||||
return_value_buffer_type, retval_alloca_name, &b_, retval_alignment);
|
||||
|
||||
std::vector<llvm::Value*> allocas_for_returned_scalars;
|
||||
if (is_scalar_return) {
|
||||
allocas_for_returned_scalars.push_back(return_value_buffer);
|
||||
} else {
|
||||
constexpr int max_tuple_size = 1000;
|
||||
CHECK_LT(return_shape.tuple_shapes_size(), max_tuple_size)
|
||||
<< "Multivalue function can not return more than 1000 elements to avoid"
|
||||
<< " stack smashing";
|
||||
allocas_for_returned_scalars =
|
||||
llvm_ir::EmitTupleAllocasAtFunctionEntry(return_shape, &b_);
|
||||
llvm_ir::IrArray tuple_array(return_value_buffer, return_shape);
|
||||
|
||||
EmitTuple(tuple_array, allocas_for_returned_scalars, &b_);
|
||||
}
|
||||
|
||||
Call(FindOrDie(emitted_functions_, &callee),
|
||||
GetArrayFunctionCallArguments(
|
||||
@ -3085,7 +3199,12 @@ llvm::Value* IrEmitter::EmitThreadLocalCall(
|
||||
llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo()),
|
||||
/*profile_counters_arg=*/GetProfileCountersArgument()));
|
||||
|
||||
return Load(return_value_buffer);
|
||||
std::vector<llvm::Value*> returned_scalars;
|
||||
returned_scalars.reserve(allocas_for_returned_scalars.size());
|
||||
for (llvm::Value* addr : allocas_for_returned_scalars) {
|
||||
returned_scalars.push_back(Load(addr));
|
||||
}
|
||||
return returned_scalars;
|
||||
}
|
||||
|
||||
void IrEmitter::EmitGlobalCall(const HloComputation& callee,
|
||||
|
@ -132,8 +132,8 @@ class IrEmitter : public DfsHloVisitorWithDefault,
|
||||
// Emit code to emit the element at `index` for a reduce instruction.
|
||||
StatusOr<llvm::Value*> EmitElementalReduce(
|
||||
const HloReduceInstruction* reduce,
|
||||
const llvm_ir::ElementGenerator& input_generator,
|
||||
const llvm_ir::ElementGenerator& initial_value_generator,
|
||||
std::vector<llvm_ir::ElementGenerator> input_generators,
|
||||
std::vector<llvm_ir::ElementGenerator> initial_value_generator,
|
||||
const llvm_ir::IrArray::Index& index);
|
||||
|
||||
protected:
|
||||
@ -197,6 +197,14 @@ class IrEmitter : public DfsHloVisitorWithDefault,
|
||||
// Private helper to initialize an IR function for the computation.
|
||||
void InitializeIrFunction(const string& function_name);
|
||||
|
||||
// Emits the copying epilogue for the function,
|
||||
// where it copies the returned value to the reserved alloca.
|
||||
// This is only necessary for thread-local functions.
|
||||
// Note that since the call graph is flattened, if the same function is
|
||||
// called in both thread-local and non-thread-local it would be codegen'd
|
||||
// twice, and we would know whether it's thread-local at codegen time.
|
||||
void EmitThreadLocalFunctionEpilogue(HloComputation* computation);
|
||||
|
||||
// Convenience functions to generate a GEP into the profile counter parameter
|
||||
// which would correspond to the index for a given HLO instruction or
|
||||
// computation.
|
||||
@ -267,11 +275,17 @@ class IrEmitter : public DfsHloVisitorWithDefault,
|
||||
// Emits a call to a thread local function (e.g. to the computation nested
|
||||
// within a reduce or a map). Thread local callees (by definition) only write
|
||||
// to and read from thread local allocations.
|
||||
// Supports only functions returning scalars or tuples of scalars.
|
||||
//
|
||||
// `parameters` holds the *scalar values* that need to be passed to the
|
||||
// callee. The return value is the scalar returned by the callee.
|
||||
llvm::Value* EmitThreadLocalCall(const HloComputation& callee,
|
||||
absl::Span<llvm::Value* const> parameters,
|
||||
std::vector<llvm::Value*> EmitThreadLocalCall(
|
||||
const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
|
||||
absl::string_view name);
|
||||
|
||||
// Similar to EmitThreadLocal, yet assumes that the function returns a scalar.
|
||||
llvm::Value* EmitScalarReturningThreadLocalCall(
|
||||
const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
|
||||
absl::string_view name);
|
||||
|
||||
// Emits a call to a "global" function (e.g. to the computation nested within
|
||||
|
@ -47,14 +47,14 @@ LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator,
|
||||
shape_(target_array.GetShape()),
|
||||
b_(b) {}
|
||||
|
||||
static LoopEmitter::BodyEmitter MakeBodyEmitterForMultiOutputFusion(
|
||||
static LoopEmitter::BodyEmitter MakeBodyEmitterForMultiOutput(
|
||||
const ElementGenerator& target_element_generator,
|
||||
const std::vector<IrArray>& target_arrays, llvm::IRBuilder<>* b) {
|
||||
return [=](const llvm_ir::IrArray::Index array_index) {
|
||||
TF_ASSIGN_OR_RETURN(llvm::Value * target_element,
|
||||
target_element_generator(array_index));
|
||||
CHECK(target_element->getType()->isStructTy())
|
||||
<< "This BodyEmitter is for multi-output fusion, but target element "
|
||||
<< "This BodyEmitter is for multi-output, but target element "
|
||||
"generator does not produce values of struct type.";
|
||||
CHECK_EQ(target_element->getType()->getStructNumElements(),
|
||||
target_arrays.size());
|
||||
@ -70,7 +70,7 @@ static LoopEmitter::BodyEmitter MakeBodyEmitterForMultiOutputFusion(
|
||||
LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator,
|
||||
absl::Span<const IrArray> target_arrays,
|
||||
llvm::IRBuilder<>* b)
|
||||
: body_emitter_(MakeBodyEmitterForMultiOutputFusion(
|
||||
: body_emitter_(MakeBodyEmitterForMultiOutput(
|
||||
target_element_generator,
|
||||
std::vector<IrArray>(target_arrays.begin(), target_arrays.end()), b)),
|
||||
shape_(target_arrays[0].GetShape()),
|
||||
|
@ -91,6 +91,32 @@ void EmitTuple(const IrArray& tuple, absl::Span<const IrArray> buffers,
|
||||
llvm_ir::EmitTuple(tuple, buffer_ptrs, b);
|
||||
}
|
||||
|
||||
std::vector<llvm::Value*> EmitTupleAllocasAtFunctionEntry(
|
||||
const Shape& tuple_shape, llvm::IRBuilder<>* b) {
|
||||
llvm::Module* module = b->GetInsertBlock()->getModule();
|
||||
|
||||
llvm::IRBuilder<>::InsertPointGuard guard(*b);
|
||||
llvm::Function* function = b->GetInsertBlock()->getParent();
|
||||
b->SetInsertPoint(&function->getEntryBlock(),
|
||||
function->getEntryBlock().getFirstInsertionPt());
|
||||
CHECK(tuple_shape.IsTuple());
|
||||
int tuple_size = tuple_shape.tuple_shapes_size();
|
||||
|
||||
std::vector<llvm::Value*> generated_allocas;
|
||||
for (int i = 0; i < tuple_size; i++) {
|
||||
const Shape& element_shape = tuple_shape.tuple_shapes(i);
|
||||
CHECK(ShapeUtil::IsScalar(element_shape));
|
||||
llvm::Type* type =
|
||||
llvm_ir::PrimitiveTypeToIrType(element_shape.element_type(), module);
|
||||
llvm::AllocaInst* alloca = b->CreateAlloca(
|
||||
type,
|
||||
/*ArraySize=*/nullptr, AsStringRef(absl::StrCat("tuple_element_", i)));
|
||||
generated_allocas.push_back(alloca);
|
||||
}
|
||||
|
||||
return generated_allocas;
|
||||
}
|
||||
|
||||
llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index,
|
||||
int alignment, llvm::Value* operand,
|
||||
llvm::IRBuilder<>* b) {
|
||||
|
@ -68,6 +68,12 @@ void EmitTupleSelect(const IrArray& select, const IrArray& pred,
|
||||
void EmitTuple(const IrArray& tuple, absl::Span<llvm::Value* const> operands,
|
||||
llvm::IRBuilder<>* b);
|
||||
|
||||
// Emits one alloca for each element in the tuple of shape tuple_shape,
|
||||
// returns the emitted allocas.
|
||||
// Precondition: tuple_shape should be a tuple of scalars.
|
||||
std::vector<llvm::Value*> EmitTupleAllocasAtFunctionEntry(
|
||||
const Shape& tuple_shape, llvm::IRBuilder<>* b);
|
||||
|
||||
// Similar to EmitTuple above, except that the output buffers are provided in
|
||||
// the form of IrArray.
|
||||
void EmitTuple(const IrArray& tuple, absl::Span<const IrArray> buffers,
|
||||
|
@ -1168,6 +1168,7 @@ xla_test(
|
||||
"//tensorflow/compiler/xla/client:xla_computation",
|
||||
"//tensorflow/compiler/xla/client/lib:arithmetic",
|
||||
"//tensorflow/compiler/xla/tests:client_library_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:lib",
|
||||
|
@ -51,6 +51,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
@ -1002,5 +1003,100 @@ XLA_TEST_F(ReduceTest, R0ReduceInDisguise) {
|
||||
ErrorSpec(0.001));
|
||||
}
|
||||
|
||||
class VariadicReduceTest : public HloTestBase {};
|
||||
|
||||
XLA_TEST_F(VariadicReduceTest, DISABLED_ON_GPU(Reduce_R3x2_to_R1x2_simple)) {
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule Reduce_R3x2_to_R1x2_simple
|
||||
|
||||
add {
|
||||
op1 = f32[] parameter(0)
|
||||
op2 = f32[] parameter(1)
|
||||
acc1 = f32[] parameter(2)
|
||||
acc2 = f32[] parameter(3)
|
||||
out1 = f32[] add(acc1, op1)
|
||||
out2 = f32[] add(acc2, op2)
|
||||
ROOT result = (f32[], f32[]) tuple(out1, out2)
|
||||
}
|
||||
|
||||
ENTRY main {
|
||||
inp1 = f32[10,20,3] parameter(0)
|
||||
inp2 = f32[10,20,3] parameter(1)
|
||||
zero = f32[] constant(0)
|
||||
|
||||
ROOT out = (f32[10], f32[10]) reduce(inp1, inp2, zero, zero),
|
||||
dimensions={1,2},
|
||||
to_apply=add
|
||||
}
|
||||
)";
|
||||
|
||||
EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-5, 1e-5}));
|
||||
}
|
||||
|
||||
XLA_TEST_F(VariadicReduceTest, DISABLED_ON_GPU(Reduce_R1x2_to_R0x2_simple)) {
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule Reduce_R1x2_to_R0x2_simple
|
||||
|
||||
add {
|
||||
op1 = f32[] parameter(0)
|
||||
op2 = f32[] parameter(1)
|
||||
acc1 = f32[] parameter(2)
|
||||
acc2 = f32[] parameter(3)
|
||||
out1 = f32[] add(acc1, op1)
|
||||
out2 = f32[] add(acc2, op2)
|
||||
ROOT result = (f32[], f32[]) tuple(out1, out2)
|
||||
}
|
||||
|
||||
ENTRY main {
|
||||
inp1 = f32[100] parameter(0)
|
||||
inp2 = f32[100] parameter(1)
|
||||
zero = f32[] constant(0)
|
||||
|
||||
ROOT out = (f32[], f32[]) reduce(inp1, inp2, zero, zero),
|
||||
dimensions={0},
|
||||
to_apply=add
|
||||
}
|
||||
)";
|
||||
|
||||
EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-5, 1e-5}));
|
||||
}
|
||||
|
||||
XLA_TEST_F(VariadicReduceTest, DISABLED_ON_GPU(Reduce_R1x2_to_R0x2_argmax)) {
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule Reduce_R1x2_to_R0x2_argmax
|
||||
|
||||
argmax {
|
||||
running_max = u32[] parameter(0)
|
||||
running_max_idx = u32[] parameter(1)
|
||||
current_value = u32[] parameter(2)
|
||||
current_value_idx = u32[] parameter(3)
|
||||
|
||||
current = (u32[], u32[]) tuple(running_max, running_max_idx)
|
||||
potential = (u32[], u32[]) tuple(current_value, current_value_idx)
|
||||
|
||||
cmp_code = pred[] compare(current_value, running_max), direction=GT
|
||||
|
||||
new_max = u32[] select(cmp_code, current_value, running_max)
|
||||
new_idx = u32[] select(cmp_code, current_value_idx, running_max_idx)
|
||||
|
||||
ROOT out = (u32[], u32[]) tuple(new_max, new_idx)
|
||||
}
|
||||
|
||||
ENTRY main {
|
||||
input = u32[100] parameter(0)
|
||||
idxs = u32[100]{0} iota(), iota_dimension=0
|
||||
zero = u32[] constant(0)
|
||||
zero_idx = u32[] constant(0)
|
||||
|
||||
ROOT out = (u32[], u32[]) reduce(
|
||||
input, idxs, zero, zero_idx),
|
||||
dimensions={0},
|
||||
to_apply=%argmax
|
||||
}
|
||||
)";
|
||||
|
||||
EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-5, 1e-5}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
Loading…
Reference in New Issue
Block a user