[XLA] Use array all to all in the client frontend
PiperOrigin-RevId: 354660949 Change-Id: I201e941172ab949d2632f83efc5cd80f4ebdddca
This commit is contained in:
parent
4a0023ba49
commit
88b90a0a37
@ -34,6 +34,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||||
#include "tensorflow/compiler/xla/comparison_util.h"
|
#include "tensorflow/compiler/xla/comparison_util.h"
|
||||||
#include "tensorflow/compiler/xla/execution_options_util.h"
|
#include "tensorflow/compiler/xla/execution_options_util.h"
|
||||||
|
#include "tensorflow/compiler/xla/layout_util.h"
|
||||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
|
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
|
||||||
@ -2847,6 +2848,72 @@ XlaOp XlaBuilder::AllToAll(XlaOp operand, int64 split_dimension,
|
|||||||
int64 concat_dimension, int64 split_count,
|
int64 concat_dimension, int64 split_count,
|
||||||
const std::vector<ReplicaGroup>& replica_groups,
|
const std::vector<ReplicaGroup>& replica_groups,
|
||||||
const absl::optional<Layout>& layout) {
|
const absl::optional<Layout>& layout) {
|
||||||
|
// Array all_to_all may need to violate layout constraint to be legal so use
|
||||||
|
// the tuple version.
|
||||||
|
if (layout.has_value()) {
|
||||||
|
return AllToAllTuple(operand, split_dimension, concat_dimension,
|
||||||
|
split_count, replica_groups, layout);
|
||||||
|
}
|
||||||
|
return AllToAllArray(operand, split_dimension, concat_dimension, split_count,
|
||||||
|
replica_groups);
|
||||||
|
}
|
||||||
|
|
||||||
|
XlaOp XlaBuilder::AllToAllArray(
|
||||||
|
XlaOp operand, int64 split_dimension, int64 concat_dimension,
|
||||||
|
int64 split_count, const std::vector<ReplicaGroup>& replica_groups) {
|
||||||
|
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||||
|
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
const Shape all_to_all_shape,
|
||||||
|
ShapeInference::InferAllToAllShape(*operand_shape, split_dimension,
|
||||||
|
concat_dimension, split_count));
|
||||||
|
HloInstructionProto instr;
|
||||||
|
*instr.mutable_shape() = operand_shape->ToProto();
|
||||||
|
if (replica_groups.empty()) {
|
||||||
|
auto* group = instr.add_replica_groups();
|
||||||
|
for (int64 i = 0; i < split_count; ++i) {
|
||||||
|
group->add_replica_ids(i);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (const ReplicaGroup& group : replica_groups) {
|
||||||
|
*instr.add_replica_groups() = group;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
instr.add_dimensions(split_dimension);
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
XlaOp all_to_all,
|
||||||
|
AddInstruction(std::move(instr), HloOpcode::kAllToAll, {operand}));
|
||||||
|
if (split_dimension == concat_dimension) {
|
||||||
|
return all_to_all;
|
||||||
|
}
|
||||||
|
DimensionVector sizes;
|
||||||
|
for (int64 i = 0; i < operand_shape->rank(); ++i) {
|
||||||
|
if (i != split_dimension) {
|
||||||
|
sizes.push_back(operand_shape->dimensions(i));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
sizes.push_back(split_count);
|
||||||
|
sizes.push_back(operand_shape->dimensions(i) / split_count);
|
||||||
|
}
|
||||||
|
all_to_all = Reshape(all_to_all, sizes);
|
||||||
|
|
||||||
|
std::vector<int64> permutation;
|
||||||
|
for (int64 i = 0; i < operand_shape->rank(); ++i) {
|
||||||
|
int64 dim_after_reshape = i >= split_dimension ? i + 1 : i;
|
||||||
|
if (i == concat_dimension) {
|
||||||
|
permutation.push_back(split_dimension);
|
||||||
|
}
|
||||||
|
permutation.push_back(dim_after_reshape);
|
||||||
|
}
|
||||||
|
all_to_all = Transpose(all_to_all, permutation);
|
||||||
|
return Reshape(all_to_all_shape, all_to_all);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
XlaOp XlaBuilder::AllToAllTuple(XlaOp operand, int64 split_dimension,
|
||||||
|
int64 concat_dimension, int64 split_count,
|
||||||
|
const std::vector<ReplicaGroup>& replica_groups,
|
||||||
|
const absl::optional<Layout>& layout) {
|
||||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||||
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
|
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
|
||||||
|
|
||||||
@ -4581,6 +4648,15 @@ XlaOp AllToAll(const XlaOp operand, int64 split_dimension,
|
|||||||
split_count, replica_groups, layout);
|
split_count, replica_groups, layout);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
XlaOp AllToAllTuple(const XlaOp operand, int64 split_dimension,
|
||||||
|
int64 concat_dimension, int64 split_count,
|
||||||
|
const std::vector<ReplicaGroup>& replica_groups,
|
||||||
|
const absl::optional<Layout>& layout) {
|
||||||
|
return operand.builder()->AllToAllTuple(operand, split_dimension,
|
||||||
|
concat_dimension, split_count,
|
||||||
|
replica_groups, layout);
|
||||||
|
}
|
||||||
|
|
||||||
XlaOp CollectivePermute(
|
XlaOp CollectivePermute(
|
||||||
const XlaOp operand,
|
const XlaOp operand,
|
||||||
const std::vector<std::pair<int64, int64>>& source_target_pairs) {
|
const std::vector<std::pair<int64, int64>>& source_target_pairs) {
|
||||||
|
@ -745,6 +745,11 @@ class XlaBuilder {
|
|||||||
const std::vector<ReplicaGroup>& replica_groups,
|
const std::vector<ReplicaGroup>& replica_groups,
|
||||||
const absl::optional<Layout>& layout = absl::nullopt);
|
const absl::optional<Layout>& layout = absl::nullopt);
|
||||||
|
|
||||||
|
XlaOp AllToAllTuple(XlaOp operand, int64 split_dimension,
|
||||||
|
int64 concat_dimension, int64 split_count,
|
||||||
|
const std::vector<ReplicaGroup>& replica_groups,
|
||||||
|
const absl::optional<Layout>& layout);
|
||||||
|
|
||||||
XlaOp CollectivePermute(
|
XlaOp CollectivePermute(
|
||||||
XlaOp operand,
|
XlaOp operand,
|
||||||
const std::vector<std::pair<int64, int64>>& source_target_pairs);
|
const std::vector<std::pair<int64, int64>>& source_target_pairs);
|
||||||
@ -1297,6 +1302,10 @@ class XlaBuilder {
|
|||||||
int64 concat_dimension, int64 split_count,
|
int64 concat_dimension, int64 split_count,
|
||||||
const std::vector<ReplicaGroup>& replica_groups,
|
const std::vector<ReplicaGroup>& replica_groups,
|
||||||
const absl::optional<Layout>& layout);
|
const absl::optional<Layout>& layout);
|
||||||
|
friend XlaOp AllToAllTuple(XlaOp operand, int64 split_dimension,
|
||||||
|
int64 concat_dimension, int64 split_count,
|
||||||
|
const std::vector<ReplicaGroup>& replica_groups,
|
||||||
|
const absl::optional<Layout>& layout);
|
||||||
friend XlaOp CollectivePermute(
|
friend XlaOp CollectivePermute(
|
||||||
XlaOp operand,
|
XlaOp operand,
|
||||||
const std::vector<std::pair<int64, int64>>& source_target_pairs);
|
const std::vector<std::pair<int64, int64>>& source_target_pairs);
|
||||||
@ -1425,6 +1434,10 @@ class XlaBuilder {
|
|||||||
absl::Span<const XlaComputation* const> branch_computations,
|
absl::Span<const XlaComputation* const> branch_computations,
|
||||||
absl::Span<const XlaOp> branch_operands);
|
absl::Span<const XlaOp> branch_operands);
|
||||||
|
|
||||||
|
XlaOp AllToAllArray(XlaOp operand, int64 split_dimension,
|
||||||
|
int64 concat_dimension, int64 split_count,
|
||||||
|
const std::vector<ReplicaGroup>& replica_groups);
|
||||||
|
|
||||||
// Creates an op with the given opcode and the output shape.
|
// Creates an op with the given opcode and the output shape.
|
||||||
virtual StatusOr<XlaOp> AddOpWithShape(HloOpcode opcode, const Shape& shape,
|
virtual StatusOr<XlaOp> AddOpWithShape(HloOpcode opcode, const Shape& shape,
|
||||||
absl::Span<const XlaOp> operands);
|
absl::Span<const XlaOp> operands);
|
||||||
@ -2203,6 +2216,11 @@ XlaOp AllToAll(XlaOp operand, int64 split_dimension, int64 concat_dimension,
|
|||||||
const std::vector<ReplicaGroup>& replica_groups = {},
|
const std::vector<ReplicaGroup>& replica_groups = {},
|
||||||
const absl::optional<Layout>& layout = absl::nullopt);
|
const absl::optional<Layout>& layout = absl::nullopt);
|
||||||
|
|
||||||
|
XlaOp AllToAllTuple(XlaOp operand, int64 split_dimension,
|
||||||
|
int64 concat_dimension, int64 split_count,
|
||||||
|
const std::vector<ReplicaGroup>& replica_groups = {},
|
||||||
|
const absl::optional<Layout>& layout = absl::nullopt);
|
||||||
|
|
||||||
// Enqueues an collective operation that sends and receives data cross replicas.
|
// Enqueues an collective operation that sends and receives data cross replicas.
|
||||||
//
|
//
|
||||||
// - `source_target_pair`: a list of (source_replica_id, target_replica_id)
|
// - `source_target_pair`: a list of (source_replica_id, target_replica_id)
|
||||||
|
@ -414,12 +414,27 @@ TEST_F(XlaBuilderTest, AllToAll) {
|
|||||||
auto root = module->entry_computation()->root_instruction();
|
auto root = module->entry_computation()->root_instruction();
|
||||||
|
|
||||||
// AllToAll is decomposed into slices -> all-to-all -> gte -> concat.
|
// AllToAll is decomposed into slices -> all-to-all -> gte -> concat.
|
||||||
EXPECT_EQ(root->opcode(), HloOpcode::kConcatenate);
|
EXPECT_EQ(root->opcode(), HloOpcode::kReshape);
|
||||||
EXPECT_EQ(root->operand(0)->operand(0)->opcode(), HloOpcode::kAllToAll);
|
EXPECT_EQ(root->operand(0)->operand(0)->operand(0)->opcode(),
|
||||||
|
HloOpcode::kAllToAll);
|
||||||
EXPECT_TRUE(
|
EXPECT_TRUE(
|
||||||
ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {8, 8})));
|
ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {8, 8})));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(XlaBuilderTest, AllToAllSpecial) {
|
||||||
|
XlaBuilder b(TestName());
|
||||||
|
auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16, 8}), "x");
|
||||||
|
AllToAll(x, /*split_dimension=*/0, /*concat_dimension=*/0,
|
||||||
|
/*split_count=*/2);
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
|
||||||
|
auto root = module->entry_computation()->root_instruction();
|
||||||
|
|
||||||
|
// AllToAll is converted into a single all-to-all HloInstruction.
|
||||||
|
EXPECT_EQ(root->opcode(), HloOpcode::kAllToAll);
|
||||||
|
EXPECT_TRUE(
|
||||||
|
ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {4, 16, 8})));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(XlaBuilderTest, CollectivePermute) {
|
TEST_F(XlaBuilderTest, CollectivePermute) {
|
||||||
XlaBuilder b(TestName());
|
XlaBuilder b(TestName());
|
||||||
auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x");
|
auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x");
|
||||||
|
@ -2649,6 +2649,27 @@ tf_cc_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "all_to_all_decomposer",
|
||||||
|
srcs = ["all_to_all_decomposer.cc"],
|
||||||
|
hdrs = ["all_to_all_decomposer.h"],
|
||||||
|
deps = [
|
||||||
|
":hlo",
|
||||||
|
":hlo_casting_utils",
|
||||||
|
":hlo_pass",
|
||||||
|
":op_expander_pass",
|
||||||
|
"//tensorflow/compiler/xla:literal_util",
|
||||||
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
|
"//tensorflow/compiler/xla:status_macros",
|
||||||
|
"//tensorflow/compiler/xla:types",
|
||||||
|
"//tensorflow/compiler/xla:util",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"@com_google_absl//absl/algorithm:container",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_absl//absl/types:optional",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "all_gather_decomposer",
|
name = "all_gather_decomposer",
|
||||||
srcs = ["all_gather_decomposer.cc"],
|
srcs = ["all_gather_decomposer.cc"],
|
||||||
|
131
tensorflow/compiler/xla/service/all_to_all_decomposer.cc
Normal file
131
tensorflow/compiler/xla/service/all_to_all_decomposer.cc
Normal file
@ -0,0 +1,131 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/service/all_to_all_decomposer.h"
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/algorithm/container.h"
|
||||||
|
#include "absl/strings/str_join.h"
|
||||||
|
#include "absl/types/optional.h"
|
||||||
|
#include "tensorflow/compiler/xla/layout_util.h"
|
||||||
|
#include "tensorflow/compiler/xla/literal_util.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||||
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
|
#include "tensorflow/compiler/xla/status_macros.h"
|
||||||
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
bool AllToAllDecomposer::InstructionMatchesPattern(
|
||||||
|
HloInstruction* instruction) {
|
||||||
|
auto* all_to_all = DynCast<HloAllToAllInstruction>(instruction);
|
||||||
|
if (all_to_all == nullptr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// Do not attempt to change layout constrained collectives.
|
||||||
|
if (all_to_all->constrain_layout()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (all_to_all->shape().IsTuple()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (decompose_to_tuple_) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return all_to_all->shape().rank() < min_array_rank_;
|
||||||
|
}
|
||||||
|
StatusOr<HloInstruction*> AllToAllDecomposer::ExpandInstruction(
|
||||||
|
HloInstruction* instruction) {
|
||||||
|
auto* all_to_all = Cast<HloAllToAllInstruction>(instruction);
|
||||||
|
int64 split_dim = *all_to_all->split_dimension();
|
||||||
|
int64 all_to_all_group_size =
|
||||||
|
all_to_all->replica_groups().empty()
|
||||||
|
? instruction->parent()->parent()->config().replica_count()
|
||||||
|
: all_to_all->replica_groups()[0].replica_ids_size();
|
||||||
|
int64 split_size =
|
||||||
|
all_to_all->shape().dimensions(split_dim) / all_to_all_group_size;
|
||||||
|
if (!decompose_to_tuple_) {
|
||||||
|
Shape new_all_to_all_shape;
|
||||||
|
new_all_to_all_shape.set_element_type(
|
||||||
|
instruction->operand(0)->shape().element_type());
|
||||||
|
for (int64 i = 0; i < instruction->shape().rank(); ++i) {
|
||||||
|
if (i != split_dim) {
|
||||||
|
new_all_to_all_shape.add_dimensions(all_to_all->shape().dimensions(i));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
new_all_to_all_shape.add_dimensions(all_to_all_group_size);
|
||||||
|
new_all_to_all_shape.add_dimensions(split_size);
|
||||||
|
for (int64 j = all_to_all->shape().rank() + 1; j < min_array_rank_; ++j) {
|
||||||
|
new_all_to_all_shape.add_dimensions(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*(new_all_to_all_shape.mutable_layout()) =
|
||||||
|
LayoutUtil::GetDefaultLayoutForRank(min_array_rank_);
|
||||||
|
HloInstruction* operand_reshape =
|
||||||
|
instruction->parent()->AddInstruction(HloInstruction::CreateReshape(
|
||||||
|
new_all_to_all_shape, instruction->mutable_operand(0)));
|
||||||
|
instruction->SetupDerivedInstruction(operand_reshape);
|
||||||
|
HloInstruction* all_to_all =
|
||||||
|
instruction->parent()->AddInstruction(instruction->CloneWithNewOperands(
|
||||||
|
new_all_to_all_shape, {operand_reshape}));
|
||||||
|
HloInstruction* output_reshape = instruction->parent()->AddInstruction(
|
||||||
|
HloInstruction::CreateReshape(instruction->shape(), all_to_all));
|
||||||
|
instruction->SetupDerivedInstruction(output_reshape);
|
||||||
|
return output_reshape;
|
||||||
|
}
|
||||||
|
DimensionVector slice_starts(all_to_all->shape().rank(), 0);
|
||||||
|
DimensionVector slice_strides(all_to_all->shape().rank(), 1);
|
||||||
|
DimensionVector slice_limits(all_to_all->shape().dimensions().begin(),
|
||||||
|
all_to_all->shape().dimensions().end());
|
||||||
|
slice_limits[split_dim] = split_size;
|
||||||
|
Shape slice_shape = all_to_all->shape();
|
||||||
|
slice_shape.set_dimensions(split_dim, split_size);
|
||||||
|
std::vector<HloInstruction*> slices;
|
||||||
|
slices.reserve(all_to_all_group_size);
|
||||||
|
HloInstruction* operand = all_to_all->mutable_operand(0);
|
||||||
|
for (int64 i = 0; i < all_to_all_group_size; ++i) {
|
||||||
|
slices.push_back(
|
||||||
|
all_to_all->parent()->AddInstruction(HloInstruction::CreateSlice(
|
||||||
|
slice_shape, operand, slice_starts, slice_limits, slice_strides)));
|
||||||
|
all_to_all->SetupDerivedInstruction(slices.back());
|
||||||
|
slice_starts[split_dim] = slice_limits[split_dim];
|
||||||
|
slice_limits[split_dim] += split_size;
|
||||||
|
}
|
||||||
|
Shape all_to_all_shape = ShapeUtil::MakeTupleShape(
|
||||||
|
std::vector<Shape>(all_to_all_group_size, slice_shape));
|
||||||
|
HloInstruction* new_all_to_all =
|
||||||
|
all_to_all->parent()->AddInstruction(HloInstruction::CreateAllToAll(
|
||||||
|
all_to_all_shape, slices, all_to_all->replica_groups(), false,
|
||||||
|
all_to_all->channel_id(), absl::nullopt));
|
||||||
|
std::vector<HloInstruction*> gtes;
|
||||||
|
gtes.reserve(all_to_all_group_size);
|
||||||
|
for (int64 i = 0; i < all_to_all_group_size; ++i) {
|
||||||
|
gtes.push_back(all_to_all->parent()->AddInstruction(
|
||||||
|
HloInstruction::CreateGetTupleElement(slice_shape, new_all_to_all, i)));
|
||||||
|
all_to_all->SetupDerivedInstruction(new_all_to_all);
|
||||||
|
}
|
||||||
|
HloInstruction* concat = all_to_all->parent()->AddInstruction(
|
||||||
|
HloInstruction::CreateConcatenate(all_to_all->shape(), gtes, split_dim));
|
||||||
|
all_to_all->SetupDerivedInstruction(concat);
|
||||||
|
return concat;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace xla
|
49
tensorflow/compiler/xla/service/all_to_all_decomposer.h
Normal file
49
tensorflow/compiler/xla/service/all_to_all_decomposer.h
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_ALL_TO_ALL_DECOMPOSER_H_
|
||||||
|
#define TENSORFLOW_COMPILER_XLA_SERVICE_ALL_TO_ALL_DECOMPOSER_H_
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/op_expander_pass.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
|
||||||
|
// AllToAllDecomposer is a pass which converts unsupported array all_to_all
|
||||||
|
// into tuple all_to_all or array all_to_all with a minimum rank where the split
|
||||||
|
// dimension is the size of the replica_groups.
|
||||||
|
class AllToAllDecomposer : public OpExpanderPass {
|
||||||
|
public:
|
||||||
|
explicit AllToAllDecomposer(bool decompose_to_tuple = true,
|
||||||
|
int64 min_array_rank = 0)
|
||||||
|
: decompose_to_tuple_(decompose_to_tuple),
|
||||||
|
min_array_rank_(min_array_rank) {}
|
||||||
|
~AllToAllDecomposer() override = default;
|
||||||
|
absl::string_view name() const override { return "all_to_all_decomposer"; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool InstructionMatchesPattern(HloInstruction* instruction) override;
|
||||||
|
StatusOr<HloInstruction*> ExpandInstruction(
|
||||||
|
HloInstruction* instruction) override;
|
||||||
|
bool decompose_to_tuple_;
|
||||||
|
int64 min_array_rank_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace xla
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_ALL_TO_ALL_DECOMPOSER_H_
|
@ -157,6 +157,7 @@ cc_library(
|
|||||||
"@llvm-project//mlir:StandardOps",
|
"@llvm-project//mlir:StandardOps",
|
||||||
"@llvm-project//mlir:VectorOps",
|
"@llvm-project//mlir:VectorOps",
|
||||||
"//tensorflow/compiler/xla/service:all_gather_decomposer",
|
"//tensorflow/compiler/xla/service:all_gather_decomposer",
|
||||||
|
"//tensorflow/compiler/xla/service:all_to_all_decomposer",
|
||||||
"//tensorflow/compiler/xla/service:copy_insertion",
|
"//tensorflow/compiler/xla/service:copy_insertion",
|
||||||
"//tensorflow/compiler/xla/service:dump",
|
"//tensorflow/compiler/xla/service:dump",
|
||||||
"//tensorflow/compiler/xla/service:topk_rewriter",
|
"//tensorflow/compiler/xla/service:topk_rewriter",
|
||||||
|
@ -57,6 +57,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/protobuf_util.h"
|
#include "tensorflow/compiler/xla/protobuf_util.h"
|
||||||
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
|
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
|
||||||
#include "tensorflow/compiler/xla/service/all_gather_decomposer.h"
|
#include "tensorflow/compiler/xla/service/all_gather_decomposer.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/all_to_all_decomposer.h"
|
||||||
#include "tensorflow/compiler/xla/service/batch_dot_simplification.h"
|
#include "tensorflow/compiler/xla/service/batch_dot_simplification.h"
|
||||||
#include "tensorflow/compiler/xla/service/batchnorm_expander.h"
|
#include "tensorflow/compiler/xla/service/batchnorm_expander.h"
|
||||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||||
@ -292,6 +293,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
|
|||||||
pipeline.AddPass<QrExpander>();
|
pipeline.AddPass<QrExpander>();
|
||||||
pipeline.AddPass<TriangularSolveExpander>();
|
pipeline.AddPass<TriangularSolveExpander>();
|
||||||
pipeline.AddPass<AllGatherDecomposer>();
|
pipeline.AddPass<AllGatherDecomposer>();
|
||||||
|
pipeline.AddPass<AllToAllDecomposer>();
|
||||||
|
|
||||||
// Inline computations with a single call site.
|
// Inline computations with a single call site.
|
||||||
pipeline.AddPass<CallInliner>(/*single_call_site=*/true);
|
pipeline.AddPass<CallInliner>(/*single_call_site=*/true);
|
||||||
|
@ -1427,6 +1427,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla/service:algebraic_simplifier",
|
"//tensorflow/compiler/xla/service:algebraic_simplifier",
|
||||||
"//tensorflow/compiler/xla/service:all_gather_decomposer",
|
"//tensorflow/compiler/xla/service:all_gather_decomposer",
|
||||||
"//tensorflow/compiler/xla/service:all_reduce_combiner",
|
"//tensorflow/compiler/xla/service:all_reduce_combiner",
|
||||||
|
"//tensorflow/compiler/xla/service:all_to_all_decomposer",
|
||||||
"//tensorflow/compiler/xla/service:batchnorm_expander",
|
"//tensorflow/compiler/xla/service:batchnorm_expander",
|
||||||
"//tensorflow/compiler/xla/service:buffer_assignment",
|
"//tensorflow/compiler/xla/service:buffer_assignment",
|
||||||
"//tensorflow/compiler/xla/service:call_inliner",
|
"//tensorflow/compiler/xla/service:call_inliner",
|
||||||
|
@ -40,6 +40,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
|
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
|
||||||
#include "tensorflow/compiler/xla/service/all_gather_decomposer.h"
|
#include "tensorflow/compiler/xla/service/all_gather_decomposer.h"
|
||||||
#include "tensorflow/compiler/xla/service/all_reduce_combiner.h"
|
#include "tensorflow/compiler/xla/service/all_reduce_combiner.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/all_to_all_decomposer.h"
|
||||||
#include "tensorflow/compiler/xla/service/batchnorm_expander.h"
|
#include "tensorflow/compiler/xla/service/batchnorm_expander.h"
|
||||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||||
#include "tensorflow/compiler/xla/service/call_inliner.h"
|
#include "tensorflow/compiler/xla/service/call_inliner.h"
|
||||||
@ -157,6 +158,7 @@ Status GpuCompiler::OptimizeHloModule(
|
|||||||
[](const HloAllGatherInstruction& ag) {
|
[](const HloAllGatherInstruction& ag) {
|
||||||
return !NcclAllGatherThunk::CanImplement(&ag);
|
return !NcclAllGatherThunk::CanImplement(&ag);
|
||||||
});
|
});
|
||||||
|
pipeline.AddPass<AllToAllDecomposer>();
|
||||||
|
|
||||||
pipeline.AddPass<OperandUpcaster>();
|
pipeline.AddPass<OperandUpcaster>();
|
||||||
|
|
||||||
@ -804,30 +806,30 @@ GpuCompiler::CompileToTargetBinary(const HloModuleConfig& module_config,
|
|||||||
llvm_modules.size());
|
llvm_modules.size());
|
||||||
tensorflow::BlockingCounter counter(llvm_modules.size());
|
tensorflow::BlockingCounter counter(llvm_modules.size());
|
||||||
for (int i = 0; i < llvm_modules.size(); i++) {
|
for (int i = 0; i < llvm_modules.size(); i++) {
|
||||||
thread_pool->Schedule([&compile_results, compile_single_module, i,
|
thread_pool->Schedule(
|
||||||
&llvm_modules, &counter] {
|
[&compile_results, compile_single_module, i, &llvm_modules, &counter] {
|
||||||
llvm::Module* original_module = llvm_modules[i].get();
|
llvm::Module* original_module = llvm_modules[i].get();
|
||||||
llvm::LLVMContext context;
|
llvm::LLVMContext context;
|
||||||
std::string buffer;
|
std::string buffer;
|
||||||
llvm::raw_string_ostream error(buffer);
|
llvm::raw_string_ostream error(buffer);
|
||||||
|
|
||||||
std::unique_ptr<llvm::Module> new_llvm_module;
|
std::unique_ptr<llvm::Module> new_llvm_module;
|
||||||
// Switch to a new context by dumping and re-parsing LLVM IR. Each thread
|
// Switch to a new context by dumping and re-parsing LLVM IR. Each
|
||||||
// has its own context to avoid race conditions.
|
// thread has its own context to avoid race conditions.
|
||||||
{
|
{
|
||||||
std::string ir;
|
std::string ir;
|
||||||
{
|
{
|
||||||
llvm::raw_string_ostream os(ir);
|
llvm::raw_string_ostream os(ir);
|
||||||
original_module->print(os, nullptr);
|
original_module->print(os, nullptr);
|
||||||
}
|
}
|
||||||
llvm::SMDiagnostic err;
|
llvm::SMDiagnostic err;
|
||||||
new_llvm_module = llvm::parseAssemblyString(ir, err, context);
|
new_llvm_module = llvm::parseAssemblyString(ir, err, context);
|
||||||
}
|
}
|
||||||
|
|
||||||
compile_results[i] = compile_single_module(
|
compile_results[i] = compile_single_module(
|
||||||
new_llvm_module.get(), /*relocatable=*/true, /*shard_number=*/i);
|
new_llvm_module.get(), /*relocatable=*/true, /*shard_number=*/i);
|
||||||
counter.DecrementCount();
|
counter.DecrementCount();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
counter.Wait();
|
counter.Wait();
|
||||||
|
|
||||||
|
@ -229,7 +229,8 @@ static Status CheckReplicaGroups(HloInstruction* hlo) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
int64 replica_count = hlo->GetModule()->config().replica_count();
|
int64 replica_count = hlo->GetModule()->config().replica_count();
|
||||||
if (!replicas_seen.empty() && replicas_seen.size() != replica_count) {
|
if (replica_count != 1 && !replicas_seen.empty() &&
|
||||||
|
replicas_seen.size() != replica_count) {
|
||||||
return InternalError(
|
return InternalError(
|
||||||
"Replica count in HloModuleConfig is %d, but ReplicaGroup config "
|
"Replica count in HloModuleConfig is %d, but ReplicaGroup config "
|
||||||
"contains %d replicas: %s",
|
"contains %d replicas: %s",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user