Variadic reduce implementation on CPU

Implements variadic reduce on the CPU backend.
Before this change, thread-local functions could only return scalars.
The biggest part of this change is allowing thread-local functions to return
tuples of scalars, which required changes to function generation,
allocation of space for the returned values on the callers side,
and generating the function epilogue.

PiperOrigin-RevId: 239022395
This commit is contained in:
George Karpenkov 2019-03-18 11:12:23 -07:00 committed by TensorFlower Gardener
parent 1a690e6c75
commit 803b0f5151
8 changed files with 340 additions and 68 deletions

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Module.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
@ -135,10 +136,19 @@ llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator(
};
case HloOpcode::kReduce:
return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
auto reduce_instr = Cast<HloReduceInstruction>(hlo);
std::vector<llvm_ir::ElementGenerator> input_generators;
for (const HloInstruction* instr : reduce_instr->inputs()) {
input_generators.push_back(operand_to_generator.at(instr));
}
std::vector<llvm_ir::ElementGenerator> initial_value_generators;
for (const HloInstruction* instr : reduce_instr->init_values()) {
initial_value_generators.push_back(operand_to_generator.at(instr));
}
return ir_emitter_->EmitElementalReduce(
Cast<HloReduceInstruction>(hlo),
operand_to_generator.at(hlo->operand(0)),
operand_to_generator.at(hlo->operand(1)), index);
reduce_instr, std::move(input_generators),
std::move(initial_value_generators), index);
};
default:
return ElementalIrEmitter::MakeElementGenerator(hlo,

View File

@ -58,8 +58,10 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
#include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h"
#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
#include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@ -106,6 +108,42 @@ IrEmitter::IrEmitter(
TF_CHECK_OK(s) << "Should have failed buffer assignment.";
}
void IrEmitter::EmitThreadLocalFunctionEpilogue(HloComputation* computation) {
llvm::Argument* out_parameter = compute_function_->result_arg();
llvm_ir::IrArray root_value = GetIrArrayFor(computation->root_instruction());
const Shape& return_shape = computation->root_instruction()->shape();
if (ShapeUtil::IsScalar(return_shape)) {
llvm::Value* ret_value =
Load(root_value.GetBasePointer(), "load_ret_value");
Store(ret_value,
BitCast(out_parameter, root_value.GetBasePointer()->getType()));
} else {
CHECK(return_shape.IsTuple());
llvm::Type* tuple_type = llvm_ir::ShapeToIrType(return_shape, module_);
llvm::Type* tuple_type_lvalue = tuple_type->getPointerTo();
llvm::Value* tuple_lvalue = BitCast(out_parameter, tuple_type_lvalue);
for (int i = 0; i < return_shape.tuple_shapes_size(); i++) {
const Shape& element_shape = return_shape.tuple_shapes(i);
llvm::Value* destination = llvm_ir::EmitGetTupleElement(
element_shape,
/*index=*/i,
/*alignment=*/MinimumAlignmentForShape(element_shape), tuple_lvalue,
&b_);
llvm::Value* source = llvm_ir::EmitGetTupleElement(
element_shape,
/*index=*/i,
/*alignment=*/MinimumAlignmentForShape(element_shape),
root_value.GetBasePointer(), &b_);
Store(Load(source), destination);
}
}
}
StatusOr<llvm::Function*> IrEmitter::EmitComputation(
HloComputation* computation, const string& function_name_prefix,
bool is_top_level_computation,
@ -143,6 +181,16 @@ StatusOr<llvm::Function*> IrEmitter::EmitComputation(
InsertOrDie(&emitted_functions_, computation, ir_function);
// Delete 'compute_function', finalizing 'ir_function' and restoring caller
// IR insert point.
// Function epilogue: copying the value over to either the return register,
// or values pointing from the return register.
const BufferAllocation* root_allocation =
computation_root_allocation_.allocation();
if (root_allocation && root_allocation->is_thread_local()) {
EmitThreadLocalFunctionEpilogue(computation);
}
// Destructor for compute_function_ emits the "ret void" instruction.
compute_function_.reset();
computation_root_allocation_ = BufferAllocation::Slice();
computation_parameter_allocations_.clear();
@ -634,7 +682,8 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) {
llvm::Value* IrEmitter::EmitElementalMap(
const HloMapInstruction& map_instr,
absl::Span<llvm::Value* const> elemental_operands, absl::string_view name) {
return EmitThreadLocalCall(*map_instr.to_apply(), elemental_operands, name);
return EmitScalarReturningThreadLocalCall(*map_instr.to_apply(),
elemental_operands, name);
}
StatusOr<llvm::Value*> IrEmitter::EmitElementalReduceWindow(
@ -716,7 +765,7 @@ StatusOr<llvm::Value*> IrEmitter::EmitElementalReduceWindow(
b_.getInt64Ty());
TF_ASSIGN_OR_RETURN(llvm::Value* const input_value,
input_generator(input_index));
llvm::Value* result = EmitThreadLocalCall(
llvm::Value* result = EmitScalarReturningThreadLocalCall(
*reduce_window->to_apply(), {Load(accumulator_address), input_value},
"reducer_function");
Store(result, accumulator_address);
@ -868,7 +917,7 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
llvm::Value* operand_address =
operand_array.EmitArrayElementAddress(operand_index, &b_);
llvm::Value* operand_element = Load(operand_address);
llvm::Value* result = EmitThreadLocalCall(
llvm::Value* result = EmitScalarReturningThreadLocalCall(
*select_and_scatter->select(),
{Load(selected_value_address), operand_element}, "select_function");
@ -903,9 +952,9 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
selected_multi_index, output_array.GetShape(), source_index.GetType());
llvm::Value* output_value =
output_array.EmitReadArrayElement(selected_index, &b_);
llvm::Value* scatter_value =
EmitThreadLocalCall(*select_and_scatter->scatter(),
{output_value, source_value}, "scatter_function");
llvm::Value* scatter_value = EmitScalarReturningThreadLocalCall(
*select_and_scatter->scatter(), {output_value, source_value},
"scatter_function");
output_array.EmitWriteArrayElement(selected_index, scatter_value, &b_);
SetToFirstInsertPoint(source_loops.GetOuterLoopExitBasicBlock(), &b_);
@ -1665,6 +1714,11 @@ StatusOr<bool> IrEmitter::EmitVectorizedReduce(
HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value,
absl::Span<const int64> dimensions, HloComputation* function,
string* failure_reason) {
if (!reduce->shape().IsArray()) {
*failure_reason = "vectorization of variadic reduce not implemented";
return false;
}
if (!ReductionPreservesLayout(*reduce)) {
return false;
}
@ -1813,21 +1867,39 @@ StatusOr<bool> IrEmitter::EmitVectorizedReduce(
StatusOr<llvm::Value*> IrEmitter::EmitElementalReduce(
const HloReduceInstruction* reduce,
const llvm_ir::ElementGenerator& input_generator,
const llvm_ir::ElementGenerator& initial_value_generator,
std::vector<llvm_ir::ElementGenerator> input_generators,
std::vector<llvm_ir::ElementGenerator> initial_value_generators,
const llvm_ir::IrArray::Index& index) {
const HloInstruction* arg = reduce->operand(0);
absl::Span<const int64> dimensions(reduce->dimensions());
const Shape& out_shape = reduce->shape();
bool is_variadic = !out_shape.IsArray();
int accumulators_count = 1;
if (is_variadic) {
CHECK(out_shape.IsTuple());
accumulators_count = out_shape.tuple_shapes_size();
}
absl::Span<const int64> reduced_dimensions(reduce->dimensions());
std::vector<llvm::Value*> accumulator_addrs;
std::vector<llvm::Type*> accumulator_types;
for (int i = 0; i < accumulators_count; i++) {
const Shape& element_shape =
is_variadic ? out_shape.tuple_shapes(i) : out_shape;
PrimitiveType accumulator_type = element_shape.element_type();
llvm::Type* accumulator_llvm_type =
llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_);
accumulator_types.push_back(accumulator_llvm_type);
// Initialize an accumulator with init_value.
PrimitiveType accumulator_type = reduce->shape().element_type();
llvm::AllocaInst* accumulator_addr = llvm_ir::EmitAllocaAtFunctionEntry(
llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_), "accumulator",
&b_, MinimumAlignmentForPrimitiveType(accumulator_type));
accumulator_llvm_type, "accumulator_" + std::to_string(i), &b_,
MinimumAlignmentForPrimitiveType(accumulator_type));
TF_ASSIGN_OR_RETURN(
llvm::Value* const init_value,
initial_value_generator(llvm_ir::IrArray::Index(index.GetType())));
initial_value_generators[i](llvm_ir::IrArray::Index(index.GetType())));
Store(init_value, accumulator_addr);
accumulator_addrs.push_back(accumulator_addr);
}
// The enclosing loops go over all the target elements. Now we have to compute
// the actual target element. For this, we build a new loop nest to iterate
@ -1835,14 +1907,15 @@ StatusOr<llvm::Value*> IrEmitter::EmitElementalReduce(
// AddLoopsForShapeOnDimensions will return an Index where induction Value*s
// are placed for each dimension in dimensions, and all the rest are nullptrs.
llvm_ir::ForLoopNest loops(IrName(reduce, "inner"), &b_);
const HloInstruction* arg = reduce->operand(0);
std::vector<llvm::Value*> input_multi_index =
loops.AddLoopsForShapeOnDimensions(arg->shape(), dimensions,
loops.AddLoopsForShapeOnDimensions(arg->shape(), reduced_dimensions,
"reduction_dim");
SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_);
// Build a full index for the input argument, using reduced_dims_index as the
// base. In reduced_dims_index only the reduction dimensions are filled in. We
// Build a full index for the input argument, using input_multi_index as the
// base. In input_multi_index only the reduction dimensions are filled in. We
// fill in the rest of the dimensions with induction Value*s taken from
// 'index' which iterates over the target array. See the high-level
// description in the XLA documentation for details.
@ -1857,23 +1930,44 @@ StatusOr<llvm::Value*> IrEmitter::EmitElementalReduce(
llvm_ir::IrArray::Index input_index(input_multi_index, arg->shape(),
b_.getInt64Ty());
// Apply the reduction function to the loaded value.
TF_ASSIGN_OR_RETURN(llvm::Value* const input_element,
input_generator(input_index));
llvm::Value* result = EmitThreadLocalCall(
*reduce->to_apply(), {Load(accumulator_addr), input_element},
"reduce_function");
Store(result, accumulator_addr);
std::vector<llvm::Value*> reduction_operands;
for (llvm::Value* accum : accumulator_addrs) {
llvm::Value* accum_value = Load(accum);
reduction_operands.push_back(accum_value);
}
for (int i = 0; i < accumulators_count; i++) {
TF_ASSIGN_OR_RETURN(llvm::Value* const input_element,
input_generators[i](input_index));
reduction_operands.push_back(input_element);
}
std::vector<llvm::Value*> results = EmitThreadLocalCall(
*reduce->to_apply(), reduction_operands, "reduce_function");
CHECK(results.size() == accumulators_count);
for (int i = 0; i < accumulators_count; i++) {
Store(results[i], accumulator_addrs[i]);
}
SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
return Load(accumulator_addr);
if (is_variadic) {
// Emit a structure, as that what the LoopEmitter expects.
llvm::Value* returned_structure = llvm::UndefValue::get(
llvm::StructType::get(b_.getContext(), accumulator_types));
for (int i = 0; i < accumulators_count; i++) {
llvm::Value* accumulator_value = Load(accumulator_addrs[i]);
returned_structure =
b_.CreateInsertValue(returned_structure, accumulator_value, i);
}
return returned_structure;
} else {
CHECK_EQ(accumulator_addrs.size(), 1);
return Load(accumulator_addrs[0]);
}
}
Status IrEmitter::HandleReduce(HloInstruction* reduce) {
// TODO(b/118333695): Support variadic reduce.
if (!reduce->shape().IsArray()) {
return Unimplemented("Variadic reduce is not supported on CPU");
}
auto arg = reduce->mutable_operand(0);
auto init_value = reduce->mutable_operand(1);
absl::Span<const int64> dimensions(reduce->dimensions());
@ -2848,15 +2942,6 @@ llvm::Value* IrEmitter::EmitThreadLocalBufferPointer(
const BufferAllocation::Slice& slice, const Shape& target_shape) {
const BufferAllocation& allocation = *slice.allocation();
llvm::Value* tempbuf_address = [&]() -> llvm::Value* {
if (slice == computation_root_allocation_) {
llvm::Argument* retval = compute_function_->result_arg();
llvm::AttrBuilder attr_builder;
attr_builder.addAlignmentAttr(MinimumAlignmentForShape(target_shape));
attr_builder.addDereferenceableAttr(ByteSizeOf(target_shape));
retval->addAttrs(attr_builder);
return retval;
}
auto param_it =
computation_parameter_allocations_.find(slice.allocation()->index());
if (param_it != computation_parameter_allocations_.end()) {
@ -2966,7 +3051,8 @@ Status IrEmitter::EmitTargetElementLoop(
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(target_op));
llvm_ir::IrArray target_array = GetIrArrayFor(target_op);
if (target_op->IsMultiOutputFusion()) {
if (target_shape.IsTuple() && (target_op->IsMultiOutputFusion() ||
target_op->opcode() == HloOpcode::kReduce)) {
// For multiple outputs fusion, we need to emit each operand and the root.
TF_RET_CHECK(num_dynamic_loop_bounds_ == 0);
std::vector<llvm_ir::IrArray> output_arrays;
@ -3048,19 +3134,27 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) {
hlo, elemental_emitter.MakeElementGenerator(hlo, operand_to_generator));
}
llvm::Value* IrEmitter::EmitThreadLocalCall(
llvm::Value* IrEmitter::EmitScalarReturningThreadLocalCall(
const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
absl::string_view name) {
std::vector<llvm::Value*> return_value =
EmitThreadLocalCall(callee, parameters, name);
CHECK_EQ(return_value.size(), 1);
return return_value[0];
}
std::vector<llvm::Value*> IrEmitter::EmitThreadLocalCall(
const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
absl::string_view name) {
CHECK(absl::c_binary_search(thread_local_computations_, &callee));
const Shape& return_shape = callee.root_instruction()->shape();
// Lifting this restriction to allow "small" arrays should be easy. Allowing
// larger arrays is difficult because we allocate the buffer for this return
// value on the stack.
CHECK(ShapeUtil::IsScalar(return_shape));
PrimitiveType return_type = return_shape.element_type();
bool is_scalar_return = ShapeUtil::IsScalar(return_shape);
bool is_tuple_of_scalars_return =
return_shape.IsTuple() &&
absl::c_all_of(return_shape.tuple_shapes(), [&](const Shape& shape) {
return ShapeUtil::IsScalar(shape);
});
CHECK(is_scalar_return || is_tuple_of_scalars_return);
std::vector<llvm::Value*> parameter_addrs;
for (llvm::Value* parameter : parameters) {
@ -3071,10 +3165,30 @@ llvm::Value* IrEmitter::EmitThreadLocalCall(
parameter_addrs.push_back(parameter_addr);
}
llvm::Type* return_value_buffer_type =
llvm_ir::ShapeToIrType(return_shape, module_);
std::string retval_alloca_name = absl::StrCat(name, "_return_value_addr");
int retval_alignment =
is_scalar_return
? MinimumAlignmentForPrimitiveType(return_shape.element_type())
: 0;
llvm::Value* return_value_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
llvm_ir::PrimitiveTypeToIrType(return_type, module_),
absl::StrCat(name, "_retval_addr"), &b_,
MinimumAlignmentForPrimitiveType(return_type));
return_value_buffer_type, retval_alloca_name, &b_, retval_alignment);
std::vector<llvm::Value*> allocas_for_returned_scalars;
if (is_scalar_return) {
allocas_for_returned_scalars.push_back(return_value_buffer);
} else {
constexpr int max_tuple_size = 1000;
CHECK_LT(return_shape.tuple_shapes_size(), max_tuple_size)
<< "Multivalue function can not return more than 1000 elements to avoid"
<< " stack smashing";
allocas_for_returned_scalars =
llvm_ir::EmitTupleAllocasAtFunctionEntry(return_shape, &b_);
llvm_ir::IrArray tuple_array(return_value_buffer, return_shape);
EmitTuple(tuple_array, allocas_for_returned_scalars, &b_);
}
Call(FindOrDie(emitted_functions_, &callee),
GetArrayFunctionCallArguments(
@ -3085,7 +3199,12 @@ llvm::Value* IrEmitter::EmitThreadLocalCall(
llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo()),
/*profile_counters_arg=*/GetProfileCountersArgument()));
return Load(return_value_buffer);
std::vector<llvm::Value*> returned_scalars;
returned_scalars.reserve(allocas_for_returned_scalars.size());
for (llvm::Value* addr : allocas_for_returned_scalars) {
returned_scalars.push_back(Load(addr));
}
return returned_scalars;
}
void IrEmitter::EmitGlobalCall(const HloComputation& callee,

View File

@ -132,8 +132,8 @@ class IrEmitter : public DfsHloVisitorWithDefault,
// Emit code to emit the element at `index` for a reduce instruction.
StatusOr<llvm::Value*> EmitElementalReduce(
const HloReduceInstruction* reduce,
const llvm_ir::ElementGenerator& input_generator,
const llvm_ir::ElementGenerator& initial_value_generator,
std::vector<llvm_ir::ElementGenerator> input_generators,
std::vector<llvm_ir::ElementGenerator> initial_value_generator,
const llvm_ir::IrArray::Index& index);
protected:
@ -197,6 +197,14 @@ class IrEmitter : public DfsHloVisitorWithDefault,
// Private helper to initialize an IR function for the computation.
void InitializeIrFunction(const string& function_name);
// Emits the copying epilogue for the function,
// where it copies the returned value to the reserved alloca.
// This is only necessary for thread-local functions.
// Note that since the call graph is flattened, if the same function is
// called in both thread-local and non-thread-local it would be codegen'd
// twice, and we would know whether it's thread-local at codegen time.
void EmitThreadLocalFunctionEpilogue(HloComputation* computation);
// Convenience functions to generate a GEP into the profile counter parameter
// which would correspond to the index for a given HLO instruction or
// computation.
@ -267,11 +275,17 @@ class IrEmitter : public DfsHloVisitorWithDefault,
// Emits a call to a thread local function (e.g. to the computation nested
// within a reduce or a map). Thread local callees (by definition) only write
// to and read from thread local allocations.
// Supports only functions returning scalars or tuples of scalars.
//
// `parameters` holds the *scalar values* that need to be passed to the
// callee. The return value is the scalar returned by the callee.
llvm::Value* EmitThreadLocalCall(const HloComputation& callee,
absl::Span<llvm::Value* const> parameters,
std::vector<llvm::Value*> EmitThreadLocalCall(
const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
absl::string_view name);
// Similar to EmitThreadLocal, yet assumes that the function returns a scalar.
llvm::Value* EmitScalarReturningThreadLocalCall(
const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
absl::string_view name);
// Emits a call to a "global" function (e.g. to the computation nested within

View File

@ -47,14 +47,14 @@ LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator,
shape_(target_array.GetShape()),
b_(b) {}
static LoopEmitter::BodyEmitter MakeBodyEmitterForMultiOutputFusion(
static LoopEmitter::BodyEmitter MakeBodyEmitterForMultiOutput(
const ElementGenerator& target_element_generator,
const std::vector<IrArray>& target_arrays, llvm::IRBuilder<>* b) {
return [=](const llvm_ir::IrArray::Index array_index) {
TF_ASSIGN_OR_RETURN(llvm::Value * target_element,
target_element_generator(array_index));
CHECK(target_element->getType()->isStructTy())
<< "This BodyEmitter is for multi-output fusion, but target element "
<< "This BodyEmitter is for multi-output, but target element "
"generator does not produce values of struct type.";
CHECK_EQ(target_element->getType()->getStructNumElements(),
target_arrays.size());
@ -70,7 +70,7 @@ static LoopEmitter::BodyEmitter MakeBodyEmitterForMultiOutputFusion(
LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator,
absl::Span<const IrArray> target_arrays,
llvm::IRBuilder<>* b)
: body_emitter_(MakeBodyEmitterForMultiOutputFusion(
: body_emitter_(MakeBodyEmitterForMultiOutput(
target_element_generator,
std::vector<IrArray>(target_arrays.begin(), target_arrays.end()), b)),
shape_(target_arrays[0].GetShape()),

View File

@ -91,6 +91,32 @@ void EmitTuple(const IrArray& tuple, absl::Span<const IrArray> buffers,
llvm_ir::EmitTuple(tuple, buffer_ptrs, b);
}
std::vector<llvm::Value*> EmitTupleAllocasAtFunctionEntry(
const Shape& tuple_shape, llvm::IRBuilder<>* b) {
llvm::Module* module = b->GetInsertBlock()->getModule();
llvm::IRBuilder<>::InsertPointGuard guard(*b);
llvm::Function* function = b->GetInsertBlock()->getParent();
b->SetInsertPoint(&function->getEntryBlock(),
function->getEntryBlock().getFirstInsertionPt());
CHECK(tuple_shape.IsTuple());
int tuple_size = tuple_shape.tuple_shapes_size();
std::vector<llvm::Value*> generated_allocas;
for (int i = 0; i < tuple_size; i++) {
const Shape& element_shape = tuple_shape.tuple_shapes(i);
CHECK(ShapeUtil::IsScalar(element_shape));
llvm::Type* type =
llvm_ir::PrimitiveTypeToIrType(element_shape.element_type(), module);
llvm::AllocaInst* alloca = b->CreateAlloca(
type,
/*ArraySize=*/nullptr, AsStringRef(absl::StrCat("tuple_element_", i)));
generated_allocas.push_back(alloca);
}
return generated_allocas;
}
llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index,
int alignment, llvm::Value* operand,
llvm::IRBuilder<>* b) {

View File

@ -68,6 +68,12 @@ void EmitTupleSelect(const IrArray& select, const IrArray& pred,
void EmitTuple(const IrArray& tuple, absl::Span<llvm::Value* const> operands,
llvm::IRBuilder<>* b);
// Emits one alloca for each element in the tuple of shape tuple_shape,
// returns the emitted allocas.
// Precondition: tuple_shape should be a tuple of scalars.
std::vector<llvm::Value*> EmitTupleAllocasAtFunctionEntry(
const Shape& tuple_shape, llvm::IRBuilder<>* b);
// Similar to EmitTuple above, except that the output buffers are provided in
// the form of IrArray.
void EmitTuple(const IrArray& tuple, absl::Span<const IrArray> buffers,

View File

@ -1168,6 +1168,7 @@ xla_test(
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/client/lib:arithmetic",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",

View File

@ -51,6 +51,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/util.h"
@ -1002,5 +1003,100 @@ XLA_TEST_F(ReduceTest, R0ReduceInDisguise) {
ErrorSpec(0.001));
}
class VariadicReduceTest : public HloTestBase {};
XLA_TEST_F(VariadicReduceTest, DISABLED_ON_GPU(Reduce_R3x2_to_R1x2_simple)) {
absl::string_view hlo_string = R"(
HloModule Reduce_R3x2_to_R1x2_simple
add {
op1 = f32[] parameter(0)
op2 = f32[] parameter(1)
acc1 = f32[] parameter(2)
acc2 = f32[] parameter(3)
out1 = f32[] add(acc1, op1)
out2 = f32[] add(acc2, op2)
ROOT result = (f32[], f32[]) tuple(out1, out2)
}
ENTRY main {
inp1 = f32[10,20,3] parameter(0)
inp2 = f32[10,20,3] parameter(1)
zero = f32[] constant(0)
ROOT out = (f32[10], f32[10]) reduce(inp1, inp2, zero, zero),
dimensions={1,2},
to_apply=add
}
)";
EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-5, 1e-5}));
}
XLA_TEST_F(VariadicReduceTest, DISABLED_ON_GPU(Reduce_R1x2_to_R0x2_simple)) {
absl::string_view hlo_string = R"(
HloModule Reduce_R1x2_to_R0x2_simple
add {
op1 = f32[] parameter(0)
op2 = f32[] parameter(1)
acc1 = f32[] parameter(2)
acc2 = f32[] parameter(3)
out1 = f32[] add(acc1, op1)
out2 = f32[] add(acc2, op2)
ROOT result = (f32[], f32[]) tuple(out1, out2)
}
ENTRY main {
inp1 = f32[100] parameter(0)
inp2 = f32[100] parameter(1)
zero = f32[] constant(0)
ROOT out = (f32[], f32[]) reduce(inp1, inp2, zero, zero),
dimensions={0},
to_apply=add
}
)";
EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-5, 1e-5}));
}
XLA_TEST_F(VariadicReduceTest, DISABLED_ON_GPU(Reduce_R1x2_to_R0x2_argmax)) {
absl::string_view hlo_string = R"(
HloModule Reduce_R1x2_to_R0x2_argmax
argmax {
running_max = u32[] parameter(0)
running_max_idx = u32[] parameter(1)
current_value = u32[] parameter(2)
current_value_idx = u32[] parameter(3)
current = (u32[], u32[]) tuple(running_max, running_max_idx)
potential = (u32[], u32[]) tuple(current_value, current_value_idx)
cmp_code = pred[] compare(current_value, running_max), direction=GT
new_max = u32[] select(cmp_code, current_value, running_max)
new_idx = u32[] select(cmp_code, current_value_idx, running_max_idx)
ROOT out = (u32[], u32[]) tuple(new_max, new_idx)
}
ENTRY main {
input = u32[100] parameter(0)
idxs = u32[100]{0} iota(), iota_dimension=0
zero = u32[] constant(0)
zero_idx = u32[] constant(0)
ROOT out = (u32[], u32[]) reduce(
input, idxs, zero, zero_idx),
dimensions={0},
to_apply=%argmax
}
)";
EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-5, 1e-5}));
}
} // namespace
} // namespace xla