[XLA:CPU] AllToAll support for XLA:CPU
A single master thread performs all the work PiperOrigin-RevId: 320074537 Change-Id: Iaa4e4a78b0f058ffdb11334a12e8b78126399e89
This commit is contained in:
parent
a1b927ce1d
commit
052263c130
@ -122,6 +122,7 @@ extern const char* const kTracingStartSymbolName =
|
|||||||
extern const char* const kTracingEndSymbolName = "__xla_cpu_runtime_TracingEnd";
|
extern const char* const kTracingEndSymbolName = "__xla_cpu_runtime_TracingEnd";
|
||||||
extern const char* const kXlaCpuRuntimeSymbolNamePrefix = "__xla_cpu_runtime_";
|
extern const char* const kXlaCpuRuntimeSymbolNamePrefix = "__xla_cpu_runtime_";
|
||||||
extern const char* const kAllReduceSymbolName = "__xla_cpu_runtime_AllReduce";
|
extern const char* const kAllReduceSymbolName = "__xla_cpu_runtime_AllReduce";
|
||||||
|
extern const char* const kAllToAllSymbolName = "__xla_cpu_runtime_AllToAll";
|
||||||
extern const char* const kCollectivePermuteSymbolName =
|
extern const char* const kCollectivePermuteSymbolName =
|
||||||
"__xla_cpu_runtime_CollectivePermute";
|
"__xla_cpu_runtime_CollectivePermute";
|
||||||
extern const char* const kReplicaIdSymbolName = "__xla_cpu_runtime_ReplicaId";
|
extern const char* const kReplicaIdSymbolName = "__xla_cpu_runtime_ReplicaId";
|
||||||
@ -154,6 +155,34 @@ struct CollectivePermuteParticipantData : xla::ParticipantData {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct AllToAllParticipantData : xla::ParticipantData {
|
||||||
|
AllToAllParticipantData(const xla::RendezvousKey& rendezvous_key_p,
|
||||||
|
xla::int64 device_ordinal_p, se::Stream* stream_p)
|
||||||
|
: ParticipantData(rendezvous_key_p, device_ordinal_p, stream_p) {}
|
||||||
|
|
||||||
|
std::vector<se::DeviceMemoryBase> source_buffers;
|
||||||
|
std::vector<se::DeviceMemoryBase> destination_buffers;
|
||||||
|
int replica_id;
|
||||||
|
|
||||||
|
// Replica ids participating in AllToAll, concatenation happens in the order
|
||||||
|
// of appearence.
|
||||||
|
std::vector<xla::int64> replica_ids_to_copy_to;
|
||||||
|
|
||||||
|
std::string ToString() const override {
|
||||||
|
auto addr_formatter = [](std::string* out,
|
||||||
|
const se::DeviceMemoryBase& mem) {
|
||||||
|
absl::StrAppend(out, absl::StrFormat("%p", mem.opaque()));
|
||||||
|
};
|
||||||
|
return absl::StrFormat(
|
||||||
|
"AllToAllParticipantData{replica_id=%d, "
|
||||||
|
"replica_ids_to_copy_to=[%s], source_buffers=[%s], "
|
||||||
|
"destination_buffers=[%s]}",
|
||||||
|
replica_id, absl::StrJoin(replica_ids_to_copy_to, ", "),
|
||||||
|
absl::StrJoin(source_buffers, ", ", addr_formatter),
|
||||||
|
absl::StrJoin(destination_buffers, ", ", addr_formatter));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// Inverses the encoding of a Shape protobuf into an LLVM global variable.
|
// Inverses the encoding of a Shape protobuf into an LLVM global variable.
|
||||||
xla::StatusOr<xla::Shape> DecodeSelfDescribingShapeConstant(
|
xla::StatusOr<xla::Shape> DecodeSelfDescribingShapeConstant(
|
||||||
const void* shape_ptr, xla::int32 size_bytes) {
|
const void* shape_ptr, xla::int32 size_bytes) {
|
||||||
@ -286,6 +315,70 @@ __xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation(
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
class CpuAllToAllRendezvous
|
||||||
|
: public xla::Rendezvous<AllToAllParticipantData, std::nullptr_t> {
|
||||||
|
public:
|
||||||
|
explicit CpuAllToAllRendezvous(const xla::RendezvousKey& k)
|
||||||
|
: xla::Rendezvous<AllToAllParticipantData, std::nullptr_t>(k) {}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
xla::StatusOr<ParticipantImplOutput> RunCollectiveOp(
|
||||||
|
const AllToAllParticipantData& /*participant*/) override {
|
||||||
|
bool is_primary = InitializationBarrier();
|
||||||
|
|
||||||
|
if (is_primary) {
|
||||||
|
tensorflow::mutex_lock lock(mu_);
|
||||||
|
|
||||||
|
CHECK(!participants_.empty());
|
||||||
|
CHECK(!participants_[0].source_buffers.empty());
|
||||||
|
int expected_buffer_size = participants_[0].source_buffers[0].size();
|
||||||
|
|
||||||
|
// Replica id -> position in participants_.
|
||||||
|
absl::flat_hash_map<int, int> replica_id_map;
|
||||||
|
|
||||||
|
for (int pos = 0; pos < participants_.size(); pos++) {
|
||||||
|
const AllToAllParticipantData& p = participants_[pos];
|
||||||
|
CHECK_EQ(p.source_buffers.size(), p.destination_buffers.size());
|
||||||
|
CHECK_EQ(p.source_buffers.size(), participants_.size());
|
||||||
|
for (int i = 0; i < p.source_buffers.size(); i++) {
|
||||||
|
CHECK_EQ(p.destination_buffers[i].size(), expected_buffer_size);
|
||||||
|
CHECK_EQ(p.source_buffers[i].size(), expected_buffer_size);
|
||||||
|
}
|
||||||
|
replica_id_map[p.replica_id] = pos;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (AllToAllParticipantData& p : participants_) {
|
||||||
|
VLOG(3) << "Processing AllToAll participant data: " << p.ToString();
|
||||||
|
for (int j = 0; j < p.source_buffers.size(); j++) {
|
||||||
|
for (int i = 0; i < p.replica_ids_to_copy_to.size(); i++) {
|
||||||
|
int replica_id = p.replica_ids_to_copy_to[i];
|
||||||
|
int participant_num = xla::FindOrDie(replica_id_map, replica_id);
|
||||||
|
AllToAllParticipantData& other = participants_[participant_num];
|
||||||
|
|
||||||
|
// Sort by replica ordering.
|
||||||
|
std::vector<se::DeviceMemoryBase> destination_buffers =
|
||||||
|
other.destination_buffers;
|
||||||
|
absl::flat_hash_map<const void*, int> buffers_index;
|
||||||
|
for (int idx = 0; idx < destination_buffers.size(); idx++) {
|
||||||
|
buffers_index[destination_buffers[idx].opaque()] = idx;
|
||||||
|
}
|
||||||
|
absl::c_sort(
|
||||||
|
destination_buffers, [&](const se::DeviceMemoryBase& a,
|
||||||
|
const se::DeviceMemoryBase& b) {
|
||||||
|
return p.replica_ids_to_copy_to[buffers_index[a.opaque()]] <
|
||||||
|
p.replica_ids_to_copy_to[buffers_index[b.opaque()]];
|
||||||
|
});
|
||||||
|
|
||||||
|
std::memcpy(destination_buffers[j].opaque(),
|
||||||
|
p.source_buffers[j].opaque(), expected_buffer_size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ParticipantImplOutput{is_primary, nullptr};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
class CpuCollectivePermuteRendezvous
|
class CpuCollectivePermuteRendezvous
|
||||||
: public xla::Rendezvous<CollectivePermuteParticipantData, std::nullptr_t> {
|
: public xla::Rendezvous<CollectivePermuteParticipantData, std::nullptr_t> {
|
||||||
public:
|
public:
|
||||||
@ -486,6 +579,13 @@ GlobalCollectivePermuteRendezvousMap() {
|
|||||||
return m;
|
return m;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
xla::RefcountingHashMap<xla::RendezvousKey, CpuAllToAllRendezvous>&
|
||||||
|
GlobalAllToAllRendezvousMap() {
|
||||||
|
static auto& m =
|
||||||
|
*new xla::RefcountingHashMap<xla::RendezvousKey, CpuAllToAllRendezvous>;
|
||||||
|
return m;
|
||||||
|
}
|
||||||
|
|
||||||
int GetDeviceOrdinal(const xla::ExecutableRunOptions* run_options) {
|
int GetDeviceOrdinal(const xla::ExecutableRunOptions* run_options) {
|
||||||
if (run_options->stream()) {
|
if (run_options->stream()) {
|
||||||
return run_options->stream()->parent()->device_ordinal();
|
return run_options->stream()->parent()->device_ordinal();
|
||||||
@ -524,6 +624,48 @@ xla::RendezvousKey GetRendezvousKey(
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllToAll(
|
||||||
|
const xla::ExecutableRunOptions* run_options, xla::int32 channel_id_present,
|
||||||
|
xla::int64 op_id, const void* replica_groups_str,
|
||||||
|
xla::int32 replica_groups_str_size, xla::int32 num_buffers,
|
||||||
|
xla::int64 buffer_size, void** source_buffers, void** destination_buffers) {
|
||||||
|
int device_ordinal = GetDeviceOrdinal(run_options);
|
||||||
|
xla::int32 replica_id = run_options->device_assignment()
|
||||||
|
->ReplicaIdForDeviceOrdinal(device_ordinal)
|
||||||
|
.ValueOrDie();
|
||||||
|
absl::string_view replica_groups_serialized(
|
||||||
|
static_cast<const char*>(replica_groups_str), replica_groups_str_size);
|
||||||
|
std::vector<xla::ReplicaGroup> group =
|
||||||
|
xla::ParseReplicaGroupsOnly(replica_groups_serialized).ValueOrDie();
|
||||||
|
xla::RendezvousKey rendezvous_key =
|
||||||
|
GetRendezvousKey(run_options, group, channel_id_present, op_id);
|
||||||
|
|
||||||
|
AllToAllParticipantData participant(rendezvous_key, device_ordinal,
|
||||||
|
run_options->stream());
|
||||||
|
participant.replica_id = replica_id;
|
||||||
|
participant.replica_ids_to_copy_to =
|
||||||
|
xla::GetParticipatingReplicas(
|
||||||
|
xla::GlobalDeviceId(device_ordinal), group,
|
||||||
|
run_options->device_assignment()->replica_count(),
|
||||||
|
*run_options->device_assignment())
|
||||||
|
.ValueOrDie();
|
||||||
|
for (int i = 0; i < num_buffers; i++) {
|
||||||
|
participant.source_buffers.emplace_back(source_buffers[i], buffer_size);
|
||||||
|
participant.destination_buffers.emplace_back(destination_buffers[i],
|
||||||
|
buffer_size);
|
||||||
|
}
|
||||||
|
auto make_cpu_rendezvous = [](const xla::RendezvousKey& k) {
|
||||||
|
return absl::make_unique<CpuAllToAllRendezvous>(k);
|
||||||
|
};
|
||||||
|
TF_CHECK_OK(CpuAllToAllRendezvous::SubmitParticipant(
|
||||||
|
[&] {
|
||||||
|
return GlobalAllToAllRendezvousMap().GetOrCreateIfAbsent(
|
||||||
|
rendezvous_key, make_cpu_rendezvous);
|
||||||
|
},
|
||||||
|
participant)
|
||||||
|
.status());
|
||||||
|
}
|
||||||
|
|
||||||
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllReduce(
|
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllReduce(
|
||||||
const xla::ExecutableRunOptions* run_options,
|
const xla::ExecutableRunOptions* run_options,
|
||||||
const void* replica_groups_str, xla::int32 replica_groups_str_size,
|
const void* replica_groups_str, xla::int32 replica_groups_str_size,
|
||||||
|
@ -77,6 +77,7 @@ extern const char* const kCollectivePermuteSymbolName;
|
|||||||
extern const char* const kReplicaIdSymbolName;
|
extern const char* const kReplicaIdSymbolName;
|
||||||
extern const char* const kTracingStartSymbolName;
|
extern const char* const kTracingStartSymbolName;
|
||||||
extern const char* const kTracingEndSymbolName;
|
extern const char* const kTracingEndSymbolName;
|
||||||
|
extern const char* const kAllToAllSymbolName;
|
||||||
|
|
||||||
// All symbol names for XLA CPU runtime functions need to start with this
|
// All symbol names for XLA CPU runtime functions need to start with this
|
||||||
// prefix.
|
// prefix.
|
||||||
@ -181,6 +182,12 @@ extern void __xla_cpu_runtime_CollectivePermute(
|
|||||||
void* output_buffer, const void* source_target_pairs,
|
void* output_buffer, const void* source_target_pairs,
|
||||||
xla::int32 source_target_pairs_size);
|
xla::int32 source_target_pairs_size);
|
||||||
|
|
||||||
|
extern void __xla_cpu_runtime_AllToAll(
|
||||||
|
const xla::ExecutableRunOptions* run_options, xla::int32 channel_id_present,
|
||||||
|
xla::int64 op_id, const void* replica_groups_str,
|
||||||
|
xla::int32 replica_groups_str_size, xla::int32 num_buffers,
|
||||||
|
xla::int64 buffer_size, void** source_buffers, void** destination_buffers);
|
||||||
|
|
||||||
// Write the replica ID into the output buffer.
|
// Write the replica ID into the output buffer.
|
||||||
extern void __xla_cpu_runtime_ReplicaId(
|
extern void __xla_cpu_runtime_ReplicaId(
|
||||||
const xla::ExecutableRunOptions* run_options, void* output_buffer);
|
const xla::ExecutableRunOptions* run_options, void* output_buffer);
|
||||||
|
@ -359,7 +359,7 @@ Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) {
|
|||||||
// to the output buffer of its corresponding operand. A GetTupleElement
|
// to the output buffer of its corresponding operand. A GetTupleElement
|
||||||
// instruction forwards a pointer to the tuple element buffer at the given
|
// instruction forwards a pointer to the tuple element buffer at the given
|
||||||
// index.
|
// index.
|
||||||
auto operand = get_tuple_element->operand(0);
|
const HloInstruction* operand = get_tuple_element->operand(0);
|
||||||
const Shape& shape = get_tuple_element->shape();
|
const Shape& shape = get_tuple_element->shape();
|
||||||
emitted_value_[get_tuple_element] = llvm_ir::EmitGetTupleElement(
|
emitted_value_[get_tuple_element] = llvm_ir::EmitGetTupleElement(
|
||||||
shape, get_tuple_element->tuple_index(), MinimumAlignmentForShape(shape),
|
shape, get_tuple_element->tuple_index(), MinimumAlignmentForShape(shape),
|
||||||
@ -1432,6 +1432,83 @@ Status IrEmitter::HandleAllReduce(HloInstruction* crs) {
|
|||||||
return HandleAllReduceMultipleReplica(crs);
|
return HandleAllReduceMultipleReplica(crs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status IrEmitter::HandleAllToAll(HloInstruction* instruction) {
|
||||||
|
auto* instr = Cast<HloAllToAllInstruction>(instruction);
|
||||||
|
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(instruction));
|
||||||
|
CHECK(!instr->split_dimension() && instr->shape().IsTuple())
|
||||||
|
<< "Only tuple AllToAll is supported";
|
||||||
|
|
||||||
|
llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());
|
||||||
|
llvm::Type* int32_type = b_.getInt32Ty();
|
||||||
|
llvm::Type* int64_type = b_.getInt64Ty();
|
||||||
|
|
||||||
|
// TODO(cheshire): 3 statements below should be a single line.
|
||||||
|
llvm::FunctionType* all_to_all_func_ty =
|
||||||
|
llvm::FunctionType::get(b_.getVoidTy(),
|
||||||
|
{/*run_options=*/i8_ptr_type,
|
||||||
|
/*channel_id_present=*/int32_type,
|
||||||
|
/*op_id=*/int64_type,
|
||||||
|
/*replica_groups=*/i8_ptr_type,
|
||||||
|
/*replica_groups_size=*/int32_type,
|
||||||
|
/*num_buffers=*/int32_type,
|
||||||
|
/*buffer_size=*/int64_type,
|
||||||
|
/*input_buffer=*/i8_ptr_type,
|
||||||
|
/*output_buffer=*/i8_ptr_type},
|
||||||
|
/*isVarArg=*/false);
|
||||||
|
auto all_to_all_func = llvm::dyn_cast<llvm::Function>(
|
||||||
|
module_
|
||||||
|
->getOrInsertFunction(runtime::kAllToAllSymbolName,
|
||||||
|
all_to_all_func_ty)
|
||||||
|
.getCallee());
|
||||||
|
all_to_all_func->setCallingConv(llvm::CallingConv::C);
|
||||||
|
|
||||||
|
std::string replica_groups =
|
||||||
|
ReplicaGroupsToString(instruction->replica_groups());
|
||||||
|
int32 replica_groups_size = replica_groups.size();
|
||||||
|
llvm::Value* replica_groups_v = b_.CreateGlobalStringPtr(replica_groups);
|
||||||
|
|
||||||
|
int64 buffer_size = -1;
|
||||||
|
std::vector<llvm::Value*> input_buffer_ptrs;
|
||||||
|
std::vector<llvm::Value*> output_buffer_ptrs;
|
||||||
|
|
||||||
|
for (int64 i = 0; i < instruction->operand_count(); i++) {
|
||||||
|
const HloInstruction* op = instruction->operand(i);
|
||||||
|
TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_slice,
|
||||||
|
assignment_.GetUniqueSlice(instruction, {i}));
|
||||||
|
const Shape& operand_shape = instruction->operand(i)->shape();
|
||||||
|
CHECK(operand_shape.IsArray())
|
||||||
|
<< "Operands to all-to-all must be arrays: " << instruction->ToString();
|
||||||
|
output_buffer_ptrs.push_back(EmitBufferPointer(out_slice, operand_shape));
|
||||||
|
input_buffer_ptrs.push_back(GetEmittedValueFor(op));
|
||||||
|
CHECK(buffer_size == -1 || buffer_size == out_slice.size());
|
||||||
|
buffer_size = out_slice.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::Value* input_buffers =
|
||||||
|
EncodeArrayFunctionArguments(input_buffer_ptrs, "input_buffers", &b_);
|
||||||
|
llvm::Value* output_buffers =
|
||||||
|
EncodeArrayFunctionArguments(output_buffer_ptrs, "output_buffers", &b_);
|
||||||
|
|
||||||
|
b_.CreateCall(
|
||||||
|
all_to_all_func,
|
||||||
|
{/*run_options=*/GetExecutableRunOptionsArgument(),
|
||||||
|
/*channel_id_present=*/
|
||||||
|
b_.getInt32(static_cast<int32>(instruction->channel_id().has_value())),
|
||||||
|
/*op_id=*/
|
||||||
|
b_.getInt64(instruction->channel_id().has_value()
|
||||||
|
? *instruction->channel_id()
|
||||||
|
: instruction->GetModule()->unique_id()),
|
||||||
|
/*replica_groups=*/replica_groups_v,
|
||||||
|
/*replica_groups_size=*/b_.getInt32(replica_groups_size),
|
||||||
|
/*num_buffers=*/b_.getInt32(instruction->operand_count()),
|
||||||
|
/*buffer_size=*/b_.getInt64(buffer_size),
|
||||||
|
/*source_buffers=*/b_.CreateBitCast(input_buffers, i8_ptr_type),
|
||||||
|
/*destination_buffers=*/b_.CreateBitCast(output_buffers, i8_ptr_type)});
|
||||||
|
|
||||||
|
llvm_ir::EmitTuple(GetIrArrayFor(instruction), output_buffer_ptrs, &b_);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
Status IrEmitter::HandleCollectivePermute(HloInstruction* crs) {
|
Status IrEmitter::HandleCollectivePermute(HloInstruction* crs) {
|
||||||
auto* instr = Cast<HloCollectivePermuteInstruction>(crs);
|
auto* instr = Cast<HloCollectivePermuteInstruction>(crs);
|
||||||
std::string source_target_pairs = absl::StrJoin(
|
std::string source_target_pairs = absl::StrJoin(
|
||||||
@ -2017,10 +2094,6 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) {
|
|||||||
return DefaultAction(reduce);
|
return DefaultAction(reduce);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status IrEmitter::HandleAllToAll(HloInstruction*) {
|
|
||||||
return Unimplemented("AllToAll is not implemented on CPU.");
|
|
||||||
}
|
|
||||||
|
|
||||||
Status IrEmitter::HandleSend(HloInstruction* send) {
|
Status IrEmitter::HandleSend(HloInstruction* send) {
|
||||||
// TODO(b/33942983): Support Send/Recv on CPU.
|
// TODO(b/33942983): Support Send/Recv on CPU.
|
||||||
return Unimplemented("Send is not implemented on CPU.");
|
return Unimplemented("Send is not implemented on CPU.");
|
||||||
@ -2749,10 +2822,10 @@ void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source,
|
|||||||
element_alignment);
|
element_alignment);
|
||||||
target_array.AnnotateLoadStoreInstructionWithMetadata(store_instruction);
|
target_array.AnnotateLoadStoreInstructionWithMetadata(store_instruction);
|
||||||
} else {
|
} else {
|
||||||
auto* memcpy_instruction =
|
auto* memcpy_instruction = b_.CreateMemCpy(
|
||||||
MemCpy(target, /*DstAlign=*/llvm::Align(element_alignment), source,
|
target, /*DstAlign=*/llvm::Align(element_alignment), source,
|
||||||
/*SrcAlign=*/llvm::Align(element_alignment),
|
/*SrcAlign=*/llvm::Align(element_alignment),
|
||||||
element_count * primitive_type_size);
|
element_count * primitive_type_size);
|
||||||
|
|
||||||
// The memcpy does the load and the store internally. The aliasing related
|
// The memcpy does the load and the store internally. The aliasing related
|
||||||
// metadata has to reflect that.
|
// metadata has to reflect that.
|
||||||
|
@ -45,6 +45,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h"
|
#include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h"
|
||||||
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
|
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
|
||||||
#include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h"
|
#include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
|
||||||
#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
|
#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
|
||||||
#include "tensorflow/compiler/xla/service/name_uniquer.h"
|
#include "tensorflow/compiler/xla/service/name_uniquer.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
|
@ -241,6 +241,7 @@ bool RegisterKnownJITSymbols() {
|
|||||||
REGISTER_CPU_RUNTIME_SYMBOL(AcquireOutfeedBufferForPopulation);
|
REGISTER_CPU_RUNTIME_SYMBOL(AcquireOutfeedBufferForPopulation);
|
||||||
REGISTER_CPU_RUNTIME_SYMBOL(AllReduce);
|
REGISTER_CPU_RUNTIME_SYMBOL(AllReduce);
|
||||||
REGISTER_CPU_RUNTIME_SYMBOL(CollectivePermute);
|
REGISTER_CPU_RUNTIME_SYMBOL(CollectivePermute);
|
||||||
|
REGISTER_CPU_RUNTIME_SYMBOL(AllToAll);
|
||||||
REGISTER_CPU_RUNTIME_SYMBOL(ReplicaId);
|
REGISTER_CPU_RUNTIME_SYMBOL(ReplicaId);
|
||||||
REGISTER_CPU_RUNTIME_SYMBOL(MKLConvF32);
|
REGISTER_CPU_RUNTIME_SYMBOL(MKLConvF32);
|
||||||
REGISTER_CPU_RUNTIME_SYMBOL(EigenConvF16);
|
REGISTER_CPU_RUNTIME_SYMBOL(EigenConvF16);
|
||||||
|
@ -108,7 +108,7 @@ class CollectiveOpsTest : public HloTestBase {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename LiteralType>
|
template <typename LiteralType>
|
||||||
void TestAllOps() {
|
void TestAllOpsForReduce() {
|
||||||
auto cast = [&](int value) { return static_cast<LiteralType>(value); };
|
auto cast = [&](int value) { return static_cast<LiteralType>(value); };
|
||||||
auto to_literal = [&](absl::Span<const LiteralType> values) {
|
auto to_literal = [&](absl::Span<const LiteralType> values) {
|
||||||
return LiteralUtil::CreateR1<LiteralType>(values);
|
return LiteralUtil::CreateR1<LiteralType>(values);
|
||||||
@ -183,39 +183,39 @@ XLA_TEST_F(CollectiveOpsTest, AllReduceSingleOutput_float32) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_int8) {
|
XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_int8) {
|
||||||
TestAllOps<int8>();
|
TestAllOpsForReduce<int8>();
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_uint8) {
|
XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_uint8) {
|
||||||
TestAllOps<uint8>();
|
TestAllOpsForReduce<uint8>();
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_uint32) {
|
XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_uint32) {
|
||||||
TestAllOps<uint32>();
|
TestAllOpsForReduce<uint32>();
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_int32) {
|
XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_int32) {
|
||||||
TestAllOps<int32>();
|
TestAllOpsForReduce<int32>();
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_int64) {
|
XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_int64) {
|
||||||
TestAllOps<int64>();
|
TestAllOpsForReduce<int64>();
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_uint64) {
|
XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_uint64) {
|
||||||
TestAllOps<uint64>();
|
TestAllOpsForReduce<uint64>();
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_float32) {
|
XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_float32) {
|
||||||
TestAllOps<float>();
|
TestAllOpsForReduce<float>();
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_double) {
|
XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_double) {
|
||||||
TestAllOps<double>();
|
TestAllOpsForReduce<double>();
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_half) {
|
XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_half) {
|
||||||
TestAllOps<Eigen::half>();
|
TestAllOpsForReduce<Eigen::half>();
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(CollectiveOpsTest, AllReduceAnd_Pred) {
|
XLA_TEST_F(CollectiveOpsTest, AllReduceAnd_Pred) {
|
||||||
@ -593,6 +593,98 @@ XLA_TEST_F(CollectiveOpsTest, CollectivePermute_Simple) {
|
|||||||
results[3]));
|
results[3]));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_GPU(AllToAll_EmptyReplicaGroups)) {
|
||||||
|
const char* const kModuleStr = R"(
|
||||||
|
HloModule test
|
||||||
|
ENTRY test_computation {
|
||||||
|
a = f32[2] constant({10, 10})
|
||||||
|
b = f32[2] constant({20, 20})
|
||||||
|
c = f32[2] constant({30, 30})
|
||||||
|
d = f32[2] constant({40, 40})
|
||||||
|
all2all = (f32[2], f32[2], f32[2], f32[2]) all-to-all(a, b, c, d), replica_groups={}
|
||||||
|
a_prime = f32[2] get-tuple-element(all2all), index=0
|
||||||
|
b_prime = f32[2] get-tuple-element(all2all), index=1
|
||||||
|
c_prime = f32[2] get-tuple-element(all2all), index=2
|
||||||
|
d_prime = f32[2] get-tuple-element(all2all), index=3
|
||||||
|
ROOT out = f32[8] concatenate(a_prime, b_prime, c_prime, d_prime), dimensions={0}
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
const int64 kNumReplicas = 4;
|
||||||
|
auto config = GetModuleConfigForTest(kNumReplicas);
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
|
ParseAndReturnVerifiedModule(kModuleStr, config));
|
||||||
|
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
|
||||||
|
ExecuteReplicated(std::move(module), {}, kNumReplicas,
|
||||||
|
/*use_threads=*/true));
|
||||||
|
ASSERT_EQ(results.size(), kNumReplicas);
|
||||||
|
for (int i = 0; i < kNumReplicas; i++) {
|
||||||
|
EXPECT_TRUE(LiteralTestUtil::NearOrEqual(
|
||||||
|
LiteralUtil::CreateR1<float>({10, 10, 20, 20, 30, 30, 40, 40}),
|
||||||
|
results[i], ErrorSpec{1e-5, 1e-5}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_GPU(AllToAll_OrderedReplicaGroups)) {
|
||||||
|
const char* const kModuleStr = R"(
|
||||||
|
HloModule test
|
||||||
|
ENTRY test_computation {
|
||||||
|
a = f32[2] constant({10, 10})
|
||||||
|
b = f32[2] constant({20, 20})
|
||||||
|
c = f32[2] constant({30, 30})
|
||||||
|
d = f32[2] constant({40, 40})
|
||||||
|
all2all = (f32[2], f32[2], f32[2], f32[2]) all-to-all(a, b, c, d), replica_groups={{3,2,1,0}}
|
||||||
|
a_prime = f32[2] get-tuple-element(all2all), index=0
|
||||||
|
b_prime = f32[2] get-tuple-element(all2all), index=1
|
||||||
|
c_prime = f32[2] get-tuple-element(all2all), index=2
|
||||||
|
d_prime = f32[2] get-tuple-element(all2all), index=3
|
||||||
|
ROOT out = f32[8] concatenate(a_prime, b_prime, c_prime, d_prime), dimensions={0}
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
const int64 kNumReplicas = 4;
|
||||||
|
auto config = GetModuleConfigForTest(kNumReplicas);
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
|
ParseAndReturnVerifiedModule(kModuleStr, config));
|
||||||
|
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
|
||||||
|
ExecuteReplicated(std::move(module), {}, kNumReplicas,
|
||||||
|
/*use_threads=*/true));
|
||||||
|
ASSERT_EQ(results.size(), kNumReplicas);
|
||||||
|
for (int i = 0; i < kNumReplicas; i++) {
|
||||||
|
EXPECT_TRUE(LiteralTestUtil::NearOrEqual(
|
||||||
|
LiteralUtil::CreateR1<float>({40, 40, 30, 30, 20, 20, 10, 10}),
|
||||||
|
results[i], ErrorSpec{1e-5, 1e-5}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_GPU(AllToAll_TwoReplicaGroups)) {
|
||||||
|
const char* const kModuleStr = R"(
|
||||||
|
HloModule test
|
||||||
|
ENTRY test_computation {
|
||||||
|
a = f32[2] constant({10, 10})
|
||||||
|
b = f32[2] constant({20, 20})
|
||||||
|
all2all = (f32[2], f32[2]) all-to-all(a, b), replica_groups={{2,1},{3,0}}
|
||||||
|
a_prime = f32[2] get-tuple-element(all2all), index=0
|
||||||
|
b_prime = f32[2] get-tuple-element(all2all), index=1
|
||||||
|
ROOT out = f32[4] concatenate(a_prime, b_prime), dimensions={0}
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
const int64 kNumReplicas = 4;
|
||||||
|
auto config = GetModuleConfigForTest(kNumReplicas);
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
|
ParseAndReturnVerifiedModule(kModuleStr, config));
|
||||||
|
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
|
||||||
|
ExecuteReplicated(std::move(module), {}, kNumReplicas,
|
||||||
|
/*use_threads=*/true));
|
||||||
|
ASSERT_EQ(results.size(), kNumReplicas);
|
||||||
|
for (int i = 0; i < kNumReplicas; i++) {
|
||||||
|
EXPECT_TRUE(LiteralTestUtil::NearOrEqual(
|
||||||
|
LiteralUtil::CreateR1<float>({20, 20, 10, 10}), results[i],
|
||||||
|
ErrorSpec{1e-5, 1e-5}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
XLA_TEST_F(CollectiveOpsTest, AllReduce_TupleAllReduce) {
|
XLA_TEST_F(CollectiveOpsTest, AllReduce_TupleAllReduce) {
|
||||||
std::string hlo_string = R"(
|
std::string hlo_string = R"(
|
||||||
HloModule test
|
HloModule test
|
||||||
|
Loading…
x
Reference in New Issue
Block a user