[XLA:CPU] Support multi-operand all-reduce on CPU
PiperOrigin-RevId: 305169720 Change-Id: I98c64d171fdc02692103260fcf2bc78bb2be7a7b
This commit is contained in:
parent
ef5bd97017
commit
4cd7132e91
@ -263,7 +263,6 @@ class CpuAllReduceRendezvous
|
|||||||
protected:
|
protected:
|
||||||
xla::StatusOr<ParticipantImplOutput> SubmitParticipantImpl(
|
xla::StatusOr<ParticipantImplOutput> SubmitParticipantImpl(
|
||||||
const xla::AllReduceParticipantData& participant) override {
|
const xla::AllReduceParticipantData& participant) override {
|
||||||
TF_RET_CHECK(participant.buffers.size() == 1);
|
|
||||||
xla::PrimitiveType datatype = participant.buffers.front().primitive_type;
|
xla::PrimitiveType datatype = participant.buffers.front().primitive_type;
|
||||||
bool primary = [&] {
|
bool primary = [&] {
|
||||||
tensorflow::mutex_lock lock(mu_);
|
tensorflow::mutex_lock lock(mu_);
|
||||||
@ -321,60 +320,87 @@ class CpuAllReduceRendezvous
|
|||||||
for (const auto& p : participants_) {
|
for (const auto& p : participants_) {
|
||||||
CHECK(p.reduction_kind == reduction_kind);
|
CHECK(p.reduction_kind == reduction_kind);
|
||||||
}
|
}
|
||||||
|
int num_participants = participants_.size();
|
||||||
|
|
||||||
std::vector<absl::Span<T>> input_buffers;
|
// participant_idx -> buffer_idx -> buffer.
|
||||||
std::vector<absl::Span<T>> output_buffers;
|
std::vector<std::vector<absl::Span<T>>> input_buffers;
|
||||||
input_buffers.reserve(participants_.size());
|
std::vector<std::vector<absl::Span<T>>> output_buffers;
|
||||||
output_buffers.reserve(participants_.size());
|
input_buffers.reserve(num_participants);
|
||||||
|
output_buffers.reserve(num_participants);
|
||||||
|
const xla::AllReduceParticipantData& first_participant =
|
||||||
|
participants_.front();
|
||||||
|
|
||||||
for (auto& p : participants_) {
|
int buffers_per_participant = first_participant.buffers.size();
|
||||||
CHECK_EQ(p.buffers.size(), 1);
|
for (xla::AllReduceParticipantData& p : participants_) {
|
||||||
CHECK_EQ(p.buffers.front().element_count,
|
CHECK_EQ(p.buffers.size(), buffers_per_participant);
|
||||||
participants_.front().buffers.front().element_count);
|
|
||||||
xla::int64 element_count = participant.buffers.front().element_count;
|
input_buffers.emplace_back();
|
||||||
input_buffers.emplace_back(
|
output_buffers.emplace_back();
|
||||||
static_cast<T*>(p.buffers.front().source_data.opaque()),
|
std::vector<absl::Span<T>>& participant_input_buffers =
|
||||||
element_count);
|
input_buffers.back();
|
||||||
output_buffers.emplace_back(
|
std::vector<absl::Span<T>>& participant_output_buffers =
|
||||||
static_cast<T*>(p.buffers.front().destination_data.opaque()),
|
output_buffers.back();
|
||||||
element_count);
|
participant_input_buffers.reserve(p.buffers.size());
|
||||||
|
participant_output_buffers.reserve(p.buffers.size());
|
||||||
|
|
||||||
|
for (int buffer_idx = 0; buffer_idx < buffers_per_participant;
|
||||||
|
buffer_idx++) {
|
||||||
|
auto& participant_buffer = p.buffers[buffer_idx];
|
||||||
|
participant_input_buffers.emplace_back(
|
||||||
|
static_cast<T*>(participant_buffer.source_data.opaque()),
|
||||||
|
participant_buffer.element_count);
|
||||||
|
participant_output_buffers.emplace_back(
|
||||||
|
static_cast<T*>(participant_buffer.destination_data.opaque()),
|
||||||
|
participant_buffer.element_count);
|
||||||
|
CHECK_EQ(participant_buffer.element_count,
|
||||||
|
first_participant.buffers[buffer_idx].element_count);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
xla::int64 element_count =
|
|
||||||
participants_.front().buffers.front().element_count;
|
|
||||||
|
|
||||||
auto compute = [reduction_kind](T a, T b) -> T {
|
for (int buffer_idx = 0; buffer_idx < buffers_per_participant;
|
||||||
switch (reduction_kind) {
|
buffer_idx++) {
|
||||||
case xla::ReductionKind::SUM:
|
int element_count = first_participant.buffers[buffer_idx].element_count;
|
||||||
return a + b;
|
for (int idx = 0; idx < element_count; idx++) {
|
||||||
case xla::ReductionKind::PRODUCT:
|
T out = GetInitialValue<T>(reduction_kind);
|
||||||
return a * b;
|
for (int participant_idx = 0; participant_idx < participants_.size();
|
||||||
case xla::ReductionKind::MIN:
|
participant_idx++) {
|
||||||
return std::min(a, b);
|
out = PerformReductionStep<T>(
|
||||||
case xla::ReductionKind::MAX:
|
reduction_kind, out,
|
||||||
return std::max(a, b);
|
input_buffers[participant_idx][buffer_idx][idx]);
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
for (int idx = 0; idx < element_count; idx++) {
|
|
||||||
T out = [&]() -> T {
|
|
||||||
switch (reduction_kind) {
|
|
||||||
case xla::ReductionKind::SUM:
|
|
||||||
return static_cast<T>(0);
|
|
||||||
case xla::ReductionKind::PRODUCT:
|
|
||||||
return static_cast<T>(1);
|
|
||||||
case xla::ReductionKind::MIN:
|
|
||||||
return std::numeric_limits<T>::max();
|
|
||||||
case xla::ReductionKind::MAX:
|
|
||||||
return std::numeric_limits<T>::min();
|
|
||||||
}
|
}
|
||||||
}();
|
for (int participant_idx = 0; participant_idx < participants_.size();
|
||||||
|
participant_idx++) {
|
||||||
|
output_buffers[participant_idx][buffer_idx][idx] = out;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for (auto& input : input_buffers) {
|
template <typename T>
|
||||||
out = compute(out, input[idx]);
|
T GetInitialValue(xla::ReductionKind reduction_kind) {
|
||||||
}
|
switch (reduction_kind) {
|
||||||
for (auto& output : output_buffers) {
|
case xla::ReductionKind::SUM:
|
||||||
output[idx] = out;
|
return static_cast<T>(0);
|
||||||
}
|
case xla::ReductionKind::PRODUCT:
|
||||||
|
return static_cast<T>(1);
|
||||||
|
case xla::ReductionKind::MIN:
|
||||||
|
return std::numeric_limits<T>::max();
|
||||||
|
case xla::ReductionKind::MAX:
|
||||||
|
return std::numeric_limits<T>::min();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
T PerformReductionStep(xla::ReductionKind reduction_kind, T a, T b) {
|
||||||
|
switch (reduction_kind) {
|
||||||
|
case xla::ReductionKind::SUM:
|
||||||
|
return a + b;
|
||||||
|
case xla::ReductionKind::PRODUCT:
|
||||||
|
return a * b;
|
||||||
|
case xla::ReductionKind::MIN:
|
||||||
|
return std::min(a, b);
|
||||||
|
case xla::ReductionKind::MAX:
|
||||||
|
return std::max(a, b);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -392,8 +418,8 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllReduce(
|
|||||||
const xla::ExecutableRunOptions* run_options,
|
const xla::ExecutableRunOptions* run_options,
|
||||||
const void* replica_groups_str, xla::int32 replica_groups_str_size,
|
const void* replica_groups_str, xla::int32 replica_groups_str_size,
|
||||||
xla::int32 channel_id_present, xla::int64 op_id, xla::int32 reduction_kind,
|
xla::int32 channel_id_present, xla::int64 op_id, xla::int32 reduction_kind,
|
||||||
const void* shape_ptr, xla::int32 shape_length, void* input_buffer,
|
const void* shape_ptr, xla::int32 shape_length, xla::int32 num_buffers,
|
||||||
void* output_buffer) {
|
void** input_buffers, void** output_buffers) {
|
||||||
absl::string_view replica_groups_serialized(
|
absl::string_view replica_groups_serialized(
|
||||||
static_cast<const char*>(replica_groups_str), replica_groups_str_size);
|
static_cast<const char*>(replica_groups_str), replica_groups_str_size);
|
||||||
|
|
||||||
@ -435,21 +461,25 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllReduce(
|
|||||||
|
|
||||||
xla::Shape shape =
|
xla::Shape shape =
|
||||||
DecodeSelfDescribingShapeConstant(shape_ptr, shape_length).ValueOrDie();
|
DecodeSelfDescribingShapeConstant(shape_ptr, shape_length).ValueOrDie();
|
||||||
CHECK(xla::LayoutUtil::IsDenseArray(shape))
|
|
||||||
<< "All-reduce on CPU is implemented only for dense arrays";
|
CHECK((num_buffers > 1 && shape.IsTuple()) ||
|
||||||
|
(num_buffers == 1 && xla::LayoutUtil::IsDenseArray(shape)));
|
||||||
|
|
||||||
xla::AllReduceParticipantData participant(rendezvous_key);
|
xla::AllReduceParticipantData participant(rendezvous_key);
|
||||||
participant.device_ordinal = device_ordinal;
|
participant.device_ordinal = device_ordinal;
|
||||||
participant.stream = run_options->stream();
|
participant.stream = run_options->stream();
|
||||||
xla::AllReduceParticipantData::Buffer buffer;
|
|
||||||
buffer.element_count = xla::ShapeUtil::ElementsIn(shape);
|
|
||||||
buffer.primitive_type = shape.element_type();
|
|
||||||
buffer.source_data =
|
|
||||||
se::DeviceMemoryBase(input_buffer, xla::ShapeUtil::ByteSizeOf(shape));
|
|
||||||
buffer.destination_data =
|
|
||||||
se::DeviceMemoryBase(output_buffer, xla::ShapeUtil::ByteSizeOf(shape));
|
|
||||||
participant.buffers = {buffer};
|
|
||||||
participant.reduction_kind = static_cast<xla::ReductionKind>(reduction_kind);
|
participant.reduction_kind = static_cast<xla::ReductionKind>(reduction_kind);
|
||||||
|
for (int i = 0; i < num_buffers; i++) {
|
||||||
|
xla::Shape subshape = num_buffers == 1 ? shape : shape.tuple_shapes(i);
|
||||||
|
xla::AllReduceParticipantData::Buffer buffer;
|
||||||
|
buffer.element_count = xla::ShapeUtil::ElementsIn(subshape);
|
||||||
|
buffer.primitive_type = subshape.element_type();
|
||||||
|
buffer.source_data = se::DeviceMemoryBase(
|
||||||
|
input_buffers[i], xla::ShapeUtil::ByteSizeOf(subshape));
|
||||||
|
buffer.destination_data = se::DeviceMemoryBase(
|
||||||
|
output_buffers[i], xla::ShapeUtil::ByteSizeOf(subshape));
|
||||||
|
participant.buffers.push_back(buffer);
|
||||||
|
}
|
||||||
|
|
||||||
auto make_cpu_rendezvous = [](const xla::RendezvousKey& k) {
|
auto make_cpu_rendezvous = [](const xla::RendezvousKey& k) {
|
||||||
return absl::make_unique<CpuAllReduceRendezvous>(k);
|
return absl::make_unique<CpuAllReduceRendezvous>(k);
|
||||||
|
|||||||
@ -167,8 +167,8 @@ extern void __xla_cpu_runtime_AllReduce(
|
|||||||
const xla::ExecutableRunOptions* run_options,
|
const xla::ExecutableRunOptions* run_options,
|
||||||
const void* replica_groups_str, xla::int32 replica_groups_str_size,
|
const void* replica_groups_str, xla::int32 replica_groups_str_size,
|
||||||
xla::int32 channel_id_present, xla::int64 op_id, xla::int32 reduction_kind,
|
xla::int32 channel_id_present, xla::int64 op_id, xla::int32 reduction_kind,
|
||||||
const void* shape_ptr, xla::int32 shape_length, void* input_buffer,
|
const void* shape_ptr, xla::int32 shape_length, xla::int32 num_buffers,
|
||||||
void* output_buffer);
|
void** input_buffers, void** output_buffers);
|
||||||
|
|
||||||
// Write the replica ID into the output buffer.
|
// Write the replica ID into the output buffer.
|
||||||
extern void __xla_cpu_runtime_ReplicaId(
|
extern void __xla_cpu_runtime_ReplicaId(
|
||||||
|
|||||||
@ -1458,6 +1458,7 @@ Status IrEmitter::HandleAllReduceMultipleReplica(HloInstruction* crs) {
|
|||||||
/*reduction_kind=*/int32_type,
|
/*reduction_kind=*/int32_type,
|
||||||
/*shape_ptr=*/i8_ptr_type,
|
/*shape_ptr=*/i8_ptr_type,
|
||||||
/*shape_length=*/int32_type,
|
/*shape_length=*/int32_type,
|
||||||
|
/*num_buffers=*/int32_type,
|
||||||
/*input_buffer=*/i8_ptr_type,
|
/*input_buffer=*/i8_ptr_type,
|
||||||
/*output_buffer=*/i8_ptr_type},
|
/*output_buffer=*/i8_ptr_type},
|
||||||
/*isVarArg=*/false);
|
/*isVarArg=*/false);
|
||||||
@ -1473,19 +1474,41 @@ Status IrEmitter::HandleAllReduceMultipleReplica(HloInstruction* crs) {
|
|||||||
int32 replica_groups_size = replica_groups.size();
|
int32 replica_groups_size = replica_groups.size();
|
||||||
llvm::Value* replica_groups_v = b_.CreateGlobalStringPtr(replica_groups);
|
llvm::Value* replica_groups_v = b_.CreateGlobalStringPtr(replica_groups);
|
||||||
|
|
||||||
Shape shape = crs->operand(0)->shape();
|
bool is_tuple = crs->operand_count() > 1;
|
||||||
|
std::vector<llvm::Value*> input_buffer_ptrs;
|
||||||
|
std::vector<llvm::Value*> output_buffer_ptrs;
|
||||||
|
if (is_tuple) {
|
||||||
|
CHECK(crs->shape().IsTuple());
|
||||||
|
|
||||||
|
for (int64 i = 0; i < crs->operand_count(); i++) {
|
||||||
|
const HloInstruction* op = crs->operand(i);
|
||||||
|
TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_slice,
|
||||||
|
assignment_.GetUniqueSlice(crs, {i}));
|
||||||
|
const Shape& operand_shape = crs->operand(i)->shape();
|
||||||
|
CHECK(operand_shape.IsArray())
|
||||||
|
<< "Operands to all-reduce must be arrays: " << crs->ToString();
|
||||||
|
output_buffer_ptrs.push_back(EmitBufferPointer(out_slice, operand_shape));
|
||||||
|
input_buffer_ptrs.push_back(GetEmittedValueFor(op));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Shape shape = crs->operand(0)->shape();
|
||||||
|
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice input_slice,
|
||||||
|
assignment_.GetUniqueSlice(crs->operand(0), {}));
|
||||||
|
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice,
|
||||||
|
assignment_.GetUniqueSlice(crs, {}));
|
||||||
|
input_buffer_ptrs.push_back(EmitBufferPointer(input_slice, shape));
|
||||||
|
output_buffer_ptrs.push_back(EmitBufferPointer(output_slice, shape));
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::Value* input_buffers =
|
||||||
|
EncodeArrayFunctionArguments(input_buffer_ptrs, "input_buffers", &b_);
|
||||||
|
llvm::Value* output_buffers =
|
||||||
|
EncodeArrayFunctionArguments(output_buffer_ptrs, "output_buffers", &b_);
|
||||||
|
|
||||||
int32 shape_length;
|
int32 shape_length;
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(llvm::Value * shape_ptr,
|
||||||
llvm::Value * shape_ptr,
|
llvm_ir::EncodeSelfDescribingShapeConstant(
|
||||||
llvm_ir::EncodeSelfDescribingShapeConstant(shape, &shape_length, &b_));
|
crs->shape(), &shape_length, &b_));
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice input_slice,
|
|
||||||
assignment_.GetUniqueSlice(crs->operand(0), {}));
|
|
||||||
llvm::Value* input_buffer = EmitBufferPointer(input_slice, shape);
|
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice,
|
|
||||||
assignment_.GetUniqueSlice(crs, {}));
|
|
||||||
llvm::Value* output_buffer = EmitBufferPointer(output_slice, shape);
|
|
||||||
|
|
||||||
Call(all_reduce_func,
|
Call(all_reduce_func,
|
||||||
{/*run_options=*/GetExecutableRunOptionsArgument(),
|
{/*run_options=*/GetExecutableRunOptionsArgument(),
|
||||||
@ -1498,16 +1521,14 @@ Status IrEmitter::HandleAllReduceMultipleReplica(HloInstruction* crs) {
|
|||||||
b_.getInt64(crs->channel_id().has_value()
|
b_.getInt64(crs->channel_id().has_value()
|
||||||
? *crs->channel_id()
|
? *crs->channel_id()
|
||||||
: crs->GetModule()->unique_id()),
|
: crs->GetModule()->unique_id()),
|
||||||
|
|
||||||
/*reduction_kind=*/
|
/*reduction_kind=*/
|
||||||
b_.getInt32(
|
b_.getInt32(
|
||||||
static_cast<int32>(*MatchReductionComputation(crs->to_apply()))),
|
static_cast<int32>(*MatchReductionComputation(crs->to_apply()))),
|
||||||
|
|
||||||
/*shape_ptr=*/shape_ptr,
|
/*shape_ptr=*/shape_ptr,
|
||||||
/*shape_length=*/b_.getInt32(shape_length),
|
/*shape_length=*/b_.getInt32(shape_length),
|
||||||
|
/*num_buffers=*/b_.getInt32(crs->operand_count()),
|
||||||
/*input_buffer=*/b_.CreateBitCast(input_buffer, i8_ptr_type),
|
/*input_buffers=*/b_.CreateBitCast(input_buffers, i8_ptr_type),
|
||||||
/*output_buffer=*/b_.CreateBitCast(output_buffer, i8_ptr_type)});
|
/*output_buffers=*/b_.CreateBitCast(output_buffers, i8_ptr_type)});
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|||||||
@ -186,6 +186,30 @@ llvm::Value* IrFunction::GetDynamicLoopBound(const int64 offset) {
|
|||||||
b_->getInt64(offset), name));
|
b_->getInt64(offset), name));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
llvm::Value* EncodeArrayFunctionArguments(
|
||||||
|
absl::Span<llvm::Value* const> arguments, absl::string_view name,
|
||||||
|
llvm::IRBuilder<>* b) {
|
||||||
|
llvm::Value* arguments_buffer;
|
||||||
|
llvm::Type* int8ptr_ty = b->getInt8PtrTy();
|
||||||
|
if (arguments.empty()) {
|
||||||
|
arguments_buffer = llvm::Constant::getNullValue(int8ptr_ty->getPointerTo());
|
||||||
|
} else {
|
||||||
|
arguments_buffer = llvm_ir::EmitAllocaAtFunctionEntryWithCount(
|
||||||
|
int8ptr_ty, b->getInt32(arguments.size()),
|
||||||
|
absl::StrCat(name, "_parameter_addresses"), b);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < arguments.size(); i++) {
|
||||||
|
llvm::Value* parameter_as_i8ptr = b->CreateBitCast(
|
||||||
|
arguments[i], b->getInt8PtrTy(),
|
||||||
|
absl::StrCat(name, "_parameter_", i, "_address_as_i8ptr"));
|
||||||
|
llvm::Value* slot_in_param_addresses =
|
||||||
|
b->CreateInBoundsGEP(arguments_buffer, {b->getInt64(i)});
|
||||||
|
b->CreateStore(parameter_as_i8ptr, slot_in_param_addresses);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return arguments_buffer;
|
||||||
|
}
|
||||||
|
|
||||||
// Emits code to allocate an array of parameter address pointers, and store
|
// Emits code to allocate an array of parameter address pointers, and store
|
||||||
// each address from 'parameter_addresses'.
|
// each address from 'parameter_addresses'.
|
||||||
// Returns an array of compute function call arguments (including parameter
|
// Returns an array of compute function call arguments (including parameter
|
||||||
@ -195,25 +219,8 @@ std::vector<llvm::Value*> GetArrayFunctionCallArguments(
|
|||||||
absl::string_view name, llvm::Value* return_value_buffer,
|
absl::string_view name, llvm::Value* return_value_buffer,
|
||||||
llvm::Value* exec_run_options_arg, llvm::Value* buffer_table_arg,
|
llvm::Value* exec_run_options_arg, llvm::Value* buffer_table_arg,
|
||||||
llvm::Value* profile_counters_arg) {
|
llvm::Value* profile_counters_arg) {
|
||||||
llvm::Value* parameter_addresses_buffer;
|
llvm::Value* parameter_addresses_buffer =
|
||||||
|
EncodeArrayFunctionArguments(parameter_addresses, name, b);
|
||||||
if (parameter_addresses.empty()) {
|
|
||||||
parameter_addresses_buffer =
|
|
||||||
llvm::Constant::getNullValue(b->getInt8PtrTy()->getPointerTo());
|
|
||||||
} else {
|
|
||||||
parameter_addresses_buffer = llvm_ir::EmitAllocaAtFunctionEntryWithCount(
|
|
||||||
b->getInt8PtrTy(), b->getInt32(parameter_addresses.size()),
|
|
||||||
absl::StrCat(name, "_parameter_addresses"), b);
|
|
||||||
|
|
||||||
for (size_t i = 0; i < parameter_addresses.size(); ++i) {
|
|
||||||
llvm::Value* parameter_as_i8ptr = b->CreateBitCast(
|
|
||||||
parameter_addresses[i], b->getInt8PtrTy(),
|
|
||||||
absl::StrCat(name, "_parameter_", i, "_address_as_i8ptr"));
|
|
||||||
llvm::Value* slot_in_param_addresses =
|
|
||||||
b->CreateInBoundsGEP(parameter_addresses_buffer, {b->getInt64(i)});
|
|
||||||
b->CreateStore(parameter_as_i8ptr, slot_in_param_addresses);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const auto to_int8_ptr = [=](llvm::Value* ptr) {
|
const auto to_int8_ptr = [=](llvm::Value* ptr) {
|
||||||
return b->CreatePointerCast(ptr, b->getInt8PtrTy());
|
return b->CreatePointerCast(ptr, b->getInt8PtrTy());
|
||||||
|
|||||||
@ -114,6 +114,12 @@ class IrFunction {
|
|||||||
llvm::Value* profile_counters_arg_;
|
llvm::Value* profile_counters_arg_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Returns arguments in `arguments` encoded as a single buffer, suitable for a
|
||||||
|
// function call.
|
||||||
|
llvm::Value* EncodeArrayFunctionArguments(
|
||||||
|
absl::Span<llvm::Value* const> arguments, absl::string_view name,
|
||||||
|
llvm::IRBuilder<>* b);
|
||||||
|
|
||||||
// Returns an array of compute function call argument ir values.
|
// Returns an array of compute function call argument ir values.
|
||||||
std::vector<llvm::Value*> GetArrayFunctionCallArguments(
|
std::vector<llvm::Value*> GetArrayFunctionCallArguments(
|
||||||
absl::Span<llvm::Value* const> parameter_addresses, llvm::IRBuilder<>* b,
|
absl::Span<llvm::Value* const> parameter_addresses, llvm::IRBuilder<>* b,
|
||||||
|
|||||||
@ -593,5 +593,52 @@ XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(CollectivePermute_Simple)) {
|
|||||||
results[3]));
|
results[3]));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
XLA_TEST_F(CollectiveOpsTest, AllReduce_TupleAllReduce) {
|
||||||
|
std::string hlo_string = R"(
|
||||||
|
HloModule test
|
||||||
|
|
||||||
|
apply_op {
|
||||||
|
x = f32[] parameter(0)
|
||||||
|
y = f32[] parameter(1)
|
||||||
|
ROOT apply_op = f32[] add(x, y)
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY test_computation {
|
||||||
|
p0 = f32[5] parameter(0)
|
||||||
|
p1 = f32[7] parameter(1)
|
||||||
|
ROOT out = (f32[5], f32[7]) all-reduce(p0, p1), replica_groups={}, to_apply=apply_op
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
static constexpr int kNumReplicas = 2;
|
||||||
|
auto config = GetModuleConfigForTest();
|
||||||
|
config.set_replica_count(kNumReplicas);
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||||
|
ParseAndReturnVerifiedModule(hlo_string, config));
|
||||||
|
|
||||||
|
std::vector<float> input0_vec = {1., 2., 3., 4., 5.};
|
||||||
|
auto input0_literal = LiteralUtil::CreateR1<float>(input0_vec);
|
||||||
|
std::vector<float> input1_vec = {
|
||||||
|
7., 3., 4., 1., 2., 3., 4.,
|
||||||
|
};
|
||||||
|
auto input1_literal = LiteralUtil::CreateR1<float>(input1_vec);
|
||||||
|
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
|
std::vector<Literal> results,
|
||||||
|
ExecuteReplicated(std::move(module), {&input0_literal, &input1_literal},
|
||||||
|
/*num_replicas=*/kNumReplicas,
|
||||||
|
/*use_threads=*/true));
|
||||||
|
std::vector<float> expected0_vec = {2., 4., 6., 8., 10.};
|
||||||
|
auto expected0_literal = LiteralUtil::CreateR1<float>(expected0_vec);
|
||||||
|
std::vector<float> expected1_vec = {14., 6., 8., 2., 4., 6., 8.};
|
||||||
|
auto expected1_literal = LiteralUtil::CreateR1<float>(expected1_vec);
|
||||||
|
for (int replica_idx = 0; replica_idx < kNumReplicas; replica_idx++) {
|
||||||
|
auto rs = results[replica_idx].DecomposeTuple();
|
||||||
|
EXPECT_TRUE(LiteralTestUtil::NearOrEqual(expected0_literal, rs[0],
|
||||||
|
ErrorSpec{1e-5, 1e-5}));
|
||||||
|
EXPECT_TRUE(LiteralTestUtil::NearOrEqual(expected1_literal, rs[1],
|
||||||
|
ErrorSpec{1e-5, 1e-5}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user