[XLA] Use array all to all in the client frontend
PiperOrigin-RevId: 354660949 Change-Id: I201e941172ab949d2632f83efc5cd80f4ebdddca
This commit is contained in:
parent
4a0023ba49
commit
88b90a0a37
@ -34,6 +34,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||
#include "tensorflow/compiler/xla/comparison_util.h"
|
||||
#include "tensorflow/compiler/xla/execution_options_util.h"
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
|
||||
@ -2847,6 +2848,72 @@ XlaOp XlaBuilder::AllToAll(XlaOp operand, int64 split_dimension,
|
||||
int64 concat_dimension, int64 split_count,
|
||||
const std::vector<ReplicaGroup>& replica_groups,
|
||||
const absl::optional<Layout>& layout) {
|
||||
// Array all_to_all may need to violate layout constraint to be legal so use
|
||||
// the tuple version.
|
||||
if (layout.has_value()) {
|
||||
return AllToAllTuple(operand, split_dimension, concat_dimension,
|
||||
split_count, replica_groups, layout);
|
||||
}
|
||||
return AllToAllArray(operand, split_dimension, concat_dimension, split_count,
|
||||
replica_groups);
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::AllToAllArray(
|
||||
XlaOp operand, int64 split_dimension, int64 concat_dimension,
|
||||
int64 split_count, const std::vector<ReplicaGroup>& replica_groups) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
const Shape all_to_all_shape,
|
||||
ShapeInference::InferAllToAllShape(*operand_shape, split_dimension,
|
||||
concat_dimension, split_count));
|
||||
HloInstructionProto instr;
|
||||
*instr.mutable_shape() = operand_shape->ToProto();
|
||||
if (replica_groups.empty()) {
|
||||
auto* group = instr.add_replica_groups();
|
||||
for (int64 i = 0; i < split_count; ++i) {
|
||||
group->add_replica_ids(i);
|
||||
}
|
||||
} else {
|
||||
for (const ReplicaGroup& group : replica_groups) {
|
||||
*instr.add_replica_groups() = group;
|
||||
}
|
||||
}
|
||||
instr.add_dimensions(split_dimension);
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
XlaOp all_to_all,
|
||||
AddInstruction(std::move(instr), HloOpcode::kAllToAll, {operand}));
|
||||
if (split_dimension == concat_dimension) {
|
||||
return all_to_all;
|
||||
}
|
||||
DimensionVector sizes;
|
||||
for (int64 i = 0; i < operand_shape->rank(); ++i) {
|
||||
if (i != split_dimension) {
|
||||
sizes.push_back(operand_shape->dimensions(i));
|
||||
continue;
|
||||
}
|
||||
sizes.push_back(split_count);
|
||||
sizes.push_back(operand_shape->dimensions(i) / split_count);
|
||||
}
|
||||
all_to_all = Reshape(all_to_all, sizes);
|
||||
|
||||
std::vector<int64> permutation;
|
||||
for (int64 i = 0; i < operand_shape->rank(); ++i) {
|
||||
int64 dim_after_reshape = i >= split_dimension ? i + 1 : i;
|
||||
if (i == concat_dimension) {
|
||||
permutation.push_back(split_dimension);
|
||||
}
|
||||
permutation.push_back(dim_after_reshape);
|
||||
}
|
||||
all_to_all = Transpose(all_to_all, permutation);
|
||||
return Reshape(all_to_all_shape, all_to_all);
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::AllToAllTuple(XlaOp operand, int64 split_dimension,
|
||||
int64 concat_dimension, int64 split_count,
|
||||
const std::vector<ReplicaGroup>& replica_groups,
|
||||
const absl::optional<Layout>& layout) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
|
||||
|
||||
@ -4581,6 +4648,15 @@ XlaOp AllToAll(const XlaOp operand, int64 split_dimension,
|
||||
split_count, replica_groups, layout);
|
||||
}
|
||||
|
||||
XlaOp AllToAllTuple(const XlaOp operand, int64 split_dimension,
|
||||
int64 concat_dimension, int64 split_count,
|
||||
const std::vector<ReplicaGroup>& replica_groups,
|
||||
const absl::optional<Layout>& layout) {
|
||||
return operand.builder()->AllToAllTuple(operand, split_dimension,
|
||||
concat_dimension, split_count,
|
||||
replica_groups, layout);
|
||||
}
|
||||
|
||||
XlaOp CollectivePermute(
|
||||
const XlaOp operand,
|
||||
const std::vector<std::pair<int64, int64>>& source_target_pairs) {
|
||||
|
@ -745,6 +745,11 @@ class XlaBuilder {
|
||||
const std::vector<ReplicaGroup>& replica_groups,
|
||||
const absl::optional<Layout>& layout = absl::nullopt);
|
||||
|
||||
XlaOp AllToAllTuple(XlaOp operand, int64 split_dimension,
|
||||
int64 concat_dimension, int64 split_count,
|
||||
const std::vector<ReplicaGroup>& replica_groups,
|
||||
const absl::optional<Layout>& layout);
|
||||
|
||||
XlaOp CollectivePermute(
|
||||
XlaOp operand,
|
||||
const std::vector<std::pair<int64, int64>>& source_target_pairs);
|
||||
@ -1297,6 +1302,10 @@ class XlaBuilder {
|
||||
int64 concat_dimension, int64 split_count,
|
||||
const std::vector<ReplicaGroup>& replica_groups,
|
||||
const absl::optional<Layout>& layout);
|
||||
friend XlaOp AllToAllTuple(XlaOp operand, int64 split_dimension,
|
||||
int64 concat_dimension, int64 split_count,
|
||||
const std::vector<ReplicaGroup>& replica_groups,
|
||||
const absl::optional<Layout>& layout);
|
||||
friend XlaOp CollectivePermute(
|
||||
XlaOp operand,
|
||||
const std::vector<std::pair<int64, int64>>& source_target_pairs);
|
||||
@ -1425,6 +1434,10 @@ class XlaBuilder {
|
||||
absl::Span<const XlaComputation* const> branch_computations,
|
||||
absl::Span<const XlaOp> branch_operands);
|
||||
|
||||
XlaOp AllToAllArray(XlaOp operand, int64 split_dimension,
|
||||
int64 concat_dimension, int64 split_count,
|
||||
const std::vector<ReplicaGroup>& replica_groups);
|
||||
|
||||
// Creates an op with the given opcode and the output shape.
|
||||
virtual StatusOr<XlaOp> AddOpWithShape(HloOpcode opcode, const Shape& shape,
|
||||
absl::Span<const XlaOp> operands);
|
||||
@ -2203,6 +2216,11 @@ XlaOp AllToAll(XlaOp operand, int64 split_dimension, int64 concat_dimension,
|
||||
const std::vector<ReplicaGroup>& replica_groups = {},
|
||||
const absl::optional<Layout>& layout = absl::nullopt);
|
||||
|
||||
XlaOp AllToAllTuple(XlaOp operand, int64 split_dimension,
|
||||
int64 concat_dimension, int64 split_count,
|
||||
const std::vector<ReplicaGroup>& replica_groups = {},
|
||||
const absl::optional<Layout>& layout = absl::nullopt);
|
||||
|
||||
// Enqueues an collective operation that sends and receives data cross replicas.
|
||||
//
|
||||
// - `source_target_pair`: a list of (source_replica_id, target_replica_id)
|
||||
|
@ -414,12 +414,27 @@ TEST_F(XlaBuilderTest, AllToAll) {
|
||||
auto root = module->entry_computation()->root_instruction();
|
||||
|
||||
// AllToAll is decomposed into slices -> all-to-all -> gte -> concat.
|
||||
EXPECT_EQ(root->opcode(), HloOpcode::kConcatenate);
|
||||
EXPECT_EQ(root->operand(0)->operand(0)->opcode(), HloOpcode::kAllToAll);
|
||||
EXPECT_EQ(root->opcode(), HloOpcode::kReshape);
|
||||
EXPECT_EQ(root->operand(0)->operand(0)->operand(0)->opcode(),
|
||||
HloOpcode::kAllToAll);
|
||||
EXPECT_TRUE(
|
||||
ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {8, 8})));
|
||||
}
|
||||
|
||||
TEST_F(XlaBuilderTest, AllToAllSpecial) {
|
||||
XlaBuilder b(TestName());
|
||||
auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16, 8}), "x");
|
||||
AllToAll(x, /*split_dimension=*/0, /*concat_dimension=*/0,
|
||||
/*split_count=*/2);
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
|
||||
auto root = module->entry_computation()->root_instruction();
|
||||
|
||||
// AllToAll is converted into a single all-to-all HloInstruction.
|
||||
EXPECT_EQ(root->opcode(), HloOpcode::kAllToAll);
|
||||
EXPECT_TRUE(
|
||||
ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {4, 16, 8})));
|
||||
}
|
||||
|
||||
TEST_F(XlaBuilderTest, CollectivePermute) {
|
||||
XlaBuilder b(TestName());
|
||||
auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x");
|
||||
|
@ -2649,6 +2649,27 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "all_to_all_decomposer",
|
||||
srcs = ["all_to_all_decomposer.cc"],
|
||||
hdrs = ["all_to_all_decomposer.h"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_casting_utils",
|
||||
":hlo_pass",
|
||||
":op_expander_pass",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "all_gather_decomposer",
|
||||
srcs = ["all_gather_decomposer.cc"],
|
||||
|
131
tensorflow/compiler/xla/service/all_to_all_decomposer.cc
Normal file
131
tensorflow/compiler/xla/service/all_to_all_decomposer.cc
Normal file
@ -0,0 +1,131 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/all_to_all_decomposer.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace xla {
|
||||
bool AllToAllDecomposer::InstructionMatchesPattern(
|
||||
HloInstruction* instruction) {
|
||||
auto* all_to_all = DynCast<HloAllToAllInstruction>(instruction);
|
||||
if (all_to_all == nullptr) {
|
||||
return false;
|
||||
}
|
||||
// Do not attempt to change layout constrained collectives.
|
||||
if (all_to_all->constrain_layout()) {
|
||||
return false;
|
||||
}
|
||||
if (all_to_all->shape().IsTuple()) {
|
||||
return false;
|
||||
}
|
||||
if (decompose_to_tuple_) {
|
||||
return true;
|
||||
}
|
||||
return all_to_all->shape().rank() < min_array_rank_;
|
||||
}
|
||||
StatusOr<HloInstruction*> AllToAllDecomposer::ExpandInstruction(
|
||||
HloInstruction* instruction) {
|
||||
auto* all_to_all = Cast<HloAllToAllInstruction>(instruction);
|
||||
int64 split_dim = *all_to_all->split_dimension();
|
||||
int64 all_to_all_group_size =
|
||||
all_to_all->replica_groups().empty()
|
||||
? instruction->parent()->parent()->config().replica_count()
|
||||
: all_to_all->replica_groups()[0].replica_ids_size();
|
||||
int64 split_size =
|
||||
all_to_all->shape().dimensions(split_dim) / all_to_all_group_size;
|
||||
if (!decompose_to_tuple_) {
|
||||
Shape new_all_to_all_shape;
|
||||
new_all_to_all_shape.set_element_type(
|
||||
instruction->operand(0)->shape().element_type());
|
||||
for (int64 i = 0; i < instruction->shape().rank(); ++i) {
|
||||
if (i != split_dim) {
|
||||
new_all_to_all_shape.add_dimensions(all_to_all->shape().dimensions(i));
|
||||
continue;
|
||||
}
|
||||
new_all_to_all_shape.add_dimensions(all_to_all_group_size);
|
||||
new_all_to_all_shape.add_dimensions(split_size);
|
||||
for (int64 j = all_to_all->shape().rank() + 1; j < min_array_rank_; ++j) {
|
||||
new_all_to_all_shape.add_dimensions(1);
|
||||
}
|
||||
}
|
||||
*(new_all_to_all_shape.mutable_layout()) =
|
||||
LayoutUtil::GetDefaultLayoutForRank(min_array_rank_);
|
||||
HloInstruction* operand_reshape =
|
||||
instruction->parent()->AddInstruction(HloInstruction::CreateReshape(
|
||||
new_all_to_all_shape, instruction->mutable_operand(0)));
|
||||
instruction->SetupDerivedInstruction(operand_reshape);
|
||||
HloInstruction* all_to_all =
|
||||
instruction->parent()->AddInstruction(instruction->CloneWithNewOperands(
|
||||
new_all_to_all_shape, {operand_reshape}));
|
||||
HloInstruction* output_reshape = instruction->parent()->AddInstruction(
|
||||
HloInstruction::CreateReshape(instruction->shape(), all_to_all));
|
||||
instruction->SetupDerivedInstruction(output_reshape);
|
||||
return output_reshape;
|
||||
}
|
||||
DimensionVector slice_starts(all_to_all->shape().rank(), 0);
|
||||
DimensionVector slice_strides(all_to_all->shape().rank(), 1);
|
||||
DimensionVector slice_limits(all_to_all->shape().dimensions().begin(),
|
||||
all_to_all->shape().dimensions().end());
|
||||
slice_limits[split_dim] = split_size;
|
||||
Shape slice_shape = all_to_all->shape();
|
||||
slice_shape.set_dimensions(split_dim, split_size);
|
||||
std::vector<HloInstruction*> slices;
|
||||
slices.reserve(all_to_all_group_size);
|
||||
HloInstruction* operand = all_to_all->mutable_operand(0);
|
||||
for (int64 i = 0; i < all_to_all_group_size; ++i) {
|
||||
slices.push_back(
|
||||
all_to_all->parent()->AddInstruction(HloInstruction::CreateSlice(
|
||||
slice_shape, operand, slice_starts, slice_limits, slice_strides)));
|
||||
all_to_all->SetupDerivedInstruction(slices.back());
|
||||
slice_starts[split_dim] = slice_limits[split_dim];
|
||||
slice_limits[split_dim] += split_size;
|
||||
}
|
||||
Shape all_to_all_shape = ShapeUtil::MakeTupleShape(
|
||||
std::vector<Shape>(all_to_all_group_size, slice_shape));
|
||||
HloInstruction* new_all_to_all =
|
||||
all_to_all->parent()->AddInstruction(HloInstruction::CreateAllToAll(
|
||||
all_to_all_shape, slices, all_to_all->replica_groups(), false,
|
||||
all_to_all->channel_id(), absl::nullopt));
|
||||
std::vector<HloInstruction*> gtes;
|
||||
gtes.reserve(all_to_all_group_size);
|
||||
for (int64 i = 0; i < all_to_all_group_size; ++i) {
|
||||
gtes.push_back(all_to_all->parent()->AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(slice_shape, new_all_to_all, i)));
|
||||
all_to_all->SetupDerivedInstruction(new_all_to_all);
|
||||
}
|
||||
HloInstruction* concat = all_to_all->parent()->AddInstruction(
|
||||
HloInstruction::CreateConcatenate(all_to_all->shape(), gtes, split_dim));
|
||||
all_to_all->SetupDerivedInstruction(concat);
|
||||
return concat;
|
||||
}
|
||||
|
||||
} // namespace xla
|
49
tensorflow/compiler/xla/service/all_to_all_decomposer.h
Normal file
49
tensorflow/compiler/xla/service/all_to_all_decomposer.h
Normal file
@ -0,0 +1,49 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_ALL_TO_ALL_DECOMPOSER_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_ALL_TO_ALL_DECOMPOSER_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||
#include "tensorflow/compiler/xla/service/op_expander_pass.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// AllToAllDecomposer is a pass which converts unsupported array all_to_all
|
||||
// into tuple all_to_all or array all_to_all with a minimum rank where the split
|
||||
// dimension is the size of the replica_groups.
|
||||
class AllToAllDecomposer : public OpExpanderPass {
|
||||
public:
|
||||
explicit AllToAllDecomposer(bool decompose_to_tuple = true,
|
||||
int64 min_array_rank = 0)
|
||||
: decompose_to_tuple_(decompose_to_tuple),
|
||||
min_array_rank_(min_array_rank) {}
|
||||
~AllToAllDecomposer() override = default;
|
||||
absl::string_view name() const override { return "all_to_all_decomposer"; }
|
||||
|
||||
private:
|
||||
bool InstructionMatchesPattern(HloInstruction* instruction) override;
|
||||
StatusOr<HloInstruction*> ExpandInstruction(
|
||||
HloInstruction* instruction) override;
|
||||
bool decompose_to_tuple_;
|
||||
int64 min_array_rank_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_ALL_TO_ALL_DECOMPOSER_H_
|
@ -157,6 +157,7 @@ cc_library(
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:VectorOps",
|
||||
"//tensorflow/compiler/xla/service:all_gather_decomposer",
|
||||
"//tensorflow/compiler/xla/service:all_to_all_decomposer",
|
||||
"//tensorflow/compiler/xla/service:copy_insertion",
|
||||
"//tensorflow/compiler/xla/service:dump",
|
||||
"//tensorflow/compiler/xla/service:topk_rewriter",
|
||||
|
@ -57,6 +57,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/protobuf_util.h"
|
||||
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
|
||||
#include "tensorflow/compiler/xla/service/all_gather_decomposer.h"
|
||||
#include "tensorflow/compiler/xla/service/all_to_all_decomposer.h"
|
||||
#include "tensorflow/compiler/xla/service/batch_dot_simplification.h"
|
||||
#include "tensorflow/compiler/xla/service/batchnorm_expander.h"
|
||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||
@ -292,6 +293,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
|
||||
pipeline.AddPass<QrExpander>();
|
||||
pipeline.AddPass<TriangularSolveExpander>();
|
||||
pipeline.AddPass<AllGatherDecomposer>();
|
||||
pipeline.AddPass<AllToAllDecomposer>();
|
||||
|
||||
// Inline computations with a single call site.
|
||||
pipeline.AddPass<CallInliner>(/*single_call_site=*/true);
|
||||
|
@ -1427,6 +1427,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:algebraic_simplifier",
|
||||
"//tensorflow/compiler/xla/service:all_gather_decomposer",
|
||||
"//tensorflow/compiler/xla/service:all_reduce_combiner",
|
||||
"//tensorflow/compiler/xla/service:all_to_all_decomposer",
|
||||
"//tensorflow/compiler/xla/service:batchnorm_expander",
|
||||
"//tensorflow/compiler/xla/service:buffer_assignment",
|
||||
"//tensorflow/compiler/xla/service:call_inliner",
|
||||
|
@ -40,6 +40,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
|
||||
#include "tensorflow/compiler/xla/service/all_gather_decomposer.h"
|
||||
#include "tensorflow/compiler/xla/service/all_reduce_combiner.h"
|
||||
#include "tensorflow/compiler/xla/service/all_to_all_decomposer.h"
|
||||
#include "tensorflow/compiler/xla/service/batchnorm_expander.h"
|
||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/call_inliner.h"
|
||||
@ -157,6 +158,7 @@ Status GpuCompiler::OptimizeHloModule(
|
||||
[](const HloAllGatherInstruction& ag) {
|
||||
return !NcclAllGatherThunk::CanImplement(&ag);
|
||||
});
|
||||
pipeline.AddPass<AllToAllDecomposer>();
|
||||
|
||||
pipeline.AddPass<OperandUpcaster>();
|
||||
|
||||
@ -804,30 +806,30 @@ GpuCompiler::CompileToTargetBinary(const HloModuleConfig& module_config,
|
||||
llvm_modules.size());
|
||||
tensorflow::BlockingCounter counter(llvm_modules.size());
|
||||
for (int i = 0; i < llvm_modules.size(); i++) {
|
||||
thread_pool->Schedule([&compile_results, compile_single_module, i,
|
||||
&llvm_modules, &counter] {
|
||||
llvm::Module* original_module = llvm_modules[i].get();
|
||||
llvm::LLVMContext context;
|
||||
std::string buffer;
|
||||
llvm::raw_string_ostream error(buffer);
|
||||
thread_pool->Schedule(
|
||||
[&compile_results, compile_single_module, i, &llvm_modules, &counter] {
|
||||
llvm::Module* original_module = llvm_modules[i].get();
|
||||
llvm::LLVMContext context;
|
||||
std::string buffer;
|
||||
llvm::raw_string_ostream error(buffer);
|
||||
|
||||
std::unique_ptr<llvm::Module> new_llvm_module;
|
||||
// Switch to a new context by dumping and re-parsing LLVM IR. Each thread
|
||||
// has its own context to avoid race conditions.
|
||||
{
|
||||
std::string ir;
|
||||
{
|
||||
llvm::raw_string_ostream os(ir);
|
||||
original_module->print(os, nullptr);
|
||||
}
|
||||
llvm::SMDiagnostic err;
|
||||
new_llvm_module = llvm::parseAssemblyString(ir, err, context);
|
||||
}
|
||||
std::unique_ptr<llvm::Module> new_llvm_module;
|
||||
// Switch to a new context by dumping and re-parsing LLVM IR. Each
|
||||
// thread has its own context to avoid race conditions.
|
||||
{
|
||||
std::string ir;
|
||||
{
|
||||
llvm::raw_string_ostream os(ir);
|
||||
original_module->print(os, nullptr);
|
||||
}
|
||||
llvm::SMDiagnostic err;
|
||||
new_llvm_module = llvm::parseAssemblyString(ir, err, context);
|
||||
}
|
||||
|
||||
compile_results[i] = compile_single_module(
|
||||
new_llvm_module.get(), /*relocatable=*/true, /*shard_number=*/i);
|
||||
counter.DecrementCount();
|
||||
});
|
||||
compile_results[i] = compile_single_module(
|
||||
new_llvm_module.get(), /*relocatable=*/true, /*shard_number=*/i);
|
||||
counter.DecrementCount();
|
||||
});
|
||||
}
|
||||
counter.Wait();
|
||||
|
||||
|
@ -229,7 +229,8 @@ static Status CheckReplicaGroups(HloInstruction* hlo) {
|
||||
}
|
||||
|
||||
int64 replica_count = hlo->GetModule()->config().replica_count();
|
||||
if (!replicas_seen.empty() && replicas_seen.size() != replica_count) {
|
||||
if (replica_count != 1 && !replicas_seen.empty() &&
|
||||
replicas_seen.size() != replica_count) {
|
||||
return InternalError(
|
||||
"Replica count in HloModuleConfig is %d, but ReplicaGroup config "
|
||||
"contains %d replicas: %s",
|
||||
|
Loading…
x
Reference in New Issue
Block a user