[XLA:CPU] Support multi-operand all-reduce on CPU
PiperOrigin-RevId: 305169720 Change-Id: I98c64d171fdc02692103260fcf2bc78bb2be7a7b
This commit is contained in:
parent
ef5bd97017
commit
4cd7132e91
@ -263,7 +263,6 @@ class CpuAllReduceRendezvous
|
||||
protected:
|
||||
xla::StatusOr<ParticipantImplOutput> SubmitParticipantImpl(
|
||||
const xla::AllReduceParticipantData& participant) override {
|
||||
TF_RET_CHECK(participant.buffers.size() == 1);
|
||||
xla::PrimitiveType datatype = participant.buffers.front().primitive_type;
|
||||
bool primary = [&] {
|
||||
tensorflow::mutex_lock lock(mu_);
|
||||
@ -321,60 +320,87 @@ class CpuAllReduceRendezvous
|
||||
for (const auto& p : participants_) {
|
||||
CHECK(p.reduction_kind == reduction_kind);
|
||||
}
|
||||
int num_participants = participants_.size();
|
||||
|
||||
std::vector<absl::Span<T>> input_buffers;
|
||||
std::vector<absl::Span<T>> output_buffers;
|
||||
input_buffers.reserve(participants_.size());
|
||||
output_buffers.reserve(participants_.size());
|
||||
// participant_idx -> buffer_idx -> buffer.
|
||||
std::vector<std::vector<absl::Span<T>>> input_buffers;
|
||||
std::vector<std::vector<absl::Span<T>>> output_buffers;
|
||||
input_buffers.reserve(num_participants);
|
||||
output_buffers.reserve(num_participants);
|
||||
const xla::AllReduceParticipantData& first_participant =
|
||||
participants_.front();
|
||||
|
||||
for (auto& p : participants_) {
|
||||
CHECK_EQ(p.buffers.size(), 1);
|
||||
CHECK_EQ(p.buffers.front().element_count,
|
||||
participants_.front().buffers.front().element_count);
|
||||
xla::int64 element_count = participant.buffers.front().element_count;
|
||||
input_buffers.emplace_back(
|
||||
static_cast<T*>(p.buffers.front().source_data.opaque()),
|
||||
element_count);
|
||||
output_buffers.emplace_back(
|
||||
static_cast<T*>(p.buffers.front().destination_data.opaque()),
|
||||
element_count);
|
||||
int buffers_per_participant = first_participant.buffers.size();
|
||||
for (xla::AllReduceParticipantData& p : participants_) {
|
||||
CHECK_EQ(p.buffers.size(), buffers_per_participant);
|
||||
|
||||
input_buffers.emplace_back();
|
||||
output_buffers.emplace_back();
|
||||
std::vector<absl::Span<T>>& participant_input_buffers =
|
||||
input_buffers.back();
|
||||
std::vector<absl::Span<T>>& participant_output_buffers =
|
||||
output_buffers.back();
|
||||
participant_input_buffers.reserve(p.buffers.size());
|
||||
participant_output_buffers.reserve(p.buffers.size());
|
||||
|
||||
for (int buffer_idx = 0; buffer_idx < buffers_per_participant;
|
||||
buffer_idx++) {
|
||||
auto& participant_buffer = p.buffers[buffer_idx];
|
||||
participant_input_buffers.emplace_back(
|
||||
static_cast<T*>(participant_buffer.source_data.opaque()),
|
||||
participant_buffer.element_count);
|
||||
participant_output_buffers.emplace_back(
|
||||
static_cast<T*>(participant_buffer.destination_data.opaque()),
|
||||
participant_buffer.element_count);
|
||||
CHECK_EQ(participant_buffer.element_count,
|
||||
first_participant.buffers[buffer_idx].element_count);
|
||||
}
|
||||
}
|
||||
xla::int64 element_count =
|
||||
participants_.front().buffers.front().element_count;
|
||||
|
||||
auto compute = [reduction_kind](T a, T b) -> T {
|
||||
switch (reduction_kind) {
|
||||
case xla::ReductionKind::SUM:
|
||||
return a + b;
|
||||
case xla::ReductionKind::PRODUCT:
|
||||
return a * b;
|
||||
case xla::ReductionKind::MIN:
|
||||
return std::min(a, b);
|
||||
case xla::ReductionKind::MAX:
|
||||
return std::max(a, b);
|
||||
}
|
||||
};
|
||||
|
||||
for (int idx = 0; idx < element_count; idx++) {
|
||||
T out = [&]() -> T {
|
||||
switch (reduction_kind) {
|
||||
case xla::ReductionKind::SUM:
|
||||
return static_cast<T>(0);
|
||||
case xla::ReductionKind::PRODUCT:
|
||||
return static_cast<T>(1);
|
||||
case xla::ReductionKind::MIN:
|
||||
return std::numeric_limits<T>::max();
|
||||
case xla::ReductionKind::MAX:
|
||||
return std::numeric_limits<T>::min();
|
||||
for (int buffer_idx = 0; buffer_idx < buffers_per_participant;
|
||||
buffer_idx++) {
|
||||
int element_count = first_participant.buffers[buffer_idx].element_count;
|
||||
for (int idx = 0; idx < element_count; idx++) {
|
||||
T out = GetInitialValue<T>(reduction_kind);
|
||||
for (int participant_idx = 0; participant_idx < participants_.size();
|
||||
participant_idx++) {
|
||||
out = PerformReductionStep<T>(
|
||||
reduction_kind, out,
|
||||
input_buffers[participant_idx][buffer_idx][idx]);
|
||||
}
|
||||
}();
|
||||
for (int participant_idx = 0; participant_idx < participants_.size();
|
||||
participant_idx++) {
|
||||
output_buffers[participant_idx][buffer_idx][idx] = out;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& input : input_buffers) {
|
||||
out = compute(out, input[idx]);
|
||||
}
|
||||
for (auto& output : output_buffers) {
|
||||
output[idx] = out;
|
||||
}
|
||||
template <typename T>
|
||||
T GetInitialValue(xla::ReductionKind reduction_kind) {
|
||||
switch (reduction_kind) {
|
||||
case xla::ReductionKind::SUM:
|
||||
return static_cast<T>(0);
|
||||
case xla::ReductionKind::PRODUCT:
|
||||
return static_cast<T>(1);
|
||||
case xla::ReductionKind::MIN:
|
||||
return std::numeric_limits<T>::max();
|
||||
case xla::ReductionKind::MAX:
|
||||
return std::numeric_limits<T>::min();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T PerformReductionStep(xla::ReductionKind reduction_kind, T a, T b) {
|
||||
switch (reduction_kind) {
|
||||
case xla::ReductionKind::SUM:
|
||||
return a + b;
|
||||
case xla::ReductionKind::PRODUCT:
|
||||
return a * b;
|
||||
case xla::ReductionKind::MIN:
|
||||
return std::min(a, b);
|
||||
case xla::ReductionKind::MAX:
|
||||
return std::max(a, b);
|
||||
}
|
||||
}
|
||||
};
|
||||
@ -392,8 +418,8 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllReduce(
|
||||
const xla::ExecutableRunOptions* run_options,
|
||||
const void* replica_groups_str, xla::int32 replica_groups_str_size,
|
||||
xla::int32 channel_id_present, xla::int64 op_id, xla::int32 reduction_kind,
|
||||
const void* shape_ptr, xla::int32 shape_length, void* input_buffer,
|
||||
void* output_buffer) {
|
||||
const void* shape_ptr, xla::int32 shape_length, xla::int32 num_buffers,
|
||||
void** input_buffers, void** output_buffers) {
|
||||
absl::string_view replica_groups_serialized(
|
||||
static_cast<const char*>(replica_groups_str), replica_groups_str_size);
|
||||
|
||||
@ -435,21 +461,25 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllReduce(
|
||||
|
||||
xla::Shape shape =
|
||||
DecodeSelfDescribingShapeConstant(shape_ptr, shape_length).ValueOrDie();
|
||||
CHECK(xla::LayoutUtil::IsDenseArray(shape))
|
||||
<< "All-reduce on CPU is implemented only for dense arrays";
|
||||
|
||||
CHECK((num_buffers > 1 && shape.IsTuple()) ||
|
||||
(num_buffers == 1 && xla::LayoutUtil::IsDenseArray(shape)));
|
||||
|
||||
xla::AllReduceParticipantData participant(rendezvous_key);
|
||||
participant.device_ordinal = device_ordinal;
|
||||
participant.stream = run_options->stream();
|
||||
xla::AllReduceParticipantData::Buffer buffer;
|
||||
buffer.element_count = xla::ShapeUtil::ElementsIn(shape);
|
||||
buffer.primitive_type = shape.element_type();
|
||||
buffer.source_data =
|
||||
se::DeviceMemoryBase(input_buffer, xla::ShapeUtil::ByteSizeOf(shape));
|
||||
buffer.destination_data =
|
||||
se::DeviceMemoryBase(output_buffer, xla::ShapeUtil::ByteSizeOf(shape));
|
||||
participant.buffers = {buffer};
|
||||
participant.reduction_kind = static_cast<xla::ReductionKind>(reduction_kind);
|
||||
for (int i = 0; i < num_buffers; i++) {
|
||||
xla::Shape subshape = num_buffers == 1 ? shape : shape.tuple_shapes(i);
|
||||
xla::AllReduceParticipantData::Buffer buffer;
|
||||
buffer.element_count = xla::ShapeUtil::ElementsIn(subshape);
|
||||
buffer.primitive_type = subshape.element_type();
|
||||
buffer.source_data = se::DeviceMemoryBase(
|
||||
input_buffers[i], xla::ShapeUtil::ByteSizeOf(subshape));
|
||||
buffer.destination_data = se::DeviceMemoryBase(
|
||||
output_buffers[i], xla::ShapeUtil::ByteSizeOf(subshape));
|
||||
participant.buffers.push_back(buffer);
|
||||
}
|
||||
|
||||
auto make_cpu_rendezvous = [](const xla::RendezvousKey& k) {
|
||||
return absl::make_unique<CpuAllReduceRendezvous>(k);
|
||||
|
@ -167,8 +167,8 @@ extern void __xla_cpu_runtime_AllReduce(
|
||||
const xla::ExecutableRunOptions* run_options,
|
||||
const void* replica_groups_str, xla::int32 replica_groups_str_size,
|
||||
xla::int32 channel_id_present, xla::int64 op_id, xla::int32 reduction_kind,
|
||||
const void* shape_ptr, xla::int32 shape_length, void* input_buffer,
|
||||
void* output_buffer);
|
||||
const void* shape_ptr, xla::int32 shape_length, xla::int32 num_buffers,
|
||||
void** input_buffers, void** output_buffers);
|
||||
|
||||
// Write the replica ID into the output buffer.
|
||||
extern void __xla_cpu_runtime_ReplicaId(
|
||||
|
@ -1458,6 +1458,7 @@ Status IrEmitter::HandleAllReduceMultipleReplica(HloInstruction* crs) {
|
||||
/*reduction_kind=*/int32_type,
|
||||
/*shape_ptr=*/i8_ptr_type,
|
||||
/*shape_length=*/int32_type,
|
||||
/*num_buffers=*/int32_type,
|
||||
/*input_buffer=*/i8_ptr_type,
|
||||
/*output_buffer=*/i8_ptr_type},
|
||||
/*isVarArg=*/false);
|
||||
@ -1473,19 +1474,41 @@ Status IrEmitter::HandleAllReduceMultipleReplica(HloInstruction* crs) {
|
||||
int32 replica_groups_size = replica_groups.size();
|
||||
llvm::Value* replica_groups_v = b_.CreateGlobalStringPtr(replica_groups);
|
||||
|
||||
Shape shape = crs->operand(0)->shape();
|
||||
bool is_tuple = crs->operand_count() > 1;
|
||||
std::vector<llvm::Value*> input_buffer_ptrs;
|
||||
std::vector<llvm::Value*> output_buffer_ptrs;
|
||||
if (is_tuple) {
|
||||
CHECK(crs->shape().IsTuple());
|
||||
|
||||
for (int64 i = 0; i < crs->operand_count(); i++) {
|
||||
const HloInstruction* op = crs->operand(i);
|
||||
TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_slice,
|
||||
assignment_.GetUniqueSlice(crs, {i}));
|
||||
const Shape& operand_shape = crs->operand(i)->shape();
|
||||
CHECK(operand_shape.IsArray())
|
||||
<< "Operands to all-reduce must be arrays: " << crs->ToString();
|
||||
output_buffer_ptrs.push_back(EmitBufferPointer(out_slice, operand_shape));
|
||||
input_buffer_ptrs.push_back(GetEmittedValueFor(op));
|
||||
}
|
||||
} else {
|
||||
Shape shape = crs->operand(0)->shape();
|
||||
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice input_slice,
|
||||
assignment_.GetUniqueSlice(crs->operand(0), {}));
|
||||
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice,
|
||||
assignment_.GetUniqueSlice(crs, {}));
|
||||
input_buffer_ptrs.push_back(EmitBufferPointer(input_slice, shape));
|
||||
output_buffer_ptrs.push_back(EmitBufferPointer(output_slice, shape));
|
||||
}
|
||||
|
||||
llvm::Value* input_buffers =
|
||||
EncodeArrayFunctionArguments(input_buffer_ptrs, "input_buffers", &b_);
|
||||
llvm::Value* output_buffers =
|
||||
EncodeArrayFunctionArguments(output_buffer_ptrs, "output_buffers", &b_);
|
||||
|
||||
int32 shape_length;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
llvm::Value * shape_ptr,
|
||||
llvm_ir::EncodeSelfDescribingShapeConstant(shape, &shape_length, &b_));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice input_slice,
|
||||
assignment_.GetUniqueSlice(crs->operand(0), {}));
|
||||
llvm::Value* input_buffer = EmitBufferPointer(input_slice, shape);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice,
|
||||
assignment_.GetUniqueSlice(crs, {}));
|
||||
llvm::Value* output_buffer = EmitBufferPointer(output_slice, shape);
|
||||
TF_ASSIGN_OR_RETURN(llvm::Value * shape_ptr,
|
||||
llvm_ir::EncodeSelfDescribingShapeConstant(
|
||||
crs->shape(), &shape_length, &b_));
|
||||
|
||||
Call(all_reduce_func,
|
||||
{/*run_options=*/GetExecutableRunOptionsArgument(),
|
||||
@ -1498,16 +1521,14 @@ Status IrEmitter::HandleAllReduceMultipleReplica(HloInstruction* crs) {
|
||||
b_.getInt64(crs->channel_id().has_value()
|
||||
? *crs->channel_id()
|
||||
: crs->GetModule()->unique_id()),
|
||||
|
||||
/*reduction_kind=*/
|
||||
b_.getInt32(
|
||||
static_cast<int32>(*MatchReductionComputation(crs->to_apply()))),
|
||||
|
||||
/*shape_ptr=*/shape_ptr,
|
||||
/*shape_length=*/b_.getInt32(shape_length),
|
||||
|
||||
/*input_buffer=*/b_.CreateBitCast(input_buffer, i8_ptr_type),
|
||||
/*output_buffer=*/b_.CreateBitCast(output_buffer, i8_ptr_type)});
|
||||
/*num_buffers=*/b_.getInt32(crs->operand_count()),
|
||||
/*input_buffers=*/b_.CreateBitCast(input_buffers, i8_ptr_type),
|
||||
/*output_buffers=*/b_.CreateBitCast(output_buffers, i8_ptr_type)});
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -186,6 +186,30 @@ llvm::Value* IrFunction::GetDynamicLoopBound(const int64 offset) {
|
||||
b_->getInt64(offset), name));
|
||||
}
|
||||
|
||||
llvm::Value* EncodeArrayFunctionArguments(
|
||||
absl::Span<llvm::Value* const> arguments, absl::string_view name,
|
||||
llvm::IRBuilder<>* b) {
|
||||
llvm::Value* arguments_buffer;
|
||||
llvm::Type* int8ptr_ty = b->getInt8PtrTy();
|
||||
if (arguments.empty()) {
|
||||
arguments_buffer = llvm::Constant::getNullValue(int8ptr_ty->getPointerTo());
|
||||
} else {
|
||||
arguments_buffer = llvm_ir::EmitAllocaAtFunctionEntryWithCount(
|
||||
int8ptr_ty, b->getInt32(arguments.size()),
|
||||
absl::StrCat(name, "_parameter_addresses"), b);
|
||||
|
||||
for (size_t i = 0; i < arguments.size(); i++) {
|
||||
llvm::Value* parameter_as_i8ptr = b->CreateBitCast(
|
||||
arguments[i], b->getInt8PtrTy(),
|
||||
absl::StrCat(name, "_parameter_", i, "_address_as_i8ptr"));
|
||||
llvm::Value* slot_in_param_addresses =
|
||||
b->CreateInBoundsGEP(arguments_buffer, {b->getInt64(i)});
|
||||
b->CreateStore(parameter_as_i8ptr, slot_in_param_addresses);
|
||||
}
|
||||
}
|
||||
return arguments_buffer;
|
||||
}
|
||||
|
||||
// Emits code to allocate an array of parameter address pointers, and store
|
||||
// each address from 'parameter_addresses'.
|
||||
// Returns an array of compute function call arguments (including parameter
|
||||
@ -195,25 +219,8 @@ std::vector<llvm::Value*> GetArrayFunctionCallArguments(
|
||||
absl::string_view name, llvm::Value* return_value_buffer,
|
||||
llvm::Value* exec_run_options_arg, llvm::Value* buffer_table_arg,
|
||||
llvm::Value* profile_counters_arg) {
|
||||
llvm::Value* parameter_addresses_buffer;
|
||||
|
||||
if (parameter_addresses.empty()) {
|
||||
parameter_addresses_buffer =
|
||||
llvm::Constant::getNullValue(b->getInt8PtrTy()->getPointerTo());
|
||||
} else {
|
||||
parameter_addresses_buffer = llvm_ir::EmitAllocaAtFunctionEntryWithCount(
|
||||
b->getInt8PtrTy(), b->getInt32(parameter_addresses.size()),
|
||||
absl::StrCat(name, "_parameter_addresses"), b);
|
||||
|
||||
for (size_t i = 0; i < parameter_addresses.size(); ++i) {
|
||||
llvm::Value* parameter_as_i8ptr = b->CreateBitCast(
|
||||
parameter_addresses[i], b->getInt8PtrTy(),
|
||||
absl::StrCat(name, "_parameter_", i, "_address_as_i8ptr"));
|
||||
llvm::Value* slot_in_param_addresses =
|
||||
b->CreateInBoundsGEP(parameter_addresses_buffer, {b->getInt64(i)});
|
||||
b->CreateStore(parameter_as_i8ptr, slot_in_param_addresses);
|
||||
}
|
||||
}
|
||||
llvm::Value* parameter_addresses_buffer =
|
||||
EncodeArrayFunctionArguments(parameter_addresses, name, b);
|
||||
|
||||
const auto to_int8_ptr = [=](llvm::Value* ptr) {
|
||||
return b->CreatePointerCast(ptr, b->getInt8PtrTy());
|
||||
|
@ -114,6 +114,12 @@ class IrFunction {
|
||||
llvm::Value* profile_counters_arg_;
|
||||
};
|
||||
|
||||
// Returns arguments in `arguments` encoded as a single buffer, suitable for a
|
||||
// function call.
|
||||
llvm::Value* EncodeArrayFunctionArguments(
|
||||
absl::Span<llvm::Value* const> arguments, absl::string_view name,
|
||||
llvm::IRBuilder<>* b);
|
||||
|
||||
// Returns an array of compute function call argument ir values.
|
||||
std::vector<llvm::Value*> GetArrayFunctionCallArguments(
|
||||
absl::Span<llvm::Value* const> parameter_addresses, llvm::IRBuilder<>* b,
|
||||
|
@ -593,5 +593,52 @@ XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(CollectivePermute_Simple)) {
|
||||
results[3]));
|
||||
}
|
||||
|
||||
XLA_TEST_F(CollectiveOpsTest, AllReduce_TupleAllReduce) {
|
||||
std::string hlo_string = R"(
|
||||
HloModule test
|
||||
|
||||
apply_op {
|
||||
x = f32[] parameter(0)
|
||||
y = f32[] parameter(1)
|
||||
ROOT apply_op = f32[] add(x, y)
|
||||
}
|
||||
|
||||
ENTRY test_computation {
|
||||
p0 = f32[5] parameter(0)
|
||||
p1 = f32[7] parameter(1)
|
||||
ROOT out = (f32[5], f32[7]) all-reduce(p0, p1), replica_groups={}, to_apply=apply_op
|
||||
}
|
||||
)";
|
||||
static constexpr int kNumReplicas = 2;
|
||||
auto config = GetModuleConfigForTest();
|
||||
config.set_replica_count(kNumReplicas);
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(hlo_string, config));
|
||||
|
||||
std::vector<float> input0_vec = {1., 2., 3., 4., 5.};
|
||||
auto input0_literal = LiteralUtil::CreateR1<float>(input0_vec);
|
||||
std::vector<float> input1_vec = {
|
||||
7., 3., 4., 1., 2., 3., 4.,
|
||||
};
|
||||
auto input1_literal = LiteralUtil::CreateR1<float>(input1_vec);
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::vector<Literal> results,
|
||||
ExecuteReplicated(std::move(module), {&input0_literal, &input1_literal},
|
||||
/*num_replicas=*/kNumReplicas,
|
||||
/*use_threads=*/true));
|
||||
std::vector<float> expected0_vec = {2., 4., 6., 8., 10.};
|
||||
auto expected0_literal = LiteralUtil::CreateR1<float>(expected0_vec);
|
||||
std::vector<float> expected1_vec = {14., 6., 8., 2., 4., 6., 8.};
|
||||
auto expected1_literal = LiteralUtil::CreateR1<float>(expected1_vec);
|
||||
for (int replica_idx = 0; replica_idx < kNumReplicas; replica_idx++) {
|
||||
auto rs = results[replica_idx].DecomposeTuple();
|
||||
EXPECT_TRUE(LiteralTestUtil::NearOrEqual(expected0_literal, rs[0],
|
||||
ErrorSpec{1e-5, 1e-5}));
|
||||
EXPECT_TRUE(LiteralTestUtil::NearOrEqual(expected1_literal, rs[1],
|
||||
ErrorSpec{1e-5, 1e-5}));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
Loading…
x
Reference in New Issue
Block a user