[XLA:CPU] Support multi-operand all-reduce on CPU

PiperOrigin-RevId: 305169720
Change-Id: I98c64d171fdc02692103260fcf2bc78bb2be7a7b
This commit is contained in:
George Karpenkov 2020-04-06 19:45:54 -07:00 committed by TensorFlower Gardener
parent ef5bd97017
commit 4cd7132e91
6 changed files with 210 additions and 99 deletions

View File

@ -263,7 +263,6 @@ class CpuAllReduceRendezvous
protected:
xla::StatusOr<ParticipantImplOutput> SubmitParticipantImpl(
const xla::AllReduceParticipantData& participant) override {
TF_RET_CHECK(participant.buffers.size() == 1);
xla::PrimitiveType datatype = participant.buffers.front().primitive_type;
bool primary = [&] {
tensorflow::mutex_lock lock(mu_);
@ -321,60 +320,87 @@ class CpuAllReduceRendezvous
for (const auto& p : participants_) {
CHECK(p.reduction_kind == reduction_kind);
}
int num_participants = participants_.size();
std::vector<absl::Span<T>> input_buffers;
std::vector<absl::Span<T>> output_buffers;
input_buffers.reserve(participants_.size());
output_buffers.reserve(participants_.size());
// participant_idx -> buffer_idx -> buffer.
std::vector<std::vector<absl::Span<T>>> input_buffers;
std::vector<std::vector<absl::Span<T>>> output_buffers;
input_buffers.reserve(num_participants);
output_buffers.reserve(num_participants);
const xla::AllReduceParticipantData& first_participant =
participants_.front();
for (auto& p : participants_) {
CHECK_EQ(p.buffers.size(), 1);
CHECK_EQ(p.buffers.front().element_count,
participants_.front().buffers.front().element_count);
xla::int64 element_count = participant.buffers.front().element_count;
input_buffers.emplace_back(
static_cast<T*>(p.buffers.front().source_data.opaque()),
element_count);
output_buffers.emplace_back(
static_cast<T*>(p.buffers.front().destination_data.opaque()),
element_count);
int buffers_per_participant = first_participant.buffers.size();
for (xla::AllReduceParticipantData& p : participants_) {
CHECK_EQ(p.buffers.size(), buffers_per_participant);
input_buffers.emplace_back();
output_buffers.emplace_back();
std::vector<absl::Span<T>>& participant_input_buffers =
input_buffers.back();
std::vector<absl::Span<T>>& participant_output_buffers =
output_buffers.back();
participant_input_buffers.reserve(p.buffers.size());
participant_output_buffers.reserve(p.buffers.size());
for (int buffer_idx = 0; buffer_idx < buffers_per_participant;
buffer_idx++) {
auto& participant_buffer = p.buffers[buffer_idx];
participant_input_buffers.emplace_back(
static_cast<T*>(participant_buffer.source_data.opaque()),
participant_buffer.element_count);
participant_output_buffers.emplace_back(
static_cast<T*>(participant_buffer.destination_data.opaque()),
participant_buffer.element_count);
CHECK_EQ(participant_buffer.element_count,
first_participant.buffers[buffer_idx].element_count);
}
}
xla::int64 element_count =
participants_.front().buffers.front().element_count;
auto compute = [reduction_kind](T a, T b) -> T {
switch (reduction_kind) {
case xla::ReductionKind::SUM:
return a + b;
case xla::ReductionKind::PRODUCT:
return a * b;
case xla::ReductionKind::MIN:
return std::min(a, b);
case xla::ReductionKind::MAX:
return std::max(a, b);
}
};
for (int idx = 0; idx < element_count; idx++) {
T out = [&]() -> T {
switch (reduction_kind) {
case xla::ReductionKind::SUM:
return static_cast<T>(0);
case xla::ReductionKind::PRODUCT:
return static_cast<T>(1);
case xla::ReductionKind::MIN:
return std::numeric_limits<T>::max();
case xla::ReductionKind::MAX:
return std::numeric_limits<T>::min();
for (int buffer_idx = 0; buffer_idx < buffers_per_participant;
buffer_idx++) {
int element_count = first_participant.buffers[buffer_idx].element_count;
for (int idx = 0; idx < element_count; idx++) {
T out = GetInitialValue<T>(reduction_kind);
for (int participant_idx = 0; participant_idx < participants_.size();
participant_idx++) {
out = PerformReductionStep<T>(
reduction_kind, out,
input_buffers[participant_idx][buffer_idx][idx]);
}
}();
for (int participant_idx = 0; participant_idx < participants_.size();
participant_idx++) {
output_buffers[participant_idx][buffer_idx][idx] = out;
}
}
}
}
for (auto& input : input_buffers) {
out = compute(out, input[idx]);
}
for (auto& output : output_buffers) {
output[idx] = out;
}
template <typename T>
T GetInitialValue(xla::ReductionKind reduction_kind) {
switch (reduction_kind) {
case xla::ReductionKind::SUM:
return static_cast<T>(0);
case xla::ReductionKind::PRODUCT:
return static_cast<T>(1);
case xla::ReductionKind::MIN:
return std::numeric_limits<T>::max();
case xla::ReductionKind::MAX:
return std::numeric_limits<T>::min();
}
}
template <typename T>
T PerformReductionStep(xla::ReductionKind reduction_kind, T a, T b) {
switch (reduction_kind) {
case xla::ReductionKind::SUM:
return a + b;
case xla::ReductionKind::PRODUCT:
return a * b;
case xla::ReductionKind::MIN:
return std::min(a, b);
case xla::ReductionKind::MAX:
return std::max(a, b);
}
}
};
@ -392,8 +418,8 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllReduce(
const xla::ExecutableRunOptions* run_options,
const void* replica_groups_str, xla::int32 replica_groups_str_size,
xla::int32 channel_id_present, xla::int64 op_id, xla::int32 reduction_kind,
const void* shape_ptr, xla::int32 shape_length, void* input_buffer,
void* output_buffer) {
const void* shape_ptr, xla::int32 shape_length, xla::int32 num_buffers,
void** input_buffers, void** output_buffers) {
absl::string_view replica_groups_serialized(
static_cast<const char*>(replica_groups_str), replica_groups_str_size);
@ -435,21 +461,25 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllReduce(
xla::Shape shape =
DecodeSelfDescribingShapeConstant(shape_ptr, shape_length).ValueOrDie();
CHECK(xla::LayoutUtil::IsDenseArray(shape))
<< "All-reduce on CPU is implemented only for dense arrays";
CHECK((num_buffers > 1 && shape.IsTuple()) ||
(num_buffers == 1 && xla::LayoutUtil::IsDenseArray(shape)));
xla::AllReduceParticipantData participant(rendezvous_key);
participant.device_ordinal = device_ordinal;
participant.stream = run_options->stream();
xla::AllReduceParticipantData::Buffer buffer;
buffer.element_count = xla::ShapeUtil::ElementsIn(shape);
buffer.primitive_type = shape.element_type();
buffer.source_data =
se::DeviceMemoryBase(input_buffer, xla::ShapeUtil::ByteSizeOf(shape));
buffer.destination_data =
se::DeviceMemoryBase(output_buffer, xla::ShapeUtil::ByteSizeOf(shape));
participant.buffers = {buffer};
participant.reduction_kind = static_cast<xla::ReductionKind>(reduction_kind);
for (int i = 0; i < num_buffers; i++) {
xla::Shape subshape = num_buffers == 1 ? shape : shape.tuple_shapes(i);
xla::AllReduceParticipantData::Buffer buffer;
buffer.element_count = xla::ShapeUtil::ElementsIn(subshape);
buffer.primitive_type = subshape.element_type();
buffer.source_data = se::DeviceMemoryBase(
input_buffers[i], xla::ShapeUtil::ByteSizeOf(subshape));
buffer.destination_data = se::DeviceMemoryBase(
output_buffers[i], xla::ShapeUtil::ByteSizeOf(subshape));
participant.buffers.push_back(buffer);
}
auto make_cpu_rendezvous = [](const xla::RendezvousKey& k) {
return absl::make_unique<CpuAllReduceRendezvous>(k);

View File

@ -167,8 +167,8 @@ extern void __xla_cpu_runtime_AllReduce(
const xla::ExecutableRunOptions* run_options,
const void* replica_groups_str, xla::int32 replica_groups_str_size,
xla::int32 channel_id_present, xla::int64 op_id, xla::int32 reduction_kind,
const void* shape_ptr, xla::int32 shape_length, void* input_buffer,
void* output_buffer);
const void* shape_ptr, xla::int32 shape_length, xla::int32 num_buffers,
void** input_buffers, void** output_buffers);
// Write the replica ID into the output buffer.
extern void __xla_cpu_runtime_ReplicaId(

View File

@ -1458,6 +1458,7 @@ Status IrEmitter::HandleAllReduceMultipleReplica(HloInstruction* crs) {
/*reduction_kind=*/int32_type,
/*shape_ptr=*/i8_ptr_type,
/*shape_length=*/int32_type,
/*num_buffers=*/int32_type,
/*input_buffer=*/i8_ptr_type,
/*output_buffer=*/i8_ptr_type},
/*isVarArg=*/false);
@ -1473,19 +1474,41 @@ Status IrEmitter::HandleAllReduceMultipleReplica(HloInstruction* crs) {
int32 replica_groups_size = replica_groups.size();
llvm::Value* replica_groups_v = b_.CreateGlobalStringPtr(replica_groups);
Shape shape = crs->operand(0)->shape();
bool is_tuple = crs->operand_count() > 1;
std::vector<llvm::Value*> input_buffer_ptrs;
std::vector<llvm::Value*> output_buffer_ptrs;
if (is_tuple) {
CHECK(crs->shape().IsTuple());
for (int64 i = 0; i < crs->operand_count(); i++) {
const HloInstruction* op = crs->operand(i);
TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_slice,
assignment_.GetUniqueSlice(crs, {i}));
const Shape& operand_shape = crs->operand(i)->shape();
CHECK(operand_shape.IsArray())
<< "Operands to all-reduce must be arrays: " << crs->ToString();
output_buffer_ptrs.push_back(EmitBufferPointer(out_slice, operand_shape));
input_buffer_ptrs.push_back(GetEmittedValueFor(op));
}
} else {
Shape shape = crs->operand(0)->shape();
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice input_slice,
assignment_.GetUniqueSlice(crs->operand(0), {}));
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice,
assignment_.GetUniqueSlice(crs, {}));
input_buffer_ptrs.push_back(EmitBufferPointer(input_slice, shape));
output_buffer_ptrs.push_back(EmitBufferPointer(output_slice, shape));
}
llvm::Value* input_buffers =
EncodeArrayFunctionArguments(input_buffer_ptrs, "input_buffers", &b_);
llvm::Value* output_buffers =
EncodeArrayFunctionArguments(output_buffer_ptrs, "output_buffers", &b_);
int32 shape_length;
TF_ASSIGN_OR_RETURN(
llvm::Value * shape_ptr,
llvm_ir::EncodeSelfDescribingShapeConstant(shape, &shape_length, &b_));
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice input_slice,
assignment_.GetUniqueSlice(crs->operand(0), {}));
llvm::Value* input_buffer = EmitBufferPointer(input_slice, shape);
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice,
assignment_.GetUniqueSlice(crs, {}));
llvm::Value* output_buffer = EmitBufferPointer(output_slice, shape);
TF_ASSIGN_OR_RETURN(llvm::Value * shape_ptr,
llvm_ir::EncodeSelfDescribingShapeConstant(
crs->shape(), &shape_length, &b_));
Call(all_reduce_func,
{/*run_options=*/GetExecutableRunOptionsArgument(),
@ -1498,16 +1521,14 @@ Status IrEmitter::HandleAllReduceMultipleReplica(HloInstruction* crs) {
b_.getInt64(crs->channel_id().has_value()
? *crs->channel_id()
: crs->GetModule()->unique_id()),
/*reduction_kind=*/
b_.getInt32(
static_cast<int32>(*MatchReductionComputation(crs->to_apply()))),
/*shape_ptr=*/shape_ptr,
/*shape_length=*/b_.getInt32(shape_length),
/*input_buffer=*/b_.CreateBitCast(input_buffer, i8_ptr_type),
/*output_buffer=*/b_.CreateBitCast(output_buffer, i8_ptr_type)});
/*num_buffers=*/b_.getInt32(crs->operand_count()),
/*input_buffers=*/b_.CreateBitCast(input_buffers, i8_ptr_type),
/*output_buffers=*/b_.CreateBitCast(output_buffers, i8_ptr_type)});
return Status::OK();
}

View File

@ -186,6 +186,30 @@ llvm::Value* IrFunction::GetDynamicLoopBound(const int64 offset) {
b_->getInt64(offset), name));
}
llvm::Value* EncodeArrayFunctionArguments(
absl::Span<llvm::Value* const> arguments, absl::string_view name,
llvm::IRBuilder<>* b) {
llvm::Value* arguments_buffer;
llvm::Type* int8ptr_ty = b->getInt8PtrTy();
if (arguments.empty()) {
arguments_buffer = llvm::Constant::getNullValue(int8ptr_ty->getPointerTo());
} else {
arguments_buffer = llvm_ir::EmitAllocaAtFunctionEntryWithCount(
int8ptr_ty, b->getInt32(arguments.size()),
absl::StrCat(name, "_parameter_addresses"), b);
for (size_t i = 0; i < arguments.size(); i++) {
llvm::Value* parameter_as_i8ptr = b->CreateBitCast(
arguments[i], b->getInt8PtrTy(),
absl::StrCat(name, "_parameter_", i, "_address_as_i8ptr"));
llvm::Value* slot_in_param_addresses =
b->CreateInBoundsGEP(arguments_buffer, {b->getInt64(i)});
b->CreateStore(parameter_as_i8ptr, slot_in_param_addresses);
}
}
return arguments_buffer;
}
// Emits code to allocate an array of parameter address pointers, and store
// each address from 'parameter_addresses'.
// Returns an array of compute function call arguments (including parameter
@ -195,25 +219,8 @@ std::vector<llvm::Value*> GetArrayFunctionCallArguments(
absl::string_view name, llvm::Value* return_value_buffer,
llvm::Value* exec_run_options_arg, llvm::Value* buffer_table_arg,
llvm::Value* profile_counters_arg) {
llvm::Value* parameter_addresses_buffer;
if (parameter_addresses.empty()) {
parameter_addresses_buffer =
llvm::Constant::getNullValue(b->getInt8PtrTy()->getPointerTo());
} else {
parameter_addresses_buffer = llvm_ir::EmitAllocaAtFunctionEntryWithCount(
b->getInt8PtrTy(), b->getInt32(parameter_addresses.size()),
absl::StrCat(name, "_parameter_addresses"), b);
for (size_t i = 0; i < parameter_addresses.size(); ++i) {
llvm::Value* parameter_as_i8ptr = b->CreateBitCast(
parameter_addresses[i], b->getInt8PtrTy(),
absl::StrCat(name, "_parameter_", i, "_address_as_i8ptr"));
llvm::Value* slot_in_param_addresses =
b->CreateInBoundsGEP(parameter_addresses_buffer, {b->getInt64(i)});
b->CreateStore(parameter_as_i8ptr, slot_in_param_addresses);
}
}
llvm::Value* parameter_addresses_buffer =
EncodeArrayFunctionArguments(parameter_addresses, name, b);
const auto to_int8_ptr = [=](llvm::Value* ptr) {
return b->CreatePointerCast(ptr, b->getInt8PtrTy());

View File

@ -114,6 +114,12 @@ class IrFunction {
llvm::Value* profile_counters_arg_;
};
// Returns arguments in `arguments` encoded as a single buffer, suitable for a
// function call.
llvm::Value* EncodeArrayFunctionArguments(
absl::Span<llvm::Value* const> arguments, absl::string_view name,
llvm::IRBuilder<>* b);
// Returns an array of compute function call argument ir values.
std::vector<llvm::Value*> GetArrayFunctionCallArguments(
absl::Span<llvm::Value* const> parameter_addresses, llvm::IRBuilder<>* b,

View File

@ -593,5 +593,52 @@ XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(CollectivePermute_Simple)) {
results[3]));
}
XLA_TEST_F(CollectiveOpsTest, AllReduce_TupleAllReduce) {
std::string hlo_string = R"(
HloModule test
apply_op {
x = f32[] parameter(0)
y = f32[] parameter(1)
ROOT apply_op = f32[] add(x, y)
}
ENTRY test_computation {
p0 = f32[5] parameter(0)
p1 = f32[7] parameter(1)
ROOT out = (f32[5], f32[7]) all-reduce(p0, p1), replica_groups={}, to_apply=apply_op
}
)";
static constexpr int kNumReplicas = 2;
auto config = GetModuleConfigForTest();
config.set_replica_count(kNumReplicas);
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_string, config));
std::vector<float> input0_vec = {1., 2., 3., 4., 5.};
auto input0_literal = LiteralUtil::CreateR1<float>(input0_vec);
std::vector<float> input1_vec = {
7., 3., 4., 1., 2., 3., 4.,
};
auto input1_literal = LiteralUtil::CreateR1<float>(input1_vec);
TF_ASSERT_OK_AND_ASSIGN(
std::vector<Literal> results,
ExecuteReplicated(std::move(module), {&input0_literal, &input1_literal},
/*num_replicas=*/kNumReplicas,
/*use_threads=*/true));
std::vector<float> expected0_vec = {2., 4., 6., 8., 10.};
auto expected0_literal = LiteralUtil::CreateR1<float>(expected0_vec);
std::vector<float> expected1_vec = {14., 6., 8., 2., 4., 6., 8.};
auto expected1_literal = LiteralUtil::CreateR1<float>(expected1_vec);
for (int replica_idx = 0; replica_idx < kNumReplicas; replica_idx++) {
auto rs = results[replica_idx].DecomposeTuple();
EXPECT_TRUE(LiteralTestUtil::NearOrEqual(expected0_literal, rs[0],
ErrorSpec{1e-5, 1e-5}));
EXPECT_TRUE(LiteralTestUtil::NearOrEqual(expected1_literal, rs[1],
ErrorSpec{1e-5, 1e-5}));
}
}
} // namespace
} // namespace xla