[XLA:GPU] Enforce collectives ordering to be their appearence in the module
The ordering is enforced using control edges. This avoids deadlocks in multi-host launches which arise due to change in ordering due to non-determinism (coming from autotuning). PiperOrigin-RevId: 361069409 Change-Id: I499639aafc74f4128226bccf55fe6d48ecce3f67
This commit is contained in:
parent
f13ee79e03
commit
d379154ca0
@ -2214,6 +2214,53 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "collectives_schedule_linearizer",
|
||||
srcs = ["collectives_schedule_linearizer.cc"],
|
||||
hdrs = ["collectives_schedule_linearizer.h"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_casting_utils",
|
||||
":hlo_domain_map",
|
||||
":hlo_pass",
|
||||
":hlo_query",
|
||||
":hlo_reachability",
|
||||
":shape_inference",
|
||||
"//tensorflow/compiler/xla:array2d",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "collectives_schedule_linearizer_test",
|
||||
srcs = ["collectives_schedule_linearizer_test.cc"],
|
||||
deps = [
|
||||
":collectives_schedule_linearizer",
|
||||
":hlo",
|
||||
":hlo_graph_dumper",
|
||||
":hlo_matchers",
|
||||
":hlo_runner",
|
||||
":pattern_matcher",
|
||||
"//tensorflow/compiler/xla:debug_options_flags",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:test_helpers",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "all_reduce_combiner",
|
||||
srcs = ["all_reduce_combiner.cc"],
|
||||
|
@ -0,0 +1,68 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/collectives_schedule_linearizer.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <list>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_domain_map.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_query.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_reachability.h"
|
||||
#include "tensorflow/compiler/xla/service/shape_inference.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// TODO(b/181653482): Fix for interprocedural collectives as well.
|
||||
StatusOr<bool> CollectivesScheduleLinearizer::Run(HloModule* module) {
|
||||
bool changed = false;
|
||||
for (HloComputation* computation : module->MakeNonfusionComputations()) {
|
||||
std::unique_ptr<HloReachabilityMap> reachability =
|
||||
HloReachabilityMap::Build(computation);
|
||||
HloCollectiveInstruction* prev = nullptr;
|
||||
for (HloInstruction* instruction : computation->instructions()) {
|
||||
if (auto* next = DynCast<HloCollectiveInstruction>(instruction)) {
|
||||
if (prev != nullptr && !reachability->IsConnected(next, prev)) {
|
||||
// We check for reachability as we don't want to form a cycle.
|
||||
TF_RETURN_IF_ERROR(prev->AddControlDependencyTo(next));
|
||||
VLOG(1) << "Adding control dependency from " << prev->ToString()
|
||||
<< " to " << next->ToString();
|
||||
changed = true;
|
||||
}
|
||||
prev = next;
|
||||
}
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace xla
|
@ -0,0 +1,46 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COLLECTIVES_SCHEDULE_LINEARIZER_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_COLLECTIVES_SCHEDULE_LINEARIZER_H_
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/compiler/xla/array2d.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Enforces a total order on all collectives present in the module, based on the
|
||||
// order given to the instructions.
|
||||
//
|
||||
// Does not insert inter-computation dependencies, only linearizes the order
|
||||
// within each computation.
|
||||
class CollectivesScheduleLinearizer : public HloModulePass {
|
||||
public:
|
||||
absl::string_view name() const override {
|
||||
return "collectives-schedule-linearizer";
|
||||
}
|
||||
|
||||
CollectivesScheduleLinearizer() = default;
|
||||
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COLLECTIVES_SCHEDULE_LINEARIZER_H_
|
@ -0,0 +1,117 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/collectives_schedule_linearizer.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/debug_options_flags.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_runner.h"
|
||||
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
namespace m = match;
|
||||
|
||||
int64 CountControlEdges(const HloComputation& computation) {
|
||||
int64 count = 0;
|
||||
for (const auto& instruction : computation.instructions()) {
|
||||
count += instruction->control_successors().size();
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
class CollectivesScheduleLinearizerTest : public HloTestBase {
|
||||
protected:
|
||||
void InsertCollectivesSchedule(HloModule* module) {
|
||||
CollectivesScheduleLinearizer collectives_schedule_linearizer;
|
||||
ASSERT_IS_OK(collectives_schedule_linearizer.Run(module).status());
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(CollectivesScheduleLinearizerTest, FixOrdering) {
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
sum {
|
||||
a = f32[] parameter(0)
|
||||
b = f32[] parameter(1)
|
||||
ROOT out = f32[] add(a, b)
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
p0 = f32[100] parameter(0), parameter_replication={false}
|
||||
p1 = f32[100] parameter(1), parameter_replication={false}
|
||||
c1 = f32[100] all-reduce(p0), replica_groups={}, to_apply=sum
|
||||
c2 = f32[100] all-reduce(p1), replica_groups={}, to_apply=sum
|
||||
ROOT out = f32[100] add(c1, c2)
|
||||
}
|
||||
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
InsertCollectivesSchedule(module.get());
|
||||
EXPECT_EQ(CountControlEdges(*module->entry_computation()), 1);
|
||||
HloInstruction *c1 = nullptr, *c2 = nullptr;
|
||||
for (HloInstruction* instr : module->entry_computation()->instructions()) {
|
||||
if (Match(instr, m::AllReduce(m::Parameter(0)))) {
|
||||
c1 = instr;
|
||||
}
|
||||
if (Match(instr, m::AllReduce(m::Parameter(1)))) {
|
||||
c2 = instr;
|
||||
}
|
||||
}
|
||||
EXPECT_TRUE(c1 != nullptr && c2 != nullptr);
|
||||
EXPECT_TRUE(absl::c_linear_search(c2->control_predecessors(), c1));
|
||||
}
|
||||
|
||||
TEST_F(CollectivesScheduleLinearizerTest, NoFixRequired) {
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
sum {
|
||||
a = f32[] parameter(0)
|
||||
b = f32[] parameter(1)
|
||||
ROOT out = f32[] add(a, b)
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
p0 = f32[100] parameter(0), parameter_replication={false}
|
||||
p1 = f32[100] parameter(1), parameter_replication={false}
|
||||
c1 = f32[100] all-reduce(p0), replica_groups={}, to_apply=sum
|
||||
c2 = f32[100] all-reduce(p1), replica_groups={}, to_apply=sum, control-predecessors={c1}
|
||||
ROOT out = f32[100] add(c1, c2)
|
||||
}
|
||||
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
InsertCollectivesSchedule(module.get());
|
||||
EXPECT_EQ(CountControlEdges(*module->entry_computation()), 1);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
@ -1282,6 +1282,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:batchnorm_expander",
|
||||
"//tensorflow/compiler/xla/service:buffer_assignment",
|
||||
"//tensorflow/compiler/xla/service:call_inliner",
|
||||
"//tensorflow/compiler/xla/service:collectives_schedule_linearizer",
|
||||
"//tensorflow/compiler/xla/service:comparison_expander",
|
||||
"//tensorflow/compiler/xla/service:conditional_canonicalizer",
|
||||
"//tensorflow/compiler/xla/service:conditional_simplifier",
|
||||
|
@ -44,6 +44,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/batchnorm_expander.h"
|
||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/call_inliner.h"
|
||||
#include "tensorflow/compiler/xla/service/collectives_schedule_linearizer.h"
|
||||
#include "tensorflow/compiler/xla/service/comparison_expander.h"
|
||||
#include "tensorflow/compiler/xla/service/conditional_canonicalizer.h"
|
||||
#include "tensorflow/compiler/xla/service/conditional_simplifier.h"
|
||||
@ -376,6 +377,13 @@ Status GpuCompiler::OptimizeHloModule(
|
||||
/*combine_threshold_count=*/256);
|
||||
TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
|
||||
}
|
||||
|
||||
{
|
||||
HloPassPipeline pipeline("collectives_schedule_linearizer");
|
||||
pipeline.AddPass<CollectivesScheduleLinearizer>();
|
||||
TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
|
||||
}
|
||||
|
||||
{
|
||||
// Now we allow to replace any transposes outside of fusions with bitcasts.
|
||||
HloPassPipeline pipeline("final_algebraic_simplifier");
|
||||
|
Loading…
Reference in New Issue
Block a user