From 4cd7132e91952843b4c817cbfcafcbfbf4a10325 Mon Sep 17 00:00:00 2001 From: George Karpenkov Date: Mon, 6 Apr 2020 19:45:54 -0700 Subject: [PATCH] [XLA:CPU] Support multi-operand all-reduce on CPU PiperOrigin-RevId: 305169720 Change-Id: I98c64d171fdc02692103260fcf2bc78bb2be7a7b --- .../compiler/xla/service/cpu/cpu_runtime.cc | 152 +++++++++++------- .../compiler/xla/service/cpu/cpu_runtime.h | 4 +- .../compiler/xla/service/cpu/ir_emitter.cc | 55 +++++-- .../compiler/xla/service/cpu/ir_function.cc | 45 +++--- .../compiler/xla/service/cpu/ir_function.h | 6 + .../compiler/xla/tests/collective_ops_test.cc | 47 ++++++ 6 files changed, 210 insertions(+), 99 deletions(-) diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc index 201eb936a1b..d6f64828c32 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -263,7 +263,6 @@ class CpuAllReduceRendezvous protected: xla::StatusOr SubmitParticipantImpl( const xla::AllReduceParticipantData& participant) override { - TF_RET_CHECK(participant.buffers.size() == 1); xla::PrimitiveType datatype = participant.buffers.front().primitive_type; bool primary = [&] { tensorflow::mutex_lock lock(mu_); @@ -321,60 +320,87 @@ class CpuAllReduceRendezvous for (const auto& p : participants_) { CHECK(p.reduction_kind == reduction_kind); } + int num_participants = participants_.size(); - std::vector> input_buffers; - std::vector> output_buffers; - input_buffers.reserve(participants_.size()); - output_buffers.reserve(participants_.size()); + // participant_idx -> buffer_idx -> buffer. + std::vector>> input_buffers; + std::vector>> output_buffers; + input_buffers.reserve(num_participants); + output_buffers.reserve(num_participants); + const xla::AllReduceParticipantData& first_participant = + participants_.front(); - for (auto& p : participants_) { - CHECK_EQ(p.buffers.size(), 1); - CHECK_EQ(p.buffers.front().element_count, - participants_.front().buffers.front().element_count); - xla::int64 element_count = participant.buffers.front().element_count; - input_buffers.emplace_back( - static_cast(p.buffers.front().source_data.opaque()), - element_count); - output_buffers.emplace_back( - static_cast(p.buffers.front().destination_data.opaque()), - element_count); + int buffers_per_participant = first_participant.buffers.size(); + for (xla::AllReduceParticipantData& p : participants_) { + CHECK_EQ(p.buffers.size(), buffers_per_participant); + + input_buffers.emplace_back(); + output_buffers.emplace_back(); + std::vector>& participant_input_buffers = + input_buffers.back(); + std::vector>& participant_output_buffers = + output_buffers.back(); + participant_input_buffers.reserve(p.buffers.size()); + participant_output_buffers.reserve(p.buffers.size()); + + for (int buffer_idx = 0; buffer_idx < buffers_per_participant; + buffer_idx++) { + auto& participant_buffer = p.buffers[buffer_idx]; + participant_input_buffers.emplace_back( + static_cast(participant_buffer.source_data.opaque()), + participant_buffer.element_count); + participant_output_buffers.emplace_back( + static_cast(participant_buffer.destination_data.opaque()), + participant_buffer.element_count); + CHECK_EQ(participant_buffer.element_count, + first_participant.buffers[buffer_idx].element_count); + } } - xla::int64 element_count = - participants_.front().buffers.front().element_count; - auto compute = [reduction_kind](T a, T b) -> T { - switch (reduction_kind) { - case xla::ReductionKind::SUM: - return a + b; - case xla::ReductionKind::PRODUCT: - return a * b; - case xla::ReductionKind::MIN: - return std::min(a, b); - case xla::ReductionKind::MAX: - return std::max(a, b); - } - }; - - for (int idx = 0; idx < element_count; idx++) { - T out = [&]() -> T { - switch (reduction_kind) { - case xla::ReductionKind::SUM: - return static_cast(0); - case xla::ReductionKind::PRODUCT: - return static_cast(1); - case xla::ReductionKind::MIN: - return std::numeric_limits::max(); - case xla::ReductionKind::MAX: - return std::numeric_limits::min(); + for (int buffer_idx = 0; buffer_idx < buffers_per_participant; + buffer_idx++) { + int element_count = first_participant.buffers[buffer_idx].element_count; + for (int idx = 0; idx < element_count; idx++) { + T out = GetInitialValue(reduction_kind); + for (int participant_idx = 0; participant_idx < participants_.size(); + participant_idx++) { + out = PerformReductionStep( + reduction_kind, out, + input_buffers[participant_idx][buffer_idx][idx]); } - }(); + for (int participant_idx = 0; participant_idx < participants_.size(); + participant_idx++) { + output_buffers[participant_idx][buffer_idx][idx] = out; + } + } + } + } - for (auto& input : input_buffers) { - out = compute(out, input[idx]); - } - for (auto& output : output_buffers) { - output[idx] = out; - } + template + T GetInitialValue(xla::ReductionKind reduction_kind) { + switch (reduction_kind) { + case xla::ReductionKind::SUM: + return static_cast(0); + case xla::ReductionKind::PRODUCT: + return static_cast(1); + case xla::ReductionKind::MIN: + return std::numeric_limits::max(); + case xla::ReductionKind::MAX: + return std::numeric_limits::min(); + } + } + + template + T PerformReductionStep(xla::ReductionKind reduction_kind, T a, T b) { + switch (reduction_kind) { + case xla::ReductionKind::SUM: + return a + b; + case xla::ReductionKind::PRODUCT: + return a * b; + case xla::ReductionKind::MIN: + return std::min(a, b); + case xla::ReductionKind::MAX: + return std::max(a, b); } } }; @@ -392,8 +418,8 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllReduce( const xla::ExecutableRunOptions* run_options, const void* replica_groups_str, xla::int32 replica_groups_str_size, xla::int32 channel_id_present, xla::int64 op_id, xla::int32 reduction_kind, - const void* shape_ptr, xla::int32 shape_length, void* input_buffer, - void* output_buffer) { + const void* shape_ptr, xla::int32 shape_length, xla::int32 num_buffers, + void** input_buffers, void** output_buffers) { absl::string_view replica_groups_serialized( static_cast(replica_groups_str), replica_groups_str_size); @@ -435,21 +461,25 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllReduce( xla::Shape shape = DecodeSelfDescribingShapeConstant(shape_ptr, shape_length).ValueOrDie(); - CHECK(xla::LayoutUtil::IsDenseArray(shape)) - << "All-reduce on CPU is implemented only for dense arrays"; + + CHECK((num_buffers > 1 && shape.IsTuple()) || + (num_buffers == 1 && xla::LayoutUtil::IsDenseArray(shape))); xla::AllReduceParticipantData participant(rendezvous_key); participant.device_ordinal = device_ordinal; participant.stream = run_options->stream(); - xla::AllReduceParticipantData::Buffer buffer; - buffer.element_count = xla::ShapeUtil::ElementsIn(shape); - buffer.primitive_type = shape.element_type(); - buffer.source_data = - se::DeviceMemoryBase(input_buffer, xla::ShapeUtil::ByteSizeOf(shape)); - buffer.destination_data = - se::DeviceMemoryBase(output_buffer, xla::ShapeUtil::ByteSizeOf(shape)); - participant.buffers = {buffer}; participant.reduction_kind = static_cast(reduction_kind); + for (int i = 0; i < num_buffers; i++) { + xla::Shape subshape = num_buffers == 1 ? shape : shape.tuple_shapes(i); + xla::AllReduceParticipantData::Buffer buffer; + buffer.element_count = xla::ShapeUtil::ElementsIn(subshape); + buffer.primitive_type = subshape.element_type(); + buffer.source_data = se::DeviceMemoryBase( + input_buffers[i], xla::ShapeUtil::ByteSizeOf(subshape)); + buffer.destination_data = se::DeviceMemoryBase( + output_buffers[i], xla::ShapeUtil::ByteSizeOf(subshape)); + participant.buffers.push_back(buffer); + } auto make_cpu_rendezvous = [](const xla::RendezvousKey& k) { return absl::make_unique(k); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h index 598ab353e80..6af41dea484 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h @@ -167,8 +167,8 @@ extern void __xla_cpu_runtime_AllReduce( const xla::ExecutableRunOptions* run_options, const void* replica_groups_str, xla::int32 replica_groups_str_size, xla::int32 channel_id_present, xla::int64 op_id, xla::int32 reduction_kind, - const void* shape_ptr, xla::int32 shape_length, void* input_buffer, - void* output_buffer); + const void* shape_ptr, xla::int32 shape_length, xla::int32 num_buffers, + void** input_buffers, void** output_buffers); // Write the replica ID into the output buffer. extern void __xla_cpu_runtime_ReplicaId( diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 81bee71360d..cef45128ea0 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -1458,6 +1458,7 @@ Status IrEmitter::HandleAllReduceMultipleReplica(HloInstruction* crs) { /*reduction_kind=*/int32_type, /*shape_ptr=*/i8_ptr_type, /*shape_length=*/int32_type, + /*num_buffers=*/int32_type, /*input_buffer=*/i8_ptr_type, /*output_buffer=*/i8_ptr_type}, /*isVarArg=*/false); @@ -1473,19 +1474,41 @@ Status IrEmitter::HandleAllReduceMultipleReplica(HloInstruction* crs) { int32 replica_groups_size = replica_groups.size(); llvm::Value* replica_groups_v = b_.CreateGlobalStringPtr(replica_groups); - Shape shape = crs->operand(0)->shape(); + bool is_tuple = crs->operand_count() > 1; + std::vector input_buffer_ptrs; + std::vector output_buffer_ptrs; + if (is_tuple) { + CHECK(crs->shape().IsTuple()); + + for (int64 i = 0; i < crs->operand_count(); i++) { + const HloInstruction* op = crs->operand(i); + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_slice, + assignment_.GetUniqueSlice(crs, {i})); + const Shape& operand_shape = crs->operand(i)->shape(); + CHECK(operand_shape.IsArray()) + << "Operands to all-reduce must be arrays: " << crs->ToString(); + output_buffer_ptrs.push_back(EmitBufferPointer(out_slice, operand_shape)); + input_buffer_ptrs.push_back(GetEmittedValueFor(op)); + } + } else { + Shape shape = crs->operand(0)->shape(); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice input_slice, + assignment_.GetUniqueSlice(crs->operand(0), {})); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice, + assignment_.GetUniqueSlice(crs, {})); + input_buffer_ptrs.push_back(EmitBufferPointer(input_slice, shape)); + output_buffer_ptrs.push_back(EmitBufferPointer(output_slice, shape)); + } + + llvm::Value* input_buffers = + EncodeArrayFunctionArguments(input_buffer_ptrs, "input_buffers", &b_); + llvm::Value* output_buffers = + EncodeArrayFunctionArguments(output_buffer_ptrs, "output_buffers", &b_); + int32 shape_length; - TF_ASSIGN_OR_RETURN( - llvm::Value * shape_ptr, - llvm_ir::EncodeSelfDescribingShapeConstant(shape, &shape_length, &b_)); - - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice input_slice, - assignment_.GetUniqueSlice(crs->operand(0), {})); - llvm::Value* input_buffer = EmitBufferPointer(input_slice, shape); - - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice, - assignment_.GetUniqueSlice(crs, {})); - llvm::Value* output_buffer = EmitBufferPointer(output_slice, shape); + TF_ASSIGN_OR_RETURN(llvm::Value * shape_ptr, + llvm_ir::EncodeSelfDescribingShapeConstant( + crs->shape(), &shape_length, &b_)); Call(all_reduce_func, {/*run_options=*/GetExecutableRunOptionsArgument(), @@ -1498,16 +1521,14 @@ Status IrEmitter::HandleAllReduceMultipleReplica(HloInstruction* crs) { b_.getInt64(crs->channel_id().has_value() ? *crs->channel_id() : crs->GetModule()->unique_id()), - /*reduction_kind=*/ b_.getInt32( static_cast(*MatchReductionComputation(crs->to_apply()))), - /*shape_ptr=*/shape_ptr, /*shape_length=*/b_.getInt32(shape_length), - - /*input_buffer=*/b_.CreateBitCast(input_buffer, i8_ptr_type), - /*output_buffer=*/b_.CreateBitCast(output_buffer, i8_ptr_type)}); + /*num_buffers=*/b_.getInt32(crs->operand_count()), + /*input_buffers=*/b_.CreateBitCast(input_buffers, i8_ptr_type), + /*output_buffers=*/b_.CreateBitCast(output_buffers, i8_ptr_type)}); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.cc b/tensorflow/compiler/xla/service/cpu/ir_function.cc index 42acd72f966..6553dc16748 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_function.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_function.cc @@ -186,6 +186,30 @@ llvm::Value* IrFunction::GetDynamicLoopBound(const int64 offset) { b_->getInt64(offset), name)); } +llvm::Value* EncodeArrayFunctionArguments( + absl::Span arguments, absl::string_view name, + llvm::IRBuilder<>* b) { + llvm::Value* arguments_buffer; + llvm::Type* int8ptr_ty = b->getInt8PtrTy(); + if (arguments.empty()) { + arguments_buffer = llvm::Constant::getNullValue(int8ptr_ty->getPointerTo()); + } else { + arguments_buffer = llvm_ir::EmitAllocaAtFunctionEntryWithCount( + int8ptr_ty, b->getInt32(arguments.size()), + absl::StrCat(name, "_parameter_addresses"), b); + + for (size_t i = 0; i < arguments.size(); i++) { + llvm::Value* parameter_as_i8ptr = b->CreateBitCast( + arguments[i], b->getInt8PtrTy(), + absl::StrCat(name, "_parameter_", i, "_address_as_i8ptr")); + llvm::Value* slot_in_param_addresses = + b->CreateInBoundsGEP(arguments_buffer, {b->getInt64(i)}); + b->CreateStore(parameter_as_i8ptr, slot_in_param_addresses); + } + } + return arguments_buffer; +} + // Emits code to allocate an array of parameter address pointers, and store // each address from 'parameter_addresses'. // Returns an array of compute function call arguments (including parameter @@ -195,25 +219,8 @@ std::vector GetArrayFunctionCallArguments( absl::string_view name, llvm::Value* return_value_buffer, llvm::Value* exec_run_options_arg, llvm::Value* buffer_table_arg, llvm::Value* profile_counters_arg) { - llvm::Value* parameter_addresses_buffer; - - if (parameter_addresses.empty()) { - parameter_addresses_buffer = - llvm::Constant::getNullValue(b->getInt8PtrTy()->getPointerTo()); - } else { - parameter_addresses_buffer = llvm_ir::EmitAllocaAtFunctionEntryWithCount( - b->getInt8PtrTy(), b->getInt32(parameter_addresses.size()), - absl::StrCat(name, "_parameter_addresses"), b); - - for (size_t i = 0; i < parameter_addresses.size(); ++i) { - llvm::Value* parameter_as_i8ptr = b->CreateBitCast( - parameter_addresses[i], b->getInt8PtrTy(), - absl::StrCat(name, "_parameter_", i, "_address_as_i8ptr")); - llvm::Value* slot_in_param_addresses = - b->CreateInBoundsGEP(parameter_addresses_buffer, {b->getInt64(i)}); - b->CreateStore(parameter_as_i8ptr, slot_in_param_addresses); - } - } + llvm::Value* parameter_addresses_buffer = + EncodeArrayFunctionArguments(parameter_addresses, name, b); const auto to_int8_ptr = [=](llvm::Value* ptr) { return b->CreatePointerCast(ptr, b->getInt8PtrTy()); diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.h b/tensorflow/compiler/xla/service/cpu/ir_function.h index 02bcec9dfc7..cc0c1d30e14 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_function.h +++ b/tensorflow/compiler/xla/service/cpu/ir_function.h @@ -114,6 +114,12 @@ class IrFunction { llvm::Value* profile_counters_arg_; }; +// Returns arguments in `arguments` encoded as a single buffer, suitable for a +// function call. +llvm::Value* EncodeArrayFunctionArguments( + absl::Span arguments, absl::string_view name, + llvm::IRBuilder<>* b); + // Returns an array of compute function call argument ir values. std::vector GetArrayFunctionCallArguments( absl::Span parameter_addresses, llvm::IRBuilder<>* b, diff --git a/tensorflow/compiler/xla/tests/collective_ops_test.cc b/tensorflow/compiler/xla/tests/collective_ops_test.cc index 3aacf065156..380486357f7 100644 --- a/tensorflow/compiler/xla/tests/collective_ops_test.cc +++ b/tensorflow/compiler/xla/tests/collective_ops_test.cc @@ -593,5 +593,52 @@ XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(CollectivePermute_Simple)) { results[3])); } +XLA_TEST_F(CollectiveOpsTest, AllReduce_TupleAllReduce) { + std::string hlo_string = R"( + HloModule test + + apply_op { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT apply_op = f32[] add(x, y) + } + + ENTRY test_computation { + p0 = f32[5] parameter(0) + p1 = f32[7] parameter(1) + ROOT out = (f32[5], f32[7]) all-reduce(p0, p1), replica_groups={}, to_apply=apply_op + } + )"; + static constexpr int kNumReplicas = 2; + auto config = GetModuleConfigForTest(); + config.set_replica_count(kNumReplicas); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string, config)); + + std::vector input0_vec = {1., 2., 3., 4., 5.}; + auto input0_literal = LiteralUtil::CreateR1(input0_vec); + std::vector input1_vec = { + 7., 3., 4., 1., 2., 3., 4., + }; + auto input1_literal = LiteralUtil::CreateR1(input1_vec); + + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated(std::move(module), {&input0_literal, &input1_literal}, + /*num_replicas=*/kNumReplicas, + /*use_threads=*/true)); + std::vector expected0_vec = {2., 4., 6., 8., 10.}; + auto expected0_literal = LiteralUtil::CreateR1(expected0_vec); + std::vector expected1_vec = {14., 6., 8., 2., 4., 6., 8.}; + auto expected1_literal = LiteralUtil::CreateR1(expected1_vec); + for (int replica_idx = 0; replica_idx < kNumReplicas; replica_idx++) { + auto rs = results[replica_idx].DecomposeTuple(); + EXPECT_TRUE(LiteralTestUtil::NearOrEqual(expected0_literal, rs[0], + ErrorSpec{1e-5, 1e-5})); + EXPECT_TRUE(LiteralTestUtil::NearOrEqual(expected1_literal, rs[1], + ErrorSpec{1e-5, 1e-5})); + } +} + } // namespace } // namespace xla