[XLA] Use array all to all in the client frontend

PiperOrigin-RevId: 354660949
Change-Id: I201e941172ab949d2632f83efc5cd80f4ebdddca
This commit is contained in:
Blake Hechtman 2021-01-29 20:46:11 -08:00 committed by TensorFlower Gardener
parent 4a0023ba49
commit 88b90a0a37
11 changed files with 342 additions and 25 deletions

View File

@ -34,6 +34,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/comparison_util.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
@ -2847,6 +2848,72 @@ XlaOp XlaBuilder::AllToAll(XlaOp operand, int64 split_dimension,
int64 concat_dimension, int64 split_count,
const std::vector<ReplicaGroup>& replica_groups,
const absl::optional<Layout>& layout) {
// Array all_to_all may need to violate layout constraint to be legal so use
// the tuple version.
if (layout.has_value()) {
return AllToAllTuple(operand, split_dimension, concat_dimension,
split_count, replica_groups, layout);
}
return AllToAllArray(operand, split_dimension, concat_dimension, split_count,
replica_groups);
}
XlaOp XlaBuilder::AllToAllArray(
XlaOp operand, int64 split_dimension, int64 concat_dimension,
int64 split_count, const std::vector<ReplicaGroup>& replica_groups) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
TF_ASSIGN_OR_RETURN(
const Shape all_to_all_shape,
ShapeInference::InferAllToAllShape(*operand_shape, split_dimension,
concat_dimension, split_count));
HloInstructionProto instr;
*instr.mutable_shape() = operand_shape->ToProto();
if (replica_groups.empty()) {
auto* group = instr.add_replica_groups();
for (int64 i = 0; i < split_count; ++i) {
group->add_replica_ids(i);
}
} else {
for (const ReplicaGroup& group : replica_groups) {
*instr.add_replica_groups() = group;
}
}
instr.add_dimensions(split_dimension);
TF_ASSIGN_OR_RETURN(
XlaOp all_to_all,
AddInstruction(std::move(instr), HloOpcode::kAllToAll, {operand}));
if (split_dimension == concat_dimension) {
return all_to_all;
}
DimensionVector sizes;
for (int64 i = 0; i < operand_shape->rank(); ++i) {
if (i != split_dimension) {
sizes.push_back(operand_shape->dimensions(i));
continue;
}
sizes.push_back(split_count);
sizes.push_back(operand_shape->dimensions(i) / split_count);
}
all_to_all = Reshape(all_to_all, sizes);
std::vector<int64> permutation;
for (int64 i = 0; i < operand_shape->rank(); ++i) {
int64 dim_after_reshape = i >= split_dimension ? i + 1 : i;
if (i == concat_dimension) {
permutation.push_back(split_dimension);
}
permutation.push_back(dim_after_reshape);
}
all_to_all = Transpose(all_to_all, permutation);
return Reshape(all_to_all_shape, all_to_all);
});
}
XlaOp XlaBuilder::AllToAllTuple(XlaOp operand, int64 split_dimension,
int64 concat_dimension, int64 split_count,
const std::vector<ReplicaGroup>& replica_groups,
const absl::optional<Layout>& layout) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
@ -4581,6 +4648,15 @@ XlaOp AllToAll(const XlaOp operand, int64 split_dimension,
split_count, replica_groups, layout);
}
XlaOp AllToAllTuple(const XlaOp operand, int64 split_dimension,
int64 concat_dimension, int64 split_count,
const std::vector<ReplicaGroup>& replica_groups,
const absl::optional<Layout>& layout) {
return operand.builder()->AllToAllTuple(operand, split_dimension,
concat_dimension, split_count,
replica_groups, layout);
}
XlaOp CollectivePermute(
const XlaOp operand,
const std::vector<std::pair<int64, int64>>& source_target_pairs) {

View File

@ -745,6 +745,11 @@ class XlaBuilder {
const std::vector<ReplicaGroup>& replica_groups,
const absl::optional<Layout>& layout = absl::nullopt);
XlaOp AllToAllTuple(XlaOp operand, int64 split_dimension,
int64 concat_dimension, int64 split_count,
const std::vector<ReplicaGroup>& replica_groups,
const absl::optional<Layout>& layout);
XlaOp CollectivePermute(
XlaOp operand,
const std::vector<std::pair<int64, int64>>& source_target_pairs);
@ -1297,6 +1302,10 @@ class XlaBuilder {
int64 concat_dimension, int64 split_count,
const std::vector<ReplicaGroup>& replica_groups,
const absl::optional<Layout>& layout);
friend XlaOp AllToAllTuple(XlaOp operand, int64 split_dimension,
int64 concat_dimension, int64 split_count,
const std::vector<ReplicaGroup>& replica_groups,
const absl::optional<Layout>& layout);
friend XlaOp CollectivePermute(
XlaOp operand,
const std::vector<std::pair<int64, int64>>& source_target_pairs);
@ -1425,6 +1434,10 @@ class XlaBuilder {
absl::Span<const XlaComputation* const> branch_computations,
absl::Span<const XlaOp> branch_operands);
XlaOp AllToAllArray(XlaOp operand, int64 split_dimension,
int64 concat_dimension, int64 split_count,
const std::vector<ReplicaGroup>& replica_groups);
// Creates an op with the given opcode and the output shape.
virtual StatusOr<XlaOp> AddOpWithShape(HloOpcode opcode, const Shape& shape,
absl::Span<const XlaOp> operands);
@ -2203,6 +2216,11 @@ XlaOp AllToAll(XlaOp operand, int64 split_dimension, int64 concat_dimension,
const std::vector<ReplicaGroup>& replica_groups = {},
const absl::optional<Layout>& layout = absl::nullopt);
XlaOp AllToAllTuple(XlaOp operand, int64 split_dimension,
int64 concat_dimension, int64 split_count,
const std::vector<ReplicaGroup>& replica_groups = {},
const absl::optional<Layout>& layout = absl::nullopt);
// Enqueues an collective operation that sends and receives data cross replicas.
//
// - `source_target_pair`: a list of (source_replica_id, target_replica_id)

View File

@ -414,12 +414,27 @@ TEST_F(XlaBuilderTest, AllToAll) {
auto root = module->entry_computation()->root_instruction();
// AllToAll is decomposed into slices -> all-to-all -> gte -> concat.
EXPECT_EQ(root->opcode(), HloOpcode::kConcatenate);
EXPECT_EQ(root->operand(0)->operand(0)->opcode(), HloOpcode::kAllToAll);
EXPECT_EQ(root->opcode(), HloOpcode::kReshape);
EXPECT_EQ(root->operand(0)->operand(0)->operand(0)->opcode(),
HloOpcode::kAllToAll);
EXPECT_TRUE(
ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {8, 8})));
}
TEST_F(XlaBuilderTest, AllToAllSpecial) {
XlaBuilder b(TestName());
auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16, 8}), "x");
AllToAll(x, /*split_dimension=*/0, /*concat_dimension=*/0,
/*split_count=*/2);
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
auto root = module->entry_computation()->root_instruction();
// AllToAll is converted into a single all-to-all HloInstruction.
EXPECT_EQ(root->opcode(), HloOpcode::kAllToAll);
EXPECT_TRUE(
ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {4, 16, 8})));
}
TEST_F(XlaBuilderTest, CollectivePermute) {
XlaBuilder b(TestName());
auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x");

View File

@ -2649,6 +2649,27 @@ tf_cc_test(
],
)
cc_library(
name = "all_to_all_decomposer",
srcs = ["all_to_all_decomposer.cc"],
hdrs = ["all_to_all_decomposer.h"],
deps = [
":hlo",
":hlo_casting_utils",
":hlo_pass",
":op_expander_pass",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
],
)
cc_library(
name = "all_gather_decomposer",
srcs = ["all_gather_decomposer.cc"],

View File

@ -0,0 +1,131 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/all_to_all_decomposer.h"
#include <vector>
#include "absl/algorithm/container.h"
#include "absl/strings/str_join.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
bool AllToAllDecomposer::InstructionMatchesPattern(
HloInstruction* instruction) {
auto* all_to_all = DynCast<HloAllToAllInstruction>(instruction);
if (all_to_all == nullptr) {
return false;
}
// Do not attempt to change layout constrained collectives.
if (all_to_all->constrain_layout()) {
return false;
}
if (all_to_all->shape().IsTuple()) {
return false;
}
if (decompose_to_tuple_) {
return true;
}
return all_to_all->shape().rank() < min_array_rank_;
}
StatusOr<HloInstruction*> AllToAllDecomposer::ExpandInstruction(
HloInstruction* instruction) {
auto* all_to_all = Cast<HloAllToAllInstruction>(instruction);
int64 split_dim = *all_to_all->split_dimension();
int64 all_to_all_group_size =
all_to_all->replica_groups().empty()
? instruction->parent()->parent()->config().replica_count()
: all_to_all->replica_groups()[0].replica_ids_size();
int64 split_size =
all_to_all->shape().dimensions(split_dim) / all_to_all_group_size;
if (!decompose_to_tuple_) {
Shape new_all_to_all_shape;
new_all_to_all_shape.set_element_type(
instruction->operand(0)->shape().element_type());
for (int64 i = 0; i < instruction->shape().rank(); ++i) {
if (i != split_dim) {
new_all_to_all_shape.add_dimensions(all_to_all->shape().dimensions(i));
continue;
}
new_all_to_all_shape.add_dimensions(all_to_all_group_size);
new_all_to_all_shape.add_dimensions(split_size);
for (int64 j = all_to_all->shape().rank() + 1; j < min_array_rank_; ++j) {
new_all_to_all_shape.add_dimensions(1);
}
}
*(new_all_to_all_shape.mutable_layout()) =
LayoutUtil::GetDefaultLayoutForRank(min_array_rank_);
HloInstruction* operand_reshape =
instruction->parent()->AddInstruction(HloInstruction::CreateReshape(
new_all_to_all_shape, instruction->mutable_operand(0)));
instruction->SetupDerivedInstruction(operand_reshape);
HloInstruction* all_to_all =
instruction->parent()->AddInstruction(instruction->CloneWithNewOperands(
new_all_to_all_shape, {operand_reshape}));
HloInstruction* output_reshape = instruction->parent()->AddInstruction(
HloInstruction::CreateReshape(instruction->shape(), all_to_all));
instruction->SetupDerivedInstruction(output_reshape);
return output_reshape;
}
DimensionVector slice_starts(all_to_all->shape().rank(), 0);
DimensionVector slice_strides(all_to_all->shape().rank(), 1);
DimensionVector slice_limits(all_to_all->shape().dimensions().begin(),
all_to_all->shape().dimensions().end());
slice_limits[split_dim] = split_size;
Shape slice_shape = all_to_all->shape();
slice_shape.set_dimensions(split_dim, split_size);
std::vector<HloInstruction*> slices;
slices.reserve(all_to_all_group_size);
HloInstruction* operand = all_to_all->mutable_operand(0);
for (int64 i = 0; i < all_to_all_group_size; ++i) {
slices.push_back(
all_to_all->parent()->AddInstruction(HloInstruction::CreateSlice(
slice_shape, operand, slice_starts, slice_limits, slice_strides)));
all_to_all->SetupDerivedInstruction(slices.back());
slice_starts[split_dim] = slice_limits[split_dim];
slice_limits[split_dim] += split_size;
}
Shape all_to_all_shape = ShapeUtil::MakeTupleShape(
std::vector<Shape>(all_to_all_group_size, slice_shape));
HloInstruction* new_all_to_all =
all_to_all->parent()->AddInstruction(HloInstruction::CreateAllToAll(
all_to_all_shape, slices, all_to_all->replica_groups(), false,
all_to_all->channel_id(), absl::nullopt));
std::vector<HloInstruction*> gtes;
gtes.reserve(all_to_all_group_size);
for (int64 i = 0; i < all_to_all_group_size; ++i) {
gtes.push_back(all_to_all->parent()->AddInstruction(
HloInstruction::CreateGetTupleElement(slice_shape, new_all_to_all, i)));
all_to_all->SetupDerivedInstruction(new_all_to_all);
}
HloInstruction* concat = all_to_all->parent()->AddInstruction(
HloInstruction::CreateConcatenate(all_to_all->shape(), gtes, split_dim));
all_to_all->SetupDerivedInstruction(concat);
return concat;
}
} // namespace xla

View File

@ -0,0 +1,49 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_ALL_TO_ALL_DECOMPOSER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_ALL_TO_ALL_DECOMPOSER_H_
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
#include "tensorflow/compiler/xla/service/op_expander_pass.h"
namespace xla {
// AllToAllDecomposer is a pass which converts unsupported array all_to_all
// into tuple all_to_all or array all_to_all with a minimum rank where the split
// dimension is the size of the replica_groups.
class AllToAllDecomposer : public OpExpanderPass {
public:
explicit AllToAllDecomposer(bool decompose_to_tuple = true,
int64 min_array_rank = 0)
: decompose_to_tuple_(decompose_to_tuple),
min_array_rank_(min_array_rank) {}
~AllToAllDecomposer() override = default;
absl::string_view name() const override { return "all_to_all_decomposer"; }
private:
bool InstructionMatchesPattern(HloInstruction* instruction) override;
StatusOr<HloInstruction*> ExpandInstruction(
HloInstruction* instruction) override;
bool decompose_to_tuple_;
int64 min_array_rank_;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_ALL_TO_ALL_DECOMPOSER_H_

View File

@ -157,6 +157,7 @@ cc_library(
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:VectorOps",
"//tensorflow/compiler/xla/service:all_gather_decomposer",
"//tensorflow/compiler/xla/service:all_to_all_decomposer",
"//tensorflow/compiler/xla/service:copy_insertion",
"//tensorflow/compiler/xla/service:dump",
"//tensorflow/compiler/xla/service:topk_rewriter",

View File

@ -57,6 +57,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/protobuf_util.h"
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
#include "tensorflow/compiler/xla/service/all_gather_decomposer.h"
#include "tensorflow/compiler/xla/service/all_to_all_decomposer.h"
#include "tensorflow/compiler/xla/service/batch_dot_simplification.h"
#include "tensorflow/compiler/xla/service/batchnorm_expander.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
@ -292,6 +293,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
pipeline.AddPass<QrExpander>();
pipeline.AddPass<TriangularSolveExpander>();
pipeline.AddPass<AllGatherDecomposer>();
pipeline.AddPass<AllToAllDecomposer>();
// Inline computations with a single call site.
pipeline.AddPass<CallInliner>(/*single_call_site=*/true);

View File

@ -1427,6 +1427,7 @@ cc_library(
"//tensorflow/compiler/xla/service:algebraic_simplifier",
"//tensorflow/compiler/xla/service:all_gather_decomposer",
"//tensorflow/compiler/xla/service:all_reduce_combiner",
"//tensorflow/compiler/xla/service:all_to_all_decomposer",
"//tensorflow/compiler/xla/service:batchnorm_expander",
"//tensorflow/compiler/xla/service:buffer_assignment",
"//tensorflow/compiler/xla/service:call_inliner",

View File

@ -40,6 +40,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
#include "tensorflow/compiler/xla/service/all_gather_decomposer.h"
#include "tensorflow/compiler/xla/service/all_reduce_combiner.h"
#include "tensorflow/compiler/xla/service/all_to_all_decomposer.h"
#include "tensorflow/compiler/xla/service/batchnorm_expander.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/call_inliner.h"
@ -157,6 +158,7 @@ Status GpuCompiler::OptimizeHloModule(
[](const HloAllGatherInstruction& ag) {
return !NcclAllGatherThunk::CanImplement(&ag);
});
pipeline.AddPass<AllToAllDecomposer>();
pipeline.AddPass<OperandUpcaster>();
@ -804,30 +806,30 @@ GpuCompiler::CompileToTargetBinary(const HloModuleConfig& module_config,
llvm_modules.size());
tensorflow::BlockingCounter counter(llvm_modules.size());
for (int i = 0; i < llvm_modules.size(); i++) {
thread_pool->Schedule([&compile_results, compile_single_module, i,
&llvm_modules, &counter] {
llvm::Module* original_module = llvm_modules[i].get();
llvm::LLVMContext context;
std::string buffer;
llvm::raw_string_ostream error(buffer);
thread_pool->Schedule(
[&compile_results, compile_single_module, i, &llvm_modules, &counter] {
llvm::Module* original_module = llvm_modules[i].get();
llvm::LLVMContext context;
std::string buffer;
llvm::raw_string_ostream error(buffer);
std::unique_ptr<llvm::Module> new_llvm_module;
// Switch to a new context by dumping and re-parsing LLVM IR. Each thread
// has its own context to avoid race conditions.
{
std::string ir;
{
llvm::raw_string_ostream os(ir);
original_module->print(os, nullptr);
}
llvm::SMDiagnostic err;
new_llvm_module = llvm::parseAssemblyString(ir, err, context);
}
std::unique_ptr<llvm::Module> new_llvm_module;
// Switch to a new context by dumping and re-parsing LLVM IR. Each
// thread has its own context to avoid race conditions.
{
std::string ir;
{
llvm::raw_string_ostream os(ir);
original_module->print(os, nullptr);
}
llvm::SMDiagnostic err;
new_llvm_module = llvm::parseAssemblyString(ir, err, context);
}
compile_results[i] = compile_single_module(
new_llvm_module.get(), /*relocatable=*/true, /*shard_number=*/i);
counter.DecrementCount();
});
compile_results[i] = compile_single_module(
new_llvm_module.get(), /*relocatable=*/true, /*shard_number=*/i);
counter.DecrementCount();
});
}
counter.Wait();

View File

@ -229,7 +229,8 @@ static Status CheckReplicaGroups(HloInstruction* hlo) {
}
int64 replica_count = hlo->GetModule()->config().replica_count();
if (!replicas_seen.empty() && replicas_seen.size() != replica_count) {
if (replica_count != 1 && !replicas_seen.empty() &&
replicas_seen.size() != replica_count) {
return InternalError(
"Replica count in HloModuleConfig is %d, but ReplicaGroup config "
"contains %d replicas: %s",